use super::parallel_traversal::*;
use std::collections::HashMap;
fn create_test_graph() -> HashMap<u64, Vec<(u64, u64)>> {
let mut graph = HashMap::new();
graph.insert(1, vec![(2, 100), (3, 101)]);
graph.insert(2, vec![(4, 102), (5, 103)]);
graph.insert(3, vec![(5, 104), (6, 105)]);
graph.insert(4, vec![]);
graph.insert(5, vec![]);
graph.insert(6, vec![]);
graph
}
#[test]
fn test_traversal_result_new() {
let result = TraversalResult::new(1, 5, vec![100, 103], 2);
assert_eq!(result.start_node, 1);
assert_eq!(result.end_node, 5);
assert_eq!(result.depth, 2);
assert!(result.score.is_none());
}
#[test]
fn test_traversal_result_with_score() {
let result = TraversalResult::new(1, 5, vec![100], 1).with_score(0.9);
assert_eq!(result.score, Some(0.9));
}
#[test]
fn test_path_signature_uniqueness() {
let r1 = TraversalResult::new(1, 5, vec![100, 101], 2);
let r2 = TraversalResult::new(1, 5, vec![100, 102], 2);
let r3 = TraversalResult::new(1, 5, vec![100, 101], 2);
assert_ne!(r1.path_signature(), r2.path_signature());
assert_eq!(r1.path_signature(), r3.path_signature());
}
#[test]
fn test_parallel_config_default() {
let config = ParallelConfig::default();
assert_eq!(config.max_depth, 5);
assert_eq!(config.parallel_threshold, 100);
assert_eq!(config.limit, 1000);
}
#[test]
fn test_traversal_stats() {
let stats = TraversalStats::new();
stats.add_nodes_visited(10);
stats.add_edges_traversed(20);
assert_eq!(stats.total_nodes_visited(), 10);
assert_eq!(stats.total_edges_traversed(), 20);
}
#[test]
fn test_bfs_single_start() {
let graph = create_test_graph();
let traverser = ParallelTraverser::with_config(
ParallelConfig::new()
.with_max_depth(3)
.with_parallel_threshold(1)
.with_limit(100),
);
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let (results, stats) = traverser.bfs_parallel(&[1], get_neighbors);
assert_eq!(results.len(), 6);
assert_eq!(stats.start_nodes_count, 1);
assert!(stats.total_nodes_visited() >= 6);
}
#[test]
fn test_bfs_multiple_starts() {
let graph = create_test_graph();
let traverser = ParallelTraverser::with_config(
ParallelConfig::new()
.with_max_depth(2)
.with_parallel_threshold(1)
.with_limit(100),
);
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let (results, stats) = traverser.bfs_parallel(&[1, 3], get_neighbors);
assert_eq!(stats.start_nodes_count, 2);
assert!(results.len() >= 2);
}
#[test]
fn test_bfs_depth_limit() {
let graph = create_test_graph();
let traverser = ParallelTraverser::with_config(
ParallelConfig::new()
.with_max_depth(1)
.with_parallel_threshold(1)
.with_limit(100),
);
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let (results, _) = traverser.bfs_parallel(&[1], get_neighbors);
assert_eq!(results.len(), 3);
assert!(results.iter().all(|r| r.depth <= 1));
}
#[test]
fn test_dfs_single_start() {
let graph = create_test_graph();
let traverser = ParallelTraverser::with_config(
ParallelConfig::new()
.with_max_depth(3)
.with_parallel_threshold(1)
.with_limit(100),
);
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let (results, stats) = traverser.dfs_parallel(&[1], get_neighbors);
assert_eq!(results.len(), 6);
assert_eq!(stats.start_nodes_count, 1);
}
#[test]
fn test_merge_deduplication() {
let traverser = ParallelTraverser::new();
let results = vec![
TraversalResult::new(1, 2, vec![100], 1),
TraversalResult::new(1, 2, vec![100], 1),
TraversalResult::new(1, 3, vec![101], 1),
];
let merged = traverser.merge_and_deduplicate(results);
assert_eq!(merged.len(), 2);
}
#[test]
fn test_merge_sorting_by_score() {
let traverser = ParallelTraverser::new();
let results = vec![
TraversalResult::new(1, 2, vec![100], 1).with_score(0.5),
TraversalResult::new(1, 3, vec![101], 1).with_score(0.9),
TraversalResult::new(1, 4, vec![102], 1).with_score(0.7),
];
let merged = traverser.merge_and_deduplicate(results);
assert_eq!(merged[0].score, Some(0.9));
assert_eq!(merged[1].score, Some(0.7));
assert_eq!(merged[2].score, Some(0.5));
}
#[test]
fn test_result_limit() {
let traverser = ParallelTraverser::with_config(
ParallelConfig::new()
.with_max_depth(10)
.with_parallel_threshold(1)
.with_limit(3),
);
let get_neighbors = |node: u64| -> Vec<(u64, u64)> {
if node < 100 {
vec![(node + 1, node * 10), (node + 2, node * 10 + 1)]
} else {
vec![]
}
};
let (results, _) = traverser.bfs_parallel(&[1], get_neighbors);
assert!(results.len() <= 3);
}
#[test]
fn test_frontier_parallel_bfs_basic() {
let graph = create_test_graph();
let bfs = FrontierParallelBFS::new();
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let (results, stats) = bfs.traverse(1, get_neighbors);
assert!(results.len() >= 6);
assert_eq!(stats.start_nodes_count, 1);
}
#[test]
fn test_frontier_parallel_bfs_no_duplicates() {
let graph = create_test_graph();
let bfs = FrontierParallelBFS::new();
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let (results, _) = bfs.traverse(1, get_neighbors);
let mut seen_ends: std::collections::HashSet<u64> = std::collections::HashSet::new();
for result in &results {
assert!(
seen_ends.insert(result.end_node),
"Duplicate end node: {}",
result.end_node
);
}
}
#[test]
fn test_frontier_parallel_bfs_depth_order() {
let graph = create_test_graph();
let bfs = FrontierParallelBFS::new();
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let (results, _) = bfs.traverse(1, get_neighbors);
let mut last_depth = 0;
for result in &results {
assert!(
result.depth >= last_depth || result.depth == 0,
"Results not in BFS order"
);
last_depth = result.depth;
}
}
#[test]
fn test_frontier_parallel_bfs_with_limit() {
let bfs = FrontierParallelBFS::with_config(
ParallelConfig::new()
.with_max_depth(10)
.with_parallel_threshold(1)
.with_limit(3),
);
let get_neighbors = |node: u64| -> Vec<(u64, u64)> {
if node < 100 {
vec![(node + 1, node * 10), (node + 2, node * 10 + 1)]
} else {
vec![]
}
};
let (results, _) = bfs.traverse(1, get_neighbors);
assert!(results.len() <= 3);
}
#[test]
fn test_frontier_parallel_bfs_empty_graph() {
let bfs = FrontierParallelBFS::new();
let get_neighbors = |_node: u64| -> Vec<(u64, u64)> { vec![] };
let (results, stats) = bfs.traverse(1, get_neighbors);
assert_eq!(results.len(), 1);
assert_eq!(results[0].end_node, 1);
assert_eq!(stats.start_nodes_count, 1);
}
#[test]
fn test_thread_config_auto() {
let config = ThreadConfig::Auto;
let threads = config.effective_threads();
assert!(threads >= 1);
let cpu_count = std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get);
assert!(threads <= cpu_count);
}
#[test]
fn test_thread_config_fixed() {
let config = ThreadConfig::Fixed(4);
assert_eq!(config.effective_threads(), 4);
}
#[test]
fn test_parallel_config_builder() {
let config = ParallelConfig::new()
.with_max_depth(10)
.with_parallel_threshold(50)
.with_min_frontier(25)
.with_limit(500)
.with_fixed_threads(8);
assert_eq!(config.max_depth, 10);
assert_eq!(config.parallel_threshold, 50);
assert_eq!(config.min_frontier_for_parallel, 25);
assert_eq!(config.limit, 500);
assert_eq!(config.threads, ThreadConfig::Fixed(8));
}
#[test]
fn test_should_parallelize() {
let config = ParallelConfig::new().with_parallel_threshold(100);
assert!(!config.should_parallelize(50));
assert!(!config.should_parallelize(99));
assert!(config.should_parallelize(100));
assert!(config.should_parallelize(200));
}
#[test]
fn test_should_parallelize_frontier() {
let config = ParallelConfig::new().with_min_frontier(50);
assert!(!config.should_parallelize_frontier(25));
assert!(!config.should_parallelize_frontier(49));
assert!(config.should_parallelize_frontier(50));
assert!(config.should_parallelize_frontier(100));
}
#[test]
fn test_effective_threads_from_config() {
let config_auto = ParallelConfig::new();
assert!(config_auto.effective_threads() >= 1);
let config_fixed = ParallelConfig::new().with_fixed_threads(16);
assert_eq!(config_fixed.effective_threads(), 16);
}
#[test]
fn test_sharded_traverser_shard_assignment() {
let traverser = ShardedTraverser::new(4);
assert_eq!(traverser.shard_for_node(0), 0);
assert_eq!(traverser.shard_for_node(1), 1);
assert_eq!(traverser.shard_for_node(2), 2);
assert_eq!(traverser.shard_for_node(3), 3);
assert_eq!(traverser.shard_for_node(4), 0); assert_eq!(traverser.shard_for_node(100), 0);
}
#[test]
fn test_sharded_traverser_partition() {
let traverser = ShardedTraverser::new(4);
let nodes = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
let partitions = traverser.partition_by_shard(&nodes);
assert_eq!(partitions.len(), 4);
assert_eq!(partitions[0], vec![0, 4, 8]); assert_eq!(partitions[1], vec![1, 5]); assert_eq!(partitions[2], vec![2, 6]); assert_eq!(partitions[3], vec![3, 7]); }
#[test]
fn test_sharded_traverser_basic() {
let graph = create_test_graph();
let traverser = ShardedTraverser::with_config(
2,
ParallelConfig::new()
.with_max_depth(3)
.with_parallel_threshold(1),
);
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let (results, stats) = traverser.traverse_parallel(&[1], get_neighbors);
assert!(!results.is_empty());
assert_eq!(stats.start_nodes_count, 1);
}
#[test]
fn test_sharded_traverser_cross_shard_edges() {
let mut graph: HashMap<u64, Vec<(u64, u64)>> = HashMap::new();
graph.insert(0, vec![(1, 100)]);
graph.insert(1, vec![(2, 101)]);
graph.insert(2, vec![]);
let traverser = ShardedTraverser::with_config(
2, ParallelConfig::new()
.with_max_depth(5)
.with_parallel_threshold(1),
);
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let (results, _) = traverser.traverse_parallel(&[0], get_neighbors);
assert!(results.len() >= 3);
let end_nodes: std::collections::HashSet<u64> = results.iter().map(|r| r.end_node).collect();
assert!(end_nodes.contains(&0));
assert!(end_nodes.contains(&1));
assert!(end_nodes.contains(&2));
}
#[test]
fn test_sharded_traverser_multiple_start_nodes() {
let graph = create_test_graph();
let traverser = ShardedTraverser::with_config(
4,
ParallelConfig::new()
.with_max_depth(2)
.with_parallel_threshold(1),
);
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let (results, stats) = traverser.traverse_parallel(&[1, 2, 3], get_neighbors);
assert_eq!(stats.start_nodes_count, 3);
assert!(results.len() >= 3);
}
#[test]
fn test_sharded_traverser_num_shards() {
let traverser = ShardedTraverser::new(8);
assert_eq!(traverser.num_shards(), 8);
}
use std::sync::atomic::{AtomicUsize, Ordering};
fn counting_wide_graph(
fan_out: u64,
adjacency_calls: &AtomicUsize,
) -> impl Fn(u64) -> Vec<(u64, u64)> + '_ {
move |node: u64| -> Vec<(u64, u64)> {
adjacency_calls.fetch_add(1, Ordering::Relaxed);
if node == 1 {
(0..fan_out).map(|i| (1_000 + i, 10_000 + i)).collect()
} else if (1_000..1_000 + fan_out).contains(&node) {
vec![(node + 100_000, node + 200_000)]
} else {
Vec::new()
}
}
}
#[test]
fn test_frontier_bfs_bounds_wide_level() {
let calls = AtomicUsize::new(0);
let bfs = FrontierParallelBFS::with_config(
ParallelConfig::new()
.with_max_depth(3)
.with_min_frontier(1)
.with_limit(10),
);
let (results, _) = bfs.traverse(1, counting_wide_graph(50_000, &calls));
assert_eq!(results.len(), 10);
assert!(
calls.load(Ordering::Relaxed) < 1_000,
"expanded too many nodes: {}",
calls.load(Ordering::Relaxed)
);
}
#[test]
fn test_traverse_single_bounds_high_fanout_node() {
let calls = AtomicUsize::new(0);
let traverser = ParallelTraverser::with_config(
ParallelConfig::new()
.with_max_depth(3)
.with_parallel_threshold(100) .with_limit(8),
);
let (results, _) = traverser.bfs_parallel(&[1], counting_wide_graph(50_000, &calls));
assert_eq!(results.len(), 8);
}
#[test]
fn test_sharded_bounds_wide_level() {
let calls = AtomicUsize::new(0);
let traverser =
ShardedTraverser::with_config(4, ParallelConfig::new().with_max_depth(3).with_limit(12));
let (results, _) = traverser.traverse_parallel(&[1], counting_wide_graph(50_000, &calls));
assert!(results.len() <= 12, "sharded result exceeded limit");
assert!(
calls.load(Ordering::Relaxed) < 1_000,
"expanded too many nodes: {}",
calls.load(Ordering::Relaxed)
);
}
#[test]
fn test_frontier_bfs_unbounded_matches_full_graph() {
let graph = create_test_graph();
let bfs =
FrontierParallelBFS::with_config(ParallelConfig::new().with_max_depth(5).with_limit(1000));
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let (results, _) = bfs.traverse(1, get_neighbors);
assert_eq!(results.len(), 6);
}
fn back_edge_graph() -> HashMap<u64, Vec<(u64, u64)>> {
let mut graph = HashMap::new();
graph.insert(1, vec![(2, 12), (3, 13)]);
graph.insert(2, vec![(1, 21), (3, 23), (2, 22), (4, 24), (5, 25)]);
graph.insert(3, vec![(1, 31), (2, 32), (3, 33), (6, 36), (7, 37)]);
graph.insert(4, vec![]);
graph.insert(5, vec![]);
graph.insert(6, vec![]);
graph.insert(7, vec![]);
graph
}
#[test]
fn test_frontier_bfs_back_edges_fill_to_limit() {
let graph = back_edge_graph();
let bfs = FrontierParallelBFS::with_config(
ParallelConfig::new()
.with_max_depth(5)
.with_min_frontier(1)
.with_limit(5),
);
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let (results, _) = bfs.traverse(1, get_neighbors);
assert_eq!(
results.len(),
5,
"frontier BFS under-filled below the limit"
);
}
#[test]
fn test_frontier_bfs_back_edges_deterministic_count() {
let graph = back_edge_graph();
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let mut counts = std::collections::HashSet::new();
for _ in 0..32 {
let bfs = FrontierParallelBFS::with_config(
ParallelConfig::new()
.with_max_depth(5)
.with_min_frontier(1)
.with_limit(5),
);
let (results, _) = bfs.traverse(1, get_neighbors);
counts.insert(results.len());
}
assert_eq!(counts, std::collections::HashSet::from([5]));
}
#[test]
fn test_sharded_back_edges_fill_to_limit() {
let graph = back_edge_graph();
let traverser =
ShardedTraverser::with_config(4, ParallelConfig::new().with_max_depth(5).with_limit(5));
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let (results, _) = traverser.traverse_parallel(&[1], get_neighbors);
assert_eq!(
results.len(),
5,
"sharded traversal under-filled below the limit"
);
}
#[test]
fn test_sharded_back_edges_deterministic_count() {
let graph = back_edge_graph();
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let mut counts = std::collections::HashSet::new();
for _ in 0..32 {
let traverser =
ShardedTraverser::with_config(4, ParallelConfig::new().with_max_depth(5).with_limit(5));
let (results, _) = traverser.traverse_parallel(&[1], get_neighbors);
counts.insert(results.len());
}
assert_eq!(counts, std::collections::HashSet::from([5]));
}
#[test]
fn test_sharded_unbounded_matches_full_graph() {
let graph = create_test_graph();
let traverser =
ShardedTraverser::with_config(4, ParallelConfig::new().with_max_depth(5).with_limit(1000));
let get_neighbors =
|node: u64| -> Vec<(u64, u64)> { graph.get(&node).cloned().unwrap_or_default() };
let (results, _) = traverser.traverse_parallel(&[1], get_neighbors);
assert_eq!(results.len(), 6);
}