1use crate::core::{Entity, EntityId, KnowledgeGraph, Relationship, Result};
19use std::collections::{HashMap, HashSet, VecDeque};
20
21#[derive(Debug, Clone)]
23pub struct TraversalConfig {
24 pub max_depth: usize,
26 pub max_paths: usize,
28 pub use_edge_weights: bool,
30 pub min_relationship_strength: f32,
32}
33
34impl Default for TraversalConfig {
35 fn default() -> Self {
36 Self {
37 max_depth: 3,
38 max_paths: 100,
39 use_edge_weights: true,
40 min_relationship_strength: 0.5,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct TraversalResult {
48 pub entities: Vec<Entity>,
50 pub relationships: Vec<Relationship>,
52 pub paths: Vec<Vec<EntityId>>,
54 pub distances: HashMap<EntityId, usize>,
56}
57
58pub struct GraphTraversal {
60 config: TraversalConfig,
61}
62
63impl Default for GraphTraversal {
64 fn default() -> Self {
65 Self::new(TraversalConfig::default())
66 }
67}
68
69impl GraphTraversal {
70 pub fn new(config: TraversalConfig) -> Self {
72 Self { config }
73 }
74
75 pub fn bfs(&self, graph: &KnowledgeGraph, source: &EntityId) -> Result<TraversalResult> {
87 let mut visited = HashSet::new();
88 let mut queue = VecDeque::new();
89 let mut distances = HashMap::new();
90 let mut discovered_entities = Vec::new();
91 let mut discovered_relationships = Vec::new();
92
93 queue.push_back((source.clone(), 0));
95 distances.insert(source.clone(), 0);
96
97 while let Some((current_id, depth)) = queue.pop_front() {
98 if depth >= self.config.max_depth {
100 continue;
101 }
102
103 if visited.contains(¤t_id) {
105 continue;
106 }
107 visited.insert(current_id.clone());
108
109 if let Some(entity) = graph.get_entity(¤t_id) {
111 discovered_entities.push(entity.clone());
112 }
113
114 let neighbors = self.get_neighbors(graph, ¤t_id);
116
117 for (neighbor_id, relationship) in neighbors {
118 if relationship.confidence < self.config.min_relationship_strength {
120 continue;
121 }
122
123 if !visited.contains(&neighbor_id) {
125 queue.push_back((neighbor_id.clone(), depth + 1));
126 distances.entry(neighbor_id.clone()).or_insert(depth + 1);
127 discovered_relationships.push(relationship);
128 }
129 }
130 }
131
132 Ok(TraversalResult {
133 entities: discovered_entities,
134 relationships: discovered_relationships,
135 paths: Vec::new(), distances,
137 })
138 }
139
140 pub fn dfs(&self, graph: &KnowledgeGraph, source: &EntityId) -> Result<TraversalResult> {
152 let mut visited = HashSet::new();
153 let mut distances = HashMap::new();
154 let mut discovered_entities = Vec::new();
155 let mut discovered_relationships = Vec::new();
156
157 self.dfs_recursive(
158 graph,
159 source,
160 0,
161 &mut visited,
162 &mut distances,
163 &mut discovered_entities,
164 &mut discovered_relationships,
165 )?;
166
167 Ok(TraversalResult {
168 entities: discovered_entities,
169 relationships: discovered_relationships,
170 paths: Vec::new(), distances,
172 })
173 }
174
175 #[allow(clippy::too_many_arguments)]
177 fn dfs_recursive(
178 &self,
179 graph: &KnowledgeGraph,
180 current_id: &EntityId,
181 depth: usize,
182 visited: &mut HashSet<EntityId>,
183 distances: &mut HashMap<EntityId, usize>,
184 discovered_entities: &mut Vec<Entity>,
185 discovered_relationships: &mut Vec<Relationship>,
186 ) -> Result<()> {
187 if depth >= self.config.max_depth {
189 return Ok(());
190 }
191
192 if visited.contains(current_id) {
194 return Ok(());
195 }
196
197 visited.insert(current_id.clone());
198 distances.insert(current_id.clone(), depth);
199
200 if let Some(entity) = graph.get_entity(current_id) {
202 discovered_entities.push(entity.clone());
203 }
204
205 let neighbors = self.get_neighbors(graph, current_id);
207
208 for (neighbor_id, relationship) in neighbors {
209 if relationship.confidence < self.config.min_relationship_strength {
210 continue;
211 }
212
213 if !visited.contains(&neighbor_id) {
214 discovered_relationships.push(relationship);
215 self.dfs_recursive(
216 graph,
217 &neighbor_id,
218 depth + 1,
219 visited,
220 distances,
221 discovered_entities,
222 discovered_relationships,
223 )?;
224 }
225 }
226
227 Ok(())
228 }
229
230 pub fn ego_network(
243 &self,
244 graph: &KnowledgeGraph,
245 entity_id: &EntityId,
246 k_hops: Option<usize>,
247 ) -> Result<TraversalResult> {
248 let hops = k_hops.unwrap_or(self.config.max_depth);
249
250 let mut subgraph_entities = Vec::new();
251 let mut subgraph_relationships = Vec::new();
252 let mut visited = HashSet::new();
253 let mut distances = HashMap::new();
254
255 visited.insert(entity_id.clone());
257 distances.insert(entity_id.clone(), 0);
258
259 if let Some(entity) = graph.get_entity(entity_id) {
260 subgraph_entities.push(entity.clone());
261 }
262
263 let mut current_layer = vec![entity_id.clone()];
265
266 for hop in 1..=hops {
267 let mut next_layer = Vec::new();
268
269 for current_id in ¤t_layer {
270 let neighbors = self.get_neighbors(graph, current_id);
271
272 for (neighbor_id, relationship) in neighbors {
273 if relationship.confidence < self.config.min_relationship_strength {
274 continue;
275 }
276
277 subgraph_relationships.push(relationship);
279
280 if !visited.contains(&neighbor_id) {
282 visited.insert(neighbor_id.clone());
283 distances.insert(neighbor_id.clone(), hop);
284
285 if let Some(entity) = graph.get_entity(&neighbor_id) {
286 subgraph_entities.push(entity.clone());
287 }
288
289 next_layer.push(neighbor_id);
290 }
291 }
292 }
293
294 current_layer = next_layer;
295 }
296
297 Ok(TraversalResult {
298 entities: subgraph_entities,
299 relationships: subgraph_relationships,
300 paths: Vec::new(),
301 distances,
302 })
303 }
304
305 pub fn multi_source_bfs(
317 &self,
318 graph: &KnowledgeGraph,
319 sources: &[EntityId],
320 ) -> Result<TraversalResult> {
321 let mut visited = HashSet::new();
322 let mut queue = VecDeque::new();
323 let mut distances = HashMap::new();
324 let mut discovered_entities = Vec::new();
325 let mut discovered_relationships = Vec::new();
326
327 for source in sources {
329 queue.push_back((source.clone(), 0));
330 distances.insert(source.clone(), 0);
331 }
332
333 while let Some((current_id, depth)) = queue.pop_front() {
334 if depth >= self.config.max_depth {
335 continue;
336 }
337
338 if visited.contains(¤t_id) {
339 continue;
340 }
341 visited.insert(current_id.clone());
342
343 if let Some(entity) = graph.get_entity(¤t_id) {
344 discovered_entities.push(entity.clone());
345 }
346
347 let neighbors = self.get_neighbors(graph, ¤t_id);
348
349 for (neighbor_id, relationship) in neighbors {
350 if relationship.confidence < self.config.min_relationship_strength {
351 continue;
352 }
353
354 if !visited.contains(&neighbor_id) {
355 queue.push_back((neighbor_id.clone(), depth + 1));
356 distances.entry(neighbor_id.clone()).or_insert(depth + 1);
357 discovered_relationships.push(relationship);
358 }
359 }
360 }
361
362 Ok(TraversalResult {
363 entities: discovered_entities,
364 relationships: discovered_relationships,
365 paths: Vec::new(),
366 distances,
367 })
368 }
369
370 pub fn find_all_paths(
383 &self,
384 graph: &KnowledgeGraph,
385 source: &EntityId,
386 target: &EntityId,
387 ) -> Result<TraversalResult> {
388 let mut all_paths = Vec::new();
389 let mut current_path = vec![source.clone()];
390 let mut visited = HashSet::new();
391 let mut discovered_relationships = Vec::new();
392
393 self.find_paths_recursive(
394 graph,
395 source,
396 target,
397 &mut current_path,
398 &mut visited,
399 &mut all_paths,
400 &mut discovered_relationships,
401 0,
402 )?;
403
404 let mut unique_entities = HashSet::new();
406 for path in &all_paths {
407 unique_entities.extend(path.iter().cloned());
408 }
409
410 let discovered_entities: Vec<Entity> = unique_entities
411 .iter()
412 .filter_map(|id| graph.get_entity(id).cloned())
413 .collect();
414
415 Ok(TraversalResult {
416 entities: discovered_entities,
417 relationships: discovered_relationships,
418 paths: all_paths,
419 distances: HashMap::new(),
420 })
421 }
422
423 #[allow(clippy::too_many_arguments)]
425 fn find_paths_recursive(
426 &self,
427 graph: &KnowledgeGraph,
428 current: &EntityId,
429 target: &EntityId,
430 current_path: &mut Vec<EntityId>,
431 visited: &mut HashSet<EntityId>,
432 all_paths: &mut Vec<Vec<EntityId>>,
433 discovered_relationships: &mut Vec<Relationship>,
434 depth: usize,
435 ) -> Result<()> {
436 if depth >= self.config.max_depth || all_paths.len() >= self.config.max_paths {
438 return Ok(());
439 }
440
441 if current == target {
443 all_paths.push(current_path.clone());
444 return Ok(());
445 }
446
447 visited.insert(current.clone());
448
449 let neighbors = self.get_neighbors(graph, current);
450
451 for (neighbor_id, relationship) in neighbors {
452 if relationship.confidence < self.config.min_relationship_strength {
453 continue;
454 }
455
456 if !visited.contains(&neighbor_id) {
457 current_path.push(neighbor_id.clone());
458 discovered_relationships.push(relationship);
459
460 self.find_paths_recursive(
461 graph,
462 &neighbor_id,
463 target,
464 current_path,
465 visited,
466 all_paths,
467 discovered_relationships,
468 depth + 1,
469 )?;
470
471 current_path.pop();
472 }
473 }
474
475 visited.remove(current);
476
477 Ok(())
478 }
479
480 fn get_neighbors(
482 &self,
483 graph: &KnowledgeGraph,
484 entity_id: &EntityId,
485 ) -> Vec<(EntityId, Relationship)> {
486 let mut neighbors = Vec::new();
487
488 for relationship in graph.get_all_relationships() {
490 if &relationship.source == entity_id {
491 neighbors.push((relationship.target.clone(), relationship.clone()));
492 }
493 if &relationship.target == entity_id {
495 neighbors.push((relationship.source.clone(), relationship.clone()));
496 }
497 }
498
499 neighbors
500 }
501
502 pub fn query_focused_subgraph(
517 &self,
518 graph: &KnowledgeGraph,
519 seed_entities: &[EntityId],
520 expansion_hops: usize,
521 ) -> Result<TraversalResult> {
522 let mut combined_entities = Vec::new();
523 let mut combined_relationships = Vec::new();
524 let mut combined_distances = HashMap::new();
525 let mut seen_entities = HashSet::new();
526 let mut seen_relationships = HashSet::new();
527
528 for seed in seed_entities {
530 let ego_result = self.ego_network(graph, seed, Some(expansion_hops))?;
531
532 for entity in ego_result.entities {
533 if !seen_entities.contains(&entity.id) {
534 seen_entities.insert(entity.id.clone());
535 combined_entities.push(entity);
536 }
537 }
538
539 for rel in ego_result.relationships {
540 let rel_key = (
541 rel.source.clone(),
542 rel.target.clone(),
543 rel.relation_type.clone(),
544 );
545 if !seen_relationships.contains(&rel_key) {
546 seen_relationships.insert(rel_key);
547 combined_relationships.push(rel);
548 }
549 }
550
551 for (entity_id, distance) in ego_result.distances {
552 combined_distances
553 .entry(entity_id)
554 .and_modify(|d: &mut usize| *d = (*d).min(distance))
555 .or_insert(distance);
556 }
557 }
558
559 Ok(TraversalResult {
560 entities: combined_entities,
561 relationships: combined_relationships,
562 paths: Vec::new(),
563 distances: combined_distances,
564 })
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571 use crate::core::Entity;
572
573 fn create_test_graph() -> KnowledgeGraph {
574 let mut graph = KnowledgeGraph::new();
575
576 let entity_a = Entity::new(
579 EntityId::new("A".to_string()),
580 "Entity A".to_string(),
581 "CONCEPT".to_string(),
582 0.9,
583 );
584 let entity_b = Entity::new(
585 EntityId::new("B".to_string()),
586 "Entity B".to_string(),
587 "CONCEPT".to_string(),
588 0.9,
589 );
590 let entity_c = Entity::new(
591 EntityId::new("C".to_string()),
592 "Entity C".to_string(),
593 "CONCEPT".to_string(),
594 0.9,
595 );
596 let entity_d = Entity::new(
597 EntityId::new("D".to_string()),
598 "Entity D".to_string(),
599 "CONCEPT".to_string(),
600 0.9,
601 );
602
603 graph.add_entity(entity_a);
604 graph.add_entity(entity_b);
605 graph.add_entity(entity_c);
606 graph.add_entity(entity_d);
607
608 let _ = graph.add_relationship(Relationship {
610 source: EntityId::new("A".to_string()),
611 target: EntityId::new("B".to_string()),
612 relation_type: "RELATED_TO".to_string(),
613 confidence: 0.8,
614 context: Vec::new(),
615 embedding: None,
616 temporal_type: None,
617 temporal_range: None,
618 causal_strength: None,
619 });
620
621 let _ = graph.add_relationship(Relationship {
622 source: EntityId::new("B".to_string()),
623 target: EntityId::new("C".to_string()),
624 relation_type: "RELATED_TO".to_string(),
625 confidence: 0.9,
626 context: Vec::new(),
627 embedding: None,
628 temporal_type: None,
629 temporal_range: None,
630 causal_strength: None,
631 });
632
633 let _ = graph.add_relationship(Relationship {
634 source: EntityId::new("A".to_string()),
635 target: EntityId::new("D".to_string()),
636 relation_type: "RELATED_TO".to_string(),
637 confidence: 0.7,
638 context: Vec::new(),
639 embedding: None,
640 temporal_type: None,
641 temporal_range: None,
642 causal_strength: None,
643 });
644
645 graph
646 }
647
648 #[test]
649 fn test_bfs_traversal() {
650 let graph = create_test_graph();
651 let traversal = GraphTraversal::default();
652 let source = EntityId::new("A".to_string());
653
654 let result = traversal.bfs(&graph, &source).unwrap();
655
656 assert!(result.entities.len() >= 1);
658 assert!(result.distances.contains_key(&source));
659 }
660
661 #[test]
662 fn test_dfs_traversal() {
663 let graph = create_test_graph();
664 let traversal = GraphTraversal::default();
665 let source = EntityId::new("A".to_string());
666
667 let result = traversal.dfs(&graph, &source).unwrap();
668
669 assert!(result.entities.len() >= 1);
671 assert!(result.distances.contains_key(&source));
672 }
673
674 #[test]
675 fn test_ego_network() {
676 let graph = create_test_graph();
677 let traversal = GraphTraversal::default();
678 let entity_id = EntityId::new("A".to_string());
679
680 let result = traversal.ego_network(&graph, &entity_id, Some(1)).unwrap();
681
682 assert!(result.entities.len() >= 2); assert_eq!(*result.distances.get(&entity_id).unwrap(), 0);
685 }
686
687 #[test]
688 fn test_multi_source_bfs() {
689 let graph = create_test_graph();
690 let traversal = GraphTraversal::default();
691 let sources = vec![
692 EntityId::new("A".to_string()),
693 EntityId::new("C".to_string()),
694 ];
695
696 let result = traversal.multi_source_bfs(&graph, &sources).unwrap();
697
698 assert!(result.entities.len() >= 2);
700 }
701
702 #[test]
703 fn test_find_all_paths() {
704 let graph = create_test_graph();
705 let traversal = GraphTraversal::default();
706 let source = EntityId::new("A".to_string());
707 let target = EntityId::new("C".to_string());
708
709 let result = traversal.find_all_paths(&graph, &source, &target).unwrap();
710
711 assert!(!result.paths.is_empty());
713 assert!(result.paths[0].contains(&source));
714 assert!(result.paths[0].contains(&target));
715 }
716
717 #[test]
718 fn test_query_focused_subgraph() {
719 let graph = create_test_graph();
720 let traversal = GraphTraversal::default();
721 let seeds = vec![EntityId::new("A".to_string())];
722
723 let result = traversal.query_focused_subgraph(&graph, &seeds, 2).unwrap();
724
725 assert!(!result.entities.is_empty());
727 assert!(!result.relationships.is_empty());
728 }
729}