use crate::SqliteGraphError;
use crate::backend::native::v3::V3Backend;
use crate::backend::{BackendDirection, GraphBackend, NeighborQuery};
use crate::snapshot::SnapshotId;
use rayon::prelude::*;
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone)]
pub struct BfsConfig {
pub max_threads: Option<usize>,
pub min_parallel_size: usize,
}
impl Default for BfsConfig {
fn default() -> Self {
Self {
max_threads: None, min_parallel_size: 1000, }
}
}
#[derive(Debug, Clone)]
pub struct BfsResult {
pub visited_order: Vec<i64>,
pub distances: HashMap<i64, usize>,
pub total_visited: usize,
}
impl BfsResult {
fn new() -> Self {
Self {
visited_order: Vec::new(),
distances: HashMap::new(),
total_visited: 0,
}
}
fn add_visit(&mut self, node: i64, distance: usize) {
self.visited_order.push(node);
self.distances.insert(node, distance);
self.total_visited += 1;
}
}
#[derive(Debug)]
struct ChunkResult {
new_nodes: Vec<i64>,
distances: HashMap<i64, usize>,
}
impl ChunkResult {
fn new() -> Self {
Self {
new_nodes: Vec::new(),
distances: HashMap::new(),
}
}
fn add_node(&mut self, node: i64, distance: usize) {
self.new_nodes.push(node);
self.distances.insert(node, distance);
}
}
pub fn parallel_bfs(
graph: &V3Backend,
start: i64,
config: Option<BfsConfig>,
) -> Result<BfsResult, SqliteGraphError> {
let config = config.unwrap_or_default();
let snapshot = SnapshotId::current();
if graph.get_node(snapshot, start).is_err() {
return Err(SqliteGraphError::not_found(format!(
"Start node {} not found",
start
)));
}
let node_count = graph.header().node_count;
if node_count < config.min_parallel_size as u64 {
return sequential_bfs(graph, start);
}
if let Some(max_threads) = config.max_threads {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(max_threads)
.build()
.map_err(|e| {
SqliteGraphError::connection(format!("Failed to create thread pool: {}", e))
})?;
pool.install(|| parallel_bfs_impl(graph, start, &config))
} else {
parallel_bfs_impl(graph, start, &config)
}
}
fn partition_nodes(nodes: &[i64], num_chunks: usize) -> Vec<&[i64]> {
if num_chunks == 0 || nodes.is_empty() || nodes.len() <= num_chunks {
return vec![nodes];
}
let chunk_size = (nodes.len() + num_chunks - 1) / num_chunks; let mut chunks = Vec::with_capacity(num_chunks);
let mut start = 0;
while start < nodes.len() {
let end = (start + chunk_size).min(nodes.len());
chunks.push(&nodes[start..end]);
start = end;
}
chunks
}
fn parallel_bfs_impl(
graph: &V3Backend,
start: i64,
_config: &BfsConfig,
) -> Result<BfsResult, SqliteGraphError> {
let snapshot = SnapshotId::current();
let mut result = BfsResult::new();
let mut visited: HashSet<i64> = HashSet::new();
let mut current_level: Vec<i64> = vec![start];
let mut distance = 0;
visited.insert(start);
result.add_visit(start, distance);
let num_cpus = std::cmp::min(
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4),
4, );
while !current_level.is_empty() {
distance += 1;
let chunks = partition_nodes(¤t_level, num_cpus);
let chunk_results: Vec<ChunkResult> = chunks
.into_par_iter() .map(|chunk| {
let mut local_result = ChunkResult::new();
let mut local_visited: HashSet<i64> = HashSet::new();
for &node in chunk {
let query = NeighborQuery {
direction: BackendDirection::Outgoing,
edge_type: None,
};
if let Ok(neighbors) = graph.neighbors(snapshot, node, query) {
for neighbor in neighbors {
if !visited.contains(&neighbor) {
if local_visited.insert(neighbor) {
local_result.add_node(neighbor, distance);
}
}
}
}
}
local_result })
.collect();
let mut next_level: Vec<i64> = Vec::new();
for chunk_result in chunk_results {
for (node, dist) in chunk_result.distances {
if visited.insert(node) {
result.add_visit(node, dist);
next_level.push(node);
}
}
}
current_level = next_level;
}
Ok(result)
}
fn sequential_bfs(graph: &V3Backend, start: i64) -> Result<BfsResult, SqliteGraphError> {
let snapshot = SnapshotId::current();
let mut result = BfsResult::new();
let mut visited: HashSet<i64> = HashSet::new();
let mut queue: VecDeque<(i64, usize)> = VecDeque::new();
visited.insert(start);
queue.push_back((start, 0));
result.add_visit(start, 0);
while let Some((node, distance)) = queue.pop_front() {
let query = NeighborQuery {
direction: BackendDirection::Outgoing,
edge_type: None,
};
if let Ok(neighbors) = graph.neighbors(snapshot, node, query) {
for neighbor in neighbors {
if visited.insert(neighbor) {
queue.push_back((neighbor, distance + 1));
result.add_visit(neighbor, distance + 1);
}
}
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::native::v3::V3Backend;
use crate::backend::{EdgeSpec, NodeSpec};
use tempfile::TempDir;
fn create_test_backend() -> (V3Backend, TempDir) {
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test.db");
let backend = V3Backend::create(&db_path).unwrap();
(backend, temp_dir)
}
fn create_chain_graph(backend: &V3Backend, n: i64) -> Vec<i64> {
let mut node_ids = Vec::new();
for i in 1..=n {
let node = NodeSpec {
kind: "test_node".to_string(),
name: format!("node_{}", i),
file_path: None,
data: serde_json::json!(null),
};
let id = backend.insert_node(node).unwrap();
node_ids.push(id);
}
for i in 0..node_ids.len() - 1 {
let edge = EdgeSpec {
from: node_ids[i],
to: node_ids[i + 1],
edge_type: "test_edge".to_string(),
data: serde_json::json!(null),
};
backend.insert_edge(edge).unwrap();
}
node_ids
}
#[test]
fn test_parallel_bfs_chain_graph() {
let (backend, _temp_dir) = create_test_backend();
let node_ids = create_chain_graph(&backend, 10);
let result = parallel_bfs(&backend, node_ids[0], None).unwrap();
assert_eq!(result.total_visited, 10);
assert_eq!(result.visited_order.len(), 10);
assert_eq!(result.distances[&node_ids[0]], 0);
assert_eq!(result.distances[&node_ids[1]], 1);
assert_eq!(result.distances[&node_ids[9]], 9);
for (i, &node_id) in result.visited_order.iter().enumerate() {
assert_eq!(node_id, node_ids[i]);
}
}
#[test]
fn test_parallel_bfs_nonexistent_start() {
let (backend, _temp_dir) = create_test_backend();
let result = parallel_bfs(&backend, 9999, None);
assert!(result.is_err());
}
#[test]
fn test_parallel_bfs_sequential_fallback() {
let (backend, _temp_dir) = create_test_backend();
let node_ids = create_chain_graph(&backend, 5);
let config = BfsConfig {
max_threads: None,
min_parallel_size: 1000,
};
let result = parallel_bfs(&backend, node_ids[0], Some(config)).unwrap();
assert_eq!(result.total_visited, 5);
assert_eq!(result.visited_order.len(), 5);
}
#[test]
fn test_bfs_config_default() {
let config = BfsConfig::default();
assert_eq!(config.max_threads, None);
assert_eq!(config.min_parallel_size, 1000);
}
#[test]
fn test_bfs_result_empty() {
let result = BfsResult::new();
assert_eq!(result.total_visited, 0);
assert!(result.visited_order.is_empty());
assert!(result.distances.is_empty());
}
#[test]
fn test_parallel_bfs_diamond_graph() {
let (backend, _temp_dir) = create_test_backend();
let node1 = backend
.insert_node(NodeSpec {
kind: "test".to_string(),
name: "1".to_string(),
file_path: None,
data: serde_json::json!(null),
})
.unwrap();
let node2 = backend
.insert_node(NodeSpec {
kind: "test".to_string(),
name: "2".to_string(),
file_path: None,
data: serde_json::json!(null),
})
.unwrap();
let node3 = backend
.insert_node(NodeSpec {
kind: "test".to_string(),
name: "3".to_string(),
file_path: None,
data: serde_json::json!(null),
})
.unwrap();
let node4 = backend
.insert_node(NodeSpec {
kind: "test".to_string(),
name: "4".to_string(),
file_path: None,
data: serde_json::json!(null),
})
.unwrap();
backend
.insert_edge(EdgeSpec {
from: node1,
to: node2,
edge_type: "edge".to_string(),
data: serde_json::json!(null),
})
.unwrap();
backend
.insert_edge(EdgeSpec {
from: node1,
to: node3,
edge_type: "edge".to_string(),
data: serde_json::json!(null),
})
.unwrap();
backend
.insert_edge(EdgeSpec {
from: node2,
to: node4,
edge_type: "edge".to_string(),
data: serde_json::json!(null),
})
.unwrap();
backend
.insert_edge(EdgeSpec {
from: node3,
to: node4,
edge_type: "edge".to_string(),
data: serde_json::json!(null),
})
.unwrap();
let result = parallel_bfs(&backend, node1, None).unwrap();
assert_eq!(result.total_visited, 4);
assert_eq!(result.distances[&node1], 0);
assert_eq!(result.distances[&node2], 1);
assert_eq!(result.distances[&node3], 1);
assert_eq!(result.distances[&node4], 2);
assert_eq!(result.visited_order[0], node1);
assert!(result.visited_order[1..3].contains(&node2));
assert!(result.visited_order[1..3].contains(&node3));
assert_eq!(result.visited_order[3], node4);
}
#[test]
fn test_partition_nodes_empty() {
let nodes: Vec<i64> = vec![];
let chunks = partition_nodes(&nodes, 4);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].len(), 0);
}
#[test]
fn test_partition_nodes_single() {
let nodes = vec![1, 2, 3];
let chunks = partition_nodes(&nodes, 4);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0], &[1, 2, 3]);
}
#[test]
fn test_partition_nodes_even() {
let nodes = vec![1, 2, 3, 4, 5, 6];
let chunks = partition_nodes(&nodes, 3);
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0], &[1, 2]);
assert_eq!(chunks[1], &[3, 4]);
assert_eq!(chunks[2], &[5, 6]);
}
#[test]
fn test_partition_nodes_uneven() {
let nodes = vec![1, 2, 3, 4, 5];
let chunks = partition_nodes(&nodes, 3);
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0], &[1, 2]); assert_eq!(chunks[1], &[3, 4]); assert_eq!(chunks[2], &[5]); }
}