use alloc::vec::Vec;
use core::ops::ControlFlow;
use oxgraph_algo::{
BfsBounds, BfsEpochScratch, breadth_first_search_bounded, reverse_breadth_first_search_bounded,
};
use super::{profile::TraverseMode, session::TraverseSession};
use crate::{
engine::Engine,
error::QueryError,
topology::{OverlayInView, OverlayOutView, OverlayViewFlags},
traverse::TraversalDirection,
};
pub(super) enum KernelOutcome {
Nodes(Vec<u32>),
Count(usize),
}
impl KernelOutcome {
const fn empty(mode: TraverseMode) -> Self {
match mode {
TraverseMode::Collect => Self::Nodes(Vec::new()),
TraverseMode::Count => Self::Count(0),
}
}
}
struct OutcomeVisitor {
nodes: Vec<u32>,
count: usize,
collect: bool,
}
impl OutcomeVisitor {
fn observe(&mut self, node: u32) -> ControlFlow<()> {
self.count += 1;
if self.collect {
self.nodes.push(node);
}
ControlFlow::Continue(())
}
fn finish(self, mode: TraverseMode) -> KernelOutcome {
match mode {
TraverseMode::Collect => KernelOutcome::Nodes(self.nodes),
TraverseMode::Count => KernelOutcome::Count(self.count),
}
}
}
pub(super) fn run_bfs_multi(
engine: &mut Engine,
session: &TraverseSession,
) -> Result<KernelOutcome, crate::error::PostgresGraphError> {
if session.seeds.is_empty() {
return Ok(KernelOutcome::empty(session.mode));
}
let bounds = BfsBounds {
max_depth: session.max_depth,
result_limit: session.result_limit,
include_seeds: true,
};
let flags = OverlayViewFlags {
use_unique: session.profile.use_unique,
merge_overlay: session.profile.merge_overlay,
check_nodes: session.check_nodes,
};
let node_count = session.node_count as usize;
let mut visitor = OutcomeVisitor {
nodes: Vec::new(),
count: 0,
collect: session.mode == TraverseMode::Collect,
};
let (topology, unique, overlay, scratch) = engine.traverse_workspace_mut(flags.use_unique);
let (marks, queue) = scratch.bounded_slices();
let result = match session.profile.direction {
TraversalDirection::Out => {
let view = OverlayOutView::new(topology, overlay, unique, flags, node_count);
let mut bfs_scratch = BfsEpochScratch::new(marks, queue);
breadth_first_search_bounded(
&view,
&session.seeds,
bounds,
&mut bfs_scratch,
&mut |node, _depth| visitor.observe(node),
)
}
TraversalDirection::In => {
let view = OverlayInView::new(topology, overlay, unique, flags, node_count);
let mut bfs_scratch = BfsEpochScratch::new(marks, queue);
reverse_breadth_first_search_bounded(
&view,
&session.seeds,
bounds,
&mut bfs_scratch,
&mut |node, _depth| visitor.observe(node),
)
}
};
if result.is_err() {
return Err(QueryError::InternalInvariant("bounded BFS rejected engine scratch").into());
}
Ok(visitor.finish(session.mode))
}