Skip to main content

manifoldb_graph/traversal/
shortest_path.rs

1//! Shortest path finding algorithms.
2//!
3//! This module provides BFS-based shortest path algorithms for finding
4//! the shortest path between two nodes in a graph.
5
6// Allow expect - the invariant is guaranteed by the data structure
7#![allow(clippy::expect_used)]
8
9use std::collections::{HashMap, HashSet, VecDeque};
10
11use manifoldb_core::{EdgeId, EdgeType, EntityId};
12use manifoldb_storage::Transaction;
13
14use super::{Direction, TraversalFilter};
15use crate::index::AdjacencyIndex;
16use crate::store::{EdgeStore, GraphResult};
17
18/// A path through the graph.
19///
20/// Represents a sequence of nodes and edges from a source to a target.
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct PathResult {
23    /// The nodes in the path, from source to target.
24    pub nodes: Vec<EntityId>,
25    /// The edges connecting the nodes.
26    /// Length is `nodes.len() - 1`.
27    pub edges: Vec<EdgeId>,
28    /// The total length of the path (number of edges).
29    pub length: usize,
30}
31
32impl PathResult {
33    /// Create a new path result.
34    fn new(nodes: Vec<EntityId>, edges: Vec<EdgeId>) -> Self {
35        let length = edges.len();
36        Self { nodes, edges, length }
37    }
38
39    /// Create a path for a single node (source == target).
40    fn single_node(node: EntityId) -> Self {
41        Self { nodes: vec![node], edges: Vec::new(), length: 0 }
42    }
43
44    /// Get the source node.
45    pub fn source(&self) -> EntityId {
46        self.nodes[0]
47    }
48
49    /// Get the target node.
50    pub fn target(&self) -> EntityId {
51        *self.nodes.last().expect("path has at least one node")
52    }
53
54    /// Check if the path is empty (source == target).
55    pub const fn is_empty(&self) -> bool {
56        self.length == 0
57    }
58}
59
60/// BFS-based shortest path finder.
61///
62/// Finds the shortest unweighted path between two nodes using
63/// breadth-first search.
64///
65/// # Features
66///
67/// - Unweighted shortest path (all edges have weight 1)
68/// - Configurable traversal direction
69/// - Optional edge type filtering
70/// - Maximum depth limit for bounded searches
71///
72/// # Example
73///
74/// ```ignore
75/// // Find shortest path between two users
76/// let path = ShortestPath::find(&tx, user_a, user_b, Direction::Both)?;
77///
78/// if let Some(result) = path {
79///     println!("Path length: {}", result.length);
80///     println!("Path: {:?}", result.nodes);
81/// }
82///
83/// // Find path following only FRIEND edges
84/// let path = ShortestPath::new(user_a, user_b, Direction::Both)
85///     .with_edge_type("FRIEND")
86///     .find(&tx)?;
87/// ```
88pub struct ShortestPath {
89    /// Source node.
90    source: EntityId,
91    /// Target node.
92    target: EntityId,
93    /// Traversal direction.
94    direction: Direction,
95    /// Maximum path length to search.
96    max_depth: Option<usize>,
97    /// Filter for traversal.
98    filter: TraversalFilter,
99}
100
101impl ShortestPath {
102    /// Create a new shortest path finder.
103    ///
104    /// # Arguments
105    ///
106    /// * `source` - The starting node
107    /// * `target` - The destination node
108    /// * `direction` - Which direction to traverse edges
109    pub fn new(source: EntityId, target: EntityId, direction: Direction) -> Self {
110        Self { source, target, direction, max_depth: None, filter: TraversalFilter::new() }
111    }
112
113    /// Set the maximum path length to search.
114    ///
115    /// If no path of this length or shorter is found, returns None.
116    pub const fn with_max_depth(mut self, max_depth: usize) -> Self {
117        self.max_depth = Some(max_depth);
118        self
119    }
120
121    /// Filter to only traverse edges of the specified type.
122    pub fn with_edge_type(mut self, edge_type: impl Into<EdgeType>) -> Self {
123        self.filter = self.filter.with_edge_type(edge_type);
124        self
125    }
126
127    /// Filter to only traverse edges of the specified types.
128    pub fn with_edge_types(mut self, edge_types: impl IntoIterator<Item = EdgeType>) -> Self {
129        self.filter = self.filter.with_edge_types(edge_types);
130        self
131    }
132
133    /// Exclude specific nodes from the path.
134    pub fn exclude_nodes(mut self, nodes: impl IntoIterator<Item = EntityId>) -> Self {
135        self.filter = self.filter.exclude_nodes(nodes);
136        self
137    }
138
139    /// Find the shortest path.
140    ///
141    /// # Returns
142    ///
143    /// - `Some(PathResult)` if a path exists
144    /// - `None` if no path exists within the constraints
145    pub fn find<T: Transaction>(self, tx: &T) -> GraphResult<Option<PathResult>> {
146        // Handle same source and target
147        if self.source == self.target {
148            return Ok(Some(PathResult::single_node(self.source)));
149        }
150
151        // BFS from source
152        let mut visited: HashSet<EntityId> = HashSet::new();
153        // Maps each node to (previous_node, edge_used)
154        let mut parent: HashMap<EntityId, (EntityId, EdgeId)> = HashMap::new();
155        let mut queue: VecDeque<(EntityId, usize)> = VecDeque::new();
156
157        visited.insert(self.source);
158        queue.push_back((self.source, 0));
159
160        while let Some((current, depth)) = queue.pop_front() {
161            // Check depth limit
162            if let Some(max) = self.max_depth {
163                if depth >= max {
164                    continue;
165                }
166            }
167
168            // Get neighbors
169            let neighbors = self.get_neighbors(tx, current)?;
170
171            for (neighbor, edge_id) in neighbors {
172                if visited.contains(&neighbor) {
173                    continue;
174                }
175
176                // Check node filter
177                if neighbor != self.target && !self.filter.should_include_node(neighbor) {
178                    continue;
179                }
180
181                visited.insert(neighbor);
182                parent.insert(neighbor, (current, edge_id));
183
184                // Found target
185                if neighbor == self.target {
186                    return Ok(Some(self.reconstruct_path(&parent)));
187                }
188
189                queue.push_back((neighbor, depth + 1));
190            }
191        }
192
193        Ok(None)
194    }
195
196    /// Get neighbors considering direction and edge type filters.
197    fn get_neighbors<T: Transaction>(
198        &self,
199        tx: &T,
200        node: EntityId,
201    ) -> GraphResult<Vec<(EntityId, EdgeId)>> {
202        let mut neighbors = Vec::new();
203
204        // Get outgoing neighbors
205        if self.direction.includes_outgoing() {
206            self.add_neighbors_outgoing(tx, node, &mut neighbors)?;
207        }
208
209        // Get incoming neighbors
210        if self.direction.includes_incoming() {
211            self.add_neighbors_incoming(tx, node, &mut neighbors)?;
212        }
213
214        Ok(neighbors)
215    }
216
217    fn add_neighbors_outgoing<T: Transaction>(
218        &self,
219        tx: &T,
220        node: EntityId,
221        neighbors: &mut Vec<(EntityId, EdgeId)>,
222    ) -> GraphResult<()> {
223        match &self.filter.edge_types {
224            Some(types) => {
225                for edge_type in types {
226                    AdjacencyIndex::for_each_outgoing_by_type(tx, node, edge_type, |edge_id| {
227                        if let Some(edge) = EdgeStore::get(tx, edge_id)? {
228                            neighbors.push((edge.target, edge_id));
229                        }
230                        Ok(true)
231                    })?;
232                }
233            }
234            None => {
235                AdjacencyIndex::for_each_outgoing(tx, node, |edge_id| {
236                    if let Some(edge) = EdgeStore::get(tx, edge_id)? {
237                        neighbors.push((edge.target, edge_id));
238                    }
239                    Ok(true)
240                })?;
241            }
242        }
243        Ok(())
244    }
245
246    fn add_neighbors_incoming<T: Transaction>(
247        &self,
248        tx: &T,
249        node: EntityId,
250        neighbors: &mut Vec<(EntityId, EdgeId)>,
251    ) -> GraphResult<()> {
252        match &self.filter.edge_types {
253            Some(types) => {
254                for edge_type in types {
255                    AdjacencyIndex::for_each_incoming_by_type(tx, node, edge_type, |edge_id| {
256                        if let Some(edge) = EdgeStore::get(tx, edge_id)? {
257                            neighbors.push((edge.source, edge_id));
258                        }
259                        Ok(true)
260                    })?;
261                }
262            }
263            None => {
264                AdjacencyIndex::for_each_incoming(tx, node, |edge_id| {
265                    if let Some(edge) = EdgeStore::get(tx, edge_id)? {
266                        neighbors.push((edge.source, edge_id));
267                    }
268                    Ok(true)
269                })?;
270            }
271        }
272        Ok(())
273    }
274
275    /// Reconstruct the path from source to target using the parent map.
276    fn reconstruct_path(&self, parent: &HashMap<EntityId, (EntityId, EdgeId)>) -> PathResult {
277        let mut nodes = Vec::new();
278        let mut edges = Vec::new();
279        let mut current = self.target;
280
281        // Trace back from target to source
282        while let Some(&(prev, edge_id)) = parent.get(&current) {
283            nodes.push(current);
284            edges.push(edge_id);
285            current = prev;
286        }
287
288        // Add source
289        nodes.push(self.source);
290
291        // Reverse to get source -> target order
292        nodes.reverse();
293        edges.reverse();
294
295        PathResult::new(nodes, edges)
296    }
297
298    /// Convenience method: find shortest path with default settings.
299    ///
300    /// # Arguments
301    ///
302    /// * `tx` - The transaction to use
303    /// * `source` - The starting node
304    /// * `target` - The destination node
305    /// * `direction` - Which direction to traverse
306    pub fn find_path<T: Transaction>(
307        tx: &T,
308        source: EntityId,
309        target: EntityId,
310        direction: Direction,
311    ) -> GraphResult<Option<PathResult>> {
312        Self::new(source, target, direction).find(tx)
313    }
314
315    /// Check if a path exists between two nodes.
316    ///
317    /// This is more efficient than `find()` when you only need to know
318    /// if a path exists, not what it is.
319    pub fn exists<T: Transaction>(self, tx: &T) -> GraphResult<bool> {
320        // Handle same source and target
321        if self.source == self.target {
322            return Ok(true);
323        }
324
325        let mut visited: HashSet<EntityId> = HashSet::new();
326        let mut queue: VecDeque<(EntityId, usize)> = VecDeque::new();
327
328        visited.insert(self.source);
329        queue.push_back((self.source, 0));
330
331        while let Some((current, depth)) = queue.pop_front() {
332            if let Some(max) = self.max_depth {
333                if depth >= max {
334                    continue;
335                }
336            }
337
338            let neighbors = self.get_neighbors(tx, current)?;
339
340            for (neighbor, _) in neighbors {
341                if neighbor == self.target {
342                    return Ok(true);
343                }
344
345                if visited.contains(&neighbor) {
346                    continue;
347                }
348
349                if !self.filter.should_include_node(neighbor) {
350                    continue;
351                }
352
353                visited.insert(neighbor);
354                queue.push_back((neighbor, depth + 1));
355            }
356        }
357
358        Ok(false)
359    }
360
361    /// Find the distance between two nodes (path length).
362    ///
363    /// This is more efficient than `find()` when you only need the distance.
364    pub fn distance<T: Transaction>(self, tx: &T) -> GraphResult<Option<usize>> {
365        if self.source == self.target {
366            return Ok(Some(0));
367        }
368
369        let mut visited: HashSet<EntityId> = HashSet::new();
370        let mut queue: VecDeque<(EntityId, usize)> = VecDeque::new();
371
372        visited.insert(self.source);
373        queue.push_back((self.source, 0));
374
375        while let Some((current, depth)) = queue.pop_front() {
376            if let Some(max) = self.max_depth {
377                if depth >= max {
378                    continue;
379                }
380            }
381
382            let neighbors = self.get_neighbors(tx, current)?;
383
384            for (neighbor, _) in neighbors {
385                if neighbor == self.target {
386                    return Ok(Some(depth + 1));
387                }
388
389                if visited.contains(&neighbor) {
390                    continue;
391                }
392
393                if !self.filter.should_include_node(neighbor) {
394                    continue;
395                }
396
397                visited.insert(neighbor);
398                queue.push_back((neighbor, depth + 1));
399            }
400        }
401
402        Ok(None)
403    }
404}
405
406/// Find all shortest paths between two nodes.
407///
408/// When multiple paths of the same shortest length exist, this
409/// function returns all of them.
410pub struct AllShortestPaths {
411    /// Source node.
412    source: EntityId,
413    /// Target node.
414    target: EntityId,
415    /// Traversal direction.
416    direction: Direction,
417    /// Maximum path length to search.
418    max_depth: Option<usize>,
419    /// Filter for traversal.
420    filter: TraversalFilter,
421}
422
423impl AllShortestPaths {
424    /// Create a new finder for all shortest paths.
425    pub fn new(source: EntityId, target: EntityId, direction: Direction) -> Self {
426        Self { source, target, direction, max_depth: None, filter: TraversalFilter::new() }
427    }
428
429    /// Set the maximum path length to search.
430    pub const fn with_max_depth(mut self, max_depth: usize) -> Self {
431        self.max_depth = Some(max_depth);
432        self
433    }
434
435    /// Filter to only traverse edges of the specified type.
436    pub fn with_edge_type(mut self, edge_type: impl Into<EdgeType>) -> Self {
437        self.filter = self.filter.with_edge_type(edge_type);
438        self
439    }
440
441    /// Find all shortest paths.
442    ///
443    /// # Returns
444    ///
445    /// A vector of all paths with the shortest length.
446    /// Empty if no path exists.
447    pub fn find<T: Transaction>(self, tx: &T) -> GraphResult<Vec<PathResult>> {
448        if self.source == self.target {
449            return Ok(vec![PathResult::single_node(self.source)]);
450        }
451
452        // BFS with tracking all parents at shortest distance
453        let mut visited_at_depth: HashMap<EntityId, usize> = HashMap::new();
454        // Maps node -> list of (parent, edge)
455        let mut parents: HashMap<EntityId, Vec<(EntityId, EdgeId)>> = HashMap::new();
456        let mut queue: VecDeque<(EntityId, usize)> = VecDeque::new();
457        let mut target_depth: Option<usize> = None;
458
459        visited_at_depth.insert(self.source, 0);
460        queue.push_back((self.source, 0));
461
462        while let Some((current, depth)) = queue.pop_front() {
463            // If we've found target and current depth exceeds target depth, stop
464            if let Some(td) = target_depth {
465                if depth >= td {
466                    continue;
467                }
468            }
469
470            // Check max depth
471            if let Some(max) = self.max_depth {
472                if depth >= max {
473                    continue;
474                }
475            }
476
477            let neighbors = self.get_neighbors(tx, current)?;
478
479            for (neighbor, edge_id) in neighbors {
480                let next_depth = depth + 1;
481
482                // Check if we've seen this node before
483                if let Some(&prev_depth) = visited_at_depth.get(&neighbor) {
484                    // Only add parent if at same depth (for multiple shortest paths)
485                    if prev_depth == next_depth {
486                        parents.entry(neighbor).or_default().push((current, edge_id));
487                    }
488                    continue;
489                }
490
491                // Check node filter
492                if neighbor != self.target && !self.filter.should_include_node(neighbor) {
493                    continue;
494                }
495
496                visited_at_depth.insert(neighbor, next_depth);
497                parents.entry(neighbor).or_default().push((current, edge_id));
498
499                if neighbor == self.target {
500                    target_depth = Some(next_depth);
501                } else {
502                    queue.push_back((neighbor, next_depth));
503                }
504            }
505        }
506
507        // Reconstruct all paths
508        if target_depth.is_some() {
509            Ok(self.reconstruct_all_paths(&parents))
510        } else {
511            Ok(Vec::new())
512        }
513    }
514
515    fn get_neighbors<T: Transaction>(
516        &self,
517        tx: &T,
518        node: EntityId,
519    ) -> GraphResult<Vec<(EntityId, EdgeId)>> {
520        let mut neighbors = Vec::new();
521
522        if self.direction.includes_outgoing() {
523            match &self.filter.edge_types {
524                Some(types) => {
525                    for edge_type in types {
526                        AdjacencyIndex::for_each_outgoing_by_type(
527                            tx,
528                            node,
529                            edge_type,
530                            |edge_id| {
531                                if let Some(edge) = EdgeStore::get(tx, edge_id)? {
532                                    neighbors.push((edge.target, edge_id));
533                                }
534                                Ok(true)
535                            },
536                        )?;
537                    }
538                }
539                None => {
540                    AdjacencyIndex::for_each_outgoing(tx, node, |edge_id| {
541                        if let Some(edge) = EdgeStore::get(tx, edge_id)? {
542                            neighbors.push((edge.target, edge_id));
543                        }
544                        Ok(true)
545                    })?;
546                }
547            }
548        }
549
550        if self.direction.includes_incoming() {
551            match &self.filter.edge_types {
552                Some(types) => {
553                    for edge_type in types {
554                        AdjacencyIndex::for_each_incoming_by_type(
555                            tx,
556                            node,
557                            edge_type,
558                            |edge_id| {
559                                if let Some(edge) = EdgeStore::get(tx, edge_id)? {
560                                    neighbors.push((edge.source, edge_id));
561                                }
562                                Ok(true)
563                            },
564                        )?;
565                    }
566                }
567                None => {
568                    AdjacencyIndex::for_each_incoming(tx, node, |edge_id| {
569                        if let Some(edge) = EdgeStore::get(tx, edge_id)? {
570                            neighbors.push((edge.source, edge_id));
571                        }
572                        Ok(true)
573                    })?;
574                }
575            }
576        }
577
578        Ok(neighbors)
579    }
580
581    fn reconstruct_all_paths(
582        &self,
583        parents: &HashMap<EntityId, Vec<(EntityId, EdgeId)>>,
584    ) -> Vec<PathResult> {
585        let mut paths = Vec::new();
586        let mut current_path_nodes = vec![self.target];
587        let mut current_path_edges = Vec::new();
588
589        self.backtrack_paths(
590            parents,
591            self.target,
592            &mut current_path_nodes,
593            &mut current_path_edges,
594            &mut paths,
595        );
596
597        paths
598    }
599
600    fn backtrack_paths(
601        &self,
602        parents: &HashMap<EntityId, Vec<(EntityId, EdgeId)>>,
603        current: EntityId,
604        path_nodes: &mut Vec<EntityId>,
605        path_edges: &mut Vec<EdgeId>,
606        results: &mut Vec<PathResult>,
607    ) {
608        if current == self.source {
609            // We've reached the source - build path by iterating in reverse
610            // without cloning and then reversing
611            let path_len = path_edges.len();
612            let mut nodes = Vec::with_capacity(path_nodes.len());
613            let mut edges = Vec::with_capacity(path_len);
614
615            // Iterate in reverse order to build forward path
616            for &node in path_nodes.iter().rev() {
617                nodes.push(node);
618            }
619            for &edge in path_edges.iter().rev() {
620                edges.push(edge);
621            }
622
623            results.push(PathResult::new(nodes, edges));
624            return;
625        }
626
627        if let Some(parent_list) = parents.get(&current) {
628            for &(parent, edge_id) in parent_list {
629                path_nodes.push(parent);
630                path_edges.push(edge_id);
631
632                self.backtrack_paths(parents, parent, path_nodes, path_edges, results);
633
634                path_nodes.pop();
635                path_edges.pop();
636            }
637        }
638    }
639}
640
641#[cfg(test)]
642mod tests {
643    use super::*;
644
645    #[test]
646    fn path_result_single_node() {
647        let path = PathResult::single_node(EntityId::new(1));
648        assert_eq!(path.source(), EntityId::new(1));
649        assert_eq!(path.target(), EntityId::new(1));
650        assert_eq!(path.length, 0);
651        assert!(path.is_empty());
652    }
653
654    #[test]
655    fn path_result_multi_node() {
656        let nodes = vec![EntityId::new(1), EntityId::new(2), EntityId::new(3)];
657        let edges = vec![EdgeId::new(10), EdgeId::new(20)];
658        let path = PathResult::new(nodes, edges);
659
660        assert_eq!(path.source(), EntityId::new(1));
661        assert_eq!(path.target(), EntityId::new(3));
662        assert_eq!(path.length, 2);
663        assert!(!path.is_empty());
664    }
665
666    #[test]
667    fn shortest_path_builder() {
668        let sp = ShortestPath::new(EntityId::new(1), EntityId::new(10), Direction::Both)
669            .with_max_depth(5)
670            .with_edge_type("FRIEND");
671
672        assert_eq!(sp.source, EntityId::new(1));
673        assert_eq!(sp.target, EntityId::new(10));
674        assert_eq!(sp.direction, Direction::Both);
675        assert_eq!(sp.max_depth, Some(5));
676    }
677
678    #[test]
679    fn all_shortest_paths_builder() {
680        let asp = AllShortestPaths::new(EntityId::new(1), EntityId::new(10), Direction::Outgoing)
681            .with_max_depth(3)
682            .with_edge_type("FOLLOWS");
683
684        assert_eq!(asp.source, EntityId::new(1));
685        assert_eq!(asp.target, EntityId::new(10));
686        assert_eq!(asp.max_depth, Some(3));
687    }
688}