Skip to main content

graphrag_core/graph/
traversal.rs

1//! Graph Traversal Algorithms for GraphRAG
2//!
3//! This module implements deterministic graph traversal algorithms that don't require
4//! machine learning, following NLP best practices for knowledge graph exploration:
5//!
6//! - **BFS (Breadth-First Search)**: Level-by-level exploration for shortest paths
7//! - **DFS (Depth-First Search)**: Deep exploration for discovering all paths
8//! - **Ego-Network Extraction**: K-hop neighborhoods around entities
9//! - **Multi-Source Path Finding**: Simultaneous search from multiple entities
10//! - **Query-Focused Subgraph Extraction**: Context-aware subgraph retrieval
11//!
12//! These algorithms are essential for the query phase of GraphRAG, enabling:
13//! - Efficient entity-centric retrieval
14//! - Relationship path discovery
15//! - Context-aware information gathering
16//! - Multi-hop reasoning without neural networks
17
18use crate::core::{Entity, EntityId, KnowledgeGraph, Relationship, Result};
19use std::collections::{HashMap, HashSet, VecDeque};
20
21/// Configuration for graph traversal algorithms
22#[derive(Debug, Clone)]
23pub struct TraversalConfig {
24    /// Maximum depth for BFS/DFS traversal
25    pub max_depth: usize,
26    /// Maximum number of paths to return
27    pub max_paths: usize,
28    /// Whether to include edge weights in path scoring
29    pub use_edge_weights: bool,
30    /// Minimum relationship strength to traverse
31    pub min_relationship_strength: f32,
32}
33
34impl Default for TraversalConfig {
35    fn default() -> Self {
36        Self {
37            max_depth: 3,
38            max_paths: 100,
39            use_edge_weights: true,
40            min_relationship_strength: 0.5,
41        }
42    }
43}
44
45/// Result of a graph traversal operation
46#[derive(Debug, Clone)]
47pub struct TraversalResult {
48    /// Entities discovered during traversal
49    pub entities: Vec<Entity>,
50    /// Relationships traversed
51    pub relationships: Vec<Relationship>,
52    /// Paths found (for path-finding operations)
53    pub paths: Vec<Vec<EntityId>>,
54    /// Distance/depth of each entity from source
55    pub distances: HashMap<EntityId, usize>,
56}
57
58/// Graph traversal system implementing various search algorithms
59pub struct GraphTraversal {
60    config: TraversalConfig,
61}
62
63impl Default for GraphTraversal {
64    fn default() -> Self {
65        Self::new(TraversalConfig::default())
66    }
67}
68
69impl GraphTraversal {
70    /// Create a new graph traversal system
71    pub fn new(config: TraversalConfig) -> Self {
72        Self { config }
73    }
74
75    /// Breadth-First Search (BFS) from a source entity
76    ///
77    /// BFS explores the graph level by level, guaranteeing shortest paths.
78    /// Ideal for finding entities within a certain distance from the source.
79    ///
80    /// # Arguments
81    /// * `graph` - The knowledge graph to traverse
82    /// * `source` - Starting entity ID
83    ///
84    /// # Returns
85    /// TraversalResult with entities, relationships, and distances
86    pub fn bfs(&self, graph: &KnowledgeGraph, source: &EntityId) -> Result<TraversalResult> {
87        let mut visited = HashSet::new();
88        let mut queue = VecDeque::new();
89        let mut distances = HashMap::new();
90        let mut discovered_entities = Vec::new();
91        let mut discovered_relationships = Vec::new();
92
93        // Initialize with source entity
94        queue.push_back((source.clone(), 0));
95        distances.insert(source.clone(), 0);
96
97        while let Some((current_id, depth)) = queue.pop_front() {
98            // Stop if we've reached max depth
99            if depth >= self.config.max_depth {
100                continue;
101            }
102
103            // Skip if already visited
104            if visited.contains(&current_id) {
105                continue;
106            }
107            visited.insert(current_id.clone());
108
109            // Add current entity to results
110            if let Some(entity) = graph.get_entity(&current_id) {
111                discovered_entities.push(entity.clone());
112            }
113
114            // Get all neighbors (entities connected by relationships)
115            let neighbors = self.get_neighbors(graph, &current_id);
116
117            for (neighbor_id, relationship) in neighbors {
118                // Filter by relationship confidence
119                if relationship.confidence < self.config.min_relationship_strength {
120                    continue;
121                }
122
123                // Add to queue if not visited
124                if !visited.contains(&neighbor_id) {
125                    queue.push_back((neighbor_id.clone(), depth + 1));
126                    distances.entry(neighbor_id.clone()).or_insert(depth + 1);
127                    discovered_relationships.push(relationship);
128                }
129            }
130        }
131
132        Ok(TraversalResult {
133            entities: discovered_entities,
134            relationships: discovered_relationships,
135            paths: Vec::new(), // BFS doesn't track individual paths
136            distances,
137        })
138    }
139
140    /// Depth-First Search (DFS) from a source entity
141    ///
142    /// DFS explores as far as possible along each branch before backtracking.
143    /// Useful for finding all possible paths and deep exploration.
144    ///
145    /// # Arguments
146    /// * `graph` - The knowledge graph to traverse
147    /// * `source` - Starting entity ID
148    ///
149    /// # Returns
150    /// TraversalResult with entities, relationships, and discovered paths
151    pub fn dfs(&self, graph: &KnowledgeGraph, source: &EntityId) -> Result<TraversalResult> {
152        let mut visited = HashSet::new();
153        let mut distances = HashMap::new();
154        let mut discovered_entities = Vec::new();
155        let mut discovered_relationships = Vec::new();
156
157        self.dfs_recursive(
158            graph,
159            source,
160            0,
161            &mut visited,
162            &mut distances,
163            &mut discovered_entities,
164            &mut discovered_relationships,
165        )?;
166
167        Ok(TraversalResult {
168            entities: discovered_entities,
169            relationships: discovered_relationships,
170            paths: Vec::new(), // Basic DFS doesn't track paths
171            distances,
172        })
173    }
174
175    /// Recursive DFS helper
176    #[allow(clippy::too_many_arguments)]
177    fn dfs_recursive(
178        &self,
179        graph: &KnowledgeGraph,
180        current_id: &EntityId,
181        depth: usize,
182        visited: &mut HashSet<EntityId>,
183        distances: &mut HashMap<EntityId, usize>,
184        discovered_entities: &mut Vec<Entity>,
185        discovered_relationships: &mut Vec<Relationship>,
186    ) -> Result<()> {
187        // Stop if max depth reached
188        if depth >= self.config.max_depth {
189            return Ok(());
190        }
191
192        // Skip if already visited (avoid cycles)
193        if visited.contains(current_id) {
194            return Ok(());
195        }
196
197        visited.insert(current_id.clone());
198        distances.insert(current_id.clone(), depth);
199
200        // Add current entity
201        if let Some(entity) = graph.get_entity(current_id) {
202            discovered_entities.push(entity.clone());
203        }
204
205        // Recursively visit neighbors
206        let neighbors = self.get_neighbors(graph, current_id);
207
208        for (neighbor_id, relationship) in neighbors {
209            if relationship.confidence < self.config.min_relationship_strength {
210                continue;
211            }
212
213            if !visited.contains(&neighbor_id) {
214                discovered_relationships.push(relationship);
215                self.dfs_recursive(
216                    graph,
217                    &neighbor_id,
218                    depth + 1,
219                    visited,
220                    distances,
221                    discovered_entities,
222                    discovered_relationships,
223                )?;
224            }
225        }
226
227        Ok(())
228    }
229
230    /// Extract K-hop ego-network around an entity
231    ///
232    /// An ego-network is a subgraph containing all entities within K hops
233    /// of the source entity. This is useful for context-aware retrieval.
234    ///
235    /// # Arguments
236    /// * `graph` - The knowledge graph
237    /// * `entity_id` - Center entity for the ego-network
238    /// * `k_hops` - Number of hops to include (defaults to config.max_depth)
239    ///
240    /// # Returns
241    /// TraversalResult with the ego-network subgraph
242    pub fn ego_network(
243        &self,
244        graph: &KnowledgeGraph,
245        entity_id: &EntityId,
246        k_hops: Option<usize>,
247    ) -> Result<TraversalResult> {
248        let hops = k_hops.unwrap_or(self.config.max_depth);
249
250        let mut subgraph_entities = Vec::new();
251        let mut subgraph_relationships = Vec::new();
252        let mut visited = HashSet::new();
253        let mut distances = HashMap::new();
254
255        // Start with the ego entity
256        visited.insert(entity_id.clone());
257        distances.insert(entity_id.clone(), 0);
258
259        if let Some(entity) = graph.get_entity(entity_id) {
260            subgraph_entities.push(entity.clone());
261        }
262
263        // Use BFS to expand outward for k hops
264        let mut current_layer = vec![entity_id.clone()];
265
266        for hop in 1..=hops {
267            let mut next_layer = Vec::new();
268
269            for current_id in &current_layer {
270                let neighbors = self.get_neighbors(graph, current_id);
271
272                for (neighbor_id, relationship) in neighbors {
273                    if relationship.confidence < self.config.min_relationship_strength {
274                        continue;
275                    }
276
277                    // Add relationship
278                    subgraph_relationships.push(relationship);
279
280                    // Add neighbor if not visited
281                    if !visited.contains(&neighbor_id) {
282                        visited.insert(neighbor_id.clone());
283                        distances.insert(neighbor_id.clone(), hop);
284
285                        if let Some(entity) = graph.get_entity(&neighbor_id) {
286                            subgraph_entities.push(entity.clone());
287                        }
288
289                        next_layer.push(neighbor_id);
290                    }
291                }
292            }
293
294            current_layer = next_layer;
295        }
296
297        Ok(TraversalResult {
298            entities: subgraph_entities,
299            relationships: subgraph_relationships,
300            paths: Vec::new(),
301            distances,
302        })
303    }
304
305    /// Multi-source BFS pathfinding
306    ///
307    /// Performs simultaneous BFS from multiple source entities to find
308    /// intersections and common neighbors efficiently.
309    ///
310    /// # Arguments
311    /// * `graph` - The knowledge graph
312    /// * `sources` - Multiple starting entity IDs
313    ///
314    /// # Returns
315    /// TraversalResult with entities reachable from any source
316    pub fn multi_source_bfs(
317        &self,
318        graph: &KnowledgeGraph,
319        sources: &[EntityId],
320    ) -> Result<TraversalResult> {
321        let mut visited = HashSet::new();
322        let mut queue = VecDeque::new();
323        let mut distances = HashMap::new();
324        let mut discovered_entities = Vec::new();
325        let mut discovered_relationships = Vec::new();
326
327        // Initialize queue with all sources
328        for source in sources {
329            queue.push_back((source.clone(), 0));
330            distances.insert(source.clone(), 0);
331        }
332
333        while let Some((current_id, depth)) = queue.pop_front() {
334            if depth >= self.config.max_depth {
335                continue;
336            }
337
338            if visited.contains(&current_id) {
339                continue;
340            }
341            visited.insert(current_id.clone());
342
343            if let Some(entity) = graph.get_entity(&current_id) {
344                discovered_entities.push(entity.clone());
345            }
346
347            let neighbors = self.get_neighbors(graph, &current_id);
348
349            for (neighbor_id, relationship) in neighbors {
350                if relationship.confidence < self.config.min_relationship_strength {
351                    continue;
352                }
353
354                if !visited.contains(&neighbor_id) {
355                    queue.push_back((neighbor_id.clone(), depth + 1));
356                    distances.entry(neighbor_id.clone()).or_insert(depth + 1);
357                    discovered_relationships.push(relationship);
358                }
359            }
360        }
361
362        Ok(TraversalResult {
363            entities: discovered_entities,
364            relationships: discovered_relationships,
365            paths: Vec::new(),
366            distances,
367        })
368    }
369
370    /// Find all paths between two entities
371    ///
372    /// Uses DFS to discover all possible paths from source to target
373    /// within the maximum depth limit.
374    ///
375    /// # Arguments
376    /// * `graph` - The knowledge graph
377    /// * `source` - Starting entity
378    /// * `target` - Target entity
379    ///
380    /// # Returns
381    /// TraversalResult with all discovered paths
382    pub fn find_all_paths(
383        &self,
384        graph: &KnowledgeGraph,
385        source: &EntityId,
386        target: &EntityId,
387    ) -> Result<TraversalResult> {
388        let mut all_paths = Vec::new();
389        let mut current_path = vec![source.clone()];
390        let mut visited = HashSet::new();
391        let mut discovered_relationships = Vec::new();
392
393        self.find_paths_recursive(
394            graph,
395            source,
396            target,
397            &mut current_path,
398            &mut visited,
399            &mut all_paths,
400            &mut discovered_relationships,
401            0,
402        )?;
403
404        // Collect all unique entities from paths
405        let mut unique_entities = HashSet::new();
406        for path in &all_paths {
407            unique_entities.extend(path.iter().cloned());
408        }
409
410        let discovered_entities: Vec<Entity> = unique_entities
411            .iter()
412            .filter_map(|id| graph.get_entity(id).cloned())
413            .collect();
414
415        Ok(TraversalResult {
416            entities: discovered_entities,
417            relationships: discovered_relationships,
418            paths: all_paths,
419            distances: HashMap::new(),
420        })
421    }
422
423    /// Recursive helper for find_all_paths
424    #[allow(clippy::too_many_arguments)]
425    fn find_paths_recursive(
426        &self,
427        graph: &KnowledgeGraph,
428        current: &EntityId,
429        target: &EntityId,
430        current_path: &mut Vec<EntityId>,
431        visited: &mut HashSet<EntityId>,
432        all_paths: &mut Vec<Vec<EntityId>>,
433        discovered_relationships: &mut Vec<Relationship>,
434        depth: usize,
435    ) -> Result<()> {
436        // Stop if max depth or max paths reached
437        if depth >= self.config.max_depth || all_paths.len() >= self.config.max_paths {
438            return Ok(());
439        }
440
441        // Found target - save path
442        if current == target {
443            all_paths.push(current_path.clone());
444            return Ok(());
445        }
446
447        visited.insert(current.clone());
448
449        let neighbors = self.get_neighbors(graph, current);
450
451        for (neighbor_id, relationship) in neighbors {
452            if relationship.confidence < self.config.min_relationship_strength {
453                continue;
454            }
455
456            if !visited.contains(&neighbor_id) {
457                current_path.push(neighbor_id.clone());
458                discovered_relationships.push(relationship);
459
460                self.find_paths_recursive(
461                    graph,
462                    &neighbor_id,
463                    target,
464                    current_path,
465                    visited,
466                    all_paths,
467                    discovered_relationships,
468                    depth + 1,
469                )?;
470
471                current_path.pop();
472            }
473        }
474
475        visited.remove(current);
476
477        Ok(())
478    }
479
480    /// Get neighbors of an entity with their connecting relationships
481    fn get_neighbors(
482        &self,
483        graph: &KnowledgeGraph,
484        entity_id: &EntityId,
485    ) -> Vec<(EntityId, Relationship)> {
486        let mut neighbors = Vec::new();
487
488        // Get all relationships where this entity is the source
489        for relationship in graph.get_all_relationships() {
490            if &relationship.source == entity_id {
491                neighbors.push((relationship.target.clone(), relationship.clone()));
492            }
493            // Also consider bidirectional traversal
494            if &relationship.target == entity_id {
495                neighbors.push((relationship.source.clone(), relationship.clone()));
496            }
497        }
498
499        neighbors
500    }
501
502    /// Extract query-focused subgraph
503    ///
504    /// Extracts a subgraph relevant to a specific query by:
505    /// 1. Identifying seed entities from query
506    /// 2. Expanding via ego-networks
507    /// 3. Filtering by relevance
508    ///
509    /// # Arguments
510    /// * `graph` - The knowledge graph
511    /// * `seed_entities` - Starting entities identified in query
512    /// * `expansion_hops` - How many hops to expand
513    ///
514    /// # Returns
515    /// TraversalResult with query-relevant subgraph
516    pub fn query_focused_subgraph(
517        &self,
518        graph: &KnowledgeGraph,
519        seed_entities: &[EntityId],
520        expansion_hops: usize,
521    ) -> Result<TraversalResult> {
522        let mut combined_entities = Vec::new();
523        let mut combined_relationships = Vec::new();
524        let mut combined_distances = HashMap::new();
525        let mut seen_entities = HashSet::new();
526        let mut seen_relationships = HashSet::new();
527
528        // Extract ego-network for each seed entity
529        for seed in seed_entities {
530            let ego_result = self.ego_network(graph, seed, Some(expansion_hops))?;
531
532            for entity in ego_result.entities {
533                if !seen_entities.contains(&entity.id) {
534                    seen_entities.insert(entity.id.clone());
535                    combined_entities.push(entity);
536                }
537            }
538
539            for rel in ego_result.relationships {
540                let rel_key = (
541                    rel.source.clone(),
542                    rel.target.clone(),
543                    rel.relation_type.clone(),
544                );
545                if !seen_relationships.contains(&rel_key) {
546                    seen_relationships.insert(rel_key);
547                    combined_relationships.push(rel);
548                }
549            }
550
551            for (entity_id, distance) in ego_result.distances {
552                combined_distances
553                    .entry(entity_id)
554                    .and_modify(|d: &mut usize| *d = (*d).min(distance))
555                    .or_insert(distance);
556            }
557        }
558
559        Ok(TraversalResult {
560            entities: combined_entities,
561            relationships: combined_relationships,
562            paths: Vec::new(),
563            distances: combined_distances,
564        })
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571    use crate::core::Entity;
572
573    fn create_test_graph() -> KnowledgeGraph {
574        let mut graph = KnowledgeGraph::new();
575
576        // Create entities: A -> B -> C
577        //                  A -> D
578        let entity_a = Entity::new(
579            EntityId::new("A".to_string()),
580            "Entity A".to_string(),
581            "CONCEPT".to_string(),
582            0.9,
583        );
584        let entity_b = Entity::new(
585            EntityId::new("B".to_string()),
586            "Entity B".to_string(),
587            "CONCEPT".to_string(),
588            0.9,
589        );
590        let entity_c = Entity::new(
591            EntityId::new("C".to_string()),
592            "Entity C".to_string(),
593            "CONCEPT".to_string(),
594            0.9,
595        );
596        let entity_d = Entity::new(
597            EntityId::new("D".to_string()),
598            "Entity D".to_string(),
599            "CONCEPT".to_string(),
600            0.9,
601        );
602
603        graph.add_entity(entity_a);
604        graph.add_entity(entity_b);
605        graph.add_entity(entity_c);
606        graph.add_entity(entity_d);
607
608        // Add relationships
609        let _ = graph.add_relationship(Relationship {
610            source: EntityId::new("A".to_string()),
611            target: EntityId::new("B".to_string()),
612            relation_type: "RELATED_TO".to_string(),
613            confidence: 0.8,
614            context: Vec::new(),
615            embedding: None,
616            temporal_type: None,
617            temporal_range: None,
618            causal_strength: None,
619        });
620
621        let _ = graph.add_relationship(Relationship {
622            source: EntityId::new("B".to_string()),
623            target: EntityId::new("C".to_string()),
624            relation_type: "RELATED_TO".to_string(),
625            confidence: 0.9,
626            context: Vec::new(),
627            embedding: None,
628            temporal_type: None,
629            temporal_range: None,
630            causal_strength: None,
631        });
632
633        let _ = graph.add_relationship(Relationship {
634            source: EntityId::new("A".to_string()),
635            target: EntityId::new("D".to_string()),
636            relation_type: "RELATED_TO".to_string(),
637            confidence: 0.7,
638            context: Vec::new(),
639            embedding: None,
640            temporal_type: None,
641            temporal_range: None,
642            causal_strength: None,
643        });
644
645        graph
646    }
647
648    #[test]
649    fn test_bfs_traversal() {
650        let graph = create_test_graph();
651        let traversal = GraphTraversal::default();
652        let source = EntityId::new("A".to_string());
653
654        let result = traversal.bfs(&graph, &source).unwrap();
655
656        // Should discover all connected entities
657        assert!(result.entities.len() >= 1);
658        assert!(result.distances.contains_key(&source));
659    }
660
661    #[test]
662    fn test_dfs_traversal() {
663        let graph = create_test_graph();
664        let traversal = GraphTraversal::default();
665        let source = EntityId::new("A".to_string());
666
667        let result = traversal.dfs(&graph, &source).unwrap();
668
669        // Should discover entities through DFS
670        assert!(result.entities.len() >= 1);
671        assert!(result.distances.contains_key(&source));
672    }
673
674    #[test]
675    fn test_ego_network() {
676        let graph = create_test_graph();
677        let traversal = GraphTraversal::default();
678        let entity_id = EntityId::new("A".to_string());
679
680        let result = traversal.ego_network(&graph, &entity_id, Some(1)).unwrap();
681
682        // 1-hop ego network of A should include A, B, and D
683        assert!(result.entities.len() >= 2); // At least A and one neighbor
684        assert_eq!(*result.distances.get(&entity_id).unwrap(), 0);
685    }
686
687    #[test]
688    fn test_multi_source_bfs() {
689        let graph = create_test_graph();
690        let traversal = GraphTraversal::default();
691        let sources = vec![
692            EntityId::new("A".to_string()),
693            EntityId::new("C".to_string()),
694        ];
695
696        let result = traversal.multi_source_bfs(&graph, &sources).unwrap();
697
698        // Should discover entities from both sources
699        assert!(result.entities.len() >= 2);
700    }
701
702    #[test]
703    fn test_find_all_paths() {
704        let graph = create_test_graph();
705        let traversal = GraphTraversal::default();
706        let source = EntityId::new("A".to_string());
707        let target = EntityId::new("C".to_string());
708
709        let result = traversal.find_all_paths(&graph, &source, &target).unwrap();
710
711        // Should find at least one path from A to C (A -> B -> C)
712        assert!(!result.paths.is_empty());
713        assert!(result.paths[0].contains(&source));
714        assert!(result.paths[0].contains(&target));
715    }
716
717    #[test]
718    fn test_query_focused_subgraph() {
719        let graph = create_test_graph();
720        let traversal = GraphTraversal::default();
721        let seeds = vec![EntityId::new("A".to_string())];
722
723        let result = traversal.query_focused_subgraph(&graph, &seeds, 2).unwrap();
724
725        // Should extract subgraph around seed entity
726        assert!(!result.entities.is_empty());
727        assert!(!result.relationships.is_empty());
728    }
729}