use super::{ParallelConfig, TraversalResult, TraversalStats};
use rayon::prelude::*;
use rustc_hash::FxHashSet;
use std::collections::VecDeque;
#[derive(Debug, Clone, Copy)]
pub(crate) enum TraversalOrder {
Bfs,
Dfs,
}
#[derive(Debug, Default)]
pub struct ParallelTraverser {
config: ParallelConfig,
}
impl ParallelTraverser {
#[must_use]
pub fn new() -> Self {
Self {
config: ParallelConfig::default(),
}
}
#[must_use]
pub fn with_config(config: ParallelConfig) -> Self {
Self { config }
}
pub fn bfs_parallel<F>(
&self,
start_nodes: &[u64],
adjacency: F,
) -> (Vec<TraversalResult>, TraversalStats)
where
F: Fn(u64) -> Vec<(u64, u64)> + Send + Sync,
{
self.run_parallel(start_nodes, &adjacency, TraversalOrder::Bfs)
}
pub fn dfs_parallel<F>(
&self,
start_nodes: &[u64],
adjacency: F,
) -> (Vec<TraversalResult>, TraversalStats)
where
F: Fn(u64) -> Vec<(u64, u64)> + Send + Sync,
{
self.run_parallel(start_nodes, &adjacency, TraversalOrder::Dfs)
}
#[must_use]
pub fn merge_and_deduplicate(&self, results: Vec<TraversalResult>) -> Vec<TraversalResult> {
let mut seen = FxHashSet::default();
let mut unique: Vec<TraversalResult> = results
.into_iter()
.filter(|r| seen.insert(r.path_signature()))
.collect();
unique.sort_by(|a, b| {
let score_cmp = b
.score
.unwrap_or(f32::NEG_INFINITY)
.total_cmp(&a.score.unwrap_or(f32::NEG_INFINITY));
score_cmp.then_with(|| a.depth.cmp(&b.depth))
});
unique.truncate(self.config.limit);
unique
}
fn run_parallel<F>(
&self,
start_nodes: &[u64],
adjacency: &F,
order: TraversalOrder,
) -> (Vec<TraversalResult>, TraversalStats)
where
F: Fn(u64) -> Vec<(u64, u64)> + Send + Sync,
{
let stats = TraversalStats::new();
let results: Vec<Vec<TraversalResult>> =
if self.config.should_parallelize(start_nodes.len()) {
start_nodes
.par_iter()
.map(|&start| self.traverse_single(start, adjacency, &stats, order))
.collect()
} else {
start_nodes
.iter()
.map(|&start| self.traverse_single(start, adjacency, &stats, order))
.collect()
};
let all_results: Vec<TraversalResult> = results.into_iter().flatten().collect();
let raw_count = all_results.len();
let deduplicated = self.merge_and_deduplicate(all_results);
let mut final_stats = stats;
final_stats.start_nodes_count = start_nodes.len();
final_stats.raw_results = raw_count;
final_stats.deduplicated_results = deduplicated.len();
(deduplicated, final_stats)
}
fn traverse_single<F>(
&self,
start: u64,
adjacency: &F,
stats: &TraversalStats,
order: TraversalOrder,
) -> Vec<TraversalResult>
where
F: Fn(u64) -> Vec<(u64, u64)> + Send + Sync,
{
let mut results = Vec::new();
let mut visited = FxHashSet::default();
let mut queue: VecDeque<(u64, Vec<u64>, u32)> = VecDeque::new();
visited.insert(start);
stats.add_nodes_visited(1);
results.push(TraversalResult::new(start, start, Vec::new(), 0));
queue.push_back((start, Vec::new(), 0));
while let Some((node, path, depth)) = Self::pop_next(&mut queue, order) {
if depth >= self.config.max_depth || results.len() >= self.config.limit {
match order {
TraversalOrder::Bfs => break,
TraversalOrder::Dfs => continue,
}
}
let neighbors = adjacency(node);
stats.add_edges_traversed(neighbors.len());
for (neighbor, edge_id) in neighbors {
if visited.insert(neighbor) {
stats.add_nodes_visited(1);
let mut new_path = path.clone();
new_path.push(edge_id);
let new_depth = depth + 1;
results.push(TraversalResult::new(
start,
neighbor,
new_path.clone(),
new_depth,
));
queue.push_back((neighbor, new_path, new_depth));
}
}
}
results
}
fn pop_next(
queue: &mut VecDeque<(u64, Vec<u64>, u32)>,
order: TraversalOrder,
) -> Option<(u64, Vec<u64>, u32)> {
match order {
TraversalOrder::Bfs => queue.pop_front(),
TraversalOrder::Dfs => queue.pop_back(),
}
}
}