1use std::collections::{HashMap, HashSet, VecDeque};
16
17use uuid::Uuid;
18
19use khive_storage::types::{Direction, Edge, LinkId, NeighborQuery};
20use khive_storage::EdgeRelation;
21
22use crate::error::{RuntimeError, RuntimeResult};
23use crate::runtime::{KhiveRuntime, NamespaceToken};
24
25#[derive(Debug, Clone)]
27pub struct PathNode {
28 pub entity_id: Uuid,
30 pub depth: usize,
32 pub via_edge: Option<Edge>,
34}
35
36#[derive(Debug, Clone)]
38pub struct TraversalOptions {
39 pub max_depth: usize,
41 pub direction: Direction,
43 pub relations: Option<Vec<EdgeRelation>>,
45 pub max_results: Option<usize>,
47}
48
49impl Default for TraversalOptions {
50 fn default() -> Self {
51 Self {
52 max_depth: 3,
53 direction: Direction::Out,
54 relations: None,
55 max_results: None,
56 }
57 }
58}
59
60impl KhiveRuntime {
61 pub async fn bfs_traverse(
66 &self,
67 token: &NamespaceToken,
68 start: Uuid,
69 options: TraversalOptions,
70 ) -> RuntimeResult<Vec<PathNode>> {
71 let graph = self.graph(token)?;
72 let limit = options.max_results.unwrap_or(usize::MAX);
73
74 let mut visited: HashSet<Uuid> = HashSet::new();
75 let mut results: Vec<PathNode> = Vec::new();
76 let mut queue: VecDeque<(Uuid, usize)> = VecDeque::new();
78
79 visited.insert(start);
80 results.push(PathNode {
81 entity_id: start,
82 depth: 0,
83 via_edge: None,
84 });
85 queue.push_back((start, 0));
86
87 'bfs: while let Some((current, depth)) = queue.pop_front() {
88 if depth >= options.max_depth {
89 continue;
90 }
91
92 let query = NeighborQuery {
93 direction: options.direction.clone(),
94 relations: options.relations.clone(),
95 limit: None,
96 min_weight: None,
97 };
98 let hits = graph.neighbors(current, query).await?;
99
100 for hit in hits {
101 if visited.contains(&hit.node_id) {
102 continue;
103 }
104
105 let edge = graph
106 .get_edge(LinkId::from(hit.edge_id))
107 .await?
108 .ok_or_else(|| {
109 RuntimeError::NotFound(format!("edge {} missing", hit.edge_id))
110 })?;
111
112 visited.insert(hit.node_id);
113 results.push(PathNode {
114 entity_id: hit.node_id,
115 depth: depth + 1,
116 via_edge: Some(edge),
117 });
118
119 if results.len() >= limit {
120 break 'bfs;
121 }
122
123 queue.push_back((hit.node_id, depth + 1));
124 }
125 }
126
127 Ok(results)
128 }
129
130 pub async fn shortest_path(
136 &self,
137 token: &NamespaceToken,
138 from: Uuid,
139 to: Uuid,
140 max_depth: usize,
141 ) -> RuntimeResult<Option<Vec<PathNode>>> {
142 if from == to {
143 return Ok(Some(vec![PathNode {
144 entity_id: from,
145 depth: 0,
146 via_edge: None,
147 }]));
148 }
149
150 let graph = self.graph(token)?;
151
152 let mut fwd: HashMap<Uuid, (usize, Option<Uuid>, Option<Uuid>)> = HashMap::new();
154 let mut fwd_q: VecDeque<Uuid> = VecDeque::new();
155 fwd.insert(from, (0, None, None));
156 fwd_q.push_back(from);
157
158 let mut bwd: HashMap<Uuid, (usize, Option<Uuid>, Option<Uuid>)> = HashMap::new();
160 let mut bwd_q: VecDeque<Uuid> = VecDeque::new();
161 bwd.insert(to, (0, None, None));
162 bwd_q.push_back(to);
163
164 let mut meeting: Option<(Uuid, usize)> = None;
165 let mut current_depth = 0usize;
166
167 while (!fwd_q.is_empty() || !bwd_q.is_empty()) && current_depth <= max_depth {
168 let fwd_level = fwd_q.len();
170 for _ in 0..fwd_level {
171 let Some(node) = fwd_q.pop_front() else { break };
172 let fwd_depth = fwd[&node].0;
173
174 let hits = graph
175 .neighbors(
176 node,
177 NeighborQuery {
178 direction: Direction::Out,
179 relations: None,
180 limit: None,
181 min_weight: None,
182 },
183 )
184 .await?;
185
186 for hit in hits {
187 if fwd.contains_key(&hit.node_id) {
188 continue;
189 }
190 let new_depth = fwd_depth + 1;
191 fwd.insert(hit.node_id, (new_depth, Some(node), Some(hit.edge_id)));
192 fwd_q.push_back(hit.node_id);
193
194 if let Some(&(bwd_depth, _, _)) = bwd.get(&hit.node_id) {
195 let total = new_depth + bwd_depth;
196 if total <= max_depth
197 && meeting.as_ref().is_none_or(|&(_, best)| total < best)
198 {
199 meeting = Some((hit.node_id, total));
200 }
201 }
202 }
203 }
204
205 if meeting.is_some() {
206 break;
207 }
208
209 let bwd_level = bwd_q.len();
211 for _ in 0..bwd_level {
212 let Some(node) = bwd_q.pop_front() else { break };
213 let bwd_depth = bwd[&node].0;
214
215 let hits = graph
216 .neighbors(
217 node,
218 NeighborQuery {
219 direction: Direction::In,
220 relations: None,
221 limit: None,
222 min_weight: None,
223 },
224 )
225 .await?;
226
227 for hit in hits {
228 if bwd.contains_key(&hit.node_id) {
229 continue;
230 }
231 let new_depth = bwd_depth + 1;
232 bwd.insert(hit.node_id, (new_depth, Some(node), Some(hit.edge_id)));
233 bwd_q.push_back(hit.node_id);
234
235 if let Some(&(fwd_depth, _, _)) = fwd.get(&hit.node_id) {
236 let total = fwd_depth + new_depth;
237 if total <= max_depth
238 && meeting.as_ref().is_none_or(|&(_, best)| total < best)
239 {
240 meeting = Some((hit.node_id, total));
241 }
242 }
243 }
244 }
245
246 if meeting.is_some() {
247 break;
248 }
249
250 current_depth += 1;
251 }
252
253 let (mid, _) = match meeting {
254 None => return Ok(None),
255 Some(m) => m,
256 };
257
258 let mut fwd_chain: Vec<(Uuid, Option<Uuid>)> = Vec::new();
260 {
261 let mut cur = mid;
262 loop {
263 let (_, parent, edge_id) = fwd[&cur];
264 fwd_chain.push((cur, edge_id));
265 match parent {
266 Some(p) => cur = p,
267 None => break,
268 }
269 }
270 }
271 fwd_chain.reverse();
272
273 let mut bwd_chain: Vec<(Uuid, Option<Uuid>)> = Vec::new();
274 {
275 let mut cur = mid;
276 while let Some(&(_, Some(child), edge_id)) = bwd.get(&cur) {
278 bwd_chain.push((child, edge_id));
279 cur = child;
280 }
281 }
282
283 let mut path: Vec<PathNode> = Vec::new();
285 for (i, (node_id, edge_id)) in fwd_chain.iter().enumerate() {
286 let via_edge = if i == 0 {
287 None } else if let Some(eid) = edge_id {
289 graph.get_edge(LinkId::from(*eid)).await?.or(None)
290 } else {
291 None
292 };
293 path.push(PathNode {
294 entity_id: *node_id,
295 depth: i,
296 via_edge,
297 });
298 }
299
300 let base = path.len();
301 for (i, (node_id, edge_id)) in bwd_chain.iter().enumerate() {
302 let via_edge = if let Some(eid) = edge_id {
303 graph.get_edge(LinkId::from(*eid)).await?.or(None)
304 } else {
305 None
306 };
307 path.push(PathNode {
308 entity_id: *node_id,
309 depth: base + i,
310 via_edge,
311 });
312 }
313
314 Ok(Some(path))
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use crate::runtime::{KhiveRuntime, NamespaceToken};
322 use khive_storage::EdgeRelation;
323
324 async fn rt() -> KhiveRuntime {
325 KhiveRuntime::memory().expect("memory runtime")
326 }
327
328 #[tokio::test]
329 async fn bfs_max_depth_zero_returns_only_root() {
330 let rt = rt().await;
331 let tok = NamespaceToken::local();
332 let a = rt
333 .create_entity(&tok, "concept", None, "A", None, None, vec![])
334 .await
335 .unwrap();
336 let b = rt
337 .create_entity(&tok, "concept", None, "B", None, None, vec![])
338 .await
339 .unwrap();
340 rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
341 .await
342 .unwrap();
343
344 let opts = TraversalOptions {
345 max_depth: 0,
346 ..Default::default()
347 };
348 let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
349
350 assert_eq!(nodes.len(), 1);
351 assert_eq!(nodes[0].entity_id, a.id);
352 assert_eq!(nodes[0].depth, 0);
353 assert!(nodes[0].via_edge.is_none());
354 }
355
356 #[tokio::test]
357 async fn bfs_depth_one_returns_root_and_neighbors() {
358 let rt = rt().await;
359 let tok = NamespaceToken::local();
360 let a = rt
361 .create_entity(&tok, "concept", None, "A", None, None, vec![])
362 .await
363 .unwrap();
364 let b = rt
365 .create_entity(&tok, "concept", None, "B", None, None, vec![])
366 .await
367 .unwrap();
368 let c = rt
369 .create_entity(&tok, "concept", None, "C", None, None, vec![])
370 .await
371 .unwrap();
372 rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
373 .await
374 .unwrap();
375 rt.link(&tok, a.id, c.id, EdgeRelation::Extends, 1.0, None)
376 .await
377 .unwrap();
378 let d = rt
380 .create_entity(&tok, "concept", None, "D", None, None, vec![])
381 .await
382 .unwrap();
383 rt.link(&tok, b.id, d.id, EdgeRelation::Extends, 1.0, None)
384 .await
385 .unwrap();
386
387 let opts = TraversalOptions {
388 max_depth: 1,
389 ..Default::default()
390 };
391 let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
392
393 let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
394 assert!(ids.contains(&a.id));
395 assert!(ids.contains(&b.id));
396 assert!(ids.contains(&c.id));
397 assert!(!ids.contains(&d.id));
398 for node in &nodes {
400 if node.entity_id != a.id {
401 assert_eq!(node.depth, 1);
402 }
403 }
404 }
405
406 #[tokio::test]
407 async fn bfs_direction_out_only() {
408 let rt = rt().await;
409 let tok = NamespaceToken::local();
410 let a = rt
411 .create_entity(&tok, "concept", None, "A", None, None, vec![])
412 .await
413 .unwrap();
414 let b = rt
415 .create_entity(&tok, "concept", None, "B", None, None, vec![])
416 .await
417 .unwrap();
418 rt.link(&tok, b.id, a.id, EdgeRelation::Extends, 1.0, None)
420 .await
421 .unwrap();
422
423 let opts = TraversalOptions {
424 max_depth: 2,
425 direction: Direction::Out,
426 ..Default::default()
427 };
428 let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
429 assert_eq!(
430 nodes.len(),
431 1,
432 "only root should be returned when traversing Out with no outgoing edges"
433 );
434 }
435
436 #[tokio::test]
437 async fn bfs_direction_in_only() {
438 let rt = rt().await;
439 let tok = NamespaceToken::local();
440 let a = rt
441 .create_entity(&tok, "concept", None, "A", None, None, vec![])
442 .await
443 .unwrap();
444 let b = rt
445 .create_entity(&tok, "concept", None, "B", None, None, vec![])
446 .await
447 .unwrap();
448 rt.link(&tok, b.id, a.id, EdgeRelation::Extends, 1.0, None)
450 .await
451 .unwrap();
452
453 let opts = TraversalOptions {
454 max_depth: 2,
455 direction: Direction::In,
456 ..Default::default()
457 };
458 let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
459 let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
460 assert!(
461 ids.contains(&b.id),
462 "B should be reachable via incoming edge"
463 );
464 }
465
466 #[tokio::test]
467 async fn bfs_relation_filter() {
468 let rt = rt().await;
469 let tok = NamespaceToken::local();
470 let a = rt
471 .create_entity(&tok, "concept", None, "A", None, None, vec![])
472 .await
473 .unwrap();
474 let b = rt
475 .create_entity(&tok, "concept", None, "B", None, None, vec![])
476 .await
477 .unwrap();
478 let c = rt
479 .create_entity(&tok, "concept", None, "C", None, None, vec![])
480 .await
481 .unwrap();
482 rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
483 .await
484 .unwrap();
485 rt.link(&tok, a.id, c.id, EdgeRelation::Enables, 1.0, None)
486 .await
487 .unwrap();
488
489 let opts = TraversalOptions {
490 max_depth: 2,
491 relations: Some(vec![EdgeRelation::Extends]),
492 ..Default::default()
493 };
494 let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
495 let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
496 assert!(ids.contains(&b.id), "B reachable via 'extends'");
497 assert!(
498 !ids.contains(&c.id),
499 "C not reachable when filtering to 'extends'"
500 );
501 }
502
503 #[tokio::test]
504 async fn shortest_path_connected_nodes() {
505 let rt = rt().await;
506 let tok = NamespaceToken::local();
507 let a = rt
508 .create_entity(&tok, "concept", None, "A", None, None, vec![])
509 .await
510 .unwrap();
511 let b = rt
512 .create_entity(&tok, "concept", None, "B", None, None, vec![])
513 .await
514 .unwrap();
515 let c = rt
516 .create_entity(&tok, "concept", None, "C", None, None, vec![])
517 .await
518 .unwrap();
519 rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
520 .await
521 .unwrap();
522 rt.link(&tok, b.id, c.id, EdgeRelation::Extends, 1.0, None)
523 .await
524 .unwrap();
525
526 let path = rt.shortest_path(&tok, a.id, c.id, 10).await.unwrap();
527 let path = path.expect("path should exist");
528 assert_eq!(path.len(), 3, "A -> B -> C = 3 nodes");
529 assert_eq!(path[0].entity_id, a.id);
530 assert_eq!(path[2].entity_id, c.id);
531 }
532
533 #[tokio::test]
534 async fn shortest_path_unreachable_returns_none() {
535 let rt = rt().await;
536 let tok = NamespaceToken::local();
537 let a = rt
538 .create_entity(&tok, "concept", None, "A", None, None, vec![])
539 .await
540 .unwrap();
541 let b = rt
542 .create_entity(&tok, "concept", None, "B", None, None, vec![])
543 .await
544 .unwrap();
545 let path = rt.shortest_path(&tok, a.id, b.id, 5).await.unwrap();
548 assert!(path.is_none());
549 }
550
551 #[tokio::test]
552 async fn shortest_path_same_node() {
553 let rt = rt().await;
554 let tok = NamespaceToken::local();
555 let a = rt
556 .create_entity(&tok, "concept", None, "A", None, None, vec![])
557 .await
558 .unwrap();
559
560 let path = rt.shortest_path(&tok, a.id, a.id, 5).await.unwrap();
561 let path = path.expect("trivial path should always exist");
562 assert_eq!(path.len(), 1);
563 assert_eq!(path[0].entity_id, a.id);
564 assert!(path[0].via_edge.is_none());
565 }
566
567 #[tokio::test]
568 async fn shortest_path_max_depth_zero_adjacent() {
569 let rt = rt().await;
570 let tok = NamespaceToken::local();
571 let a = rt
572 .create_entity(&tok, "concept", None, "A", None, None, vec![])
573 .await
574 .unwrap();
575 let b = rt
576 .create_entity(&tok, "concept", None, "B", None, None, vec![])
577 .await
578 .unwrap();
579 rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
580 .await
581 .unwrap();
582
583 let path = rt.shortest_path(&tok, a.id, b.id, 0).await.unwrap();
585 assert!(
586 path.is_none(),
587 "1-hop path should not be returned at max_depth=0"
588 );
589 }
590
591 #[tokio::test]
592 async fn shortest_path_max_depth_one_two_hop_chain() {
593 let rt = rt().await;
594 let tok = NamespaceToken::local();
595 let a = rt
596 .create_entity(&tok, "concept", None, "A", None, None, vec![])
597 .await
598 .unwrap();
599 let b = rt
600 .create_entity(&tok, "concept", None, "B", None, None, vec![])
601 .await
602 .unwrap();
603 let c = rt
604 .create_entity(&tok, "concept", None, "C", None, None, vec![])
605 .await
606 .unwrap();
607 rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
608 .await
609 .unwrap();
610 rt.link(&tok, b.id, c.id, EdgeRelation::Extends, 1.0, None)
611 .await
612 .unwrap();
613
614 let one_hop = rt.shortest_path(&tok, a.id, b.id, 1).await.unwrap();
616 assert!(
617 one_hop.is_some(),
618 "1-hop path should be found at max_depth=1"
619 );
620
621 let two_hop = rt.shortest_path(&tok, a.id, c.id, 1).await.unwrap();
622 assert!(
623 two_hop.is_none(),
624 "2-hop path should not be returned at max_depth=1"
625 );
626 }
627}