use super::traversal::BfsState;
use super::{EdgeStore, TraversalResult, DEFAULT_MAX_DEPTH};
use std::collections::{HashSet, VecDeque};
#[derive(Debug, Clone)]
pub struct StreamingConfig {
pub max_depth: u32,
pub limit: Option<usize>,
pub max_visited_size: usize,
pub rel_types: Vec<String>,
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
max_depth: DEFAULT_MAX_DEPTH,
limit: None,
max_visited_size: 100_000, rel_types: Vec::new(),
}
}
}
impl StreamingConfig {
#[must_use]
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
#[must_use]
pub fn with_max_depth(mut self, max_depth: u32) -> Self {
self.max_depth = max_depth;
self
}
#[must_use]
pub fn with_max_visited(mut self, max_visited: usize) -> Self {
self.max_visited_size = max_visited;
self
}
#[must_use]
pub fn with_rel_types(mut self, types: Vec<String>) -> Self {
self.rel_types = types;
self
}
}
pub struct BfsIterator<'a> {
edge_store: &'a EdgeStore,
queue: VecDeque<BfsState>,
visited: HashSet<u64>,
config: StreamingConfig,
yielded: usize,
visited_overflow: bool,
pending_results: VecDeque<TraversalResult>,
}
impl<'a> BfsIterator<'a> {
#[must_use]
pub fn new(edge_store: &'a EdgeStore, start_id: u64, config: StreamingConfig) -> Self {
let mut visited = HashSet::new();
visited.insert(start_id);
let mut queue = VecDeque::new();
queue.push_back(BfsState {
node_id: start_id,
path: Vec::new(),
depth: 0,
});
Self {
edge_store,
queue,
visited,
config,
yielded: 0,
visited_overflow: false,
pending_results: VecDeque::new(),
}
}
#[must_use]
pub fn yielded_count(&self) -> usize {
self.yielded
}
#[must_use]
pub fn is_visited_overflow(&self) -> bool {
self.visited_overflow
}
#[must_use]
pub fn visited_size(&self) -> usize {
self.visited.len()
}
}
impl Iterator for BfsIterator<'_> {
type Item = TraversalResult;
fn next(&mut self) -> Option<Self::Item> {
if let Some(limit) = self.config.limit {
if self.yielded >= limit {
return None;
}
}
if let Some(result) = self.pending_results.pop_front() {
self.yielded += 1;
return Some(result);
}
while let Some(state) = self.queue.pop_front() {
let edges = self.edge_store.get_outgoing(state.node_id);
for edge in edges {
if !self.config.rel_types.is_empty()
&& !self.config.rel_types.contains(&edge.label().to_string())
{
continue;
}
let target = edge.target();
let new_depth = state.depth + 1;
if new_depth > self.config.max_depth {
continue;
}
if !self.visited_overflow && self.visited.contains(&target) {
continue;
}
if !self.visited_overflow {
if self.visited.len() >= self.config.max_visited_size {
self.visited_overflow = true;
self.visited.clear(); } else {
self.visited.insert(target);
}
}
let mut new_path = state.path.clone();
new_path.push(edge.id());
if new_depth < self.config.max_depth {
self.queue.push_back(BfsState {
node_id: target,
path: new_path.clone(),
depth: new_depth,
});
}
self.pending_results
.push_back(TraversalResult::new(target, new_path, new_depth));
}
if let Some(result) = self.pending_results.pop_front() {
self.yielded += 1;
return Some(result);
}
}
None
}
}
#[must_use]
pub fn bfs_stream(
edge_store: &EdgeStore,
start_id: u64,
config: StreamingConfig,
) -> BfsIterator<'_> {
BfsIterator::new(edge_store, start_id, config)
}