Skip to main content

parsnip_core/
traversal.rs

1//! Graph traversal types and algorithms
2
3use crate::entity::Entity;
4use crate::relation::{Direction, Relation};
5use serde::{Deserialize, Serialize};
6use std::cmp::Ordering;
7use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
8
9/// Traversal query builder (follows SearchQuery pattern)
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct TraversalQuery {
12    /// Starting entity name
13    pub start: String,
14
15    /// Target entity name (for path finding, None for general traversal)
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub target: Option<String>,
18
19    /// Maximum traversal depth
20    #[serde(default = "default_depth")]
21    pub max_depth: u32,
22
23    /// Traversal direction
24    #[serde(default)]
25    pub direction: Direction,
26
27    /// Filter by entity types (empty = all types)
28    #[serde(default)]
29    pub entity_type_filter: Vec<String>,
30
31    /// Filter by relation types (empty = all types)
32    #[serde(default)]
33    pub relation_type_filter: Vec<String>,
34
35    /// Use weighted shortest path (Dijkstra)
36    #[serde(default)]
37    pub use_weights: bool,
38
39    /// Return all paths (not just shortest)
40    #[serde(default)]
41    pub all_paths: bool,
42
43    /// Maximum paths to return
44    #[serde(default = "default_max_paths")]
45    pub max_paths: usize,
46}
47
48fn default_depth() -> u32 {
49    10
50}
51
52fn default_max_paths() -> usize {
53    5
54}
55
56impl Default for TraversalQuery {
57    fn default() -> Self {
58        Self {
59            start: String::new(),
60            target: None,
61            max_depth: default_depth(),
62            direction: Direction::Both,
63            entity_type_filter: Vec::new(),
64            relation_type_filter: Vec::new(),
65            use_weights: false,
66            all_paths: false,
67            max_paths: default_max_paths(),
68        }
69    }
70}
71
72impl TraversalQuery {
73    /// Create a new traversal query starting from an entity
74    pub fn new(start: impl Into<String>) -> Self {
75        Self {
76            start: start.into(),
77            ..Default::default()
78        }
79    }
80
81    /// Set target for path finding
82    pub fn find_path_to(mut self, target: impl Into<String>) -> Self {
83        self.target = Some(target.into());
84        self
85    }
86
87    /// Set maximum traversal depth
88    pub fn with_depth(mut self, depth: u32) -> Self {
89        self.max_depth = depth;
90        self
91    }
92
93    /// Set traversal direction
94    pub fn with_direction(mut self, direction: Direction) -> Self {
95        self.direction = direction;
96        self
97    }
98
99    /// Filter by entity types during traversal
100    pub fn filter_entity_types(mut self, types: Vec<String>) -> Self {
101        self.entity_type_filter = types;
102        self
103    }
104
105    /// Filter by relation types during traversal
106    pub fn filter_relation_types(mut self, types: Vec<String>) -> Self {
107        self.relation_type_filter = types;
108        self
109    }
110
111    /// Use weighted shortest path (Dijkstra algorithm)
112    pub fn weighted(mut self) -> Self {
113        self.use_weights = true;
114        self
115    }
116
117    /// Return all paths up to max
118    pub fn all_paths(mut self, max: usize) -> Self {
119        self.all_paths = true;
120        self.max_paths = max;
121        self
122    }
123}
124
125/// A single path through the graph
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct GraphPath {
128    /// Ordered list of entity names in the path
129    pub nodes: Vec<String>,
130
131    /// Relations connecting the nodes
132    pub edges: Vec<PathEdge>,
133
134    /// Total path weight (sum of relation weights, 1.0 for unweighted)
135    pub total_weight: f64,
136
137    /// Path length (number of edges)
138    pub length: usize,
139}
140
141/// Edge in a path
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct PathEdge {
144    pub from: String,
145    pub to: String,
146    pub relation_type: String,
147    pub weight: Option<f64>,
148}
149
150/// Result of a traversal operation
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct TraversalResult {
153    /// Starting entity
154    pub start: String,
155
156    /// Target entity (if path finding)
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub target: Option<String>,
159
160    /// Found paths (for path finding)
161    pub paths: Vec<GraphPath>,
162
163    /// Visited entities (for general traversal)
164    pub visited_entities: Vec<String>,
165
166    /// All entities in result
167    pub entities: Vec<Entity>,
168
169    /// All relations in result
170    pub relations: Vec<Relation>,
171
172    /// Statistics
173    pub stats: TraversalStats,
174}
175
176/// Traversal statistics
177#[derive(Debug, Clone, Default, Serialize, Deserialize)]
178pub struct TraversalStats {
179    pub nodes_visited: usize,
180    pub edges_traversed: usize,
181    pub max_depth_reached: u32,
182    pub path_found: bool,
183}
184
185/// State for Dijkstra priority queue
186#[derive(Clone, PartialEq)]
187struct DijkstraState {
188    cost: f64,
189    node: String,
190}
191
192impl Eq for DijkstraState {}
193
194impl Ord for DijkstraState {
195    fn cmp(&self, other: &Self) -> Ordering {
196        // Reverse order for min-heap
197        other
198            .cost
199            .partial_cmp(&self.cost)
200            .unwrap_or(Ordering::Equal)
201    }
202}
203
204impl PartialOrd for DijkstraState {
205    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
206        Some(self.cmp(other))
207    }
208}
209
210/// Graph traversal engine
211pub struct TraversalEngine;
212
213impl TraversalEngine {
214    /// Execute a traversal query
215    pub fn execute(
216        query: &TraversalQuery,
217        entities: &HashMap<String, Entity>,
218        relations: &[Relation],
219    ) -> TraversalResult {
220        tracing::debug!(
221            "Executing traversal: start={}, target={:?}, depth={}, direction={:?}",
222            query.start,
223            query.target,
224            query.max_depth,
225            query.direction
226        );
227
228        if query.target.is_some() {
229            if query.use_weights {
230                Self::dijkstra_path(query, entities, relations)
231            } else {
232                Self::bfs_path(query, entities, relations)
233            }
234        } else {
235            Self::filtered_bfs(query, entities, relations)
236        }
237    }
238
239    /// BFS for unweighted shortest path
240    fn bfs_path(
241        query: &TraversalQuery,
242        entities: &HashMap<String, Entity>,
243        relations: &[Relation],
244    ) -> TraversalResult {
245        let target = query.target.as_ref().unwrap();
246        let mut visited: HashSet<String> = HashSet::new();
247        let mut parent: HashMap<String, (String, PathEdge)> = HashMap::new();
248        let mut queue: VecDeque<(String, u32)> = VecDeque::new();
249        let mut stats = TraversalStats::default();
250
251        queue.push_back((query.start.clone(), 0));
252        visited.insert(query.start.clone());
253
254        while let Some((current, depth)) = queue.pop_front() {
255            stats.nodes_visited += 1;
256            stats.max_depth_reached = stats.max_depth_reached.max(depth);
257
258            if &current == target {
259                stats.path_found = true;
260                tracing::debug!("BFS found path at depth {}", depth);
261                break;
262            }
263
264            if depth >= query.max_depth {
265                continue;
266            }
267
268            for rel in Self::get_neighbors(&current, &query.direction, relations) {
269                stats.edges_traversed += 1;
270
271                // Apply relation type filter
272                if !query.relation_type_filter.is_empty()
273                    && !query.relation_type_filter.contains(&rel.relation_type)
274                {
275                    continue;
276                }
277
278                let next = if rel.from_name == current {
279                    &rel.to_name
280                } else {
281                    &rel.from_name
282                };
283
284                // Apply entity type filter
285                if let Some(entity) = entities.get(next) {
286                    if !query.entity_type_filter.is_empty()
287                        && !query.entity_type_filter.contains(&entity.entity_type.0)
288                    {
289                        continue;
290                    }
291                }
292
293                if !visited.contains(next) {
294                    visited.insert(next.clone());
295                    parent.insert(
296                        next.clone(),
297                        (
298                            current.clone(),
299                            PathEdge {
300                                from: rel.from_name.clone(),
301                                to: rel.to_name.clone(),
302                                relation_type: rel.relation_type.clone(),
303                                weight: rel.weight,
304                            },
305                        ),
306                    );
307                    queue.push_back((next.clone(), depth + 1));
308                }
309            }
310        }
311
312        // Reconstruct path
313        let paths = if stats.path_found {
314            vec![Self::reconstruct_path(&query.start, target, &parent)]
315        } else {
316            vec![]
317        };
318
319        Self::build_result(query, paths, &visited, entities, relations, stats)
320    }
321
322    /// Dijkstra's algorithm for weighted shortest path
323    fn dijkstra_path(
324        query: &TraversalQuery,
325        entities: &HashMap<String, Entity>,
326        relations: &[Relation],
327    ) -> TraversalResult {
328        let target = query.target.as_ref().unwrap();
329        let mut dist: HashMap<String, f64> = HashMap::new();
330        let mut parent: HashMap<String, (String, PathEdge)> = HashMap::new();
331        let mut heap = BinaryHeap::new();
332        let mut stats = TraversalStats::default();
333
334        dist.insert(query.start.clone(), 0.0);
335        heap.push(DijkstraState {
336            cost: 0.0,
337            node: query.start.clone(),
338        });
339
340        while let Some(DijkstraState { cost, node }) = heap.pop() {
341            stats.nodes_visited += 1;
342
343            if &node == target {
344                stats.path_found = true;
345                tracing::debug!("Dijkstra found path with cost {}", cost);
346                break;
347            }
348
349            // Skip if we already found a better path
350            if cost > *dist.get(&node).unwrap_or(&f64::INFINITY) {
351                continue;
352            }
353
354            for rel in Self::get_neighbors(&node, &query.direction, relations) {
355                stats.edges_traversed += 1;
356
357                // Apply relation type filter
358                if !query.relation_type_filter.is_empty()
359                    && !query.relation_type_filter.contains(&rel.relation_type)
360                {
361                    continue;
362                }
363
364                let next = if rel.from_name == node {
365                    &rel.to_name
366                } else {
367                    &rel.from_name
368                };
369
370                // Apply entity type filter
371                if let Some(entity) = entities.get(next) {
372                    if !query.entity_type_filter.is_empty()
373                        && !query.entity_type_filter.contains(&entity.entity_type.0)
374                    {
375                        continue;
376                    }
377                }
378
379                let edge_weight = rel.weight.unwrap_or(1.0);
380                let new_cost = cost + edge_weight;
381
382                if new_cost < *dist.get(next).unwrap_or(&f64::INFINITY) {
383                    dist.insert(next.clone(), new_cost);
384                    parent.insert(
385                        next.clone(),
386                        (
387                            node.clone(),
388                            PathEdge {
389                                from: rel.from_name.clone(),
390                                to: rel.to_name.clone(),
391                                relation_type: rel.relation_type.clone(),
392                                weight: rel.weight,
393                            },
394                        ),
395                    );
396                    heap.push(DijkstraState {
397                        cost: new_cost,
398                        node: next.clone(),
399                    });
400                }
401            }
402        }
403
404        let paths = if stats.path_found {
405            vec![Self::reconstruct_path(&query.start, target, &parent)]
406        } else {
407            vec![]
408        };
409
410        let visited: HashSet<String> = dist.keys().cloned().collect();
411        Self::build_result(query, paths, &visited, entities, relations, stats)
412    }
413
414    /// Filtered BFS traversal (no target)
415    fn filtered_bfs(
416        query: &TraversalQuery,
417        entities: &HashMap<String, Entity>,
418        relations: &[Relation],
419    ) -> TraversalResult {
420        let mut visited: HashSet<String> = HashSet::new();
421        let mut queue: VecDeque<(String, u32)> = VecDeque::new();
422        let mut stats = TraversalStats::default();
423
424        queue.push_back((query.start.clone(), 0));
425        visited.insert(query.start.clone());
426
427        while let Some((current, depth)) = queue.pop_front() {
428            stats.nodes_visited += 1;
429            stats.max_depth_reached = stats.max_depth_reached.max(depth);
430
431            if depth >= query.max_depth {
432                continue;
433            }
434
435            for rel in Self::get_neighbors(&current, &query.direction, relations) {
436                stats.edges_traversed += 1;
437
438                // Apply relation type filter
439                if !query.relation_type_filter.is_empty()
440                    && !query.relation_type_filter.contains(&rel.relation_type)
441                {
442                    continue;
443                }
444
445                let next = if rel.from_name == current {
446                    &rel.to_name
447                } else {
448                    &rel.from_name
449                };
450
451                // Apply entity type filter
452                if let Some(entity) = entities.get(next) {
453                    if !query.entity_type_filter.is_empty()
454                        && !query.entity_type_filter.contains(&entity.entity_type.0)
455                    {
456                        continue;
457                    }
458                }
459
460                if !visited.contains(next) {
461                    visited.insert(next.clone());
462                    queue.push_back((next.clone(), depth + 1));
463                }
464            }
465        }
466
467        tracing::debug!(
468            "Filtered BFS visited {} nodes, traversed {} edges",
469            stats.nodes_visited,
470            stats.edges_traversed
471        );
472
473        Self::build_result(query, vec![], &visited, entities, relations, stats)
474    }
475
476    /// Get neighboring relations for a node based on direction
477    fn get_neighbors<'a>(
478        node: &str,
479        direction: &Direction,
480        relations: &'a [Relation],
481    ) -> Vec<&'a Relation> {
482        relations
483            .iter()
484            .filter(|rel| match direction {
485                Direction::Outgoing => rel.from_name == node,
486                Direction::Incoming => rel.to_name == node,
487                Direction::Both => rel.from_name == node || rel.to_name == node,
488            })
489            .collect()
490    }
491
492    /// Reconstruct path from parent map
493    fn reconstruct_path(
494        start: &str,
495        end: &str,
496        parent: &HashMap<String, (String, PathEdge)>,
497    ) -> GraphPath {
498        let mut nodes = vec![end.to_string()];
499        let mut edges = Vec::new();
500        let mut current = end.to_string();
501        let mut total_weight = 0.0;
502
503        while &current != start {
504            if let Some((prev, edge)) = parent.get(&current) {
505                total_weight += edge.weight.unwrap_or(1.0);
506                edges.push(edge.clone());
507                nodes.push(prev.clone());
508                current = prev.clone();
509            } else {
510                break;
511            }
512        }
513
514        nodes.reverse();
515        edges.reverse();
516
517        GraphPath {
518            length: edges.len(),
519            nodes,
520            edges,
521            total_weight,
522        }
523    }
524
525    /// Build result from traversal data
526    fn build_result(
527        query: &TraversalQuery,
528        paths: Vec<GraphPath>,
529        visited: &HashSet<String>,
530        entities: &HashMap<String, Entity>,
531        relations: &[Relation],
532        stats: TraversalStats,
533    ) -> TraversalResult {
534        let visited_entities: Vec<String> = visited.iter().cloned().collect();
535
536        let result_entities: Vec<Entity> = visited_entities
537            .iter()
538            .filter_map(|name| entities.get(name).cloned())
539            .collect();
540
541        let result_relations: Vec<Relation> = relations
542            .iter()
543            .filter(|r| visited.contains(&r.from_name) && visited.contains(&r.to_name))
544            .cloned()
545            .collect();
546
547        TraversalResult {
548            start: query.start.clone(),
549            target: query.target.clone(),
550            paths,
551            visited_entities,
552            entities: result_entities,
553            relations: result_relations,
554            stats,
555        }
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562    use crate::project::ProjectId;
563
564    fn create_test_graph() -> (HashMap<String, Entity>, Vec<Relation>) {
565        let project_id = ProjectId::new();
566
567        // Create entities: A, B, C, D, E, F
568        let mut entities = HashMap::new();
569        for name in ["A", "B", "C", "D", "E", "F"] {
570            let entity = Entity::new(project_id.clone(), name, "node");
571            entities.insert(name.to_string(), entity);
572        }
573
574        // Create graph:
575        // A --1.0--> B --2.0--> C --1.0--> D
576        //            |          |
577        //            v          v
578        //            E --3.0--> F
579        let relations = vec![
580            Relation::from_names(project_id.clone(), "A", "B", "connects").with_weight(1.0),
581            Relation::from_names(project_id.clone(), "B", "C", "connects").with_weight(2.0),
582            Relation::from_names(project_id.clone(), "C", "D", "connects").with_weight(1.0),
583            Relation::from_names(project_id.clone(), "B", "E", "connects").with_weight(1.0),
584            Relation::from_names(project_id.clone(), "C", "F", "connects").with_weight(1.0),
585            Relation::from_names(project_id.clone(), "E", "F", "connects").with_weight(3.0),
586        ];
587
588        (entities, relations)
589    }
590
591    #[test]
592    fn test_bfs_shortest_path() {
593        let (entities, relations) = create_test_graph();
594        let query = TraversalQuery::new("A").find_path_to("D");
595        let result = TraversalEngine::execute(&query, &entities, &relations);
596
597        assert!(result.stats.path_found);
598        assert_eq!(result.paths.len(), 1);
599        assert_eq!(result.paths[0].nodes, vec!["A", "B", "C", "D"]);
600        assert_eq!(result.paths[0].length, 3);
601    }
602
603    #[test]
604    fn test_dijkstra_weighted_path() {
605        let (entities, relations) = create_test_graph();
606
607        // Find path from A to F
608        // Path via C: A->B->C->F = 1+2+1 = 4
609        // Path via E: A->B->E->F = 1+1+3 = 5
610        // Dijkstra should find A->B->C->F with weight 4
611        let query = TraversalQuery::new("A").find_path_to("F").weighted();
612        let result = TraversalEngine::execute(&query, &entities, &relations);
613
614        assert!(result.stats.path_found);
615        assert_eq!(result.paths[0].nodes, vec!["A", "B", "C", "F"]);
616        assert!((result.paths[0].total_weight - 4.0).abs() < 0.001);
617    }
618
619    #[test]
620    fn test_filtered_traversal() {
621        let (entities, relations) = create_test_graph();
622        let query = TraversalQuery::new("A").with_depth(2);
623        let result = TraversalEngine::execute(&query, &entities, &relations);
624
625        // At depth 2 from A: A(0), B(1), C(2), E(2)
626        assert!(result.visited_entities.contains(&"A".to_string()));
627        assert!(result.visited_entities.contains(&"B".to_string()));
628        assert!(result.visited_entities.contains(&"C".to_string()));
629        assert!(result.visited_entities.contains(&"E".to_string()));
630    }
631
632    #[test]
633    fn test_no_path_found() {
634        let project_id = ProjectId::new();
635        let mut entities = HashMap::new();
636        entities.insert("A".to_string(), Entity::new(project_id.clone(), "A", "node"));
637        entities.insert("B".to_string(), Entity::new(project_id.clone(), "B", "node"));
638        // No relations - disconnected graph
639
640        let query = TraversalQuery::new("A").find_path_to("B");
641        let result = TraversalEngine::execute(&query, &entities, &[]);
642
643        assert!(!result.stats.path_found);
644        assert!(result.paths.is_empty());
645    }
646
647    #[test]
648    fn test_direction_filtering() {
649        let (entities, relations) = create_test_graph();
650
651        // Outgoing from B should reach C, E
652        let outgoing = TraversalQuery::new("B")
653            .with_direction(Direction::Outgoing)
654            .with_depth(1);
655        let result = TraversalEngine::execute(&outgoing, &entities, &relations);
656        assert!(result.visited_entities.contains(&"C".to_string()));
657        assert!(result.visited_entities.contains(&"E".to_string()));
658        assert!(!result.visited_entities.contains(&"A".to_string()));
659
660        // Incoming to B should only reach A
661        let incoming = TraversalQuery::new("B")
662            .with_direction(Direction::Incoming)
663            .with_depth(1);
664        let result = TraversalEngine::execute(&incoming, &entities, &relations);
665        assert!(result.visited_entities.contains(&"A".to_string()));
666        assert!(!result.visited_entities.contains(&"C".to_string()));
667    }
668
669    #[test]
670    fn test_relation_type_filter() {
671        let project_id = ProjectId::new();
672        let mut entities = HashMap::new();
673        for name in ["A", "B", "C"] {
674            entities.insert(
675                name.to_string(),
676                Entity::new(project_id.clone(), name, "node"),
677            );
678        }
679
680        let relations = vec![
681            Relation::from_names(project_id.clone(), "A", "B", "works_at"),
682            Relation::from_names(project_id.clone(), "B", "C", "knows"),
683        ];
684
685        // Only follow "works_at" relations
686        let query = TraversalQuery::new("A")
687            .with_depth(2)
688            .filter_relation_types(vec!["works_at".to_string()]);
689        let result = TraversalEngine::execute(&query, &entities, &relations);
690
691        // Should reach B but not C
692        assert!(result.visited_entities.contains(&"B".to_string()));
693        assert!(!result.visited_entities.contains(&"C".to_string()));
694    }
695}