use core::ops::ControlFlow;
use std::collections::{BTreeMap, BTreeSet};
use oxgraph_algo::{
BfsBounds, BfsEpochScratch, breadth_first_search_bounded, breadth_first_search_bounded_both,
reverse_breadth_first_search_bounded,
};
use oxgraph_graph::{
CanonicalElementIdentity, CanonicalRelationIdentity, EdgeSourceGraph, EdgeTargetGraph,
ElementIndex, LocalElementIdentity, OutgoingGraph,
};
use crate::{
DbError, ElementId, RelationId,
projection::{GraphProjection, ProjectionElementId},
};
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[non_exhaustive]
pub enum Direction {
#[default]
Outgoing,
Incoming,
Both,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct Walk {
pub max_depth: usize,
pub direction: Direction,
pub limit: usize,
pub include_start: bool,
}
impl Default for Walk {
fn default() -> Self {
Self {
max_depth: 1,
direction: Direction::Outgoing,
limit: usize::MAX,
include_start: false,
}
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct TraversedNode {
pub element: ElementId,
pub depth: usize,
}
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct TraversedEdge {
pub relation: RelationId,
pub source: ElementId,
pub target: ElementId,
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct Subgraph {
pub nodes: Vec<TraversedNode>,
pub edges: Vec<TraversedEdge>,
}
impl Subgraph {
#[must_use]
pub fn nodes(&self) -> &[TraversedNode] {
&self.nodes
}
#[must_use]
pub fn edges(&self) -> &[TraversedEdge] {
&self.edges
}
}
pub(crate) fn walk_graph_projection(
graph: &GraphProjection,
seeds: &[ElementId],
walk: Walk,
) -> Result<Subgraph, DbError> {
if seeds.is_empty() || walk.limit == 0 {
return Ok(Subgraph::default());
}
let local_seeds = seeds
.iter()
.map(|seed| {
graph
.local_element_id(*seed)
.ok_or(DbError::UnknownElement { id: *seed })
})
.collect::<Result<Vec<ProjectionElementId>, DbError>>()?;
let bounds = BfsBounds {
max_depth: Some(u32::try_from(walk.max_depth).unwrap_or(u32::MAX)),
result_limit: walk.limit,
include_seeds: walk.include_start,
};
let bound = graph.element_bound();
let mut marks = vec![0_u32; bound];
let mut queue = vec![ProjectionElementId::new(0); bound];
let mut scratch = BfsEpochScratch::for_graph(graph, &mut marks, &mut queue);
let mut nodes = Vec::new();
let mut discovered: BTreeMap<ElementId, ProjectionElementId> = local_seeds
.iter()
.map(|&local| (graph.canonical_element_id(local), local))
.collect();
{
let mut visitor = |element: ProjectionElementId, depth: u32| {
let canonical = graph.canonical_element_id(element);
let depth = usize::try_from(depth).unwrap_or(usize::MAX);
nodes.push(TraversedNode {
element: canonical,
depth,
});
discovered.insert(canonical, element);
ControlFlow::Continue(())
};
match walk.direction {
Direction::Outgoing => breadth_first_search_bounded(
graph,
&local_seeds,
bounds,
&mut scratch,
&mut visitor,
),
Direction::Incoming => reverse_breadth_first_search_bounded(
graph,
&local_seeds,
bounds,
&mut scratch,
&mut visitor,
),
Direction::Both => breadth_first_search_bounded_both(
graph,
&local_seeds,
bounds,
&mut scratch,
&mut visitor,
),
}
.map_err(|_error| DbError::traversal("bounded traversal failed"))?;
}
let edges = collect_internal_edges(graph, &discovered);
Ok(Subgraph { nodes, edges })
}
fn collect_internal_edges(
graph: &GraphProjection,
discovered: &BTreeMap<ElementId, ProjectionElementId>,
) -> Vec<TraversedEdge> {
let mut seen = BTreeSet::new();
let mut edges = Vec::new();
for local in discovered.values().copied() {
for edge in graph.outgoing_edges(local) {
let source = graph.canonical_element_id(graph.source(edge));
let target = graph.canonical_element_id(graph.target(edge));
if !discovered.contains_key(&target) {
continue;
}
let relation = graph.canonical_relation_id(edge);
if !seen.insert(relation) {
continue;
}
edges.push(TraversedEdge {
relation,
source,
target,
});
}
}
edges.sort_unstable();
edges
}