Skip to main content

graphos_adapters/plugins/algorithms/
traversal.rs

1//! Graph traversal algorithms: BFS and DFS.
2//!
3//! These algorithms use the visitor pattern to allow flexible customization
4//! of traversal behavior, including early termination and edge filtering.
5
6use std::collections::VecDeque;
7use std::sync::OnceLock;
8
9use graphos_common::types::{NodeId, Value};
10use graphos_common::utils::error::Result;
11use graphos_common::utils::hash::{FxHashMap, FxHashSet};
12use graphos_core::graph::Direction;
13use graphos_core::graph::lpg::LpgStore;
14
15use super::super::{AlgorithmResult, ParameterDef, ParameterType, Parameters};
16use super::traits::{Control, GraphAlgorithm, NodeValueResultBuilder, TraversalEvent};
17
18// ============================================================================
19// BFS Implementation
20// ============================================================================
21
22/// Performs breadth-first search from a starting node.
23///
24/// Returns the set of visited nodes in BFS order.
25///
26/// # Arguments
27///
28/// * `store` - The graph store to traverse
29/// * `start` - The starting node ID
30///
31/// # Returns
32///
33/// A vector of node IDs in the order they were discovered.
34pub fn bfs(store: &LpgStore, start: NodeId) -> Vec<NodeId> {
35    let mut visited = Vec::new();
36    bfs_with_visitor(store, start, |event| -> Control<()> {
37        if let TraversalEvent::Discover(node) = event {
38            visited.push(node);
39        }
40        Control::Continue
41    });
42    visited
43}
44
45/// Performs breadth-first search with a visitor callback.
46///
47/// The visitor is called for each traversal event, allowing custom
48/// behavior such as early termination or path recording.
49///
50/// # Arguments
51///
52/// * `store` - The graph store to traverse
53/// * `start` - The starting node ID
54/// * `visitor` - Callback function receiving traversal events
55///
56/// # Returns
57///
58/// `Some(B)` if the visitor returned `Control::Break(B)`, otherwise `None`.
59pub fn bfs_with_visitor<B, F>(store: &LpgStore, start: NodeId, mut visitor: F) -> Option<B>
60where
61    F: FnMut(TraversalEvent) -> Control<B>,
62{
63    let mut discovered: FxHashSet<NodeId> = FxHashSet::default();
64    let mut queue: VecDeque<NodeId> = VecDeque::new();
65
66    // Check if start node exists
67    if store.get_node(start).is_none() {
68        return None;
69    }
70
71    // Discover the start node
72    discovered.insert(start);
73    queue.push_back(start);
74
75    match visitor(TraversalEvent::Discover(start)) {
76        Control::Break(b) => return Some(b),
77        Control::Prune => {
78            // Prune means don't explore neighbors, but we still finish the node
79            match visitor(TraversalEvent::Finish(start)) {
80                Control::Break(b) => return Some(b),
81                _ => return None,
82            }
83        }
84        Control::Continue => {}
85    }
86
87    while let Some(node) = queue.pop_front() {
88        // Iterate over outgoing edges
89        for (neighbor, edge_id) in store.edges_from(node, Direction::Outgoing) {
90            if discovered.insert(neighbor) {
91                // Tree edge - neighbor not yet discovered
92                match visitor(TraversalEvent::TreeEdge {
93                    source: node,
94                    target: neighbor,
95                    edge: edge_id,
96                }) {
97                    Control::Break(b) => return Some(b),
98                    Control::Prune => continue, // Don't add to queue
99                    Control::Continue => {}
100                }
101
102                match visitor(TraversalEvent::Discover(neighbor)) {
103                    Control::Break(b) => return Some(b),
104                    Control::Prune => continue, // Don't explore neighbors
105                    Control::Continue => {}
106                }
107
108                queue.push_back(neighbor);
109            } else {
110                // Non-tree edge - neighbor already discovered
111                match visitor(TraversalEvent::NonTreeEdge {
112                    source: node,
113                    target: neighbor,
114                    edge: edge_id,
115                }) {
116                    Control::Break(b) => return Some(b),
117                    _ => {}
118                }
119            }
120        }
121
122        // Node processing complete
123        match visitor(TraversalEvent::Finish(node)) {
124            Control::Break(b) => return Some(b),
125            _ => {}
126        }
127    }
128
129    None
130}
131
132/// BFS layers - returns nodes grouped by their distance from the start.
133///
134/// # Arguments
135///
136/// * `store` - The graph store to traverse
137/// * `start` - The starting node ID
138///
139/// # Returns
140///
141/// A vector of vectors, where `result[i]` contains all nodes at distance `i` from start.
142pub fn bfs_layers(store: &LpgStore, start: NodeId) -> Vec<Vec<NodeId>> {
143    let mut layers: Vec<Vec<NodeId>> = Vec::new();
144    let mut discovered: FxHashSet<NodeId> = FxHashSet::default();
145    let mut current_layer: Vec<NodeId> = Vec::new();
146    let mut next_layer: Vec<NodeId> = Vec::new();
147
148    if store.get_node(start).is_none() {
149        return layers;
150    }
151
152    discovered.insert(start);
153    current_layer.push(start);
154
155    while !current_layer.is_empty() {
156        layers.push(current_layer.clone());
157
158        for &node in &current_layer {
159            for (neighbor, _) in store.edges_from(node, Direction::Outgoing) {
160                if discovered.insert(neighbor) {
161                    next_layer.push(neighbor);
162                }
163            }
164        }
165
166        current_layer.clear();
167        std::mem::swap(&mut current_layer, &mut next_layer);
168    }
169
170    layers
171}
172
173// ============================================================================
174// DFS Implementation
175// ============================================================================
176
177/// Node state during DFS traversal.
178#[derive(Clone, Copy, PartialEq, Eq)]
179enum NodeColor {
180    /// Not yet discovered
181    White,
182    /// Discovered but not finished (on stack)
183    Gray,
184    /// Finished processing
185    Black,
186}
187
188/// Performs depth-first search from a starting node.
189///
190/// Returns nodes in the order they were finished (post-order).
191///
192/// # Arguments
193///
194/// * `store` - The graph store to traverse
195/// * `start` - The starting node ID
196///
197/// # Returns
198///
199/// A vector of node IDs in post-order (finished order).
200pub fn dfs(store: &LpgStore, start: NodeId) -> Vec<NodeId> {
201    let mut finished = Vec::new();
202    dfs_with_visitor(store, start, |event| -> Control<()> {
203        if let TraversalEvent::Finish(node) = event {
204            finished.push(node);
205        }
206        Control::Continue
207    });
208    finished
209}
210
211/// Performs depth-first search with a visitor callback.
212///
213/// Uses an explicit stack to avoid stack overflow on deep graphs.
214///
215/// # Arguments
216///
217/// * `store` - The graph store to traverse
218/// * `start` - The starting node ID
219/// * `visitor` - Callback function receiving traversal events
220///
221/// # Returns
222///
223/// `Some(B)` if the visitor returned `Control::Break(B)`, otherwise `None`.
224pub fn dfs_with_visitor<B, F>(store: &LpgStore, start: NodeId, mut visitor: F) -> Option<B>
225where
226    F: FnMut(TraversalEvent) -> Control<B>,
227{
228    let mut color: FxHashMap<NodeId, NodeColor> = FxHashMap::default();
229
230    // Stack entries: (node, edge_iterator_index, is_first_visit)
231    // We use indices to track progress through neighbors
232    let mut stack: Vec<(NodeId, Vec<(NodeId, graphos_common::types::EdgeId)>, usize)> = Vec::new();
233
234    // Check if start node exists
235    if store.get_node(start).is_none() {
236        return None;
237    }
238
239    // Discover start node
240    color.insert(start, NodeColor::Gray);
241    match visitor(TraversalEvent::Discover(start)) {
242        Control::Break(b) => return Some(b),
243        Control::Prune => {
244            color.insert(start, NodeColor::Black);
245            match visitor(TraversalEvent::Finish(start)) {
246                Control::Break(b) => return Some(b),
247                _ => return None,
248            }
249        }
250        Control::Continue => {}
251    }
252
253    let neighbors: Vec<_> = store.edges_from(start, Direction::Outgoing).collect();
254    stack.push((start, neighbors, 0));
255
256    while let Some((node, neighbors, idx)) = stack.last_mut() {
257        if *idx >= neighbors.len() {
258            // All neighbors processed, finish this node
259            let node = *node;
260            stack.pop();
261            color.insert(node, NodeColor::Black);
262            match visitor(TraversalEvent::Finish(node)) {
263                Control::Break(b) => return Some(b),
264                _ => {}
265            }
266            continue;
267        }
268
269        let (neighbor, edge_id) = neighbors[*idx];
270        *idx += 1;
271
272        match color.get(&neighbor).copied().unwrap_or(NodeColor::White) {
273            NodeColor::White => {
274                // Tree edge - undiscovered node
275                match visitor(TraversalEvent::TreeEdge {
276                    source: *node,
277                    target: neighbor,
278                    edge: edge_id,
279                }) {
280                    Control::Break(b) => return Some(b),
281                    Control::Prune => continue,
282                    Control::Continue => {}
283                }
284
285                color.insert(neighbor, NodeColor::Gray);
286                match visitor(TraversalEvent::Discover(neighbor)) {
287                    Control::Break(b) => return Some(b),
288                    Control::Prune => {
289                        color.insert(neighbor, NodeColor::Black);
290                        match visitor(TraversalEvent::Finish(neighbor)) {
291                            Control::Break(b) => return Some(b),
292                            _ => {}
293                        }
294                        continue;
295                    }
296                    Control::Continue => {}
297                }
298
299                let neighbor_neighbors: Vec<_> =
300                    store.edges_from(neighbor, Direction::Outgoing).collect();
301                stack.push((neighbor, neighbor_neighbors, 0));
302            }
303            NodeColor::Gray => {
304                // Back edge - node is on the stack (ancestor)
305                match visitor(TraversalEvent::BackEdge {
306                    source: *node,
307                    target: neighbor,
308                    edge: edge_id,
309                }) {
310                    Control::Break(b) => return Some(b),
311                    _ => {}
312                }
313            }
314            NodeColor::Black => {
315                // Non-tree edge (cross/forward) - already finished
316                match visitor(TraversalEvent::NonTreeEdge {
317                    source: *node,
318                    target: neighbor,
319                    edge: edge_id,
320                }) {
321                    Control::Break(b) => return Some(b),
322                    _ => {}
323                }
324            }
325        }
326    }
327
328    None
329}
330
331/// Performs DFS on all nodes, visiting each connected component.
332///
333/// Returns nodes in reverse post-order (useful for topological sort).
334pub fn dfs_all(store: &LpgStore) -> Vec<NodeId> {
335    let mut finished = Vec::new();
336    let mut visited: FxHashSet<NodeId> = FxHashSet::default();
337
338    for node_id in store.node_ids() {
339        if visited.contains(&node_id) {
340            continue;
341        }
342
343        dfs_with_visitor(store, node_id, |event| -> Control<()> {
344            match event {
345                TraversalEvent::Discover(n) => {
346                    visited.insert(n);
347                }
348                TraversalEvent::Finish(n) => {
349                    finished.push(n);
350                }
351                _ => {}
352            }
353            Control::Continue
354        });
355    }
356
357    finished
358}
359
360// ============================================================================
361// Algorithm Wrappers for Plugin Registry
362// ============================================================================
363
364/// Static parameter definitions for BFS algorithm.
365static BFS_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
366
367fn bfs_params() -> &'static [ParameterDef] {
368    BFS_PARAMS.get_or_init(|| {
369        vec![ParameterDef {
370            name: "start".to_string(),
371            description: "Starting node ID".to_string(),
372            param_type: ParameterType::NodeId,
373            required: true,
374            default: None,
375        }]
376    })
377}
378
379/// BFS algorithm wrapper for the plugin registry.
380pub struct BfsAlgorithm;
381
382impl GraphAlgorithm for BfsAlgorithm {
383    fn name(&self) -> &str {
384        "bfs"
385    }
386
387    fn description(&self) -> &str {
388        "Breadth-first search traversal from a starting node"
389    }
390
391    fn parameters(&self) -> &[ParameterDef] {
392        bfs_params()
393    }
394
395    fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
396        let start_id = params.get_int("start").ok_or_else(|| {
397            graphos_common::utils::error::Error::InvalidValue(
398                "start parameter required".to_string(),
399            )
400        })?;
401
402        let start = NodeId::new(start_id as u64);
403        let layers = bfs_layers(store, start);
404
405        let mut result = AlgorithmResult::new(vec!["node_id".to_string(), "distance".to_string()]);
406
407        for (distance, layer) in layers.iter().enumerate() {
408            for &node in layer {
409                result.add_row(vec![
410                    Value::Int64(node.0 as i64),
411                    Value::Int64(distance as i64),
412                ]);
413            }
414        }
415
416        Ok(result)
417    }
418}
419
420/// Static parameter definitions for DFS algorithm.
421static DFS_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
422
423fn dfs_params() -> &'static [ParameterDef] {
424    DFS_PARAMS.get_or_init(|| {
425        vec![ParameterDef {
426            name: "start".to_string(),
427            description: "Starting node ID".to_string(),
428            param_type: ParameterType::NodeId,
429            required: true,
430            default: None,
431        }]
432    })
433}
434
435/// DFS algorithm wrapper for the plugin registry.
436pub struct DfsAlgorithm;
437
438impl GraphAlgorithm for DfsAlgorithm {
439    fn name(&self) -> &str {
440        "dfs"
441    }
442
443    fn description(&self) -> &str {
444        "Depth-first search traversal from a starting node"
445    }
446
447    fn parameters(&self) -> &[ParameterDef] {
448        dfs_params()
449    }
450
451    fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
452        let start_id = params.get_int("start").ok_or_else(|| {
453            graphos_common::utils::error::Error::InvalidValue(
454                "start parameter required".to_string(),
455            )
456        })?;
457
458        let start = NodeId::new(start_id as u64);
459        let finished = dfs(store, start);
460
461        let mut builder = NodeValueResultBuilder::with_capacity("finish_order", finished.len());
462        for (order, node) in finished.iter().enumerate() {
463            builder.push(*node, Value::Int64(order as i64));
464        }
465
466        Ok(builder.build())
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    fn create_test_graph() -> LpgStore {
475        let store = LpgStore::new();
476
477        // Create a simple graph:
478        //   0 -> 1 -> 2
479        //   |    |
480        //   v    v
481        //   3 -> 4
482        let n0 = store.create_node(&["Node"]);
483        let n1 = store.create_node(&["Node"]);
484        let n2 = store.create_node(&["Node"]);
485        let n3 = store.create_node(&["Node"]);
486        let n4 = store.create_node(&["Node"]);
487
488        store.create_edge(n0, n1, "EDGE");
489        store.create_edge(n0, n3, "EDGE");
490        store.create_edge(n1, n2, "EDGE");
491        store.create_edge(n1, n4, "EDGE");
492        store.create_edge(n3, n4, "EDGE");
493
494        store
495    }
496
497    #[test]
498    fn test_bfs_simple() {
499        let store = create_test_graph();
500        let visited = bfs(&store, NodeId::new(0));
501
502        assert!(!visited.is_empty());
503        assert_eq!(visited[0], NodeId::new(0));
504        // Node 0 should be first
505    }
506
507    #[test]
508    fn test_bfs_layers() {
509        let store = create_test_graph();
510        let layers = bfs_layers(&store, NodeId::new(0));
511
512        assert!(!layers.is_empty());
513        assert_eq!(layers[0], vec![NodeId::new(0)]);
514        // Distance 0: just the start node
515    }
516
517    #[test]
518    fn test_dfs_simple() {
519        let store = create_test_graph();
520        let finished = dfs(&store, NodeId::new(0));
521
522        assert!(!finished.is_empty());
523        // Post-order means leaves are finished first
524    }
525
526    #[test]
527    fn test_bfs_nonexistent_start() {
528        let store = LpgStore::new();
529        let visited = bfs(&store, NodeId::new(999));
530        assert!(visited.is_empty());
531    }
532
533    #[test]
534    fn test_dfs_nonexistent_start() {
535        let store = LpgStore::new();
536        let finished = dfs(&store, NodeId::new(999));
537        assert!(finished.is_empty());
538    }
539
540    #[test]
541    fn test_bfs_early_termination() {
542        let store = create_test_graph();
543        let target = NodeId::new(2);
544
545        let found = bfs_with_visitor(&store, NodeId::new(0), |event| {
546            if let TraversalEvent::Discover(node) = event {
547                if node == target {
548                    return Control::Break(true);
549                }
550            }
551            Control::Continue
552        });
553
554        assert_eq!(found, Some(true));
555    }
556}