Skip to main content

khive_runtime/
graph_traversal.rs

1// Copyright 2024-2025 khive contributors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet, VecDeque};
16
17use uuid::Uuid;
18
19use khive_storage::types::{Direction, Edge, LinkId, NeighborQuery};
20use khive_storage::EdgeRelation;
21
22use crate::error::{RuntimeError, RuntimeResult};
23use crate::runtime::KhiveRuntime;
24
25/// A node in a traversal path.
26#[derive(Debug, Clone)]
27pub struct PathNode {
28    /// Entity at this position.
29    pub entity_id: Uuid,
30    /// Distance from the start node (0 = start node).
31    pub depth: usize,
32    /// Edge that led to this node (`None` for the start node).
33    pub via_edge: Option<Edge>,
34}
35
36/// Options for BFS traversal and shortest-path search.
37#[derive(Debug, Clone)]
38pub struct TraversalOptions {
39    /// Maximum hops to follow.
40    pub max_depth: usize,
41    /// Which edge directions to follow.
42    pub direction: Direction,
43    /// Restrict traversal to these relation types (`None` = all).
44    pub relations: Option<Vec<EdgeRelation>>,
45    /// Stop after collecting this many nodes (start node counts as one).
46    pub max_results: Option<usize>,
47}
48
49impl Default for TraversalOptions {
50    fn default() -> Self {
51        Self {
52            max_depth: 3,
53            direction: Direction::Out,
54            relations: None,
55            max_results: None,
56        }
57    }
58}
59
60impl KhiveRuntime {
61    /// BFS traversal from `start`, returning nodes in level order.
62    ///
63    /// The first element is always the start node (`via_edge = None`, `depth = 0`).
64    /// Nodes already visited are skipped so the result set is deduplicated.
65    pub async fn bfs_traverse(
66        &self,
67        namespace: Option<&str>,
68        start: Uuid,
69        options: TraversalOptions,
70    ) -> RuntimeResult<Vec<PathNode>> {
71        let graph = self.graph(namespace)?;
72        let limit = options.max_results.unwrap_or(usize::MAX);
73
74        let mut visited: HashSet<Uuid> = HashSet::new();
75        let mut results: Vec<PathNode> = Vec::new();
76        // queue: (node_id, current_depth)
77        let mut queue: VecDeque<(Uuid, usize)> = VecDeque::new();
78
79        visited.insert(start);
80        results.push(PathNode {
81            entity_id: start,
82            depth: 0,
83            via_edge: None,
84        });
85        queue.push_back((start, 0));
86
87        'bfs: while let Some((current, depth)) = queue.pop_front() {
88            if depth >= options.max_depth {
89                continue;
90            }
91
92            let query = NeighborQuery {
93                direction: options.direction.clone(),
94                relations: options.relations.clone(),
95                limit: None,
96                min_weight: None,
97            };
98            let hits = graph.neighbors(current, query).await?;
99
100            for hit in hits {
101                if visited.contains(&hit.node_id) {
102                    continue;
103                }
104
105                let edge = graph
106                    .get_edge(LinkId::from(hit.edge_id))
107                    .await?
108                    .ok_or_else(|| {
109                        RuntimeError::NotFound(format!("edge {} missing", hit.edge_id))
110                    })?;
111
112                visited.insert(hit.node_id);
113                results.push(PathNode {
114                    entity_id: hit.node_id,
115                    depth: depth + 1,
116                    via_edge: Some(edge),
117                });
118
119                if results.len() >= limit {
120                    break 'bfs;
121                }
122
123                queue.push_back((hit.node_id, depth + 1));
124            }
125        }
126
127        Ok(results)
128    }
129
130    /// Bidirectional BFS shortest path from `from` to `to`.
131    ///
132    /// Returns `Some(path)` where `path[0]` is `from` and `path.last()` is `to`,
133    /// or `None` if no path exists within `max_depth` hops.
134    /// For `from == to` returns `Some` with a single-node path immediately.
135    pub async fn shortest_path(
136        &self,
137        namespace: Option<&str>,
138        from: Uuid,
139        to: Uuid,
140        max_depth: usize,
141    ) -> RuntimeResult<Option<Vec<PathNode>>> {
142        if from == to {
143            return Ok(Some(vec![PathNode {
144                entity_id: from,
145                depth: 0,
146                via_edge: None,
147            }]));
148        }
149
150        let graph = self.graph(namespace)?;
151
152        // Forward map: node -> (depth, parent, edge_id that reached this node)
153        let mut fwd: HashMap<Uuid, (usize, Option<Uuid>, Option<Uuid>)> = HashMap::new();
154        let mut fwd_q: VecDeque<Uuid> = VecDeque::new();
155        fwd.insert(from, (0, None, None));
156        fwd_q.push_back(from);
157
158        // Backward map: node -> (depth, child, edge_id that reached this node from `to` side)
159        let mut bwd: HashMap<Uuid, (usize, Option<Uuid>, Option<Uuid>)> = HashMap::new();
160        let mut bwd_q: VecDeque<Uuid> = VecDeque::new();
161        bwd.insert(to, (0, None, None));
162        bwd_q.push_back(to);
163
164        let mut meeting: Option<(Uuid, usize)> = None;
165        let mut current_depth = 0usize;
166
167        while (!fwd_q.is_empty() || !bwd_q.is_empty()) && current_depth <= max_depth {
168            // Expand the forward frontier one level.
169            let fwd_level = fwd_q.len();
170            for _ in 0..fwd_level {
171                let Some(node) = fwd_q.pop_front() else { break };
172                let fwd_depth = fwd[&node].0;
173
174                let hits = graph
175                    .neighbors(
176                        node,
177                        NeighborQuery {
178                            direction: Direction::Out,
179                            relations: None,
180                            limit: None,
181                            min_weight: None,
182                        },
183                    )
184                    .await?;
185
186                for hit in hits {
187                    if fwd.contains_key(&hit.node_id) {
188                        continue;
189                    }
190                    let new_depth = fwd_depth + 1;
191                    fwd.insert(hit.node_id, (new_depth, Some(node), Some(hit.edge_id)));
192                    fwd_q.push_back(hit.node_id);
193
194                    if let Some(&(bwd_depth, _, _)) = bwd.get(&hit.node_id) {
195                        let total = new_depth + bwd_depth;
196                        if total <= max_depth
197                            && meeting.as_ref().is_none_or(|&(_, best)| total < best)
198                        {
199                            meeting = Some((hit.node_id, total));
200                        }
201                    }
202                }
203            }
204
205            if meeting.is_some() {
206                break;
207            }
208
209            // Expand the backward frontier one level (following incoming edges).
210            let bwd_level = bwd_q.len();
211            for _ in 0..bwd_level {
212                let Some(node) = bwd_q.pop_front() else { break };
213                let bwd_depth = bwd[&node].0;
214
215                let hits = graph
216                    .neighbors(
217                        node,
218                        NeighborQuery {
219                            direction: Direction::In,
220                            relations: None,
221                            limit: None,
222                            min_weight: None,
223                        },
224                    )
225                    .await?;
226
227                for hit in hits {
228                    if bwd.contains_key(&hit.node_id) {
229                        continue;
230                    }
231                    let new_depth = bwd_depth + 1;
232                    bwd.insert(hit.node_id, (new_depth, Some(node), Some(hit.edge_id)));
233                    bwd_q.push_back(hit.node_id);
234
235                    if let Some(&(fwd_depth, _, _)) = fwd.get(&hit.node_id) {
236                        let total = fwd_depth + new_depth;
237                        if total <= max_depth
238                            && meeting.as_ref().is_none_or(|&(_, best)| total < best)
239                        {
240                            meeting = Some((hit.node_id, total));
241                        }
242                    }
243                }
244            }
245
246            if meeting.is_some() {
247                break;
248            }
249
250            current_depth += 1;
251        }
252
253        let (mid, _) = match meeting {
254            None => return Ok(None),
255            Some(m) => m,
256        };
257
258        // Reconstruct path: walk fwd map back from mid to `from`, then walk bwd map forward to `to`.
259        let mut fwd_chain: Vec<(Uuid, Option<Uuid>)> = Vec::new();
260        {
261            let mut cur = mid;
262            loop {
263                let (_, parent, edge_id) = fwd[&cur];
264                fwd_chain.push((cur, edge_id));
265                match parent {
266                    Some(p) => cur = p,
267                    None => break,
268                }
269            }
270        }
271        fwd_chain.reverse();
272
273        let mut bwd_chain: Vec<(Uuid, Option<Uuid>)> = Vec::new();
274        {
275            let mut cur = mid;
276            // Walk toward `to` using the backward map's child pointers.
277            while let Some(&(_, Some(child), edge_id)) = bwd.get(&cur) {
278                bwd_chain.push((child, edge_id));
279                cur = child;
280            }
281        }
282
283        // Build PathNode slice — fetch edges lazily.
284        let mut path: Vec<PathNode> = Vec::new();
285        for (i, (node_id, edge_id)) in fwd_chain.iter().enumerate() {
286            let via_edge = if i == 0 {
287                None // start node
288            } else if let Some(eid) = edge_id {
289                graph.get_edge(LinkId::from(*eid)).await?.or(None)
290            } else {
291                None
292            };
293            path.push(PathNode {
294                entity_id: *node_id,
295                depth: i,
296                via_edge,
297            });
298        }
299
300        let base = path.len();
301        for (i, (node_id, edge_id)) in bwd_chain.iter().enumerate() {
302            let via_edge = if let Some(eid) = edge_id {
303                graph.get_edge(LinkId::from(*eid)).await?.or(None)
304            } else {
305                None
306            };
307            path.push(PathNode {
308                entity_id: *node_id,
309                depth: base + i,
310                via_edge,
311            });
312        }
313
314        Ok(Some(path))
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use crate::runtime::KhiveRuntime;
322    use khive_storage::EdgeRelation;
323
324    async fn rt() -> KhiveRuntime {
325        KhiveRuntime::memory().expect("memory runtime")
326    }
327
328    #[tokio::test]
329    async fn bfs_max_depth_zero_returns_only_root() {
330        let rt = rt().await;
331        let a = rt
332            .create_entity(None, "concept", "A", None, None, vec![])
333            .await
334            .unwrap();
335        let b = rt
336            .create_entity(None, "concept", "B", None, None, vec![])
337            .await
338            .unwrap();
339        rt.link(None, a.id, b.id, EdgeRelation::Extends, 1.0)
340            .await
341            .unwrap();
342
343        let opts = TraversalOptions {
344            max_depth: 0,
345            ..Default::default()
346        };
347        let nodes = rt.bfs_traverse(None, a.id, opts).await.unwrap();
348
349        assert_eq!(nodes.len(), 1);
350        assert_eq!(nodes[0].entity_id, a.id);
351        assert_eq!(nodes[0].depth, 0);
352        assert!(nodes[0].via_edge.is_none());
353    }
354
355    #[tokio::test]
356    async fn bfs_depth_one_returns_root_and_neighbors() {
357        let rt = rt().await;
358        let a = rt
359            .create_entity(None, "concept", "A", None, None, vec![])
360            .await
361            .unwrap();
362        let b = rt
363            .create_entity(None, "concept", "B", None, None, vec![])
364            .await
365            .unwrap();
366        let c = rt
367            .create_entity(None, "concept", "C", None, None, vec![])
368            .await
369            .unwrap();
370        rt.link(None, a.id, b.id, EdgeRelation::Extends, 1.0)
371            .await
372            .unwrap();
373        rt.link(None, a.id, c.id, EdgeRelation::Extends, 1.0)
374            .await
375            .unwrap();
376        // Add a node two hops away — it must NOT appear.
377        let d = rt
378            .create_entity(None, "concept", "D", None, None, vec![])
379            .await
380            .unwrap();
381        rt.link(None, b.id, d.id, EdgeRelation::Extends, 1.0)
382            .await
383            .unwrap();
384
385        let opts = TraversalOptions {
386            max_depth: 1,
387            ..Default::default()
388        };
389        let nodes = rt.bfs_traverse(None, a.id, opts).await.unwrap();
390
391        let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
392        assert!(ids.contains(&a.id));
393        assert!(ids.contains(&b.id));
394        assert!(ids.contains(&c.id));
395        assert!(!ids.contains(&d.id));
396        // Every non-root node should be at depth 1.
397        for node in &nodes {
398            if node.entity_id != a.id {
399                assert_eq!(node.depth, 1);
400            }
401        }
402    }
403
404    #[tokio::test]
405    async fn bfs_direction_out_only() {
406        let rt = rt().await;
407        let a = rt
408            .create_entity(None, "concept", "A", None, None, vec![])
409            .await
410            .unwrap();
411        let b = rt
412            .create_entity(None, "concept", "B", None, None, vec![])
413            .await
414            .unwrap();
415        // Edge goes B -> A; traversing Out from A should find nothing.
416        rt.link(None, b.id, a.id, EdgeRelation::Extends, 1.0)
417            .await
418            .unwrap();
419
420        let opts = TraversalOptions {
421            max_depth: 2,
422            direction: Direction::Out,
423            ..Default::default()
424        };
425        let nodes = rt.bfs_traverse(None, a.id, opts).await.unwrap();
426        assert_eq!(
427            nodes.len(),
428            1,
429            "only root should be returned when traversing Out with no outgoing edges"
430        );
431    }
432
433    #[tokio::test]
434    async fn bfs_direction_in_only() {
435        let rt = rt().await;
436        let a = rt
437            .create_entity(None, "concept", "A", None, None, vec![])
438            .await
439            .unwrap();
440        let b = rt
441            .create_entity(None, "concept", "B", None, None, vec![])
442            .await
443            .unwrap();
444        // Edge goes B -> A; traversing In from A should find B.
445        rt.link(None, b.id, a.id, EdgeRelation::Extends, 1.0)
446            .await
447            .unwrap();
448
449        let opts = TraversalOptions {
450            max_depth: 2,
451            direction: Direction::In,
452            ..Default::default()
453        };
454        let nodes = rt.bfs_traverse(None, a.id, opts).await.unwrap();
455        let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
456        assert!(
457            ids.contains(&b.id),
458            "B should be reachable via incoming edge"
459        );
460    }
461
462    #[tokio::test]
463    async fn bfs_relation_filter() {
464        let rt = rt().await;
465        let a = rt
466            .create_entity(None, "concept", "A", None, None, vec![])
467            .await
468            .unwrap();
469        let b = rt
470            .create_entity(None, "concept", "B", None, None, vec![])
471            .await
472            .unwrap();
473        let c = rt
474            .create_entity(None, "concept", "C", None, None, vec![])
475            .await
476            .unwrap();
477        rt.link(None, a.id, b.id, EdgeRelation::Extends, 1.0)
478            .await
479            .unwrap();
480        rt.link(None, a.id, c.id, EdgeRelation::DependsOn, 1.0)
481            .await
482            .unwrap();
483
484        let opts = TraversalOptions {
485            max_depth: 2,
486            relations: Some(vec![EdgeRelation::Extends]),
487            ..Default::default()
488        };
489        let nodes = rt.bfs_traverse(None, a.id, opts).await.unwrap();
490        let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
491        assert!(ids.contains(&b.id), "B reachable via 'extends'");
492        assert!(
493            !ids.contains(&c.id),
494            "C not reachable when filtering to 'extends'"
495        );
496    }
497
498    #[tokio::test]
499    async fn shortest_path_connected_nodes() {
500        let rt = rt().await;
501        let a = rt
502            .create_entity(None, "concept", "A", None, None, vec![])
503            .await
504            .unwrap();
505        let b = rt
506            .create_entity(None, "concept", "B", None, None, vec![])
507            .await
508            .unwrap();
509        let c = rt
510            .create_entity(None, "concept", "C", None, None, vec![])
511            .await
512            .unwrap();
513        rt.link(None, a.id, b.id, EdgeRelation::Extends, 1.0)
514            .await
515            .unwrap();
516        rt.link(None, b.id, c.id, EdgeRelation::Extends, 1.0)
517            .await
518            .unwrap();
519
520        let path = rt.shortest_path(None, a.id, c.id, 10).await.unwrap();
521        let path = path.expect("path should exist");
522        assert_eq!(path.len(), 3, "A -> B -> C = 3 nodes");
523        assert_eq!(path[0].entity_id, a.id);
524        assert_eq!(path[2].entity_id, c.id);
525    }
526
527    #[tokio::test]
528    async fn shortest_path_unreachable_returns_none() {
529        let rt = rt().await;
530        let a = rt
531            .create_entity(None, "concept", "A", None, None, vec![])
532            .await
533            .unwrap();
534        let b = rt
535            .create_entity(None, "concept", "B", None, None, vec![])
536            .await
537            .unwrap();
538        // No edges between them.
539
540        let path = rt.shortest_path(None, a.id, b.id, 5).await.unwrap();
541        assert!(path.is_none());
542    }
543
544    #[tokio::test]
545    async fn shortest_path_same_node() {
546        let rt = rt().await;
547        let a = rt
548            .create_entity(None, "concept", "A", None, None, vec![])
549            .await
550            .unwrap();
551
552        let path = rt.shortest_path(None, a.id, a.id, 5).await.unwrap();
553        let path = path.expect("trivial path should always exist");
554        assert_eq!(path.len(), 1);
555        assert_eq!(path[0].entity_id, a.id);
556        assert!(path[0].via_edge.is_none());
557    }
558
559    #[tokio::test]
560    async fn shortest_path_max_depth_zero_adjacent() {
561        let rt = rt().await;
562        let a = rt
563            .create_entity(None, "concept", "A", None, None, vec![])
564            .await
565            .unwrap();
566        let b = rt
567            .create_entity(None, "concept", "B", None, None, vec![])
568            .await
569            .unwrap();
570        rt.link(None, a.id, b.id, EdgeRelation::Extends, 1.0)
571            .await
572            .unwrap();
573
574        // max_depth=0 means only the trivial from==to case succeeds.
575        let path = rt.shortest_path(None, a.id, b.id, 0).await.unwrap();
576        assert!(
577            path.is_none(),
578            "1-hop path should not be returned at max_depth=0"
579        );
580    }
581
582    #[tokio::test]
583    async fn shortest_path_max_depth_one_two_hop_chain() {
584        let rt = rt().await;
585        let a = rt
586            .create_entity(None, "concept", "A", None, None, vec![])
587            .await
588            .unwrap();
589        let b = rt
590            .create_entity(None, "concept", "B", None, None, vec![])
591            .await
592            .unwrap();
593        let c = rt
594            .create_entity(None, "concept", "C", None, None, vec![])
595            .await
596            .unwrap();
597        rt.link(None, a.id, b.id, EdgeRelation::Extends, 1.0)
598            .await
599            .unwrap();
600        rt.link(None, b.id, c.id, EdgeRelation::Extends, 1.0)
601            .await
602            .unwrap();
603
604        // max_depth=1 should find A->B but not A->B->C.
605        let one_hop = rt.shortest_path(None, a.id, b.id, 1).await.unwrap();
606        assert!(
607            one_hop.is_some(),
608            "1-hop path should be found at max_depth=1"
609        );
610
611        let two_hop = rt.shortest_path(None, a.id, c.id, 1).await.unwrap();
612        assert!(
613            two_hop.is_none(),
614            "2-hop path should not be returned at max_depth=1"
615        );
616    }
617}