use std::collections::{HashMap, HashSet, VecDeque};
use uuid::Uuid;
use khive_storage::types::{Direction, Edge, LinkId, NeighborQuery};
use khive_storage::EdgeRelation;
use crate::error::{RuntimeError, RuntimeResult};
use crate::runtime::KhiveRuntime;
#[derive(Debug, Clone)]
pub struct PathNode {
pub entity_id: Uuid,
pub depth: usize,
pub via_edge: Option<Edge>,
}
#[derive(Debug, Clone)]
pub struct TraversalOptions {
pub max_depth: usize,
pub direction: Direction,
pub relations: Option<Vec<EdgeRelation>>,
pub max_results: Option<usize>,
}
impl Default for TraversalOptions {
fn default() -> Self {
Self {
max_depth: 3,
direction: Direction::Out,
relations: None,
max_results: None,
}
}
}
impl KhiveRuntime {
pub async fn bfs_traverse(
&self,
namespace: Option<&str>,
start: Uuid,
options: TraversalOptions,
) -> RuntimeResult<Vec<PathNode>> {
let graph = self.graph(namespace)?;
let limit = options.max_results.unwrap_or(usize::MAX);
let mut visited: HashSet<Uuid> = HashSet::new();
let mut results: Vec<PathNode> = Vec::new();
let mut queue: VecDeque<(Uuid, usize)> = VecDeque::new();
visited.insert(start);
results.push(PathNode {
entity_id: start,
depth: 0,
via_edge: None,
});
queue.push_back((start, 0));
'bfs: while let Some((current, depth)) = queue.pop_front() {
if depth >= options.max_depth {
continue;
}
let query = NeighborQuery {
direction: options.direction.clone(),
relations: options.relations.clone(),
limit: None,
min_weight: None,
};
let hits = graph.neighbors(current, query).await?;
for hit in hits {
if visited.contains(&hit.node_id) {
continue;
}
let edge = graph
.get_edge(LinkId::from(hit.edge_id))
.await?
.ok_or_else(|| {
RuntimeError::NotFound(format!("edge {} missing", hit.edge_id))
})?;
visited.insert(hit.node_id);
results.push(PathNode {
entity_id: hit.node_id,
depth: depth + 1,
via_edge: Some(edge),
});
if results.len() >= limit {
break 'bfs;
}
queue.push_back((hit.node_id, depth + 1));
}
}
Ok(results)
}
pub async fn shortest_path(
&self,
namespace: Option<&str>,
from: Uuid,
to: Uuid,
max_depth: usize,
) -> RuntimeResult<Option<Vec<PathNode>>> {
if from == to {
return Ok(Some(vec![PathNode {
entity_id: from,
depth: 0,
via_edge: None,
}]));
}
let graph = self.graph(namespace)?;
let mut fwd: HashMap<Uuid, (usize, Option<Uuid>, Option<Uuid>)> = HashMap::new();
let mut fwd_q: VecDeque<Uuid> = VecDeque::new();
fwd.insert(from, (0, None, None));
fwd_q.push_back(from);
let mut bwd: HashMap<Uuid, (usize, Option<Uuid>, Option<Uuid>)> = HashMap::new();
let mut bwd_q: VecDeque<Uuid> = VecDeque::new();
bwd.insert(to, (0, None, None));
bwd_q.push_back(to);
let mut meeting: Option<(Uuid, usize)> = None;
let mut current_depth = 0usize;
while (!fwd_q.is_empty() || !bwd_q.is_empty()) && current_depth <= max_depth {
let fwd_level = fwd_q.len();
for _ in 0..fwd_level {
let Some(node) = fwd_q.pop_front() else { break };
let fwd_depth = fwd[&node].0;
let hits = graph
.neighbors(
node,
NeighborQuery {
direction: Direction::Out,
relations: None,
limit: None,
min_weight: None,
},
)
.await?;
for hit in hits {
if fwd.contains_key(&hit.node_id) {
continue;
}
let new_depth = fwd_depth + 1;
fwd.insert(hit.node_id, (new_depth, Some(node), Some(hit.edge_id)));
fwd_q.push_back(hit.node_id);
if let Some(&(bwd_depth, _, _)) = bwd.get(&hit.node_id) {
let total = new_depth + bwd_depth;
if total <= max_depth
&& meeting.as_ref().is_none_or(|&(_, best)| total < best)
{
meeting = Some((hit.node_id, total));
}
}
}
}
if meeting.is_some() {
break;
}
let bwd_level = bwd_q.len();
for _ in 0..bwd_level {
let Some(node) = bwd_q.pop_front() else { break };
let bwd_depth = bwd[&node].0;
let hits = graph
.neighbors(
node,
NeighborQuery {
direction: Direction::In,
relations: None,
limit: None,
min_weight: None,
},
)
.await?;
for hit in hits {
if bwd.contains_key(&hit.node_id) {
continue;
}
let new_depth = bwd_depth + 1;
bwd.insert(hit.node_id, (new_depth, Some(node), Some(hit.edge_id)));
bwd_q.push_back(hit.node_id);
if let Some(&(fwd_depth, _, _)) = fwd.get(&hit.node_id) {
let total = fwd_depth + new_depth;
if total <= max_depth
&& meeting.as_ref().is_none_or(|&(_, best)| total < best)
{
meeting = Some((hit.node_id, total));
}
}
}
}
if meeting.is_some() {
break;
}
current_depth += 1;
}
let (mid, _) = match meeting {
None => return Ok(None),
Some(m) => m,
};
let mut fwd_chain: Vec<(Uuid, Option<Uuid>)> = Vec::new();
{
let mut cur = mid;
loop {
let (_, parent, edge_id) = fwd[&cur];
fwd_chain.push((cur, edge_id));
match parent {
Some(p) => cur = p,
None => break,
}
}
}
fwd_chain.reverse();
let mut bwd_chain: Vec<(Uuid, Option<Uuid>)> = Vec::new();
{
let mut cur = mid;
while let Some(&(_, Some(child), edge_id)) = bwd.get(&cur) {
bwd_chain.push((child, edge_id));
cur = child;
}
}
let mut path: Vec<PathNode> = Vec::new();
for (i, (node_id, edge_id)) in fwd_chain.iter().enumerate() {
let via_edge = if i == 0 {
None } else if let Some(eid) = edge_id {
graph.get_edge(LinkId::from(*eid)).await?.or(None)
} else {
None
};
path.push(PathNode {
entity_id: *node_id,
depth: i,
via_edge,
});
}
let base = path.len();
for (i, (node_id, edge_id)) in bwd_chain.iter().enumerate() {
let via_edge = if let Some(eid) = edge_id {
graph.get_edge(LinkId::from(*eid)).await?.or(None)
} else {
None
};
path.push(PathNode {
entity_id: *node_id,
depth: base + i,
via_edge,
});
}
Ok(Some(path))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::KhiveRuntime;
use khive_storage::EdgeRelation;
async fn rt() -> KhiveRuntime {
KhiveRuntime::memory().expect("memory runtime")
}
#[tokio::test]
async fn bfs_max_depth_zero_returns_only_root() {
let rt = rt().await;
let a = rt
.create_entity(None, "concept", "A", None, None, vec![])
.await
.unwrap();
let b = rt
.create_entity(None, "concept", "B", None, None, vec![])
.await
.unwrap();
rt.link(None, a.id, b.id, EdgeRelation::Extends, 1.0)
.await
.unwrap();
let opts = TraversalOptions {
max_depth: 0,
..Default::default()
};
let nodes = rt.bfs_traverse(None, a.id, opts).await.unwrap();
assert_eq!(nodes.len(), 1);
assert_eq!(nodes[0].entity_id, a.id);
assert_eq!(nodes[0].depth, 0);
assert!(nodes[0].via_edge.is_none());
}
#[tokio::test]
async fn bfs_depth_one_returns_root_and_neighbors() {
let rt = rt().await;
let a = rt
.create_entity(None, "concept", "A", None, None, vec![])
.await
.unwrap();
let b = rt
.create_entity(None, "concept", "B", None, None, vec![])
.await
.unwrap();
let c = rt
.create_entity(None, "concept", "C", None, None, vec![])
.await
.unwrap();
rt.link(None, a.id, b.id, EdgeRelation::Extends, 1.0)
.await
.unwrap();
rt.link(None, a.id, c.id, EdgeRelation::Extends, 1.0)
.await
.unwrap();
let d = rt
.create_entity(None, "concept", "D", None, None, vec![])
.await
.unwrap();
rt.link(None, b.id, d.id, EdgeRelation::Extends, 1.0)
.await
.unwrap();
let opts = TraversalOptions {
max_depth: 1,
..Default::default()
};
let nodes = rt.bfs_traverse(None, a.id, opts).await.unwrap();
let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
assert!(ids.contains(&a.id));
assert!(ids.contains(&b.id));
assert!(ids.contains(&c.id));
assert!(!ids.contains(&d.id));
for node in &nodes {
if node.entity_id != a.id {
assert_eq!(node.depth, 1);
}
}
}
#[tokio::test]
async fn bfs_direction_out_only() {
let rt = rt().await;
let a = rt
.create_entity(None, "concept", "A", None, None, vec![])
.await
.unwrap();
let b = rt
.create_entity(None, "concept", "B", None, None, vec![])
.await
.unwrap();
rt.link(None, b.id, a.id, EdgeRelation::Extends, 1.0)
.await
.unwrap();
let opts = TraversalOptions {
max_depth: 2,
direction: Direction::Out,
..Default::default()
};
let nodes = rt.bfs_traverse(None, a.id, opts).await.unwrap();
assert_eq!(
nodes.len(),
1,
"only root should be returned when traversing Out with no outgoing edges"
);
}
#[tokio::test]
async fn bfs_direction_in_only() {
let rt = rt().await;
let a = rt
.create_entity(None, "concept", "A", None, None, vec![])
.await
.unwrap();
let b = rt
.create_entity(None, "concept", "B", None, None, vec![])
.await
.unwrap();
rt.link(None, b.id, a.id, EdgeRelation::Extends, 1.0)
.await
.unwrap();
let opts = TraversalOptions {
max_depth: 2,
direction: Direction::In,
..Default::default()
};
let nodes = rt.bfs_traverse(None, a.id, opts).await.unwrap();
let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
assert!(
ids.contains(&b.id),
"B should be reachable via incoming edge"
);
}
#[tokio::test]
async fn bfs_relation_filter() {
let rt = rt().await;
let a = rt
.create_entity(None, "concept", "A", None, None, vec![])
.await
.unwrap();
let b = rt
.create_entity(None, "concept", "B", None, None, vec![])
.await
.unwrap();
let c = rt
.create_entity(None, "concept", "C", None, None, vec![])
.await
.unwrap();
rt.link(None, a.id, b.id, EdgeRelation::Extends, 1.0)
.await
.unwrap();
rt.link(None, a.id, c.id, EdgeRelation::DependsOn, 1.0)
.await
.unwrap();
let opts = TraversalOptions {
max_depth: 2,
relations: Some(vec![EdgeRelation::Extends]),
..Default::default()
};
let nodes = rt.bfs_traverse(None, a.id, opts).await.unwrap();
let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
assert!(ids.contains(&b.id), "B reachable via 'extends'");
assert!(
!ids.contains(&c.id),
"C not reachable when filtering to 'extends'"
);
}
#[tokio::test]
async fn shortest_path_connected_nodes() {
let rt = rt().await;
let a = rt
.create_entity(None, "concept", "A", None, None, vec![])
.await
.unwrap();
let b = rt
.create_entity(None, "concept", "B", None, None, vec![])
.await
.unwrap();
let c = rt
.create_entity(None, "concept", "C", None, None, vec![])
.await
.unwrap();
rt.link(None, a.id, b.id, EdgeRelation::Extends, 1.0)
.await
.unwrap();
rt.link(None, b.id, c.id, EdgeRelation::Extends, 1.0)
.await
.unwrap();
let path = rt.shortest_path(None, a.id, c.id, 10).await.unwrap();
let path = path.expect("path should exist");
assert_eq!(path.len(), 3, "A -> B -> C = 3 nodes");
assert_eq!(path[0].entity_id, a.id);
assert_eq!(path[2].entity_id, c.id);
}
#[tokio::test]
async fn shortest_path_unreachable_returns_none() {
let rt = rt().await;
let a = rt
.create_entity(None, "concept", "A", None, None, vec![])
.await
.unwrap();
let b = rt
.create_entity(None, "concept", "B", None, None, vec![])
.await
.unwrap();
let path = rt.shortest_path(None, a.id, b.id, 5).await.unwrap();
assert!(path.is_none());
}
#[tokio::test]
async fn shortest_path_same_node() {
let rt = rt().await;
let a = rt
.create_entity(None, "concept", "A", None, None, vec![])
.await
.unwrap();
let path = rt.shortest_path(None, a.id, a.id, 5).await.unwrap();
let path = path.expect("trivial path should always exist");
assert_eq!(path.len(), 1);
assert_eq!(path[0].entity_id, a.id);
assert!(path[0].via_edge.is_none());
}
#[tokio::test]
async fn shortest_path_max_depth_zero_adjacent() {
let rt = rt().await;
let a = rt
.create_entity(None, "concept", "A", None, None, vec![])
.await
.unwrap();
let b = rt
.create_entity(None, "concept", "B", None, None, vec![])
.await
.unwrap();
rt.link(None, a.id, b.id, EdgeRelation::Extends, 1.0)
.await
.unwrap();
let path = rt.shortest_path(None, a.id, b.id, 0).await.unwrap();
assert!(
path.is_none(),
"1-hop path should not be returned at max_depth=0"
);
}
#[tokio::test]
async fn shortest_path_max_depth_one_two_hop_chain() {
let rt = rt().await;
let a = rt
.create_entity(None, "concept", "A", None, None, vec![])
.await
.unwrap();
let b = rt
.create_entity(None, "concept", "B", None, None, vec![])
.await
.unwrap();
let c = rt
.create_entity(None, "concept", "C", None, None, vec![])
.await
.unwrap();
rt.link(None, a.id, b.id, EdgeRelation::Extends, 1.0)
.await
.unwrap();
rt.link(None, b.id, c.id, EdgeRelation::Extends, 1.0)
.await
.unwrap();
let one_hop = rt.shortest_path(None, a.id, b.id, 1).await.unwrap();
assert!(
one_hop.is_some(),
"1-hop path should be found at max_depth=1"
);
let two_hop = rt.shortest_path(None, a.id, c.id, 1).await.unwrap();
assert!(
two_hop.is_none(),
"2-hop path should not be returned at max_depth=1"
);
}
}