Skip to main content

khive_runtime/
graph_traversal.rs

1// Copyright 2024-2025 khive contributors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet, VecDeque};
16
17use uuid::Uuid;
18
19use khive_storage::types::{Direction, Edge, LinkId, NeighborQuery};
20use khive_storage::EdgeRelation;
21
22use crate::error::{RuntimeError, RuntimeResult};
23use crate::runtime::{KhiveRuntime, NamespaceToken};
24
25/// A node in a traversal path.
26#[derive(Debug, Clone)]
27pub struct PathNode {
28    /// Entity at this position.
29    pub entity_id: Uuid,
30    /// Distance from the start node (0 = start node).
31    pub depth: usize,
32    /// Edge that led to this node (`None` for the start node).
33    pub via_edge: Option<Edge>,
34}
35
36/// Options for BFS traversal and shortest-path search.
37#[derive(Debug, Clone)]
38pub struct TraversalOptions {
39    /// Maximum hops to follow.
40    pub max_depth: usize,
41    /// Which edge directions to follow.
42    pub direction: Direction,
43    /// Restrict traversal to these relation types (`None` = all).
44    pub relations: Option<Vec<EdgeRelation>>,
45    /// Stop after collecting this many nodes (start node counts as one).
46    pub max_results: Option<usize>,
47}
48
49impl Default for TraversalOptions {
50    fn default() -> Self {
51        Self {
52            max_depth: 3,
53            direction: Direction::Out,
54            relations: None,
55            max_results: None,
56        }
57    }
58}
59
60impl KhiveRuntime {
61    /// BFS traversal from `start`, returning nodes in level order.
62    ///
63    /// The first element is always the start node (`via_edge = None`, `depth = 0`).
64    /// Nodes already visited are skipped so the result set is deduplicated.
65    pub async fn bfs_traverse(
66        &self,
67        token: &NamespaceToken,
68        start: Uuid,
69        options: TraversalOptions,
70    ) -> RuntimeResult<Vec<PathNode>> {
71        let graph = self.graph(token)?;
72        let limit = options.max_results.unwrap_or(usize::MAX);
73
74        let mut visited: HashSet<Uuid> = HashSet::new();
75        let mut results: Vec<PathNode> = Vec::new();
76        // queue: (node_id, current_depth)
77        let mut queue: VecDeque<(Uuid, usize)> = VecDeque::new();
78
79        visited.insert(start);
80        results.push(PathNode {
81            entity_id: start,
82            depth: 0,
83            via_edge: None,
84        });
85        queue.push_back((start, 0));
86
87        'bfs: while let Some((current, depth)) = queue.pop_front() {
88            if depth >= options.max_depth {
89                continue;
90            }
91
92            let query = NeighborQuery {
93                direction: options.direction.clone(),
94                relations: options.relations.clone(),
95                limit: None,
96                min_weight: None,
97            };
98            let hits = graph.neighbors(current, query).await?;
99
100            for hit in hits {
101                if visited.contains(&hit.node_id) {
102                    continue;
103                }
104
105                let edge = graph
106                    .get_edge(LinkId::from(hit.edge_id))
107                    .await?
108                    .ok_or_else(|| {
109                        RuntimeError::NotFound(format!("edge {} missing", hit.edge_id))
110                    })?;
111
112                visited.insert(hit.node_id);
113                results.push(PathNode {
114                    entity_id: hit.node_id,
115                    depth: depth + 1,
116                    via_edge: Some(edge),
117                });
118
119                if results.len() >= limit {
120                    break 'bfs;
121                }
122
123                queue.push_back((hit.node_id, depth + 1));
124            }
125        }
126
127        Ok(results)
128    }
129
130    /// Bidirectional BFS shortest path from `from` to `to`.
131    ///
132    /// Returns `Some(path)` where `path[0]` is `from` and `path.last()` is `to`,
133    /// or `None` if no path exists within `max_depth` hops.
134    /// For `from == to` returns `Some` with a single-node path immediately.
135    pub async fn shortest_path(
136        &self,
137        token: &NamespaceToken,
138        from: Uuid,
139        to: Uuid,
140        max_depth: usize,
141    ) -> RuntimeResult<Option<Vec<PathNode>>> {
142        if from == to {
143            return Ok(Some(vec![PathNode {
144                entity_id: from,
145                depth: 0,
146                via_edge: None,
147            }]));
148        }
149
150        let graph = self.graph(token)?;
151
152        // Forward map: node -> (depth, parent, edge_id that reached this node)
153        let mut fwd: HashMap<Uuid, (usize, Option<Uuid>, Option<Uuid>)> = HashMap::new();
154        let mut fwd_q: VecDeque<Uuid> = VecDeque::new();
155        fwd.insert(from, (0, None, None));
156        fwd_q.push_back(from);
157
158        // Backward map: node -> (depth, child, edge_id that reached this node from `to` side)
159        let mut bwd: HashMap<Uuid, (usize, Option<Uuid>, Option<Uuid>)> = HashMap::new();
160        let mut bwd_q: VecDeque<Uuid> = VecDeque::new();
161        bwd.insert(to, (0, None, None));
162        bwd_q.push_back(to);
163
164        let mut meeting: Option<(Uuid, usize)> = None;
165        let mut current_depth = 0usize;
166
167        while (!fwd_q.is_empty() || !bwd_q.is_empty()) && current_depth <= max_depth {
168            // Expand the forward frontier one level.
169            let fwd_level = fwd_q.len();
170            for _ in 0..fwd_level {
171                let Some(node) = fwd_q.pop_front() else { break };
172                let fwd_depth = fwd[&node].0;
173
174                let hits = graph
175                    .neighbors(
176                        node,
177                        NeighborQuery {
178                            direction: Direction::Out,
179                            relations: None,
180                            limit: None,
181                            min_weight: None,
182                        },
183                    )
184                    .await?;
185
186                for hit in hits {
187                    if fwd.contains_key(&hit.node_id) {
188                        continue;
189                    }
190                    let new_depth = fwd_depth + 1;
191                    fwd.insert(hit.node_id, (new_depth, Some(node), Some(hit.edge_id)));
192                    fwd_q.push_back(hit.node_id);
193
194                    if let Some(&(bwd_depth, _, _)) = bwd.get(&hit.node_id) {
195                        let total = new_depth + bwd_depth;
196                        if total <= max_depth
197                            && meeting.as_ref().is_none_or(|&(_, best)| total < best)
198                        {
199                            meeting = Some((hit.node_id, total));
200                        }
201                    }
202                }
203            }
204
205            if meeting.is_some() {
206                break;
207            }
208
209            // Expand the backward frontier one level (following incoming edges).
210            let bwd_level = bwd_q.len();
211            for _ in 0..bwd_level {
212                let Some(node) = bwd_q.pop_front() else { break };
213                let bwd_depth = bwd[&node].0;
214
215                let hits = graph
216                    .neighbors(
217                        node,
218                        NeighborQuery {
219                            direction: Direction::In,
220                            relations: None,
221                            limit: None,
222                            min_weight: None,
223                        },
224                    )
225                    .await?;
226
227                for hit in hits {
228                    if bwd.contains_key(&hit.node_id) {
229                        continue;
230                    }
231                    let new_depth = bwd_depth + 1;
232                    bwd.insert(hit.node_id, (new_depth, Some(node), Some(hit.edge_id)));
233                    bwd_q.push_back(hit.node_id);
234
235                    if let Some(&(fwd_depth, _, _)) = fwd.get(&hit.node_id) {
236                        let total = fwd_depth + new_depth;
237                        if total <= max_depth
238                            && meeting.as_ref().is_none_or(|&(_, best)| total < best)
239                        {
240                            meeting = Some((hit.node_id, total));
241                        }
242                    }
243                }
244            }
245
246            if meeting.is_some() {
247                break;
248            }
249
250            current_depth += 1;
251        }
252
253        let (mid, _) = match meeting {
254            None => return Ok(None),
255            Some(m) => m,
256        };
257
258        // Reconstruct path: walk fwd map back from mid to `from`, then walk bwd map forward to `to`.
259        let mut fwd_chain: Vec<(Uuid, Option<Uuid>)> = Vec::new();
260        {
261            let mut cur = mid;
262            loop {
263                let (_, parent, edge_id) = fwd[&cur];
264                fwd_chain.push((cur, edge_id));
265                match parent {
266                    Some(p) => cur = p,
267                    None => break,
268                }
269            }
270        }
271        fwd_chain.reverse();
272
273        let mut bwd_chain: Vec<(Uuid, Option<Uuid>)> = Vec::new();
274        {
275            let mut cur = mid;
276            // Walk toward `to` using the backward map's child pointers.
277            while let Some(&(_, Some(child), edge_id)) = bwd.get(&cur) {
278                bwd_chain.push((child, edge_id));
279                cur = child;
280            }
281        }
282
283        // Build PathNode slice — fetch edges lazily.
284        let mut path: Vec<PathNode> = Vec::new();
285        for (i, (node_id, edge_id)) in fwd_chain.iter().enumerate() {
286            let via_edge = if i == 0 {
287                None // start node
288            } else if let Some(eid) = edge_id {
289                graph.get_edge(LinkId::from(*eid)).await?.or(None)
290            } else {
291                None
292            };
293            path.push(PathNode {
294                entity_id: *node_id,
295                depth: i,
296                via_edge,
297            });
298        }
299
300        let base = path.len();
301        for (i, (node_id, edge_id)) in bwd_chain.iter().enumerate() {
302            let via_edge = if let Some(eid) = edge_id {
303                graph.get_edge(LinkId::from(*eid)).await?.or(None)
304            } else {
305                None
306            };
307            path.push(PathNode {
308                entity_id: *node_id,
309                depth: base + i,
310                via_edge,
311            });
312        }
313
314        Ok(Some(path))
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use crate::runtime::{KhiveRuntime, NamespaceToken};
322    use khive_storage::EdgeRelation;
323
324    async fn rt() -> KhiveRuntime {
325        KhiveRuntime::memory().expect("memory runtime")
326    }
327
328    #[tokio::test]
329    async fn bfs_max_depth_zero_returns_only_root() {
330        let rt = rt().await;
331        let tok = NamespaceToken::local();
332        let a = rt
333            .create_entity(&tok, "concept", None, "A", None, None, vec![])
334            .await
335            .unwrap();
336        let b = rt
337            .create_entity(&tok, "concept", None, "B", None, None, vec![])
338            .await
339            .unwrap();
340        rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
341            .await
342            .unwrap();
343
344        let opts = TraversalOptions {
345            max_depth: 0,
346            ..Default::default()
347        };
348        let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
349
350        assert_eq!(nodes.len(), 1);
351        assert_eq!(nodes[0].entity_id, a.id);
352        assert_eq!(nodes[0].depth, 0);
353        assert!(nodes[0].via_edge.is_none());
354    }
355
356    #[tokio::test]
357    async fn bfs_depth_one_returns_root_and_neighbors() {
358        let rt = rt().await;
359        let tok = NamespaceToken::local();
360        let a = rt
361            .create_entity(&tok, "concept", None, "A", None, None, vec![])
362            .await
363            .unwrap();
364        let b = rt
365            .create_entity(&tok, "concept", None, "B", None, None, vec![])
366            .await
367            .unwrap();
368        let c = rt
369            .create_entity(&tok, "concept", None, "C", None, None, vec![])
370            .await
371            .unwrap();
372        rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
373            .await
374            .unwrap();
375        rt.link(&tok, a.id, c.id, EdgeRelation::Extends, 1.0, None)
376            .await
377            .unwrap();
378        // Add a node two hops away — it must NOT appear.
379        let d = rt
380            .create_entity(&tok, "concept", None, "D", None, None, vec![])
381            .await
382            .unwrap();
383        rt.link(&tok, b.id, d.id, EdgeRelation::Extends, 1.0, None)
384            .await
385            .unwrap();
386
387        let opts = TraversalOptions {
388            max_depth: 1,
389            ..Default::default()
390        };
391        let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
392
393        let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
394        assert!(ids.contains(&a.id));
395        assert!(ids.contains(&b.id));
396        assert!(ids.contains(&c.id));
397        assert!(!ids.contains(&d.id));
398        // Every non-root node should be at depth 1.
399        for node in &nodes {
400            if node.entity_id != a.id {
401                assert_eq!(node.depth, 1);
402            }
403        }
404    }
405
406    #[tokio::test]
407    async fn bfs_direction_out_only() {
408        let rt = rt().await;
409        let tok = NamespaceToken::local();
410        let a = rt
411            .create_entity(&tok, "concept", None, "A", None, None, vec![])
412            .await
413            .unwrap();
414        let b = rt
415            .create_entity(&tok, "concept", None, "B", None, None, vec![])
416            .await
417            .unwrap();
418        // Edge goes B -> A; traversing Out from A should find nothing.
419        rt.link(&tok, b.id, a.id, EdgeRelation::Extends, 1.0, None)
420            .await
421            .unwrap();
422
423        let opts = TraversalOptions {
424            max_depth: 2,
425            direction: Direction::Out,
426            ..Default::default()
427        };
428        let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
429        assert_eq!(
430            nodes.len(),
431            1,
432            "only root should be returned when traversing Out with no outgoing edges"
433        );
434    }
435
436    #[tokio::test]
437    async fn bfs_direction_in_only() {
438        let rt = rt().await;
439        let tok = NamespaceToken::local();
440        let a = rt
441            .create_entity(&tok, "concept", None, "A", None, None, vec![])
442            .await
443            .unwrap();
444        let b = rt
445            .create_entity(&tok, "concept", None, "B", None, None, vec![])
446            .await
447            .unwrap();
448        // Edge goes B -> A; traversing In from A should find B.
449        rt.link(&tok, b.id, a.id, EdgeRelation::Extends, 1.0, None)
450            .await
451            .unwrap();
452
453        let opts = TraversalOptions {
454            max_depth: 2,
455            direction: Direction::In,
456            ..Default::default()
457        };
458        let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
459        let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
460        assert!(
461            ids.contains(&b.id),
462            "B should be reachable via incoming edge"
463        );
464    }
465
466    #[tokio::test]
467    async fn bfs_relation_filter() {
468        let rt = rt().await;
469        let tok = NamespaceToken::local();
470        let a = rt
471            .create_entity(&tok, "concept", None, "A", None, None, vec![])
472            .await
473            .unwrap();
474        let b = rt
475            .create_entity(&tok, "concept", None, "B", None, None, vec![])
476            .await
477            .unwrap();
478        let c = rt
479            .create_entity(&tok, "concept", None, "C", None, None, vec![])
480            .await
481            .unwrap();
482        rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
483            .await
484            .unwrap();
485        rt.link(&tok, a.id, c.id, EdgeRelation::Enables, 1.0, None)
486            .await
487            .unwrap();
488
489        let opts = TraversalOptions {
490            max_depth: 2,
491            relations: Some(vec![EdgeRelation::Extends]),
492            ..Default::default()
493        };
494        let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
495        let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
496        assert!(ids.contains(&b.id), "B reachable via 'extends'");
497        assert!(
498            !ids.contains(&c.id),
499            "C not reachable when filtering to 'extends'"
500        );
501    }
502
503    #[tokio::test]
504    async fn shortest_path_connected_nodes() {
505        let rt = rt().await;
506        let tok = NamespaceToken::local();
507        let a = rt
508            .create_entity(&tok, "concept", None, "A", None, None, vec![])
509            .await
510            .unwrap();
511        let b = rt
512            .create_entity(&tok, "concept", None, "B", None, None, vec![])
513            .await
514            .unwrap();
515        let c = rt
516            .create_entity(&tok, "concept", None, "C", None, None, vec![])
517            .await
518            .unwrap();
519        rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
520            .await
521            .unwrap();
522        rt.link(&tok, b.id, c.id, EdgeRelation::Extends, 1.0, None)
523            .await
524            .unwrap();
525
526        let path = rt.shortest_path(&tok, a.id, c.id, 10).await.unwrap();
527        let path = path.expect("path should exist");
528        assert_eq!(path.len(), 3, "A -> B -> C = 3 nodes");
529        assert_eq!(path[0].entity_id, a.id);
530        assert_eq!(path[2].entity_id, c.id);
531    }
532
533    #[tokio::test]
534    async fn shortest_path_unreachable_returns_none() {
535        let rt = rt().await;
536        let tok = NamespaceToken::local();
537        let a = rt
538            .create_entity(&tok, "concept", None, "A", None, None, vec![])
539            .await
540            .unwrap();
541        let b = rt
542            .create_entity(&tok, "concept", None, "B", None, None, vec![])
543            .await
544            .unwrap();
545        // No edges between them.
546
547        let path = rt.shortest_path(&tok, a.id, b.id, 5).await.unwrap();
548        assert!(path.is_none());
549    }
550
551    #[tokio::test]
552    async fn shortest_path_same_node() {
553        let rt = rt().await;
554        let tok = NamespaceToken::local();
555        let a = rt
556            .create_entity(&tok, "concept", None, "A", None, None, vec![])
557            .await
558            .unwrap();
559
560        let path = rt.shortest_path(&tok, a.id, a.id, 5).await.unwrap();
561        let path = path.expect("trivial path should always exist");
562        assert_eq!(path.len(), 1);
563        assert_eq!(path[0].entity_id, a.id);
564        assert!(path[0].via_edge.is_none());
565    }
566
567    #[tokio::test]
568    async fn shortest_path_max_depth_zero_adjacent() {
569        let rt = rt().await;
570        let tok = NamespaceToken::local();
571        let a = rt
572            .create_entity(&tok, "concept", None, "A", None, None, vec![])
573            .await
574            .unwrap();
575        let b = rt
576            .create_entity(&tok, "concept", None, "B", None, None, vec![])
577            .await
578            .unwrap();
579        rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
580            .await
581            .unwrap();
582
583        // max_depth=0 means only the trivial from==to case succeeds.
584        let path = rt.shortest_path(&tok, a.id, b.id, 0).await.unwrap();
585        assert!(
586            path.is_none(),
587            "1-hop path should not be returned at max_depth=0"
588        );
589    }
590
591    #[tokio::test]
592    async fn shortest_path_max_depth_one_two_hop_chain() {
593        let rt = rt().await;
594        let tok = NamespaceToken::local();
595        let a = rt
596            .create_entity(&tok, "concept", None, "A", None, None, vec![])
597            .await
598            .unwrap();
599        let b = rt
600            .create_entity(&tok, "concept", None, "B", None, None, vec![])
601            .await
602            .unwrap();
603        let c = rt
604            .create_entity(&tok, "concept", None, "C", None, None, vec![])
605            .await
606            .unwrap();
607        rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
608            .await
609            .unwrap();
610        rt.link(&tok, b.id, c.id, EdgeRelation::Extends, 1.0, None)
611            .await
612            .unwrap();
613
614        // max_depth=1 should find A->B but not A->B->C.
615        let one_hop = rt.shortest_path(&tok, a.id, b.id, 1).await.unwrap();
616        assert!(
617            one_hop.is_some(),
618            "1-hop path should be found at max_depth=1"
619        );
620
621        let two_hop = rt.shortest_path(&tok, a.id, c.id, 1).await.unwrap();
622        assert!(
623            two_hop.is_none(),
624            "2-hop path should not be returned at max_depth=1"
625        );
626    }
627}