use super::edge_concurrent::ConcurrentEdgeStore;
use super::traversal::{reconstruct_path, BfsState};
use super::{EdgeStore, TraversalResult, DEFAULT_MAX_DEPTH};
use rustc_hash::{FxHashMap, FxHashSet};
use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct StreamingConfig {
pub max_depth: u32,
pub limit: Option<usize>,
pub max_visited_size: usize,
pub rel_types: Vec<String>,
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
max_depth: DEFAULT_MAX_DEPTH,
limit: None,
max_visited_size: 100_000, rel_types: Vec::new(),
}
}
}
impl StreamingConfig {
#[must_use]
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
#[must_use]
pub fn with_max_depth(mut self, max_depth: u32) -> Self {
self.max_depth = max_depth;
self
}
#[must_use]
pub fn with_max_visited(mut self, max_visited: usize) -> Self {
self.max_visited_size = max_visited;
self
}
#[must_use]
pub fn with_rel_types(mut self, types: Vec<String>) -> Self {
self.rel_types = types;
self
}
}
pub struct BfsIterator<'a> {
edge_store: &'a EdgeStore,
queue: VecDeque<BfsState>,
visited: FxHashSet<u64>,
config: StreamingConfig,
rel_types_set: FxHashSet<String>,
yielded: usize,
visited_overflow: bool,
pending_results: VecDeque<TraversalResult>,
parent_map: FxHashMap<u64, (u64, u64)>,
source_id: u64,
}
impl<'a> BfsIterator<'a> {
#[must_use]
pub fn new(edge_store: &'a EdgeStore, start_id: u64, config: StreamingConfig) -> Self {
let rel_types_set: FxHashSet<String> = config.rel_types.iter().cloned().collect();
let mut iter = Self {
edge_store,
queue: VecDeque::new(),
visited: FxHashSet::default(),
config,
rel_types_set,
yielded: 0,
visited_overflow: false,
pending_results: VecDeque::new(),
parent_map: FxHashMap::default(),
source_id: start_id,
};
iter.init_first_level(start_id);
iter
}
fn init_first_level(&mut self, start_id: u64) {
self.visited.insert(start_id);
self.queue.push_back(BfsState {
node_id: start_id,
depth: 0,
});
}
#[must_use]
pub fn yielded_count(&self) -> usize {
self.yielded
}
#[must_use]
pub fn is_visited_overflow(&self) -> bool {
self.visited_overflow
}
#[must_use]
pub fn visited_size(&self) -> usize {
self.visited.len()
}
#[inline]
fn label_passes_filter(&self, label: &str) -> bool {
self.rel_types_set.is_empty() || self.rel_types_set.contains(label)
}
#[inline]
fn try_visit(&mut self, target: u64) -> bool {
if self.visited_overflow {
return true;
}
if self.visited.contains(&target) {
return false;
}
if self.visited.len() >= self.config.max_visited_size {
self.visited_overflow = true;
self.visited.clear();
return true;
}
self.visited.insert(target);
true
}
fn expand_node_csr(&mut self, state: &BfsState) {
let snapshot = self
.edge_store
.csr_snapshot()
.expect("invariant: CSR snapshot checked before calling expand_node_csr");
let targets = snapshot.neighbors(state.node_id);
let edge_ids = snapshot.edge_ids(state.node_id);
for (i, (&target, &eid)) in targets.iter().zip(edge_ids.iter()).enumerate() {
let new_depth = state.depth + 1;
if new_depth > self.config.max_depth {
continue;
}
if let Some(label) = snapshot.label_at(state.node_id, i) {
if !self.label_passes_filter(label) {
continue;
}
}
if !self.try_visit(target) {
continue;
}
self.parent_map.insert(target, (state.node_id, eid));
if new_depth < self.config.max_depth {
self.queue.push_back(BfsState {
node_id: target,
depth: new_depth,
});
}
let path = reconstruct_path(target, self.source_id, &self.parent_map);
self.pending_results
.push_back(TraversalResult::new(target, path, new_depth));
}
}
fn expand_node_legacy(&mut self, state: &BfsState) {
let edges = self.edge_store.get_outgoing(state.node_id);
for edge in edges {
if !self.label_passes_filter(edge.label()) {
continue;
}
let target = edge.target();
let new_depth = state.depth + 1;
if new_depth > self.config.max_depth {
continue;
}
if !self.try_visit(target) {
continue;
}
self.parent_map.insert(target, (state.node_id, edge.id()));
if new_depth < self.config.max_depth {
self.queue.push_back(BfsState {
node_id: target,
depth: new_depth,
});
}
let path = reconstruct_path(target, self.source_id, &self.parent_map);
self.pending_results
.push_back(TraversalResult::new(target, path, new_depth));
}
}
}
impl Iterator for BfsIterator<'_> {
type Item = TraversalResult;
fn next(&mut self) -> Option<Self::Item> {
if let Some(limit) = self.config.limit {
if self.yielded >= limit {
return None;
}
}
if let Some(result) = self.pending_results.pop_front() {
self.yielded += 1;
return Some(result);
}
while let Some(state) = self.queue.pop_front() {
if self.edge_store.has_csr_snapshot() {
self.expand_node_csr(&state);
} else {
self.expand_node_legacy(&state);
}
if let Some(result) = self.pending_results.pop_front() {
self.yielded += 1;
return Some(result);
}
}
None
}
}
#[must_use]
pub fn bfs_stream(
edge_store: &EdgeStore,
start_id: u64,
config: StreamingConfig,
) -> BfsIterator<'_> {
BfsIterator::new(edge_store, start_id, config)
}
pub struct ConcurrentBfsIterator<'a> {
edge_store: &'a ConcurrentEdgeStore,
queue: VecDeque<BfsState>,
visited: FxHashSet<u64>,
config: StreamingConfig,
rel_types_set: FxHashSet<String>,
yielded: usize,
visited_overflow: bool,
pending_results: VecDeque<TraversalResult>,
parent_map: FxHashMap<u64, (u64, u64)>,
source_id: u64,
}
impl<'a> ConcurrentBfsIterator<'a> {
#[must_use]
pub fn new(
edge_store: &'a ConcurrentEdgeStore,
start_id: u64,
config: StreamingConfig,
) -> Self {
let mut visited = FxHashSet::default();
visited.insert(start_id);
let mut queue = VecDeque::new();
queue.push_back(BfsState {
node_id: start_id,
depth: 0,
});
let rel_types_set: FxHashSet<String> = config.rel_types.iter().cloned().collect();
Self {
edge_store,
queue,
visited,
config,
rel_types_set,
yielded: 0,
visited_overflow: false,
pending_results: VecDeque::new(),
parent_map: FxHashMap::default(),
source_id: start_id,
}
}
}
impl ConcurrentBfsIterator<'_> {
#[inline]
fn label_passes_filter(&self, label: &str) -> bool {
self.rel_types_set.is_empty() || self.rel_types_set.contains(label)
}
#[inline]
fn try_visit(&mut self, target: u64) -> bool {
if self.visited_overflow {
return true;
}
if self.visited.contains(&target) {
return false;
}
if self.visited.len() >= self.config.max_visited_size {
self.visited_overflow = true;
self.visited.clear();
return true;
}
self.visited.insert(target);
true
}
fn expand_node(&mut self, state: &BfsState) {
let edges = self.edge_store.get_outgoing(state.node_id);
for edge in &edges {
if !self.label_passes_filter(edge.label()) {
continue;
}
let target = edge.target();
let new_depth = state.depth + 1;
if new_depth > self.config.max_depth {
continue;
}
if !self.try_visit(target) {
continue;
}
self.parent_map.insert(target, (state.node_id, edge.id()));
if new_depth < self.config.max_depth {
self.queue.push_back(BfsState {
node_id: target,
depth: new_depth,
});
}
let path = reconstruct_path(target, self.source_id, &self.parent_map);
self.pending_results
.push_back(TraversalResult::new(target, path, new_depth));
}
}
}
impl Iterator for ConcurrentBfsIterator<'_> {
type Item = TraversalResult;
fn next(&mut self) -> Option<Self::Item> {
if let Some(limit) = self.config.limit {
if self.yielded >= limit {
return None;
}
}
if let Some(result) = self.pending_results.pop_front() {
self.yielded += 1;
return Some(result);
}
while let Some(state) = self.queue.pop_front() {
self.expand_node(&state);
if let Some(result) = self.pending_results.pop_front() {
self.yielded += 1;
return Some(result);
}
}
None
}
}
#[must_use]
pub fn concurrent_bfs_stream(
edge_store: &ConcurrentEdgeStore,
start_id: u64,
config: StreamingConfig,
) -> ConcurrentBfsIterator<'_> {
ConcurrentBfsIterator::new(edge_store, start_id, config)
}