Skip to main content

sparrowdb_execution/
parallel_bfs.rs

1//! Parallel BFS traversal primitives using Rayon.
2//!
3//! Two primitives:
4//! - `parallel_reachability_bfs`: existential queries (RETURN DISTINCT, shortestPath)
5//! - `parallel_path_enumeration_dfs`: openCypher enumerative *M..N (all simple paths)
6
7use std::collections::HashSet;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, Mutex};
10
11use rayon::prelude::*;
12
13/// Result of a reachability BFS: set of reachable node IDs.
14pub struct ReachabilityResult {
15    pub visited: HashSet<u64>,
16}
17
18/// Parallel BFS for existential/reachability queries.
19///
20/// Uses a global visited set (mutex-protected) shared across Rayon tasks.
21/// Correct for: RETURN DISTINCT queries, shortestPath(), existential checks.
22/// NOT correct for: enumerative path listing (use `parallel_path_enumeration_dfs` instead).
23///
24/// # Arguments
25/// - `start_nodes`: seed nodes for the BFS frontier
26/// - `min_hops`: minimum hop depth (inclusive) — nodes reachable before this depth are
27///   still tracked in `visited` but the semantics of which hops "count" is caller's concern
28/// - `max_hops`: maximum hop depth (inclusive); BFS stops after this many expansions
29/// - `get_neighbors`: closure returning outgoing neighbor IDs for a given node ID
30pub fn parallel_reachability_bfs<F>(
31    start_nodes: Vec<u64>,
32    _min_hops: usize,
33    max_hops: usize,
34    get_neighbors: F,
35) -> ReachabilityResult
36where
37    F: Fn(u64) -> Vec<u64> + Send + Sync,
38{
39    let visited = Arc::new(Mutex::new(
40        start_nodes.iter().copied().collect::<HashSet<_>>(),
41    ));
42
43    let mut frontier = start_nodes;
44    let mut hop = 0usize;
45
46    while !frontier.is_empty() && hop < max_hops {
47        // Parallel expand frontier — each node expands independently
48        let next_nodes: Vec<u64> = frontier
49            .par_iter()
50            .flat_map(|&node| get_neighbors(node))
51            .collect();
52
53        // Deduplicate against visited (serial — single mutex lock per frontier wave)
54        let mut v = visited.lock().unwrap();
55        frontier = next_nodes.into_iter().filter(|n| v.insert(*n)).collect();
56        hop += 1;
57    }
58
59    let v = visited
60        .lock()
61        .expect("visited mutex should not be poisoned")
62        .clone();
63    ReachabilityResult { visited: v }
64}
65
66/// Per-path DFS state for enumeration.
67#[derive(Clone)]
68struct PathState {
69    path: Vec<u64>,
70    path_set: HashSet<u64>, // for O(1) cycle check
71}
72
73/// Shared context threaded through recursive DFS calls; avoids exceeding the
74/// clippy `too_many_arguments` limit on `dfs_enumerate`.
75struct DfsContext<'a, F> {
76    min_hops: usize,
77    max_hops: usize,
78    limit: usize,
79    get_neighbors: &'a F,
80    results: &'a Arc<Mutex<Vec<Vec<u64>>>>,
81    done: &'a Arc<AtomicBool>,
82}
83
84/// Parallel path enumeration DFS.
85///
86/// Each Rayon task has its own path state (no shared visited set), preserving
87/// openCypher simple-path semantics: a node may appear on multiple distinct paths
88/// but not more than once within a single path.
89///
90/// Correct for: `MATCH (a)-[*M..N]->(b)` enumerative semantics where the diamond
91/// graph A→B→D, A→C→D should yield D twice (two distinct simple paths).
92///
93/// WARNING: can produce exponential results on dense graphs. Caller should pass a
94/// reasonable `limit` and enforce `LIMIT` in the Cypher query.
95///
96/// # Arguments
97/// - `start_nodes`: seed nodes for path exploration
98/// - `min_hops`: minimum path length (paths shorter than this are not emitted)
99/// - `max_hops`: maximum path length (DFS does not recurse deeper)
100/// - `limit`: early-termination cap on total results collected; `0` returns immediately
101/// - `get_neighbors`: closure returning outgoing neighbor IDs for a given node ID
102pub fn parallel_path_enumeration_dfs<F>(
103    start_nodes: Vec<u64>,
104    min_hops: usize,
105    max_hops: usize,
106    limit: usize,
107    get_neighbors: F,
108) -> Vec<Vec<u64>>
109where
110    F: Fn(u64) -> Vec<u64> + Send + Sync,
111{
112    if limit == 0 {
113        return Vec::new();
114    }
115
116    let results = Arc::new(Mutex::new(Vec::<Vec<u64>>::new()));
117    let done = Arc::new(AtomicBool::new(false));
118
119    start_nodes.par_iter().for_each(|&start| {
120        if done.load(Ordering::Relaxed) {
121            return;
122        }
123        let mut initial_path_set = HashSet::new();
124        initial_path_set.insert(start);
125        let initial = PathState {
126            path: vec![start],
127            path_set: initial_path_set,
128        };
129        let ctx = DfsContext {
130            min_hops,
131            max_hops,
132            limit,
133            get_neighbors: &get_neighbors,
134            results: &results,
135            done: &done,
136        };
137        dfs_enumerate(initial, 0, &ctx);
138    });
139
140    Arc::try_unwrap(results)
141        .expect("results Arc should be uniquely owned after parallel traversal")
142        .into_inner()
143        .expect("results Mutex should not be poisoned")
144}
145
146fn dfs_enumerate<F>(state: PathState, depth: usize, ctx: &DfsContext<'_, F>)
147where
148    F: Fn(u64) -> Vec<u64> + Send + Sync,
149{
150    if ctx.done.load(Ordering::Relaxed) {
151        return;
152    }
153
154    if depth >= ctx.min_hops {
155        let mut r = ctx
156            .results
157            .lock()
158            .expect("results Mutex should not be poisoned");
159        // Re-check limit under the lock: concurrent tasks may have filled the
160        // buffer between the `done` pre-check above and acquiring the lock here.
161        if r.len() >= ctx.limit {
162            ctx.done.store(true, Ordering::Relaxed);
163            return;
164        }
165        r.push(state.path.clone());
166        if r.len() >= ctx.limit {
167            ctx.done.store(true, Ordering::Relaxed);
168            return;
169        }
170    }
171
172    if depth >= ctx.max_hops {
173        return;
174    }
175
176    let current = *state.path.last().unwrap();
177    for neighbor in (ctx.get_neighbors)(current) {
178        if !state.path_set.contains(&neighbor) {
179            let mut next_state = state.clone();
180            next_state.path.push(neighbor);
181            next_state.path_set.insert(neighbor);
182            dfs_enumerate(next_state, depth + 1, ctx);
183        }
184    }
185}