use crate::core::{Entity, EntityId, KnowledgeGraph, Relationship, Result};
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone)]
pub struct TraversalConfig {
pub max_depth: usize,
pub max_paths: usize,
pub use_edge_weights: bool,
pub min_relationship_strength: f32,
}
impl Default for TraversalConfig {
fn default() -> Self {
Self {
max_depth: 3,
max_paths: 100,
use_edge_weights: true,
min_relationship_strength: 0.5,
}
}
}
#[derive(Debug, Clone)]
pub struct TraversalResult {
pub entities: Vec<Entity>,
pub relationships: Vec<Relationship>,
pub paths: Vec<Vec<EntityId>>,
pub distances: HashMap<EntityId, usize>,
}
pub struct GraphTraversal {
config: TraversalConfig,
}
impl Default for GraphTraversal {
fn default() -> Self {
Self::new(TraversalConfig::default())
}
}
impl GraphTraversal {
pub fn new(config: TraversalConfig) -> Self {
Self { config }
}
pub fn bfs(&self, graph: &KnowledgeGraph, source: &EntityId) -> Result<TraversalResult> {
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
let mut distances = HashMap::new();
let mut discovered_entities = Vec::new();
let mut discovered_relationships = Vec::new();
queue.push_back((source.clone(), 0));
distances.insert(source.clone(), 0);
while let Some((current_id, depth)) = queue.pop_front() {
if depth >= self.config.max_depth {
continue;
}
if visited.contains(¤t_id) {
continue;
}
visited.insert(current_id.clone());
if let Some(entity) = graph.get_entity(¤t_id) {
discovered_entities.push(entity.clone());
}
let neighbors = self.get_neighbors(graph, ¤t_id);
for (neighbor_id, relationship) in neighbors {
if relationship.confidence < self.config.min_relationship_strength {
continue;
}
if !visited.contains(&neighbor_id) {
queue.push_back((neighbor_id.clone(), depth + 1));
distances.entry(neighbor_id.clone()).or_insert(depth + 1);
discovered_relationships.push(relationship);
}
}
}
Ok(TraversalResult {
entities: discovered_entities,
relationships: discovered_relationships,
paths: Vec::new(), distances,
})
}
pub fn dfs(&self, graph: &KnowledgeGraph, source: &EntityId) -> Result<TraversalResult> {
let mut visited = HashSet::new();
let mut distances = HashMap::new();
let mut discovered_entities = Vec::new();
let mut discovered_relationships = Vec::new();
self.dfs_recursive(
graph,
source,
0,
&mut visited,
&mut distances,
&mut discovered_entities,
&mut discovered_relationships,
)?;
Ok(TraversalResult {
entities: discovered_entities,
relationships: discovered_relationships,
paths: Vec::new(), distances,
})
}
#[allow(clippy::too_many_arguments)]
fn dfs_recursive(
&self,
graph: &KnowledgeGraph,
current_id: &EntityId,
depth: usize,
visited: &mut HashSet<EntityId>,
distances: &mut HashMap<EntityId, usize>,
discovered_entities: &mut Vec<Entity>,
discovered_relationships: &mut Vec<Relationship>,
) -> Result<()> {
if depth >= self.config.max_depth {
return Ok(());
}
if visited.contains(current_id) {
return Ok(());
}
visited.insert(current_id.clone());
distances.insert(current_id.clone(), depth);
if let Some(entity) = graph.get_entity(current_id) {
discovered_entities.push(entity.clone());
}
let neighbors = self.get_neighbors(graph, current_id);
for (neighbor_id, relationship) in neighbors {
if relationship.confidence < self.config.min_relationship_strength {
continue;
}
if !visited.contains(&neighbor_id) {
discovered_relationships.push(relationship);
self.dfs_recursive(
graph,
&neighbor_id,
depth + 1,
visited,
distances,
discovered_entities,
discovered_relationships,
)?;
}
}
Ok(())
}
pub fn ego_network(
&self,
graph: &KnowledgeGraph,
entity_id: &EntityId,
k_hops: Option<usize>,
) -> Result<TraversalResult> {
let hops = k_hops.unwrap_or(self.config.max_depth);
let mut subgraph_entities = Vec::new();
let mut subgraph_relationships = Vec::new();
let mut visited = HashSet::new();
let mut distances = HashMap::new();
visited.insert(entity_id.clone());
distances.insert(entity_id.clone(), 0);
if let Some(entity) = graph.get_entity(entity_id) {
subgraph_entities.push(entity.clone());
}
let mut current_layer = vec![entity_id.clone()];
for hop in 1..=hops {
let mut next_layer = Vec::new();
for current_id in ¤t_layer {
let neighbors = self.get_neighbors(graph, current_id);
for (neighbor_id, relationship) in neighbors {
if relationship.confidence < self.config.min_relationship_strength {
continue;
}
subgraph_relationships.push(relationship);
if !visited.contains(&neighbor_id) {
visited.insert(neighbor_id.clone());
distances.insert(neighbor_id.clone(), hop);
if let Some(entity) = graph.get_entity(&neighbor_id) {
subgraph_entities.push(entity.clone());
}
next_layer.push(neighbor_id);
}
}
}
current_layer = next_layer;
}
Ok(TraversalResult {
entities: subgraph_entities,
relationships: subgraph_relationships,
paths: Vec::new(),
distances,
})
}
pub fn multi_source_bfs(
&self,
graph: &KnowledgeGraph,
sources: &[EntityId],
) -> Result<TraversalResult> {
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
let mut distances = HashMap::new();
let mut discovered_entities = Vec::new();
let mut discovered_relationships = Vec::new();
for source in sources {
queue.push_back((source.clone(), 0));
distances.insert(source.clone(), 0);
}
while let Some((current_id, depth)) = queue.pop_front() {
if depth >= self.config.max_depth {
continue;
}
if visited.contains(¤t_id) {
continue;
}
visited.insert(current_id.clone());
if let Some(entity) = graph.get_entity(¤t_id) {
discovered_entities.push(entity.clone());
}
let neighbors = self.get_neighbors(graph, ¤t_id);
for (neighbor_id, relationship) in neighbors {
if relationship.confidence < self.config.min_relationship_strength {
continue;
}
if !visited.contains(&neighbor_id) {
queue.push_back((neighbor_id.clone(), depth + 1));
distances.entry(neighbor_id.clone()).or_insert(depth + 1);
discovered_relationships.push(relationship);
}
}
}
Ok(TraversalResult {
entities: discovered_entities,
relationships: discovered_relationships,
paths: Vec::new(),
distances,
})
}
pub fn find_all_paths(
&self,
graph: &KnowledgeGraph,
source: &EntityId,
target: &EntityId,
) -> Result<TraversalResult> {
let mut all_paths = Vec::new();
let mut current_path = vec![source.clone()];
let mut visited = HashSet::new();
let mut discovered_relationships = Vec::new();
self.find_paths_recursive(
graph,
source,
target,
&mut current_path,
&mut visited,
&mut all_paths,
&mut discovered_relationships,
0,
)?;
let mut unique_entities = HashSet::new();
for path in &all_paths {
unique_entities.extend(path.iter().cloned());
}
let discovered_entities: Vec<Entity> = unique_entities
.iter()
.filter_map(|id| graph.get_entity(id).cloned())
.collect();
Ok(TraversalResult {
entities: discovered_entities,
relationships: discovered_relationships,
paths: all_paths,
distances: HashMap::new(),
})
}
#[allow(clippy::too_many_arguments)]
fn find_paths_recursive(
&self,
graph: &KnowledgeGraph,
current: &EntityId,
target: &EntityId,
current_path: &mut Vec<EntityId>,
visited: &mut HashSet<EntityId>,
all_paths: &mut Vec<Vec<EntityId>>,
discovered_relationships: &mut Vec<Relationship>,
depth: usize,
) -> Result<()> {
if depth >= self.config.max_depth || all_paths.len() >= self.config.max_paths {
return Ok(());
}
if current == target {
all_paths.push(current_path.clone());
return Ok(());
}
visited.insert(current.clone());
let neighbors = self.get_neighbors(graph, current);
for (neighbor_id, relationship) in neighbors {
if relationship.confidence < self.config.min_relationship_strength {
continue;
}
if !visited.contains(&neighbor_id) {
current_path.push(neighbor_id.clone());
discovered_relationships.push(relationship);
self.find_paths_recursive(
graph,
&neighbor_id,
target,
current_path,
visited,
all_paths,
discovered_relationships,
depth + 1,
)?;
current_path.pop();
}
}
visited.remove(current);
Ok(())
}
fn get_neighbors(
&self,
graph: &KnowledgeGraph,
entity_id: &EntityId,
) -> Vec<(EntityId, Relationship)> {
let mut neighbors = Vec::new();
for relationship in graph.get_all_relationships() {
if &relationship.source == entity_id {
neighbors.push((relationship.target.clone(), relationship.clone()));
}
if &relationship.target == entity_id {
neighbors.push((relationship.source.clone(), relationship.clone()));
}
}
neighbors
}
pub fn query_focused_subgraph(
&self,
graph: &KnowledgeGraph,
seed_entities: &[EntityId],
expansion_hops: usize,
) -> Result<TraversalResult> {
let mut combined_entities = Vec::new();
let mut combined_relationships = Vec::new();
let mut combined_distances = HashMap::new();
let mut seen_entities = HashSet::new();
let mut seen_relationships = HashSet::new();
for seed in seed_entities {
let ego_result = self.ego_network(graph, seed, Some(expansion_hops))?;
for entity in ego_result.entities {
if !seen_entities.contains(&entity.id) {
seen_entities.insert(entity.id.clone());
combined_entities.push(entity);
}
}
for rel in ego_result.relationships {
let rel_key = (
rel.source.clone(),
rel.target.clone(),
rel.relation_type.clone(),
);
if !seen_relationships.contains(&rel_key) {
seen_relationships.insert(rel_key);
combined_relationships.push(rel);
}
}
for (entity_id, distance) in ego_result.distances {
combined_distances
.entry(entity_id)
.and_modify(|d: &mut usize| *d = (*d).min(distance))
.or_insert(distance);
}
}
Ok(TraversalResult {
entities: combined_entities,
relationships: combined_relationships,
paths: Vec::new(),
distances: combined_distances,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::Entity;
fn create_test_graph() -> KnowledgeGraph {
let mut graph = KnowledgeGraph::new();
let entity_a = Entity::new(
EntityId::new("A".to_string()),
"Entity A".to_string(),
"CONCEPT".to_string(),
0.9,
);
let entity_b = Entity::new(
EntityId::new("B".to_string()),
"Entity B".to_string(),
"CONCEPT".to_string(),
0.9,
);
let entity_c = Entity::new(
EntityId::new("C".to_string()),
"Entity C".to_string(),
"CONCEPT".to_string(),
0.9,
);
let entity_d = Entity::new(
EntityId::new("D".to_string()),
"Entity D".to_string(),
"CONCEPT".to_string(),
0.9,
);
let _ = graph.add_entity(entity_a);
let _ = graph.add_entity(entity_b);
let _ = graph.add_entity(entity_c);
let _ = graph.add_entity(entity_d);
let _ = graph.add_relationship(Relationship {
source: EntityId::new("A".to_string()),
target: EntityId::new("B".to_string()),
relation_type: "RELATED_TO".to_string(),
confidence: 0.8,
context: Vec::new(),
embedding: None,
temporal_type: None,
temporal_range: None,
causal_strength: None,
});
let _ = graph.add_relationship(Relationship {
source: EntityId::new("B".to_string()),
target: EntityId::new("C".to_string()),
relation_type: "RELATED_TO".to_string(),
confidence: 0.9,
context: Vec::new(),
embedding: None,
temporal_type: None,
temporal_range: None,
causal_strength: None,
});
let _ = graph.add_relationship(Relationship {
source: EntityId::new("A".to_string()),
target: EntityId::new("D".to_string()),
relation_type: "RELATED_TO".to_string(),
confidence: 0.7,
context: Vec::new(),
embedding: None,
temporal_type: None,
temporal_range: None,
causal_strength: None,
});
graph
}
#[test]
fn test_bfs_traversal() {
let graph = create_test_graph();
let traversal = GraphTraversal::default();
let source = EntityId::new("A".to_string());
let result = traversal.bfs(&graph, &source).unwrap();
assert!(!result.entities.is_empty());
assert!(result.distances.contains_key(&source));
}
#[test]
fn test_dfs_traversal() {
let graph = create_test_graph();
let traversal = GraphTraversal::default();
let source = EntityId::new("A".to_string());
let result = traversal.dfs(&graph, &source).unwrap();
assert!(!result.entities.is_empty());
assert!(result.distances.contains_key(&source));
}
#[test]
fn test_ego_network() {
let graph = create_test_graph();
let traversal = GraphTraversal::default();
let entity_id = EntityId::new("A".to_string());
let result = traversal.ego_network(&graph, &entity_id, Some(1)).unwrap();
assert!(result.entities.len() >= 2); assert_eq!(*result.distances.get(&entity_id).unwrap(), 0);
}
#[test]
fn test_multi_source_bfs() {
let graph = create_test_graph();
let traversal = GraphTraversal::default();
let sources = vec![
EntityId::new("A".to_string()),
EntityId::new("C".to_string()),
];
let result = traversal.multi_source_bfs(&graph, &sources).unwrap();
assert!(result.entities.len() >= 2);
}
#[test]
fn test_find_all_paths() {
let graph = create_test_graph();
let traversal = GraphTraversal::default();
let source = EntityId::new("A".to_string());
let target = EntityId::new("C".to_string());
let result = traversal.find_all_paths(&graph, &source, &target).unwrap();
assert!(!result.paths.is_empty());
assert!(result.paths[0].contains(&source));
assert!(result.paths[0].contains(&target));
}
#[test]
fn test_query_focused_subgraph() {
let graph = create_test_graph();
let traversal = GraphTraversal::default();
let seeds = vec![EntityId::new("A".to_string())];
let result = traversal.query_focused_subgraph(&graph, &seeds, 2).unwrap();
assert!(!result.entities.is_empty());
assert!(!result.relationships.is_empty());
}
}