grafeo_adapters/plugins/algorithms/
traversal.rs1use std::collections::VecDeque;
7use std::sync::OnceLock;
8
9use grafeo_common::types::{NodeId, Value};
10use grafeo_common::utils::error::Result;
11use grafeo_common::utils::hash::{FxHashMap, FxHashSet};
12use grafeo_core::graph::Direction;
13use grafeo_core::graph::lpg::LpgStore;
14
15use super::super::{AlgorithmResult, ParameterDef, ParameterType, Parameters};
16use super::traits::{Control, GraphAlgorithm, NodeValueResultBuilder, TraversalEvent};
17
18pub fn bfs(store: &LpgStore, start: NodeId) -> Vec<NodeId> {
35 let mut visited = Vec::new();
36 bfs_with_visitor(store, start, |event| -> Control<()> {
37 if let TraversalEvent::Discover(node) = event {
38 visited.push(node);
39 }
40 Control::Continue
41 });
42 visited
43}
44
45pub fn bfs_with_visitor<B, F>(store: &LpgStore, start: NodeId, mut visitor: F) -> Option<B>
60where
61 F: FnMut(TraversalEvent) -> Control<B>,
62{
63 let mut discovered: FxHashSet<NodeId> = FxHashSet::default();
64 let mut queue: VecDeque<NodeId> = VecDeque::new();
65
66 if store.get_node(start).is_none() {
68 return None;
69 }
70
71 discovered.insert(start);
73 queue.push_back(start);
74
75 match visitor(TraversalEvent::Discover(start)) {
76 Control::Break(b) => return Some(b),
77 Control::Prune => {
78 match visitor(TraversalEvent::Finish(start)) {
80 Control::Break(b) => return Some(b),
81 _ => return None,
82 }
83 }
84 Control::Continue => {}
85 }
86
87 while let Some(node) = queue.pop_front() {
88 for (neighbor, edge_id) in store.edges_from(node, Direction::Outgoing) {
90 if discovered.insert(neighbor) {
91 match visitor(TraversalEvent::TreeEdge {
93 source: node,
94 target: neighbor,
95 edge: edge_id,
96 }) {
97 Control::Break(b) => return Some(b),
98 Control::Prune => continue, Control::Continue => {}
100 }
101
102 match visitor(TraversalEvent::Discover(neighbor)) {
103 Control::Break(b) => return Some(b),
104 Control::Prune => continue, Control::Continue => {}
106 }
107
108 queue.push_back(neighbor);
109 } else {
110 match visitor(TraversalEvent::NonTreeEdge {
112 source: node,
113 target: neighbor,
114 edge: edge_id,
115 }) {
116 Control::Break(b) => return Some(b),
117 _ => {}
118 }
119 }
120 }
121
122 match visitor(TraversalEvent::Finish(node)) {
124 Control::Break(b) => return Some(b),
125 _ => {}
126 }
127 }
128
129 None
130}
131
132pub fn bfs_layers(store: &LpgStore, start: NodeId) -> Vec<Vec<NodeId>> {
143 let mut layers: Vec<Vec<NodeId>> = Vec::new();
144 let mut discovered: FxHashSet<NodeId> = FxHashSet::default();
145 let mut current_layer: Vec<NodeId> = Vec::new();
146 let mut next_layer: Vec<NodeId> = Vec::new();
147
148 if store.get_node(start).is_none() {
149 return layers;
150 }
151
152 discovered.insert(start);
153 current_layer.push(start);
154
155 while !current_layer.is_empty() {
156 layers.push(current_layer.clone());
157
158 for &node in ¤t_layer {
159 for (neighbor, _) in store.edges_from(node, Direction::Outgoing) {
160 if discovered.insert(neighbor) {
161 next_layer.push(neighbor);
162 }
163 }
164 }
165
166 current_layer.clear();
167 std::mem::swap(&mut current_layer, &mut next_layer);
168 }
169
170 layers
171}
172
173#[derive(Clone, Copy, PartialEq, Eq)]
179enum NodeColor {
180 White,
182 Gray,
184 Black,
186}
187
188pub fn dfs(store: &LpgStore, start: NodeId) -> Vec<NodeId> {
201 let mut finished = Vec::new();
202 dfs_with_visitor(store, start, |event| -> Control<()> {
203 if let TraversalEvent::Finish(node) = event {
204 finished.push(node);
205 }
206 Control::Continue
207 });
208 finished
209}
210
211pub fn dfs_with_visitor<B, F>(store: &LpgStore, start: NodeId, mut visitor: F) -> Option<B>
225where
226 F: FnMut(TraversalEvent) -> Control<B>,
227{
228 let mut color: FxHashMap<NodeId, NodeColor> = FxHashMap::default();
229
230 let mut stack: Vec<(NodeId, Vec<(NodeId, grafeo_common::types::EdgeId)>, usize)> = Vec::new();
233
234 if store.get_node(start).is_none() {
236 return None;
237 }
238
239 color.insert(start, NodeColor::Gray);
241 match visitor(TraversalEvent::Discover(start)) {
242 Control::Break(b) => return Some(b),
243 Control::Prune => {
244 color.insert(start, NodeColor::Black);
245 match visitor(TraversalEvent::Finish(start)) {
246 Control::Break(b) => return Some(b),
247 _ => return None,
248 }
249 }
250 Control::Continue => {}
251 }
252
253 let neighbors: Vec<_> = store.edges_from(start, Direction::Outgoing).collect();
254 stack.push((start, neighbors, 0));
255
256 while let Some((node, neighbors, idx)) = stack.last_mut() {
257 if *idx >= neighbors.len() {
258 let node = *node;
260 stack.pop();
261 color.insert(node, NodeColor::Black);
262 match visitor(TraversalEvent::Finish(node)) {
263 Control::Break(b) => return Some(b),
264 _ => {}
265 }
266 continue;
267 }
268
269 let (neighbor, edge_id) = neighbors[*idx];
270 *idx += 1;
271
272 match color.get(&neighbor).copied().unwrap_or(NodeColor::White) {
273 NodeColor::White => {
274 match visitor(TraversalEvent::TreeEdge {
276 source: *node,
277 target: neighbor,
278 edge: edge_id,
279 }) {
280 Control::Break(b) => return Some(b),
281 Control::Prune => continue,
282 Control::Continue => {}
283 }
284
285 color.insert(neighbor, NodeColor::Gray);
286 match visitor(TraversalEvent::Discover(neighbor)) {
287 Control::Break(b) => return Some(b),
288 Control::Prune => {
289 color.insert(neighbor, NodeColor::Black);
290 match visitor(TraversalEvent::Finish(neighbor)) {
291 Control::Break(b) => return Some(b),
292 _ => {}
293 }
294 continue;
295 }
296 Control::Continue => {}
297 }
298
299 let neighbor_neighbors: Vec<_> =
300 store.edges_from(neighbor, Direction::Outgoing).collect();
301 stack.push((neighbor, neighbor_neighbors, 0));
302 }
303 NodeColor::Gray => {
304 match visitor(TraversalEvent::BackEdge {
306 source: *node,
307 target: neighbor,
308 edge: edge_id,
309 }) {
310 Control::Break(b) => return Some(b),
311 _ => {}
312 }
313 }
314 NodeColor::Black => {
315 match visitor(TraversalEvent::NonTreeEdge {
317 source: *node,
318 target: neighbor,
319 edge: edge_id,
320 }) {
321 Control::Break(b) => return Some(b),
322 _ => {}
323 }
324 }
325 }
326 }
327
328 None
329}
330
331pub fn dfs_all(store: &LpgStore) -> Vec<NodeId> {
335 let mut finished = Vec::new();
336 let mut visited: FxHashSet<NodeId> = FxHashSet::default();
337
338 for node_id in store.node_ids() {
339 if visited.contains(&node_id) {
340 continue;
341 }
342
343 dfs_with_visitor(store, node_id, |event| -> Control<()> {
344 match event {
345 TraversalEvent::Discover(n) => {
346 visited.insert(n);
347 }
348 TraversalEvent::Finish(n) => {
349 finished.push(n);
350 }
351 _ => {}
352 }
353 Control::Continue
354 });
355 }
356
357 finished
358}
359
360static BFS_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
366
367fn bfs_params() -> &'static [ParameterDef] {
368 BFS_PARAMS.get_or_init(|| {
369 vec![ParameterDef {
370 name: "start".to_string(),
371 description: "Starting node ID".to_string(),
372 param_type: ParameterType::NodeId,
373 required: true,
374 default: None,
375 }]
376 })
377}
378
379pub struct BfsAlgorithm;
381
382impl GraphAlgorithm for BfsAlgorithm {
383 fn name(&self) -> &str {
384 "bfs"
385 }
386
387 fn description(&self) -> &str {
388 "Breadth-first search traversal from a starting node"
389 }
390
391 fn parameters(&self) -> &[ParameterDef] {
392 bfs_params()
393 }
394
395 fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
396 let start_id = params.get_int("start").ok_or_else(|| {
397 grafeo_common::utils::error::Error::InvalidValue("start parameter required".to_string())
398 })?;
399
400 let start = NodeId::new(start_id as u64);
401 let layers = bfs_layers(store, start);
402
403 let mut result = AlgorithmResult::new(vec!["node_id".to_string(), "distance".to_string()]);
404
405 for (distance, layer) in layers.iter().enumerate() {
406 for &node in layer {
407 result.add_row(vec![
408 Value::Int64(node.0 as i64),
409 Value::Int64(distance as i64),
410 ]);
411 }
412 }
413
414 Ok(result)
415 }
416}
417
418static DFS_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
420
421fn dfs_params() -> &'static [ParameterDef] {
422 DFS_PARAMS.get_or_init(|| {
423 vec![ParameterDef {
424 name: "start".to_string(),
425 description: "Starting node ID".to_string(),
426 param_type: ParameterType::NodeId,
427 required: true,
428 default: None,
429 }]
430 })
431}
432
433pub struct DfsAlgorithm;
435
436impl GraphAlgorithm for DfsAlgorithm {
437 fn name(&self) -> &str {
438 "dfs"
439 }
440
441 fn description(&self) -> &str {
442 "Depth-first search traversal from a starting node"
443 }
444
445 fn parameters(&self) -> &[ParameterDef] {
446 dfs_params()
447 }
448
449 fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
450 let start_id = params.get_int("start").ok_or_else(|| {
451 grafeo_common::utils::error::Error::InvalidValue("start parameter required".to_string())
452 })?;
453
454 let start = NodeId::new(start_id as u64);
455 let finished = dfs(store, start);
456
457 let mut builder = NodeValueResultBuilder::with_capacity("finish_order", finished.len());
458 for (order, node) in finished.iter().enumerate() {
459 builder.push(*node, Value::Int64(order as i64));
460 }
461
462 Ok(builder.build())
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 fn create_test_graph() -> LpgStore {
471 let store = LpgStore::new();
472
473 let n0 = store.create_node(&["Node"]);
479 let n1 = store.create_node(&["Node"]);
480 let n2 = store.create_node(&["Node"]);
481 let n3 = store.create_node(&["Node"]);
482 let n4 = store.create_node(&["Node"]);
483
484 store.create_edge(n0, n1, "EDGE");
485 store.create_edge(n0, n3, "EDGE");
486 store.create_edge(n1, n2, "EDGE");
487 store.create_edge(n1, n4, "EDGE");
488 store.create_edge(n3, n4, "EDGE");
489
490 store
491 }
492
493 #[test]
494 fn test_bfs_simple() {
495 let store = create_test_graph();
496 let visited = bfs(&store, NodeId::new(0));
497
498 assert!(!visited.is_empty());
499 assert_eq!(visited[0], NodeId::new(0));
500 }
502
503 #[test]
504 fn test_bfs_layers() {
505 let store = create_test_graph();
506 let layers = bfs_layers(&store, NodeId::new(0));
507
508 assert!(!layers.is_empty());
509 assert_eq!(layers[0], vec![NodeId::new(0)]);
510 }
512
513 #[test]
514 fn test_dfs_simple() {
515 let store = create_test_graph();
516 let finished = dfs(&store, NodeId::new(0));
517
518 assert!(!finished.is_empty());
519 }
521
522 #[test]
523 fn test_bfs_nonexistent_start() {
524 let store = LpgStore::new();
525 let visited = bfs(&store, NodeId::new(999));
526 assert!(visited.is_empty());
527 }
528
529 #[test]
530 fn test_dfs_nonexistent_start() {
531 let store = LpgStore::new();
532 let finished = dfs(&store, NodeId::new(999));
533 assert!(finished.is_empty());
534 }
535
536 #[test]
537 fn test_bfs_early_termination() {
538 let store = create_test_graph();
539 let target = NodeId::new(2);
540
541 let found = bfs_with_visitor(&store, NodeId::new(0), |event| {
542 if let TraversalEvent::Discover(node) = event {
543 if node == target {
544 return Control::Break(true);
545 }
546 }
547 Control::Continue
548 });
549
550 assert_eq!(found, Some(true));
551 }
552}