Skip to main content

graphrag_core/graph/
traversal.rs

1//! Graph Traversal Algorithms for GraphRAG
2//!
3//! This module implements deterministic graph traversal algorithms that don't require
4//! machine learning, following NLP best practices for knowledge graph exploration:
5//!
6//! - **BFS (Breadth-First Search)**: Level-by-level exploration for shortest paths
7//! - **DFS (Depth-First Search)**: Deep exploration for discovering all paths
8//! - **Ego-Network Extraction**: K-hop neighborhoods around entities
9//! - **Multi-Source Path Finding**: Simultaneous search from multiple entities
10//! - **Query-Focused Subgraph Extraction**: Context-aware subgraph retrieval
11//!
12//! These algorithms are essential for the query phase of GraphRAG, enabling:
13//! - Efficient entity-centric retrieval
14//! - Relationship path discovery
15//! - Context-aware information gathering
16//! - Multi-hop reasoning without neural networks
17
18use crate::core::{Entity, EntityId, KnowledgeGraph, Relationship, Result};
19use std::collections::{HashMap, HashSet, VecDeque};
20
21/// Configuration for graph traversal algorithms
22#[derive(Debug, Clone)]
23pub struct TraversalConfig {
24    /// Maximum depth for BFS/DFS traversal
25    pub max_depth: usize,
26    /// Maximum number of paths to return
27    pub max_paths: usize,
28    /// Whether to include edge weights in path scoring
29    pub use_edge_weights: bool,
30    /// Minimum relationship strength to traverse
31    pub min_relationship_strength: f32,
32}
33
34impl Default for TraversalConfig {
35    fn default() -> Self {
36        Self {
37            max_depth: 3,
38            max_paths: 100,
39            use_edge_weights: true,
40            min_relationship_strength: 0.5,
41        }
42    }
43}
44
45/// Result of a graph traversal operation
46#[derive(Debug, Clone)]
47pub struct TraversalResult {
48    /// Entities discovered during traversal
49    pub entities: Vec<Entity>,
50    /// Relationships traversed
51    pub relationships: Vec<Relationship>,
52    /// Paths found (for path-finding operations)
53    pub paths: Vec<Vec<EntityId>>,
54    /// Distance/depth of each entity from source
55    pub distances: HashMap<EntityId, usize>,
56}
57
58/// Graph traversal system implementing various search algorithms
59pub struct GraphTraversal {
60    config: TraversalConfig,
61}
62
63impl GraphTraversal {
64    /// Create a new graph traversal system
65    pub fn new(config: TraversalConfig) -> Self {
66        Self { config }
67    }
68
69    /// Create with default configuration
70    pub fn default() -> Self {
71        Self::new(TraversalConfig::default())
72    }
73
74    /// Breadth-First Search (BFS) from a source entity
75    ///
76    /// BFS explores the graph level by level, guaranteeing shortest paths.
77    /// Ideal for finding entities within a certain distance from the source.
78    ///
79    /// # Arguments
80    /// * `graph` - The knowledge graph to traverse
81    /// * `source` - Starting entity ID
82    ///
83    /// # Returns
84    /// TraversalResult with entities, relationships, and distances
85    pub fn bfs(
86        &self,
87        graph: &KnowledgeGraph,
88        source: &EntityId,
89    ) -> Result<TraversalResult> {
90        let mut visited = HashSet::new();
91        let mut queue = VecDeque::new();
92        let mut distances = HashMap::new();
93        let mut discovered_entities = Vec::new();
94        let mut discovered_relationships = Vec::new();
95
96        // Initialize with source entity
97        queue.push_back((source.clone(), 0));
98        distances.insert(source.clone(), 0);
99
100        while let Some((current_id, depth)) = queue.pop_front() {
101            // Stop if we've reached max depth
102            if depth >= self.config.max_depth {
103                continue;
104            }
105
106            // Skip if already visited
107            if visited.contains(&current_id) {
108                continue;
109            }
110            visited.insert(current_id.clone());
111
112            // Add current entity to results
113            if let Some(entity) = graph.get_entity(&current_id) {
114                discovered_entities.push(entity.clone());
115            }
116
117            // Get all neighbors (entities connected by relationships)
118            let neighbors = self.get_neighbors(graph, &current_id);
119
120            for (neighbor_id, relationship) in neighbors {
121                // Filter by relationship confidence
122                if relationship.confidence < self.config.min_relationship_strength {
123                    continue;
124                }
125
126                // Add to queue if not visited
127                if !visited.contains(&neighbor_id) {
128                    queue.push_back((neighbor_id.clone(), depth + 1));
129                    distances.entry(neighbor_id.clone()).or_insert(depth + 1);
130                    discovered_relationships.push(relationship);
131                }
132            }
133        }
134
135        Ok(TraversalResult {
136            entities: discovered_entities,
137            relationships: discovered_relationships,
138            paths: Vec::new(), // BFS doesn't track individual paths
139            distances,
140        })
141    }
142
143    /// Depth-First Search (DFS) from a source entity
144    ///
145    /// DFS explores as far as possible along each branch before backtracking.
146    /// Useful for finding all possible paths and deep exploration.
147    ///
148    /// # Arguments
149    /// * `graph` - The knowledge graph to traverse
150    /// * `source` - Starting entity ID
151    ///
152    /// # Returns
153    /// TraversalResult with entities, relationships, and discovered paths
154    pub fn dfs(
155        &self,
156        graph: &KnowledgeGraph,
157        source: &EntityId,
158    ) -> Result<TraversalResult> {
159        let mut visited = HashSet::new();
160        let mut distances = HashMap::new();
161        let mut discovered_entities = Vec::new();
162        let mut discovered_relationships = Vec::new();
163
164        self.dfs_recursive(
165            graph,
166            source,
167            0,
168            &mut visited,
169            &mut distances,
170            &mut discovered_entities,
171            &mut discovered_relationships,
172        )?;
173
174        Ok(TraversalResult {
175            entities: discovered_entities,
176            relationships: discovered_relationships,
177            paths: Vec::new(), // Basic DFS doesn't track paths
178            distances,
179        })
180    }
181
182    /// Recursive DFS helper
183    fn dfs_recursive(
184        &self,
185        graph: &KnowledgeGraph,
186        current_id: &EntityId,
187        depth: usize,
188        visited: &mut HashSet<EntityId>,
189        distances: &mut HashMap<EntityId, usize>,
190        discovered_entities: &mut Vec<Entity>,
191        discovered_relationships: &mut Vec<Relationship>,
192    ) -> Result<()> {
193        // Stop if max depth reached
194        if depth >= self.config.max_depth {
195            return Ok(());
196        }
197
198        // Skip if already visited (avoid cycles)
199        if visited.contains(current_id) {
200            return Ok(());
201        }
202
203        visited.insert(current_id.clone());
204        distances.insert(current_id.clone(), depth);
205
206        // Add current entity
207        if let Some(entity) = graph.get_entity(current_id) {
208            discovered_entities.push(entity.clone());
209        }
210
211        // Recursively visit neighbors
212        let neighbors = self.get_neighbors(graph, current_id);
213
214        for (neighbor_id, relationship) in neighbors {
215            if relationship.confidence < self.config.min_relationship_strength {
216                continue;
217            }
218
219            if !visited.contains(&neighbor_id) {
220                discovered_relationships.push(relationship);
221                self.dfs_recursive(
222                    graph,
223                    &neighbor_id,
224                    depth + 1,
225                    visited,
226                    distances,
227                    discovered_entities,
228                    discovered_relationships,
229                )?;
230            }
231        }
232
233        Ok(())
234    }
235
236    /// Extract K-hop ego-network around an entity
237    ///
238    /// An ego-network is a subgraph containing all entities within K hops
239    /// of the source entity. This is useful for context-aware retrieval.
240    ///
241    /// # Arguments
242    /// * `graph` - The knowledge graph
243    /// * `entity_id` - Center entity for the ego-network
244    /// * `k_hops` - Number of hops to include (defaults to config.max_depth)
245    ///
246    /// # Returns
247    /// TraversalResult with the ego-network subgraph
248    pub fn ego_network(
249        &self,
250        graph: &KnowledgeGraph,
251        entity_id: &EntityId,
252        k_hops: Option<usize>,
253    ) -> Result<TraversalResult> {
254        let hops = k_hops.unwrap_or(self.config.max_depth);
255
256        let mut subgraph_entities = Vec::new();
257        let mut subgraph_relationships = Vec::new();
258        let mut visited = HashSet::new();
259        let mut distances = HashMap::new();
260
261        // Start with the ego entity
262        visited.insert(entity_id.clone());
263        distances.insert(entity_id.clone(), 0);
264
265        if let Some(entity) = graph.get_entity(entity_id) {
266            subgraph_entities.push(entity.clone());
267        }
268
269        // Use BFS to expand outward for k hops
270        let mut current_layer = vec![entity_id.clone()];
271
272        for hop in 1..=hops {
273            let mut next_layer = Vec::new();
274
275            for current_id in &current_layer {
276                let neighbors = self.get_neighbors(graph, current_id);
277
278                for (neighbor_id, relationship) in neighbors {
279                    if relationship.confidence < self.config.min_relationship_strength {
280                        continue;
281                    }
282
283                    // Add relationship
284                    subgraph_relationships.push(relationship);
285
286                    // Add neighbor if not visited
287                    if !visited.contains(&neighbor_id) {
288                        visited.insert(neighbor_id.clone());
289                        distances.insert(neighbor_id.clone(), hop);
290
291                        if let Some(entity) = graph.get_entity(&neighbor_id) {
292                            subgraph_entities.push(entity.clone());
293                        }
294
295                        next_layer.push(neighbor_id);
296                    }
297                }
298            }
299
300            current_layer = next_layer;
301        }
302
303        Ok(TraversalResult {
304            entities: subgraph_entities,
305            relationships: subgraph_relationships,
306            paths: Vec::new(),
307            distances,
308        })
309    }
310
311    /// Multi-source BFS pathfinding
312    ///
313    /// Performs simultaneous BFS from multiple source entities to find
314    /// intersections and common neighbors efficiently.
315    ///
316    /// # Arguments
317    /// * `graph` - The knowledge graph
318    /// * `sources` - Multiple starting entity IDs
319    ///
320    /// # Returns
321    /// TraversalResult with entities reachable from any source
322    pub fn multi_source_bfs(
323        &self,
324        graph: &KnowledgeGraph,
325        sources: &[EntityId],
326    ) -> Result<TraversalResult> {
327        let mut visited = HashSet::new();
328        let mut queue = VecDeque::new();
329        let mut distances = HashMap::new();
330        let mut discovered_entities = Vec::new();
331        let mut discovered_relationships = Vec::new();
332
333        // Initialize queue with all sources
334        for source in sources {
335            queue.push_back((source.clone(), 0));
336            distances.insert(source.clone(), 0);
337        }
338
339        while let Some((current_id, depth)) = queue.pop_front() {
340            if depth >= self.config.max_depth {
341                continue;
342            }
343
344            if visited.contains(&current_id) {
345                continue;
346            }
347            visited.insert(current_id.clone());
348
349            if let Some(entity) = graph.get_entity(&current_id) {
350                discovered_entities.push(entity.clone());
351            }
352
353            let neighbors = self.get_neighbors(graph, &current_id);
354
355            for (neighbor_id, relationship) in neighbors {
356                if relationship.confidence < self.config.min_relationship_strength {
357                    continue;
358                }
359
360                if !visited.contains(&neighbor_id) {
361                    queue.push_back((neighbor_id.clone(), depth + 1));
362                    distances.entry(neighbor_id.clone()).or_insert(depth + 1);
363                    discovered_relationships.push(relationship);
364                }
365            }
366        }
367
368        Ok(TraversalResult {
369            entities: discovered_entities,
370            relationships: discovered_relationships,
371            paths: Vec::new(),
372            distances,
373        })
374    }
375
376    /// Find all paths between two entities
377    ///
378    /// Uses DFS to discover all possible paths from source to target
379    /// within the maximum depth limit.
380    ///
381    /// # Arguments
382    /// * `graph` - The knowledge graph
383    /// * `source` - Starting entity
384    /// * `target` - Target entity
385    ///
386    /// # Returns
387    /// TraversalResult with all discovered paths
388    pub fn find_all_paths(
389        &self,
390        graph: &KnowledgeGraph,
391        source: &EntityId,
392        target: &EntityId,
393    ) -> Result<TraversalResult> {
394        let mut all_paths = Vec::new();
395        let mut current_path = vec![source.clone()];
396        let mut visited = HashSet::new();
397        let mut discovered_relationships = Vec::new();
398
399        self.find_paths_recursive(
400            graph,
401            source,
402            target,
403            &mut current_path,
404            &mut visited,
405            &mut all_paths,
406            &mut discovered_relationships,
407            0,
408        )?;
409
410        // Collect all unique entities from paths
411        let mut unique_entities = HashSet::new();
412        for path in &all_paths {
413            unique_entities.extend(path.iter().cloned());
414        }
415
416        let discovered_entities: Vec<Entity> = unique_entities
417            .iter()
418            .filter_map(|id| graph.get_entity(id).cloned())
419            .collect();
420
421        Ok(TraversalResult {
422            entities: discovered_entities,
423            relationships: discovered_relationships,
424            paths: all_paths,
425            distances: HashMap::new(),
426        })
427    }
428
429    /// Recursive helper for find_all_paths
430    fn find_paths_recursive(
431        &self,
432        graph: &KnowledgeGraph,
433        current: &EntityId,
434        target: &EntityId,
435        current_path: &mut Vec<EntityId>,
436        visited: &mut HashSet<EntityId>,
437        all_paths: &mut Vec<Vec<EntityId>>,
438        discovered_relationships: &mut Vec<Relationship>,
439        depth: usize,
440    ) -> Result<()> {
441        // Stop if max depth or max paths reached
442        if depth >= self.config.max_depth || all_paths.len() >= self.config.max_paths {
443            return Ok(());
444        }
445
446        // Found target - save path
447        if current == target {
448            all_paths.push(current_path.clone());
449            return Ok(());
450        }
451
452        visited.insert(current.clone());
453
454        let neighbors = self.get_neighbors(graph, current);
455
456        for (neighbor_id, relationship) in neighbors {
457            if relationship.confidence < self.config.min_relationship_strength {
458                continue;
459            }
460
461            if !visited.contains(&neighbor_id) {
462                current_path.push(neighbor_id.clone());
463                discovered_relationships.push(relationship);
464
465                self.find_paths_recursive(
466                    graph,
467                    &neighbor_id,
468                    target,
469                    current_path,
470                    visited,
471                    all_paths,
472                    discovered_relationships,
473                    depth + 1,
474                )?;
475
476                current_path.pop();
477            }
478        }
479
480        visited.remove(current);
481
482        Ok(())
483    }
484
485    /// Get neighbors of an entity with their connecting relationships
486    fn get_neighbors(
487        &self,
488        graph: &KnowledgeGraph,
489        entity_id: &EntityId,
490    ) -> Vec<(EntityId, Relationship)> {
491        let mut neighbors = Vec::new();
492
493        // Get all relationships where this entity is the source
494        for relationship in graph.get_all_relationships() {
495            if &relationship.source == entity_id {
496                neighbors.push((relationship.target.clone(), relationship.clone()));
497            }
498            // Also consider bidirectional traversal
499            if &relationship.target == entity_id {
500                neighbors.push((relationship.source.clone(), relationship.clone()));
501            }
502        }
503
504        neighbors
505    }
506
507    /// Extract query-focused subgraph
508    ///
509    /// Extracts a subgraph relevant to a specific query by:
510    /// 1. Identifying seed entities from query
511    /// 2. Expanding via ego-networks
512    /// 3. Filtering by relevance
513    ///
514    /// # Arguments
515    /// * `graph` - The knowledge graph
516    /// * `seed_entities` - Starting entities identified in query
517    /// * `expansion_hops` - How many hops to expand
518    ///
519    /// # Returns
520    /// TraversalResult with query-relevant subgraph
521    pub fn query_focused_subgraph(
522        &self,
523        graph: &KnowledgeGraph,
524        seed_entities: &[EntityId],
525        expansion_hops: usize,
526    ) -> Result<TraversalResult> {
527        let mut combined_entities = Vec::new();
528        let mut combined_relationships = Vec::new();
529        let mut combined_distances = HashMap::new();
530        let mut seen_entities = HashSet::new();
531        let mut seen_relationships = HashSet::new();
532
533        // Extract ego-network for each seed entity
534        for seed in seed_entities {
535            let ego_result = self.ego_network(graph, seed, Some(expansion_hops))?;
536
537            for entity in ego_result.entities {
538                if !seen_entities.contains(&entity.id) {
539                    seen_entities.insert(entity.id.clone());
540                    combined_entities.push(entity);
541                }
542            }
543
544            for rel in ego_result.relationships {
545                let rel_key = (rel.source.clone(), rel.target.clone(), rel.relation_type.clone());
546                if !seen_relationships.contains(&rel_key) {
547                    seen_relationships.insert(rel_key);
548                    combined_relationships.push(rel);
549                }
550            }
551
552            for (entity_id, distance) in ego_result.distances {
553                combined_distances
554                    .entry(entity_id)
555                    .and_modify(|d: &mut usize| *d = (*d).min(distance))
556                    .or_insert(distance);
557            }
558        }
559
560        Ok(TraversalResult {
561            entities: combined_entities,
562            relationships: combined_relationships,
563            paths: Vec::new(),
564            distances: combined_distances,
565        })
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572    use crate::core::{Entity, EntityMention, Relationship};
573
574    fn create_test_graph() -> KnowledgeGraph {
575        let mut graph = KnowledgeGraph::new();
576
577        // Create entities: A -> B -> C
578        //                  A -> D
579        let entity_a = Entity::new(
580            EntityId::new("A".to_string()),
581            "Entity A".to_string(),
582            "CONCEPT".to_string(),
583            0.9,
584        );
585        let entity_b = Entity::new(
586            EntityId::new("B".to_string()),
587            "Entity B".to_string(),
588            "CONCEPT".to_string(),
589            0.9,
590        );
591        let entity_c = Entity::new(
592            EntityId::new("C".to_string()),
593            "Entity C".to_string(),
594            "CONCEPT".to_string(),
595            0.9,
596        );
597        let entity_d = Entity::new(
598            EntityId::new("D".to_string()),
599            "Entity D".to_string(),
600            "CONCEPT".to_string(),
601            0.9,
602        );
603
604        graph.add_entity(entity_a);
605        graph.add_entity(entity_b);
606        graph.add_entity(entity_c);
607        graph.add_entity(entity_d);
608
609        // Add relationships
610        let _ = graph.add_relationship(Relationship {
611            source: EntityId::new("A".to_string()),
612            target: EntityId::new("B".to_string()),
613            relation_type: "RELATED_TO".to_string(),
614            confidence: 0.8,
615            context: Vec::new(),
616        });
617
618        let _ = graph.add_relationship(Relationship {
619            source: EntityId::new("B".to_string()),
620            target: EntityId::new("C".to_string()),
621            relation_type: "RELATED_TO".to_string(),
622            confidence: 0.9,
623            context: Vec::new(),
624        });
625
626        let _ = graph.add_relationship(Relationship {
627            source: EntityId::new("A".to_string()),
628            target: EntityId::new("D".to_string()),
629            relation_type: "RELATED_TO".to_string(),
630            confidence: 0.7,
631            context: Vec::new(),
632        });
633
634        graph
635    }
636
637    #[test]
638    fn test_bfs_traversal() {
639        let graph = create_test_graph();
640        let traversal = GraphTraversal::default();
641        let source = EntityId::new("A".to_string());
642
643        let result = traversal.bfs(&graph, &source).unwrap();
644
645        // Should discover all connected entities
646        assert!(result.entities.len() >= 1);
647        assert!(result.distances.contains_key(&source));
648    }
649
650    #[test]
651    fn test_dfs_traversal() {
652        let graph = create_test_graph();
653        let traversal = GraphTraversal::default();
654        let source = EntityId::new("A".to_string());
655
656        let result = traversal.dfs(&graph, &source).unwrap();
657
658        // Should discover entities through DFS
659        assert!(result.entities.len() >= 1);
660        assert!(result.distances.contains_key(&source));
661    }
662
663    #[test]
664    fn test_ego_network() {
665        let graph = create_test_graph();
666        let traversal = GraphTraversal::default();
667        let entity_id = EntityId::new("A".to_string());
668
669        let result = traversal.ego_network(&graph, &entity_id, Some(1)).unwrap();
670
671        // 1-hop ego network of A should include A, B, and D
672        assert!(result.entities.len() >= 2); // At least A and one neighbor
673        assert_eq!(*result.distances.get(&entity_id).unwrap(), 0);
674    }
675
676    #[test]
677    fn test_multi_source_bfs() {
678        let graph = create_test_graph();
679        let traversal = GraphTraversal::default();
680        let sources = vec![
681            EntityId::new("A".to_string()),
682            EntityId::new("C".to_string()),
683        ];
684
685        let result = traversal.multi_source_bfs(&graph, &sources).unwrap();
686
687        // Should discover entities from both sources
688        assert!(result.entities.len() >= 2);
689    }
690
691    #[test]
692    fn test_find_all_paths() {
693        let graph = create_test_graph();
694        let traversal = GraphTraversal::default();
695        let source = EntityId::new("A".to_string());
696        let target = EntityId::new("C".to_string());
697
698        let result = traversal.find_all_paths(&graph, &source, &target).unwrap();
699
700        // Should find at least one path from A to C (A -> B -> C)
701        assert!(!result.paths.is_empty());
702        assert!(result.paths[0].contains(&source));
703        assert!(result.paths[0].contains(&target));
704    }
705
706    #[test]
707    fn test_query_focused_subgraph() {
708        let graph = create_test_graph();
709        let traversal = GraphTraversal::default();
710        let seeds = vec![EntityId::new("A".to_string())];
711
712        let result = traversal
713            .query_focused_subgraph(&graph, &seeds, 2)
714            .unwrap();
715
716        // Should extract subgraph around seed entity
717        assert!(!result.entities.is_empty());
718        assert!(!result.relationships.is_empty());
719    }
720}