#![allow(clippy::cast_possible_truncation)]
use super::{bfs_core, ParallelConfig, TraversalResult, TraversalStats};
use rayon::prelude::*;
use rustc_hash::FxHashSet;
#[derive(Debug)]
pub struct ShardedTraverser {
config: ParallelConfig,
num_shards: usize,
}
impl ShardedTraverser {
#[must_use]
pub fn new(num_shards: usize) -> Self {
Self {
config: ParallelConfig::default(),
num_shards: num_shards.max(1),
}
}
#[must_use]
pub fn with_config(num_shards: usize, config: ParallelConfig) -> Self {
Self {
config,
num_shards: num_shards.max(1),
}
}
#[must_use]
pub fn num_shards(&self) -> usize {
self.num_shards
}
#[must_use]
pub fn shard_for_node(&self, node_id: u64) -> usize {
(node_id as usize) % self.num_shards
}
#[must_use]
pub fn partition_by_shard(&self, nodes: &[u64]) -> Vec<Vec<u64>> {
let mut partitions = vec![Vec::new(); self.num_shards];
for &node in nodes {
let shard = self.shard_for_node(node);
partitions[shard].push(node);
}
partitions
}
pub fn traverse_parallel<F>(
&self,
start_nodes: &[u64],
adjacency: F,
) -> (Vec<TraversalResult>, TraversalStats)
where
F: Fn(u64) -> Vec<(u64, u64)> + Send + Sync,
{
let stats = TraversalStats::new();
let mut all_results = Vec::new();
let mut global_visited = FxHashSet::default();
for &start in start_nodes {
global_visited.insert(start);
stats.add_nodes_visited(1);
all_results.push(TraversalResult::new(start, start, Vec::new(), 0));
}
let mut shard_frontiers = self.initialize_frontiers(start_nodes);
for depth in 1..=self.config.max_depth {
if all_results.len() >= self.config.limit {
break;
}
if shard_frontiers.iter().all(Vec::is_empty) {
break;
}
let shard_results = self.expand_shards(&shard_frontiers, &adjacency, &stats, depth);
shard_frontiers = self.merge_shard_results(
shard_results,
&mut global_visited,
&stats,
&mut all_results,
);
}
let mut final_stats = stats;
final_stats.start_nodes_count = start_nodes.len();
final_stats.raw_results = all_results.len();
final_stats.deduplicated_results = all_results.len();
(all_results, final_stats)
}
fn initialize_frontiers(&self, start_nodes: &[u64]) -> Vec<Vec<(u64, u64, Vec<u64>)>> {
let mut frontiers = vec![Vec::new(); self.num_shards];
for &start in start_nodes {
frontiers[self.shard_for_node(start)].push((start, start, Vec::new()));
}
frontiers
}
#[allow(clippy::type_complexity, clippy::unused_self)]
fn expand_shards<F>(
&self,
shard_frontiers: &[Vec<(u64, u64, Vec<u64>)>],
adjacency: &F,
stats: &TraversalStats,
depth: u32,
) -> Vec<(Vec<TraversalResult>, Vec<(u64, u64, Vec<u64>)>)>
where
F: Fn(u64) -> Vec<(u64, u64)> + Send + Sync,
{
shard_frontiers
.par_iter()
.map(|frontier| {
let mut results = Vec::new();
let mut next_frontier = Vec::new();
for (start_node, current_node, path) in frontier {
let neighbors = adjacency(*current_node);
stats.add_edges_traversed(neighbors.len());
for (neighbor, edge_id) in neighbors {
let mut new_path = path.clone();
new_path.push(edge_id);
results.push(TraversalResult::new(
*start_node,
neighbor,
new_path.clone(),
depth,
));
next_frontier.push((*start_node, neighbor, new_path));
}
}
(results, next_frontier)
})
.collect()
}
#[allow(clippy::type_complexity)]
fn merge_shard_results(
&self,
shard_results: Vec<(Vec<TraversalResult>, Vec<(u64, u64, Vec<u64>)>)>,
global_visited: &mut FxHashSet<u64>,
stats: &TraversalStats,
all_results: &mut Vec<TraversalResult>,
) -> Vec<Vec<(u64, u64, Vec<u64>)>> {
let mut new_frontiers = vec![Vec::new(); self.num_shards];
let mut newly_visited = FxHashSet::default();
for (results, next_frontier) in shard_results {
for result in results {
if global_visited.insert(result.end_node) {
stats.add_nodes_visited(1);
newly_visited.insert(result.end_node);
all_results.push(result);
if all_results.len() >= self.config.limit {
break;
}
}
}
for (start, node, path) in next_frontier {
if newly_visited.contains(&node) {
new_frontiers[self.shard_for_node(node)].push((start, node, path));
}
}
}
new_frontiers
}
pub fn bfs_single_shard<F>(
&self,
start: u64,
adjacency: &F,
stats: &TraversalStats,
) -> Vec<TraversalResult>
where
F: Fn(u64) -> Vec<(u64, u64)> + Send + Sync,
{
let target_shard = self.shard_for_node(start);
let shard_filter = |neighbor: u64| self.shard_for_node(neighbor) == target_shard;
bfs_core(start, adjacency, stats, &self.config, Some(&shard_filter))
}
}