use std::collections::HashSet;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use rayon::prelude::*;
pub struct ReachabilityResult {
pub visited: HashSet<u64>,
}
pub fn parallel_reachability_bfs<F>(
start_nodes: Vec<u64>,
_min_hops: usize,
max_hops: usize,
get_neighbors: F,
) -> ReachabilityResult
where
F: Fn(u64) -> Vec<u64> + Send + Sync,
{
let visited = Arc::new(Mutex::new(
start_nodes.iter().copied().collect::<HashSet<_>>(),
));
let mut frontier = start_nodes;
let mut hop = 0usize;
while !frontier.is_empty() && hop < max_hops {
let next_nodes: Vec<u64> = frontier
.par_iter()
.flat_map(|&node| get_neighbors(node))
.collect();
let mut v = visited.lock().unwrap();
frontier = next_nodes.into_iter().filter(|n| v.insert(*n)).collect();
hop += 1;
}
let v = visited
.lock()
.expect("visited mutex should not be poisoned")
.clone();
ReachabilityResult { visited: v }
}
#[derive(Clone)]
struct PathState {
path: Vec<u64>,
path_set: HashSet<u64>, }
struct DfsContext<'a, F> {
min_hops: usize,
max_hops: usize,
limit: usize,
get_neighbors: &'a F,
results: &'a Arc<Mutex<Vec<Vec<u64>>>>,
done: &'a Arc<AtomicBool>,
}
pub fn parallel_path_enumeration_dfs<F>(
start_nodes: Vec<u64>,
min_hops: usize,
max_hops: usize,
limit: usize,
get_neighbors: F,
) -> Vec<Vec<u64>>
where
F: Fn(u64) -> Vec<u64> + Send + Sync,
{
if limit == 0 {
return Vec::new();
}
let results = Arc::new(Mutex::new(Vec::<Vec<u64>>::new()));
let done = Arc::new(AtomicBool::new(false));
start_nodes.par_iter().for_each(|&start| {
if done.load(Ordering::Relaxed) {
return;
}
let mut initial_path_set = HashSet::new();
initial_path_set.insert(start);
let initial = PathState {
path: vec![start],
path_set: initial_path_set,
};
let ctx = DfsContext {
min_hops,
max_hops,
limit,
get_neighbors: &get_neighbors,
results: &results,
done: &done,
};
dfs_enumerate(initial, 0, &ctx);
});
Arc::try_unwrap(results)
.expect("results Arc should be uniquely owned after parallel traversal")
.into_inner()
.expect("results Mutex should not be poisoned")
}
fn dfs_enumerate<F>(state: PathState, depth: usize, ctx: &DfsContext<'_, F>)
where
F: Fn(u64) -> Vec<u64> + Send + Sync,
{
if ctx.done.load(Ordering::Relaxed) {
return;
}
if depth >= ctx.min_hops {
let mut r = ctx
.results
.lock()
.expect("results Mutex should not be poisoned");
if r.len() >= ctx.limit {
ctx.done.store(true, Ordering::Relaxed);
return;
}
r.push(state.path.clone());
if r.len() >= ctx.limit {
ctx.done.store(true, Ordering::Relaxed);
return;
}
}
if depth >= ctx.max_hops {
return;
}
let current = *state.path.last().unwrap();
for neighbor in (ctx.get_neighbors)(current) {
if !state.path_set.contains(&neighbor) {
let mut next_state = state.clone();
next_state.path.push(neighbor);
next_state.path_set.insert(neighbor);
dfs_enumerate(next_state, depth + 1, ctx);
}
}
}