use super::edge_concurrent::ConcurrentEdgeStore;
use super::traversal::{deadline_reached, reconstruct_path, BfsState, DEADLINE_CHECK_INTERVAL};
use super::{EdgeStore, TraversalResult, DEFAULT_MAX_DEPTH};
use rustc_hash::{FxHashMap, FxHashSet};
use std::collections::VecDeque;
use std::time::Instant;
pub const MAX_VISITED_SIZE: usize = 100_000;
#[derive(Debug, Clone)]
pub struct StreamingConfig {
pub max_depth: u32,
pub limit: Option<usize>,
pub max_visited_size: usize,
pub rel_types: Vec<String>,
pub deadline: Option<Instant>,
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
max_depth: DEFAULT_MAX_DEPTH,
limit: None,
max_visited_size: MAX_VISITED_SIZE, rel_types: Vec::new(),
deadline: None,
}
}
}
impl StreamingConfig {
#[must_use]
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
#[must_use]
pub fn with_max_depth(mut self, max_depth: u32) -> Self {
self.max_depth = max_depth;
self
}
#[must_use]
pub fn with_max_visited(mut self, max_visited: usize) -> Self {
self.max_visited_size = max_visited;
self
}
#[must_use]
pub fn with_rel_types(mut self, types: Vec<String>) -> Self {
self.rel_types = types;
self
}
#[must_use]
pub fn with_deadline(mut self, deadline: Instant) -> Self {
self.deadline = Some(deadline);
self
}
}
struct BfsBookkeeping {
config: StreamingConfig,
queue: VecDeque<BfsState>,
visited: FxHashSet<u64>,
rel_types_set: FxHashSet<String>,
visited_overflow: bool,
pending_results: VecDeque<TraversalResult>,
parent_map: FxHashMap<u64, (u64, u64)>,
source_id: u64,
yielded: usize,
nodes_since_check: u32,
}
impl BfsBookkeeping {
fn new(start_id: u64, config: StreamingConfig) -> Self {
let rel_types_set: FxHashSet<String> = config.rel_types.iter().cloned().collect();
let mut visited = FxHashSet::default();
visited.insert(start_id);
let mut queue = VecDeque::new();
queue.push_back(BfsState {
node_id: start_id,
depth: 0,
});
Self {
config,
queue,
visited,
rel_types_set,
visited_overflow: false,
pending_results: VecDeque::new(),
parent_map: FxHashMap::default(),
source_id: start_id,
yielded: 0,
nodes_since_check: DEADLINE_CHECK_INTERVAL,
}
}
#[inline]
fn label_passes_filter(&self, label: &str) -> bool {
self.rel_types_set.is_empty() || self.rel_types_set.contains(label)
}
#[inline]
fn try_visit(&mut self, target: u64) -> bool {
if self.visited_overflow {
return true;
}
if self.visited.contains(&target) {
return false;
}
if self.visited.len() >= self.config.max_visited_size {
self.visited_overflow = true;
self.visited.clear();
return true;
}
self.visited.insert(target);
true
}
fn process_candidate(
&mut self,
parent_id: u64,
target: u64,
edge_id: u64,
parent_depth: u32,
label: Option<&str>,
) {
if let Some(label) = label {
if !self.label_passes_filter(label) {
return;
}
}
let new_depth = parent_depth + 1;
if new_depth > self.config.max_depth {
return;
}
if !self.try_visit(target) {
return;
}
self.parent_map.insert(target, (parent_id, edge_id));
if new_depth < self.config.max_depth {
self.queue.push_back(BfsState {
node_id: target,
depth: new_depth,
});
}
let path = reconstruct_path(target, self.source_id, &self.parent_map);
self.pending_results
.push_back(TraversalResult::new(target, path, new_depth));
}
#[inline]
fn next_pending(&mut self) -> Option<TraversalResult> {
let result = self.pending_results.pop_front()?;
self.yielded += 1;
Some(result)
}
fn drive(&mut self, mut expand: impl FnMut(&mut Self, &BfsState)) -> Option<TraversalResult> {
if self.config.limit.is_some_and(|limit| self.yielded >= limit) {
return None;
}
if let Some(result) = self.next_pending() {
return Some(result);
}
let deadline = self.config.deadline;
while let Some(state) = self.queue.pop_front() {
if deadline_reached(deadline, &mut self.nodes_since_check) {
self.queue.clear();
return None;
}
expand(self, &state);
if let Some(result) = self.next_pending() {
return Some(result);
}
}
None
}
}
fn expand_csr(edge_store: &EdgeStore, core: &mut BfsBookkeeping, state: &BfsState) {
let Some(snapshot) = edge_store.csr_snapshot() else {
return;
};
let targets = snapshot.neighbors(state.node_id);
let edge_ids = snapshot.edge_ids(state.node_id);
for (i, (&target, &eid)) in targets.iter().zip(edge_ids.iter()).enumerate() {
let label = snapshot.label_at(state.node_id, i);
core.process_candidate(state.node_id, target, eid, state.depth, label);
}
}
fn expand_legacy(edge_store: &EdgeStore, core: &mut BfsBookkeeping, state: &BfsState) {
for edge in edge_store.get_outgoing(state.node_id) {
core.process_candidate(
state.node_id,
edge.target(),
edge.id(),
state.depth,
Some(edge.label()),
);
}
}
fn expand_concurrent(
edge_store: &ConcurrentEdgeStore,
core: &mut BfsBookkeeping,
state: &BfsState,
) {
for edge in &edge_store.get_outgoing(state.node_id) {
core.process_candidate(
state.node_id,
edge.target(),
edge.id(),
state.depth,
Some(edge.label()),
);
}
}
pub struct BfsIterator<'a> {
edge_store: &'a EdgeStore,
core: BfsBookkeeping,
}
impl<'a> BfsIterator<'a> {
#[must_use]
pub fn new(edge_store: &'a EdgeStore, start_id: u64, config: StreamingConfig) -> Self {
Self {
edge_store,
core: BfsBookkeeping::new(start_id, config),
}
}
#[must_use]
pub fn yielded_count(&self) -> usize {
self.core.yielded
}
#[must_use]
pub fn is_visited_overflow(&self) -> bool {
self.core.visited_overflow
}
#[must_use]
pub fn visited_size(&self) -> usize {
self.core.visited.len()
}
}
impl Iterator for BfsIterator<'_> {
type Item = TraversalResult;
fn next(&mut self) -> Option<Self::Item> {
let edge_store = self.edge_store;
self.core.drive(|core, state| {
if edge_store.has_csr_snapshot() {
expand_csr(edge_store, core, state);
} else {
expand_legacy(edge_store, core, state);
}
})
}
}
#[must_use]
pub fn bfs_stream(
edge_store: &EdgeStore,
start_id: u64,
config: StreamingConfig,
) -> BfsIterator<'_> {
BfsIterator::new(edge_store, start_id, config)
}
pub struct ConcurrentBfsIterator<'a> {
edge_store: &'a ConcurrentEdgeStore,
core: BfsBookkeeping,
}
impl<'a> ConcurrentBfsIterator<'a> {
#[must_use]
pub fn new(
edge_store: &'a ConcurrentEdgeStore,
start_id: u64,
config: StreamingConfig,
) -> Self {
Self {
edge_store,
core: BfsBookkeeping::new(start_id, config),
}
}
}
impl Iterator for ConcurrentBfsIterator<'_> {
type Item = TraversalResult;
fn next(&mut self) -> Option<Self::Item> {
let edge_store = self.edge_store;
self.core
.drive(|core, state| expand_concurrent(edge_store, core, state))
}
}
#[must_use]
pub fn concurrent_bfs_stream(
edge_store: &ConcurrentEdgeStore,
start_id: u64,
config: StreamingConfig,
) -> ConcurrentBfsIterator<'_> {
ConcurrentBfsIterator::new(edge_store, start_id, config)
}