sparrowdb_execution/
parallel_bfs.rs1use std::collections::HashSet;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, Mutex};
10
11use rayon::prelude::*;
12
13pub struct ReachabilityResult {
15 pub visited: HashSet<u64>,
16}
17
18pub fn parallel_reachability_bfs<F>(
31 start_nodes: Vec<u64>,
32 _min_hops: usize,
33 max_hops: usize,
34 get_neighbors: F,
35) -> ReachabilityResult
36where
37 F: Fn(u64) -> Vec<u64> + Send + Sync,
38{
39 let visited = Arc::new(Mutex::new(
40 start_nodes.iter().copied().collect::<HashSet<_>>(),
41 ));
42
43 let mut frontier = start_nodes;
44 let mut hop = 0usize;
45
46 while !frontier.is_empty() && hop < max_hops {
47 let next_nodes: Vec<u64> = frontier
49 .par_iter()
50 .flat_map(|&node| get_neighbors(node))
51 .collect();
52
53 let mut v = visited.lock().unwrap();
55 frontier = next_nodes.into_iter().filter(|n| v.insert(*n)).collect();
56 hop += 1;
57 }
58
59 let v = visited
60 .lock()
61 .expect("visited mutex should not be poisoned")
62 .clone();
63 ReachabilityResult { visited: v }
64}
65
66#[derive(Clone)]
68struct PathState {
69 path: Vec<u64>,
70 path_set: HashSet<u64>, }
72
73struct DfsContext<'a, F> {
76 min_hops: usize,
77 max_hops: usize,
78 limit: usize,
79 get_neighbors: &'a F,
80 results: &'a Arc<Mutex<Vec<Vec<u64>>>>,
81 done: &'a Arc<AtomicBool>,
82}
83
84pub fn parallel_path_enumeration_dfs<F>(
103 start_nodes: Vec<u64>,
104 min_hops: usize,
105 max_hops: usize,
106 limit: usize,
107 get_neighbors: F,
108) -> Vec<Vec<u64>>
109where
110 F: Fn(u64) -> Vec<u64> + Send + Sync,
111{
112 if limit == 0 {
113 return Vec::new();
114 }
115
116 let results = Arc::new(Mutex::new(Vec::<Vec<u64>>::new()));
117 let done = Arc::new(AtomicBool::new(false));
118
119 start_nodes.par_iter().for_each(|&start| {
120 if done.load(Ordering::Relaxed) {
121 return;
122 }
123 let mut initial_path_set = HashSet::new();
124 initial_path_set.insert(start);
125 let initial = PathState {
126 path: vec![start],
127 path_set: initial_path_set,
128 };
129 let ctx = DfsContext {
130 min_hops,
131 max_hops,
132 limit,
133 get_neighbors: &get_neighbors,
134 results: &results,
135 done: &done,
136 };
137 dfs_enumerate(initial, 0, &ctx);
138 });
139
140 Arc::try_unwrap(results)
141 .expect("results Arc should be uniquely owned after parallel traversal")
142 .into_inner()
143 .expect("results Mutex should not be poisoned")
144}
145
146fn dfs_enumerate<F>(state: PathState, depth: usize, ctx: &DfsContext<'_, F>)
147where
148 F: Fn(u64) -> Vec<u64> + Send + Sync,
149{
150 if ctx.done.load(Ordering::Relaxed) {
151 return;
152 }
153
154 if depth >= ctx.min_hops {
155 let mut r = ctx
156 .results
157 .lock()
158 .expect("results Mutex should not be poisoned");
159 if r.len() >= ctx.limit {
162 ctx.done.store(true, Ordering::Relaxed);
163 return;
164 }
165 r.push(state.path.clone());
166 if r.len() >= ctx.limit {
167 ctx.done.store(true, Ordering::Relaxed);
168 return;
169 }
170 }
171
172 if depth >= ctx.max_hops {
173 return;
174 }
175
176 let current = *state.path.last().unwrap();
177 for neighbor in (ctx.get_neighbors)(current) {
178 if !state.path_set.contains(&neighbor) {
179 let mut next_state = state.clone();
180 next_state.path.push(neighbor);
181 next_state.path_set.insert(neighbor);
182 dfs_enumerate(next_state, depth + 1, ctx);
183 }
184 }
185}