#![allow(clippy::cast_possible_truncation)]
mod frontier;
mod sharded;
mod traverser;
pub use frontier::FrontierParallelBFS;
pub use sharded::ShardedTraverser;
pub use traverser::ParallelTraverser;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
#[derive(Debug, Clone)]
pub struct TraversalResult {
pub start_node: u64,
pub end_node: u64,
pub path: Vec<u64>,
pub depth: u32,
pub score: Option<f32>,
}
impl TraversalResult {
#[must_use]
pub fn new(start_node: u64, end_node: u64, path: Vec<u64>, depth: u32) -> Self {
Self {
start_node,
end_node,
path,
depth,
score: None,
}
}
#[must_use]
pub fn with_score(mut self, score: f32) -> Self {
self.score = Some(score);
self
}
#[must_use]
pub fn path_signature(&self) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = rustc_hash::FxHasher::default();
self.start_node.hash(&mut hasher);
self.end_node.hash(&mut hasher);
self.path.hash(&mut hasher);
hasher.finish()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
#[non_exhaustive]
pub enum ThreadConfig {
#[default]
Auto,
Fixed(usize),
}
impl ThreadConfig {
#[must_use]
pub fn effective_threads(&self) -> usize {
match self {
ThreadConfig::Auto => {
let cpus =
std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get);
(cpus.saturating_sub(1)).max(1)
}
ThreadConfig::Fixed(n) => *n,
}
}
}
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub max_depth: u32,
pub parallel_threshold: usize,
pub min_frontier_for_parallel: usize,
pub limit: usize,
pub relationship_types: Vec<String>,
pub threads: ThreadConfig,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
max_depth: 5,
parallel_threshold: 100,
min_frontier_for_parallel: 50,
limit: 1000,
relationship_types: Vec::new(),
threads: ThreadConfig::Auto,
}
}
}
impl ParallelConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_max_depth(mut self, depth: u32) -> Self {
self.max_depth = depth;
self
}
#[must_use]
pub fn with_parallel_threshold(mut self, threshold: usize) -> Self {
self.parallel_threshold = threshold;
self
}
#[must_use]
pub fn with_min_frontier(mut self, min_frontier: usize) -> Self {
self.min_frontier_for_parallel = min_frontier;
self
}
#[must_use]
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
#[must_use]
pub fn with_threads(mut self, threads: ThreadConfig) -> Self {
self.threads = threads;
self
}
#[must_use]
pub fn with_fixed_threads(mut self, count: usize) -> Self {
self.threads = ThreadConfig::Fixed(count);
self
}
#[must_use]
pub fn should_parallelize(&self, node_count: usize) -> bool {
node_count >= self.parallel_threshold
}
#[must_use]
pub fn should_parallelize_frontier(&self, frontier_size: usize) -> bool {
frontier_size >= self.min_frontier_for_parallel
}
#[must_use]
pub fn effective_threads(&self) -> usize {
self.threads.effective_threads()
}
}
#[derive(Debug, Default)]
pub struct TraversalStats {
pub start_nodes_count: usize,
pub nodes_visited: AtomicUsize,
pub edges_traversed: AtomicUsize,
pub raw_results: usize,
pub deduplicated_results: usize,
}
impl TraversalStats {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add_nodes_visited(&self, count: usize) {
self.nodes_visited.fetch_add(count, AtomicOrdering::Relaxed);
}
pub fn add_edges_traversed(&self, count: usize) {
self.edges_traversed
.fetch_add(count, AtomicOrdering::Relaxed);
}
#[must_use]
pub fn total_nodes_visited(&self) -> usize {
self.nodes_visited.load(AtomicOrdering::Relaxed)
}
#[must_use]
pub fn total_edges_traversed(&self) -> usize {
self.edges_traversed.load(AtomicOrdering::Relaxed)
}
}
pub(super) fn bfs_core<F, P>(
start: u64,
adjacency: &F,
stats: &TraversalStats,
config: &ParallelConfig,
neighbor_filter: Option<&P>,
) -> Vec<TraversalResult>
where
F: Fn(u64) -> Vec<(u64, u64)> + Send + Sync,
P: Fn(u64) -> bool,
{
let mut results = Vec::new();
let mut visited = rustc_hash::FxHashSet::default();
let mut queue = std::collections::VecDeque::new();
visited.insert(start);
stats.add_nodes_visited(1);
results.push(TraversalResult::new(start, start, Vec::new(), 0));
queue.push_back((start, Vec::<u64>::new(), 0u32));
while let Some((node, path, depth)) = queue.pop_front() {
if depth >= config.max_depth || results.len() >= config.limit {
break;
}
let neighbors = adjacency(node);
stats.add_edges_traversed(neighbors.len());
for (neighbor, edge_id) in neighbors {
let allowed = neighbor_filter.is_none_or(|f| f(neighbor));
if allowed && visited.insert(neighbor) {
stats.add_nodes_visited(1);
let mut new_path = path.clone();
new_path.push(edge_id);
let new_depth = depth + 1;
results.push(TraversalResult::new(
start,
neighbor,
new_path.clone(),
new_depth,
));
queue.push_back((neighbor, new_path, new_depth));
}
}
}
results
}