use core::ops::ControlFlow;
use oxgraph_topology::{
ContainsElement, ElementId, ElementIndex, ElementPredecessors, ElementSuccessors, TopologyBase,
};
use crate::bfs::{BfsError, epoch::BfsEpochScratch};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct BfsBounds {
pub max_depth: Option<u32>,
pub result_limit: usize,
pub include_seeds: bool,
}
impl BfsBounds {
const fn expands(self, depth: u32) -> bool {
match self.max_depth {
Some(max) => depth < max,
None => true,
}
}
}
pub trait BfsVisitor<G: TopologyBase> {
fn visit(&mut self, element: ElementId<G>, depth: u32) -> ControlFlow<()>;
}
impl<G, F> BfsVisitor<G> for F
where
G: TopologyBase,
F: FnMut(ElementId<G>, u32) -> ControlFlow<()>,
{
fn visit(&mut self, element: ElementId<G>, depth: u32) -> ControlFlow<()> {
self(element, depth)
}
}
struct BoundedRun<'run, G, V>
where
G: ContainsElement + ElementIndex,
V: BfsVisitor<G>,
{
marks: &'run mut [u32],
queue: &'run mut [ElementId<G>],
epoch: u32,
head: usize,
tail: usize,
wave_end: usize,
depth: u32,
emitted: usize,
bounds: BfsBounds,
element_bound: usize,
visitor: &'run mut V,
}
impl<G, V> BoundedRun<'_, G, V>
where
G: ContainsElement + ElementIndex,
V: BfsVisitor<G>,
{
fn pop(&mut self) -> Option<(ElementId<G>, u32)> {
if self.head == self.tail {
return None;
}
if self.head == self.wave_end {
self.depth = self.depth.saturating_add(1);
self.wave_end = self.tail;
}
let element = self.queue[self.head];
self.head += 1;
Some((element, self.depth))
}
fn discover(
&mut self,
target: ElementId<G>,
index: usize,
depth: u32,
) -> Result<ControlFlow<()>, BfsError> {
if index >= self.element_bound {
return Err(BfsError::NeighborIndexOutOfBounds {
index,
bound: self.element_bound,
});
}
if self.marks[index] == self.epoch {
return Ok(ControlFlow::Continue(()));
}
self.marks[index] = self.epoch;
self.queue[self.tail] = target;
self.tail += 1;
Ok(self.emit(target, depth))
}
fn emit(&mut self, element: ElementId<G>, depth: u32) -> ControlFlow<()> {
if self.visitor.visit(element, depth).is_break() {
return ControlFlow::Break(());
}
self.emitted += 1;
if self.emitted >= self.bounds.result_limit {
return ControlFlow::Break(());
}
ControlFlow::Continue(())
}
}
fn start_run<'run, G, V>(
graph: &G,
scratch: &'run mut BfsEpochScratch<'_, G>,
seeds: &[ElementId<G>],
bounds: BfsBounds,
visitor: &'run mut V,
) -> Result<Option<BoundedRun<'run, G, V>>, BfsError>
where
G: ContainsElement + ElementIndex,
V: BfsVisitor<G>,
{
let element_bound = graph.element_bound();
if scratch.mark_capacity() < element_bound {
return Err(BfsError::VisitedTooSmall {
needed: element_bound,
actual: scratch.mark_capacity(),
});
}
if scratch.queue_capacity() < element_bound {
return Err(BfsError::QueueTooSmall {
needed: element_bound,
actual: scratch.queue_capacity(),
});
}
let (marks, queue, epoch) = scratch.bounded_parts();
let mut run = BoundedRun {
marks,
queue,
epoch,
head: 0,
tail: 0,
wave_end: 0,
depth: 0,
emitted: 0,
bounds,
element_bound,
visitor,
};
for &seed in seeds {
if !graph.contains_element(seed) {
return Err(BfsError::StartElementNotContained);
}
let index = graph.element_index(seed);
if index >= element_bound {
return Err(BfsError::StartIndexOutOfBounds {
index,
bound: element_bound,
});
}
if run.marks[index] == run.epoch {
continue;
}
run.marks[index] = run.epoch;
run.queue[run.tail] = seed;
run.tail += 1;
if bounds.include_seeds && run.emit(seed, 0).is_break() {
return Ok(None);
}
}
run.wave_end = run.tail;
Ok(Some(run))
}
pub fn breadth_first_search_bounded<G, V>(
graph: &G,
seeds: &[ElementId<G>],
bounds: BfsBounds,
scratch: &mut BfsEpochScratch<'_, G>,
visitor: &mut V,
) -> Result<(), BfsError>
where
G: ContainsElement + ElementSuccessors + ElementIndex,
V: BfsVisitor<G>,
{
let Some(mut run) = start_run(graph, scratch, seeds, bounds, visitor)? else {
return Ok(());
};
while let Some((element, depth)) = run.pop() {
if !bounds.expands(depth) {
continue;
}
for target in graph.element_successors(element) {
let index = graph.element_index(target);
if run.discover(target, index, depth + 1)?.is_break() {
return Ok(());
}
}
}
Ok(())
}
pub fn reverse_breadth_first_search_bounded<G, V>(
graph: &G,
seeds: &[ElementId<G>],
bounds: BfsBounds,
scratch: &mut BfsEpochScratch<'_, G>,
visitor: &mut V,
) -> Result<(), BfsError>
where
G: ContainsElement + ElementPredecessors + ElementIndex,
V: BfsVisitor<G>,
{
let Some(mut run) = start_run(graph, scratch, seeds, bounds, visitor)? else {
return Ok(());
};
while let Some((element, depth)) = run.pop() {
if !bounds.expands(depth) {
continue;
}
for target in graph.element_predecessors(element) {
let index = graph.element_index(target);
if run.discover(target, index, depth + 1)?.is_break() {
return Ok(());
}
}
}
Ok(())
}
pub fn breadth_first_search_bounded_both<G, V>(
graph: &G,
seeds: &[ElementId<G>],
bounds: BfsBounds,
scratch: &mut BfsEpochScratch<'_, G>,
visitor: &mut V,
) -> Result<(), BfsError>
where
G: ContainsElement + ElementSuccessors + ElementPredecessors + ElementIndex,
V: BfsVisitor<G>,
{
let Some(mut run) = start_run(graph, scratch, seeds, bounds, visitor)? else {
return Ok(());
};
while let Some((element, depth)) = run.pop() {
if !bounds.expands(depth) {
continue;
}
for target in graph.element_successors(element) {
let index = graph.element_index(target);
if run.discover(target, index, depth + 1)?.is_break() {
return Ok(());
}
}
for target in graph.element_predecessors(element) {
let index = graph.element_index(target);
if run.discover(target, index, depth + 1)?.is_break() {
return Ok(());
}
}
}
Ok(())
}