1use crate::core::{Entity, EntityId, KnowledgeGraph, Relationship, Result};
19use std::collections::{HashMap, HashSet, VecDeque};
20
21#[derive(Debug, Clone)]
23pub struct TraversalConfig {
24 pub max_depth: usize,
26 pub max_paths: usize,
28 pub use_edge_weights: bool,
30 pub min_relationship_strength: f32,
32}
33
34impl Default for TraversalConfig {
35 fn default() -> Self {
36 Self {
37 max_depth: 3,
38 max_paths: 100,
39 use_edge_weights: true,
40 min_relationship_strength: 0.5,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct TraversalResult {
48 pub entities: Vec<Entity>,
50 pub relationships: Vec<Relationship>,
52 pub paths: Vec<Vec<EntityId>>,
54 pub distances: HashMap<EntityId, usize>,
56}
57
58pub struct GraphTraversal {
60 config: TraversalConfig,
61}
62
63impl GraphTraversal {
64 pub fn new(config: TraversalConfig) -> Self {
66 Self { config }
67 }
68
69 pub fn default() -> Self {
71 Self::new(TraversalConfig::default())
72 }
73
74 pub fn bfs(
86 &self,
87 graph: &KnowledgeGraph,
88 source: &EntityId,
89 ) -> Result<TraversalResult> {
90 let mut visited = HashSet::new();
91 let mut queue = VecDeque::new();
92 let mut distances = HashMap::new();
93 let mut discovered_entities = Vec::new();
94 let mut discovered_relationships = Vec::new();
95
96 queue.push_back((source.clone(), 0));
98 distances.insert(source.clone(), 0);
99
100 while let Some((current_id, depth)) = queue.pop_front() {
101 if depth >= self.config.max_depth {
103 continue;
104 }
105
106 if visited.contains(¤t_id) {
108 continue;
109 }
110 visited.insert(current_id.clone());
111
112 if let Some(entity) = graph.get_entity(¤t_id) {
114 discovered_entities.push(entity.clone());
115 }
116
117 let neighbors = self.get_neighbors(graph, ¤t_id);
119
120 for (neighbor_id, relationship) in neighbors {
121 if relationship.confidence < self.config.min_relationship_strength {
123 continue;
124 }
125
126 if !visited.contains(&neighbor_id) {
128 queue.push_back((neighbor_id.clone(), depth + 1));
129 distances.entry(neighbor_id.clone()).or_insert(depth + 1);
130 discovered_relationships.push(relationship);
131 }
132 }
133 }
134
135 Ok(TraversalResult {
136 entities: discovered_entities,
137 relationships: discovered_relationships,
138 paths: Vec::new(), distances,
140 })
141 }
142
143 pub fn dfs(
155 &self,
156 graph: &KnowledgeGraph,
157 source: &EntityId,
158 ) -> Result<TraversalResult> {
159 let mut visited = HashSet::new();
160 let mut distances = HashMap::new();
161 let mut discovered_entities = Vec::new();
162 let mut discovered_relationships = Vec::new();
163
164 self.dfs_recursive(
165 graph,
166 source,
167 0,
168 &mut visited,
169 &mut distances,
170 &mut discovered_entities,
171 &mut discovered_relationships,
172 )?;
173
174 Ok(TraversalResult {
175 entities: discovered_entities,
176 relationships: discovered_relationships,
177 paths: Vec::new(), distances,
179 })
180 }
181
182 fn dfs_recursive(
184 &self,
185 graph: &KnowledgeGraph,
186 current_id: &EntityId,
187 depth: usize,
188 visited: &mut HashSet<EntityId>,
189 distances: &mut HashMap<EntityId, usize>,
190 discovered_entities: &mut Vec<Entity>,
191 discovered_relationships: &mut Vec<Relationship>,
192 ) -> Result<()> {
193 if depth >= self.config.max_depth {
195 return Ok(());
196 }
197
198 if visited.contains(current_id) {
200 return Ok(());
201 }
202
203 visited.insert(current_id.clone());
204 distances.insert(current_id.clone(), depth);
205
206 if let Some(entity) = graph.get_entity(current_id) {
208 discovered_entities.push(entity.clone());
209 }
210
211 let neighbors = self.get_neighbors(graph, current_id);
213
214 for (neighbor_id, relationship) in neighbors {
215 if relationship.confidence < self.config.min_relationship_strength {
216 continue;
217 }
218
219 if !visited.contains(&neighbor_id) {
220 discovered_relationships.push(relationship);
221 self.dfs_recursive(
222 graph,
223 &neighbor_id,
224 depth + 1,
225 visited,
226 distances,
227 discovered_entities,
228 discovered_relationships,
229 )?;
230 }
231 }
232
233 Ok(())
234 }
235
236 pub fn ego_network(
249 &self,
250 graph: &KnowledgeGraph,
251 entity_id: &EntityId,
252 k_hops: Option<usize>,
253 ) -> Result<TraversalResult> {
254 let hops = k_hops.unwrap_or(self.config.max_depth);
255
256 let mut subgraph_entities = Vec::new();
257 let mut subgraph_relationships = Vec::new();
258 let mut visited = HashSet::new();
259 let mut distances = HashMap::new();
260
261 visited.insert(entity_id.clone());
263 distances.insert(entity_id.clone(), 0);
264
265 if let Some(entity) = graph.get_entity(entity_id) {
266 subgraph_entities.push(entity.clone());
267 }
268
269 let mut current_layer = vec![entity_id.clone()];
271
272 for hop in 1..=hops {
273 let mut next_layer = Vec::new();
274
275 for current_id in ¤t_layer {
276 let neighbors = self.get_neighbors(graph, current_id);
277
278 for (neighbor_id, relationship) in neighbors {
279 if relationship.confidence < self.config.min_relationship_strength {
280 continue;
281 }
282
283 subgraph_relationships.push(relationship);
285
286 if !visited.contains(&neighbor_id) {
288 visited.insert(neighbor_id.clone());
289 distances.insert(neighbor_id.clone(), hop);
290
291 if let Some(entity) = graph.get_entity(&neighbor_id) {
292 subgraph_entities.push(entity.clone());
293 }
294
295 next_layer.push(neighbor_id);
296 }
297 }
298 }
299
300 current_layer = next_layer;
301 }
302
303 Ok(TraversalResult {
304 entities: subgraph_entities,
305 relationships: subgraph_relationships,
306 paths: Vec::new(),
307 distances,
308 })
309 }
310
311 pub fn multi_source_bfs(
323 &self,
324 graph: &KnowledgeGraph,
325 sources: &[EntityId],
326 ) -> Result<TraversalResult> {
327 let mut visited = HashSet::new();
328 let mut queue = VecDeque::new();
329 let mut distances = HashMap::new();
330 let mut discovered_entities = Vec::new();
331 let mut discovered_relationships = Vec::new();
332
333 for source in sources {
335 queue.push_back((source.clone(), 0));
336 distances.insert(source.clone(), 0);
337 }
338
339 while let Some((current_id, depth)) = queue.pop_front() {
340 if depth >= self.config.max_depth {
341 continue;
342 }
343
344 if visited.contains(¤t_id) {
345 continue;
346 }
347 visited.insert(current_id.clone());
348
349 if let Some(entity) = graph.get_entity(¤t_id) {
350 discovered_entities.push(entity.clone());
351 }
352
353 let neighbors = self.get_neighbors(graph, ¤t_id);
354
355 for (neighbor_id, relationship) in neighbors {
356 if relationship.confidence < self.config.min_relationship_strength {
357 continue;
358 }
359
360 if !visited.contains(&neighbor_id) {
361 queue.push_back((neighbor_id.clone(), depth + 1));
362 distances.entry(neighbor_id.clone()).or_insert(depth + 1);
363 discovered_relationships.push(relationship);
364 }
365 }
366 }
367
368 Ok(TraversalResult {
369 entities: discovered_entities,
370 relationships: discovered_relationships,
371 paths: Vec::new(),
372 distances,
373 })
374 }
375
376 pub fn find_all_paths(
389 &self,
390 graph: &KnowledgeGraph,
391 source: &EntityId,
392 target: &EntityId,
393 ) -> Result<TraversalResult> {
394 let mut all_paths = Vec::new();
395 let mut current_path = vec![source.clone()];
396 let mut visited = HashSet::new();
397 let mut discovered_relationships = Vec::new();
398
399 self.find_paths_recursive(
400 graph,
401 source,
402 target,
403 &mut current_path,
404 &mut visited,
405 &mut all_paths,
406 &mut discovered_relationships,
407 0,
408 )?;
409
410 let mut unique_entities = HashSet::new();
412 for path in &all_paths {
413 unique_entities.extend(path.iter().cloned());
414 }
415
416 let discovered_entities: Vec<Entity> = unique_entities
417 .iter()
418 .filter_map(|id| graph.get_entity(id).cloned())
419 .collect();
420
421 Ok(TraversalResult {
422 entities: discovered_entities,
423 relationships: discovered_relationships,
424 paths: all_paths,
425 distances: HashMap::new(),
426 })
427 }
428
429 fn find_paths_recursive(
431 &self,
432 graph: &KnowledgeGraph,
433 current: &EntityId,
434 target: &EntityId,
435 current_path: &mut Vec<EntityId>,
436 visited: &mut HashSet<EntityId>,
437 all_paths: &mut Vec<Vec<EntityId>>,
438 discovered_relationships: &mut Vec<Relationship>,
439 depth: usize,
440 ) -> Result<()> {
441 if depth >= self.config.max_depth || all_paths.len() >= self.config.max_paths {
443 return Ok(());
444 }
445
446 if current == target {
448 all_paths.push(current_path.clone());
449 return Ok(());
450 }
451
452 visited.insert(current.clone());
453
454 let neighbors = self.get_neighbors(graph, current);
455
456 for (neighbor_id, relationship) in neighbors {
457 if relationship.confidence < self.config.min_relationship_strength {
458 continue;
459 }
460
461 if !visited.contains(&neighbor_id) {
462 current_path.push(neighbor_id.clone());
463 discovered_relationships.push(relationship);
464
465 self.find_paths_recursive(
466 graph,
467 &neighbor_id,
468 target,
469 current_path,
470 visited,
471 all_paths,
472 discovered_relationships,
473 depth + 1,
474 )?;
475
476 current_path.pop();
477 }
478 }
479
480 visited.remove(current);
481
482 Ok(())
483 }
484
485 fn get_neighbors(
487 &self,
488 graph: &KnowledgeGraph,
489 entity_id: &EntityId,
490 ) -> Vec<(EntityId, Relationship)> {
491 let mut neighbors = Vec::new();
492
493 for relationship in graph.get_all_relationships() {
495 if &relationship.source == entity_id {
496 neighbors.push((relationship.target.clone(), relationship.clone()));
497 }
498 if &relationship.target == entity_id {
500 neighbors.push((relationship.source.clone(), relationship.clone()));
501 }
502 }
503
504 neighbors
505 }
506
507 pub fn query_focused_subgraph(
522 &self,
523 graph: &KnowledgeGraph,
524 seed_entities: &[EntityId],
525 expansion_hops: usize,
526 ) -> Result<TraversalResult> {
527 let mut combined_entities = Vec::new();
528 let mut combined_relationships = Vec::new();
529 let mut combined_distances = HashMap::new();
530 let mut seen_entities = HashSet::new();
531 let mut seen_relationships = HashSet::new();
532
533 for seed in seed_entities {
535 let ego_result = self.ego_network(graph, seed, Some(expansion_hops))?;
536
537 for entity in ego_result.entities {
538 if !seen_entities.contains(&entity.id) {
539 seen_entities.insert(entity.id.clone());
540 combined_entities.push(entity);
541 }
542 }
543
544 for rel in ego_result.relationships {
545 let rel_key = (rel.source.clone(), rel.target.clone(), rel.relation_type.clone());
546 if !seen_relationships.contains(&rel_key) {
547 seen_relationships.insert(rel_key);
548 combined_relationships.push(rel);
549 }
550 }
551
552 for (entity_id, distance) in ego_result.distances {
553 combined_distances
554 .entry(entity_id)
555 .and_modify(|d: &mut usize| *d = (*d).min(distance))
556 .or_insert(distance);
557 }
558 }
559
560 Ok(TraversalResult {
561 entities: combined_entities,
562 relationships: combined_relationships,
563 paths: Vec::new(),
564 distances: combined_distances,
565 })
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572 use crate::core::{Entity, EntityMention, Relationship};
573
574 fn create_test_graph() -> KnowledgeGraph {
575 let mut graph = KnowledgeGraph::new();
576
577 let entity_a = Entity::new(
580 EntityId::new("A".to_string()),
581 "Entity A".to_string(),
582 "CONCEPT".to_string(),
583 0.9,
584 );
585 let entity_b = Entity::new(
586 EntityId::new("B".to_string()),
587 "Entity B".to_string(),
588 "CONCEPT".to_string(),
589 0.9,
590 );
591 let entity_c = Entity::new(
592 EntityId::new("C".to_string()),
593 "Entity C".to_string(),
594 "CONCEPT".to_string(),
595 0.9,
596 );
597 let entity_d = Entity::new(
598 EntityId::new("D".to_string()),
599 "Entity D".to_string(),
600 "CONCEPT".to_string(),
601 0.9,
602 );
603
604 graph.add_entity(entity_a);
605 graph.add_entity(entity_b);
606 graph.add_entity(entity_c);
607 graph.add_entity(entity_d);
608
609 let _ = graph.add_relationship(Relationship {
611 source: EntityId::new("A".to_string()),
612 target: EntityId::new("B".to_string()),
613 relation_type: "RELATED_TO".to_string(),
614 confidence: 0.8,
615 context: Vec::new(),
616 });
617
618 let _ = graph.add_relationship(Relationship {
619 source: EntityId::new("B".to_string()),
620 target: EntityId::new("C".to_string()),
621 relation_type: "RELATED_TO".to_string(),
622 confidence: 0.9,
623 context: Vec::new(),
624 });
625
626 let _ = graph.add_relationship(Relationship {
627 source: EntityId::new("A".to_string()),
628 target: EntityId::new("D".to_string()),
629 relation_type: "RELATED_TO".to_string(),
630 confidence: 0.7,
631 context: Vec::new(),
632 });
633
634 graph
635 }
636
637 #[test]
638 fn test_bfs_traversal() {
639 let graph = create_test_graph();
640 let traversal = GraphTraversal::default();
641 let source = EntityId::new("A".to_string());
642
643 let result = traversal.bfs(&graph, &source).unwrap();
644
645 assert!(result.entities.len() >= 1);
647 assert!(result.distances.contains_key(&source));
648 }
649
650 #[test]
651 fn test_dfs_traversal() {
652 let graph = create_test_graph();
653 let traversal = GraphTraversal::default();
654 let source = EntityId::new("A".to_string());
655
656 let result = traversal.dfs(&graph, &source).unwrap();
657
658 assert!(result.entities.len() >= 1);
660 assert!(result.distances.contains_key(&source));
661 }
662
663 #[test]
664 fn test_ego_network() {
665 let graph = create_test_graph();
666 let traversal = GraphTraversal::default();
667 let entity_id = EntityId::new("A".to_string());
668
669 let result = traversal.ego_network(&graph, &entity_id, Some(1)).unwrap();
670
671 assert!(result.entities.len() >= 2); assert_eq!(*result.distances.get(&entity_id).unwrap(), 0);
674 }
675
676 #[test]
677 fn test_multi_source_bfs() {
678 let graph = create_test_graph();
679 let traversal = GraphTraversal::default();
680 let sources = vec![
681 EntityId::new("A".to_string()),
682 EntityId::new("C".to_string()),
683 ];
684
685 let result = traversal.multi_source_bfs(&graph, &sources).unwrap();
686
687 assert!(result.entities.len() >= 2);
689 }
690
691 #[test]
692 fn test_find_all_paths() {
693 let graph = create_test_graph();
694 let traversal = GraphTraversal::default();
695 let source = EntityId::new("A".to_string());
696 let target = EntityId::new("C".to_string());
697
698 let result = traversal.find_all_paths(&graph, &source, &target).unwrap();
699
700 assert!(!result.paths.is_empty());
702 assert!(result.paths[0].contains(&source));
703 assert!(result.paths[0].contains(&target));
704 }
705
706 #[test]
707 fn test_query_focused_subgraph() {
708 let graph = create_test_graph();
709 let traversal = GraphTraversal::default();
710 let seeds = vec![EntityId::new("A".to_string())];
711
712 let result = traversal
713 .query_focused_subgraph(&graph, &seeds, 2)
714 .unwrap();
715
716 assert!(!result.entities.is_empty());
718 assert!(!result.relationships.is_empty());
719 }
720}