use super::EdgeStore;
use rustc_hash::{FxHashMap, FxHashSet};
use smallvec::SmallVec;
use std::collections::VecDeque;
pub const DEFAULT_MAX_DEPTH: u32 = 3;
pub const SAFETY_MAX_DEPTH: u32 = 100;
pub type TraversalPath = SmallVec<[u64; 4]>;
#[derive(Debug, Clone)]
pub struct TraversalResult {
pub target_id: u64,
pub path: Vec<u64>,
pub depth: u32,
}
impl TraversalResult {
#[must_use]
pub fn new(target_id: u64, path: Vec<u64>, depth: u32) -> Self {
Self {
target_id,
path,
depth,
}
}
#[must_use]
#[allow(clippy::needless_pass_by_value)]
#[allow(dead_code)]
pub(crate) fn from_smallvec(target_id: u64, path: TraversalPath, depth: u32) -> Self {
Self {
target_id,
path: path.to_vec(),
depth,
}
}
}
#[derive(Debug, Clone)]
pub struct TraversalConfig {
pub min_depth: u32,
pub max_depth: u32,
pub limit: usize,
pub rel_types: Vec<String>,
}
impl Default for TraversalConfig {
fn default() -> Self {
Self {
min_depth: 1,
max_depth: DEFAULT_MAX_DEPTH,
limit: 100,
rel_types: Vec::new(),
}
}
}
impl TraversalConfig {
#[must_use]
pub fn with_range(min: u32, max: u32) -> Self {
Self {
min_depth: min,
max_depth: max,
..Self::default()
}
}
#[must_use]
pub fn with_unbounded_range(min: u32) -> Self {
Self {
min_depth: min,
max_depth: SAFETY_MAX_DEPTH,
..Self::default()
}
}
#[must_use]
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
#[must_use]
pub fn with_rel_types(mut self, types: Vec<String>) -> Self {
self.rel_types = types;
self
}
#[must_use]
pub fn with_max_depth(mut self, max_depth: u32) -> Self {
self.max_depth = max_depth;
self
}
}
#[derive(Debug)]
pub(super) struct BfsState {
pub(super) node_id: u64,
pub(super) depth: u32,
}
#[must_use]
pub(super) fn reconstruct_path(
target: u64,
source: u64,
parent_map: &FxHashMap<u64, (u64, u64)>,
) -> Vec<u64> {
let mut path = Vec::new();
let mut current = target;
while current != source {
if let Some(&(parent, edge_id)) = parent_map.get(¤t) {
path.push(edge_id);
current = parent;
} else {
break;
}
}
path.reverse();
path
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum BfsDirection {
Forward,
Reverse,
}
#[must_use]
fn bfs_traverse_directed(
edge_store: &EdgeStore,
source_id: u64,
config: &TraversalConfig,
direction: BfsDirection,
) -> Vec<TraversalResult> {
let mut results = Vec::new();
let mut visited = FxHashSet::default();
let mut queue = VecDeque::new();
let mut parent_map: FxHashMap<u64, (u64, u64)> = FxHashMap::default();
let rel_filter: FxHashSet<&str> = config.rel_types.iter().map(String::as_str).collect();
visited.insert(source_id);
queue.push_back(BfsState {
node_id: source_id,
depth: 0,
});
let use_csr = direction == BfsDirection::Forward && edge_store.has_csr_snapshot();
while let Some(state) = queue.pop_front() {
if results.len() >= config.limit {
break;
}
if use_csr {
process_bfs_csr(
edge_store,
&state,
config,
source_id,
&rel_filter,
&mut results,
&mut visited,
&mut queue,
&mut parent_map,
);
} else {
let edges = match direction {
BfsDirection::Forward => edge_store.get_outgoing(state.node_id),
BfsDirection::Reverse => edge_store.get_incoming(state.node_id),
};
process_bfs_neighbors(
&edges,
&state,
config,
source_id,
&rel_filter,
direction,
&mut results,
&mut visited,
&mut queue,
&mut parent_map,
);
}
}
results
}
#[inline]
#[allow(clippy::too_many_arguments)] fn process_bfs_csr(
edge_store: &EdgeStore,
state: &BfsState,
config: &TraversalConfig,
source_id: u64,
rel_filter: &FxHashSet<&str>,
results: &mut Vec<TraversalResult>,
visited: &mut FxHashSet<u64>,
queue: &mut VecDeque<BfsState>,
parent_map: &mut FxHashMap<u64, (u64, u64)>,
) {
let snapshot = edge_store
.csr_snapshot()
.expect("invariant: CSR snapshot checked before calling process_bfs_csr");
let targets = snapshot.neighbors(state.node_id);
let edge_ids = snapshot.edge_ids(state.node_id);
for (i, (&target, &eid)) in targets.iter().zip(edge_ids.iter()).enumerate() {
if results.len() >= config.limit {
break;
}
if let Some(label) = snapshot.label_at(state.node_id, i) {
if !rel_filter.is_empty() && !rel_filter.contains(label) {
continue;
}
} else if !rel_filter.is_empty() {
continue;
}
process_bfs_candidate(
target,
eid,
state.node_id,
state.depth,
config,
source_id,
results,
visited,
queue,
parent_map,
);
}
}
#[inline]
#[allow(clippy::too_many_arguments)] fn process_bfs_neighbors(
edges: &[&super::GraphEdge],
state: &BfsState,
config: &TraversalConfig,
source_id: u64,
rel_filter: &FxHashSet<&str>,
direction: BfsDirection,
results: &mut Vec<TraversalResult>,
visited: &mut FxHashSet<u64>,
queue: &mut VecDeque<BfsState>,
parent_map: &mut FxHashMap<u64, (u64, u64)>,
) {
for edge in edges {
if results.len() >= config.limit {
break;
}
if !rel_filter.is_empty() && !rel_filter.contains(edge.label()) {
continue;
}
let next_node = match direction {
BfsDirection::Forward => edge.target(),
BfsDirection::Reverse => edge.source(),
};
process_bfs_candidate(
next_node,
edge.id(),
state.node_id,
state.depth,
config,
source_id,
results,
visited,
queue,
parent_map,
);
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
fn process_bfs_candidate(
target: u64,
edge_id: u64,
parent_node: u64,
current_depth: u32,
config: &TraversalConfig,
source_id: u64,
results: &mut Vec<TraversalResult>,
visited: &mut FxHashSet<u64>,
queue: &mut VecDeque<BfsState>,
parent_map: &mut FxHashMap<u64, (u64, u64)>,
) {
let new_depth = current_depth + 1;
if new_depth > config.max_depth {
return;
}
let is_new = visited.insert(target);
if is_new {
parent_map.insert(target, (parent_node, edge_id));
if new_depth >= config.min_depth {
let path = reconstruct_path(target, source_id, parent_map);
results.push(TraversalResult::new(target, path, new_depth));
}
if new_depth < config.max_depth {
queue.push_back(BfsState {
node_id: target,
depth: new_depth,
});
}
}
}
#[must_use]
pub fn bfs_traverse(
edge_store: &EdgeStore,
source_id: u64,
config: &TraversalConfig,
) -> Vec<TraversalResult> {
bfs_traverse_directed(edge_store, source_id, config, BfsDirection::Forward)
}
#[must_use]
pub fn bfs_traverse_reverse(
edge_store: &EdgeStore,
source_id: u64,
config: &TraversalConfig,
) -> Vec<TraversalResult> {
bfs_traverse_directed(edge_store, source_id, config, BfsDirection::Reverse)
}