#![cfg(feature = "graph")]
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[cfg(feature = "graph")]
pub use graphalgs;
#[cfg(feature = "graph")]
pub use petgraph;
use petgraph::algo::{astar, dijkstra};
use petgraph::graph::{DiGraph, NodeIndex};
use petgraph::visit::EdgeRef;
use petgraph::Direction;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KnowledgeNode {
pub id: String,
pub label: String,
pub content: String,
pub embedding: Option<Vec<f32>>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KnowledgeEdge {
pub relation: String,
pub weight: f32,
pub properties: HashMap<String, serde_json::Value>,
}
pub struct KnowledgeGraph {
graph: DiGraph<KnowledgeNode, KnowledgeEdge>,
node_index: HashMap<String, NodeIndex>,
}
impl KnowledgeGraph {
pub fn new() -> Self {
Self {
graph: DiGraph::new(),
node_index: HashMap::new(),
}
}
pub fn add_node(&mut self, node: KnowledgeNode) -> NodeIndex {
let id = node.id.clone();
let idx = self.graph.add_node(node);
self.node_index.insert(id, idx);
idx
}
pub fn add_edge(&mut self, from_id: &str, to_id: &str, edge: KnowledgeEdge) -> Result<()> {
let from_idx = self
.node_index
.get(from_id)
.ok_or_else(|| anyhow::anyhow!("Source node not found: {}", from_id))?;
let to_idx = self
.node_index
.get(to_id)
.ok_or_else(|| anyhow::anyhow!("Target node not found: {}", to_id))?;
self.graph.add_edge(*from_idx, *to_idx, edge);
Ok(())
}
pub fn get_node(&self, id: &str) -> Option<&KnowledgeNode> {
self.node_index
.get(id)
.and_then(|idx| self.graph.node_weight(*idx))
}
pub fn neighbors(&self, id: &str, direction: Direction) -> Vec<&KnowledgeNode> {
let idx = match self.node_index.get(id) {
Some(idx) => *idx,
None => return Vec::new(),
};
self.graph
.neighbors_directed(idx, direction)
.filter_map(|n| self.graph.node_weight(n))
.collect()
}
pub fn shortest_path(&self, from_id: &str, to_id: &str) -> Option<Vec<String>> {
let from_idx = self.node_index.get(from_id)?;
let to_idx = self.node_index.get(to_id)?;
let path = astar(
&self.graph,
*from_idx,
|finish| finish == *to_idx,
|e| e.weight().weight as i32,
|_| 0,
)?;
Some(
path.1
.iter()
.filter_map(|idx| self.graph.node_weight(*idx))
.map(|n| n.id.clone())
.collect(),
)
}
pub fn nodes_within_distance(
&self,
from_id: &str,
max_distance: f32,
) -> Vec<(&KnowledgeNode, f32)> {
let from_idx = match self.node_index.get(from_id) {
Some(idx) => *idx,
None => return Vec::new(),
};
let distances = dijkstra(&self.graph, from_idx, None, |e| e.weight().weight);
distances
.iter()
.filter(|(_, &dist)| dist <= max_distance)
.filter_map(|(idx, &dist)| self.graph.node_weight(*idx).map(|n| (n, dist)))
.collect()
}
pub fn edges_between(&self, from_id: &str, to_id: &str) -> Vec<&KnowledgeEdge> {
let from_idx = match self.node_index.get(from_id) {
Some(idx) => *idx,
None => return Vec::new(),
};
let to_idx = match self.node_index.get(to_id) {
Some(idx) => *idx,
None => return Vec::new(),
};
self.graph
.edges_connecting(from_idx, to_idx)
.map(|e| e.weight())
.collect()
}
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
pub fn edge_count(&self) -> usize {
self.graph.edge_count()
}
pub fn find_by_label(&self, label: &str) -> Vec<&KnowledgeNode> {
self.graph
.node_weights()
.filter(|n| n.label == label)
.collect()
}
pub fn find_related(&self, id: &str, relation: &str) -> Vec<&KnowledgeNode> {
let idx = match self.node_index.get(id) {
Some(idx) => *idx,
None => return Vec::new(),
};
self.graph
.edges_directed(idx, Direction::Outgoing)
.filter(|e| e.weight().relation == relation)
.filter_map(|e| self.graph.node_weight(e.target()))
.collect()
}
}
impl Default for KnowledgeGraph {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningPath {
pub nodes: Vec<String>,
pub relations: Vec<String>,
pub total_weight: f32,
pub confidence: f32,
}
pub struct GraphReasoner {
graph: Arc<KnowledgeGraph>,
max_depth: usize,
}
impl GraphReasoner {
pub fn new(graph: Arc<KnowledgeGraph>, max_depth: usize) -> Self {
Self { graph, max_depth }
}
pub fn find_reasoning_paths(
&self,
from: &str,
to: &str,
max_paths: usize,
) -> Vec<ReasoningPath> {
let mut paths = Vec::new();
if let Some(path) = self.graph.shortest_path(from, to) {
paths.push(ReasoningPath {
nodes: path,
relations: Vec::new(), total_weight: 1.0,
confidence: 0.8,
});
}
paths.truncate(max_paths);
paths
}
pub fn expand_concept(&self, concept: &str, depth: usize) -> ConceptExpansion {
let neighbors = self.graph.neighbors(concept, Direction::Outgoing);
ConceptExpansion {
root: concept.to_string(),
related: neighbors.iter().map(|n| n.id.clone()).collect(),
depth_reached: depth.min(self.max_depth),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConceptExpansion {
pub root: String,
pub related: Vec<String>,
pub depth_reached: usize,
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_graph() -> KnowledgeGraph {
let mut graph = KnowledgeGraph::new();
graph.add_node(KnowledgeNode {
id: "A".to_string(),
label: "concept".to_string(),
content: "Concept A".to_string(),
embedding: None,
metadata: HashMap::new(),
});
graph.add_node(KnowledgeNode {
id: "B".to_string(),
label: "concept".to_string(),
content: "Concept B".to_string(),
embedding: None,
metadata: HashMap::new(),
});
graph
.add_edge(
"A",
"B",
KnowledgeEdge {
relation: "related_to".to_string(),
weight: 1.0,
properties: HashMap::new(),
},
)
.unwrap();
graph
}
#[test]
fn test_add_and_get_node() {
let graph = create_test_graph();
let node = graph.get_node("A").unwrap();
assert_eq!(node.content, "Concept A");
}
#[test]
fn test_neighbors() {
let graph = create_test_graph();
let neighbors = graph.neighbors("A", Direction::Outgoing);
assert_eq!(neighbors.len(), 1);
assert_eq!(neighbors[0].id, "B");
}
#[test]
fn test_node_count() {
let graph = create_test_graph();
assert_eq!(graph.node_count(), 2);
assert_eq!(graph.edge_count(), 1);
}
}