Skip to main content

khive_runtime/
graph_traversal.rs

1// Copyright 2024-2025 khive contributors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet, VecDeque};
16
17use uuid::Uuid;
18
19use khive_storage::types::{Direction, Edge, LinkId, NeighborQuery};
20use khive_storage::EdgeRelation;
21
22use crate::error::{RuntimeError, RuntimeResult};
23use crate::runtime::{KhiveRuntime, NamespaceToken};
24
25/// A node in a traversal path.
26#[derive(Debug, Clone)]
27pub struct PathNode {
28    /// Entity at this position.
29    pub entity_id: Uuid,
30    /// Distance from the start node (0 = start node).
31    pub depth: usize,
32    /// Edge that led to this node (`None` for the start node).
33    pub via_edge: Option<Edge>,
34}
35
36/// Options for BFS traversal and shortest-path search.
37#[derive(Debug, Clone)]
38pub struct TraversalOptions {
39    /// Maximum hops to follow.
40    pub max_depth: usize,
41    /// Which edge directions to follow.
42    pub direction: Direction,
43    /// Restrict traversal to these relation types (`None` = all).
44    pub relations: Option<Vec<EdgeRelation>>,
45    /// Stop after collecting this many nodes (start node counts as one).
46    pub max_results: Option<usize>,
47}
48
49impl Default for TraversalOptions {
50    fn default() -> Self {
51        Self {
52            max_depth: 3,
53            direction: Direction::Out,
54            relations: None,
55            max_results: None,
56        }
57    }
58}
59
60impl KhiveRuntime {
61    /// BFS traversal from `start`, returning nodes in level order.
62    ///
63    /// The first element is always the start node (`via_edge = None`, `depth = 0`).
64    /// Nodes already visited are skipped so the result set is deduplicated.
65    pub async fn bfs_traverse(
66        &self,
67        token: &NamespaceToken,
68        start: Uuid,
69        options: TraversalOptions,
70    ) -> RuntimeResult<Vec<PathNode>> {
71        if !self.substrate_exists_in_ns(token, start).await? {
72            return Ok(Vec::new());
73        }
74
75        let graph = self.graph(token)?;
76        let limit = options.max_results.unwrap_or(usize::MAX);
77
78        let mut visited: HashSet<Uuid> = HashSet::new();
79        let mut results: Vec<PathNode> = Vec::new();
80        // queue: (node_id, current_depth)
81        let mut queue: VecDeque<(Uuid, usize)> = VecDeque::new();
82
83        visited.insert(start);
84        results.push(PathNode {
85            entity_id: start,
86            depth: 0,
87            via_edge: None,
88        });
89        queue.push_back((start, 0));
90
91        'bfs: while let Some((current, depth)) = queue.pop_front() {
92            if depth >= options.max_depth {
93                continue;
94            }
95
96            let query = NeighborQuery {
97                direction: options.direction.clone(),
98                relations: options.relations.clone(),
99                limit: None,
100                min_weight: None,
101            };
102            let hits = graph.neighbors(current, query).await?;
103
104            for hit in hits {
105                if visited.contains(&hit.node_id) {
106                    continue;
107                }
108
109                let edge = graph
110                    .get_edge(LinkId::from(hit.edge_id))
111                    .await?
112                    .ok_or_else(|| {
113                        RuntimeError::NotFound(format!("edge {} missing", hit.edge_id))
114                    })?;
115
116                visited.insert(hit.node_id);
117                results.push(PathNode {
118                    entity_id: hit.node_id,
119                    depth: depth + 1,
120                    via_edge: Some(edge),
121                });
122
123                if results.len() >= limit {
124                    break 'bfs;
125                }
126
127                queue.push_back((hit.node_id, depth + 1));
128            }
129        }
130
131        Ok(results)
132    }
133
134    /// Bidirectional BFS shortest path from `from` to `to`.
135    ///
136    /// Returns `Some(path)` where `path[0]` is `from` and `path.last()` is `to`,
137    /// or `None` if no path exists within `max_depth` hops.
138    /// For `from == to` returns `Some` with a single-node path immediately.
139    pub async fn shortest_path(
140        &self,
141        token: &NamespaceToken,
142        from: Uuid,
143        to: Uuid,
144        max_depth: usize,
145    ) -> RuntimeResult<Option<Vec<PathNode>>> {
146        if !self.substrate_exists_in_ns(token, from).await?
147            || !self.substrate_exists_in_ns(token, to).await?
148        {
149            return Ok(None);
150        }
151
152        if from == to {
153            return Ok(Some(vec![PathNode {
154                entity_id: from,
155                depth: 0,
156                via_edge: None,
157            }]));
158        }
159
160        let graph = self.graph(token)?;
161
162        // Forward map: node -> (depth, parent, edge_id that reached this node)
163        let mut fwd: HashMap<Uuid, (usize, Option<Uuid>, Option<Uuid>)> = HashMap::new();
164        let mut fwd_q: VecDeque<Uuid> = VecDeque::new();
165        fwd.insert(from, (0, None, None));
166        fwd_q.push_back(from);
167
168        // Backward map: node -> (depth, child, edge_id that reached this node from `to` side)
169        let mut bwd: HashMap<Uuid, (usize, Option<Uuid>, Option<Uuid>)> = HashMap::new();
170        let mut bwd_q: VecDeque<Uuid> = VecDeque::new();
171        bwd.insert(to, (0, None, None));
172        bwd_q.push_back(to);
173
174        let mut meeting: Option<(Uuid, usize)> = None;
175        let mut current_depth = 0usize;
176
177        while (!fwd_q.is_empty() || !bwd_q.is_empty()) && current_depth <= max_depth {
178            // Expand the forward frontier one level.
179            let fwd_level = fwd_q.len();
180            for _ in 0..fwd_level {
181                let Some(node) = fwd_q.pop_front() else { break };
182                let fwd_depth = fwd[&node].0;
183
184                let hits = graph
185                    .neighbors(
186                        node,
187                        NeighborQuery {
188                            direction: Direction::Out,
189                            relations: None,
190                            limit: None,
191                            min_weight: None,
192                        },
193                    )
194                    .await?;
195
196                for hit in hits {
197                    if fwd.contains_key(&hit.node_id) {
198                        continue;
199                    }
200                    let new_depth = fwd_depth + 1;
201                    fwd.insert(hit.node_id, (new_depth, Some(node), Some(hit.edge_id)));
202                    fwd_q.push_back(hit.node_id);
203
204                    if let Some(&(bwd_depth, _, _)) = bwd.get(&hit.node_id) {
205                        let total = new_depth + bwd_depth;
206                        if total <= max_depth
207                            && meeting.as_ref().is_none_or(|&(_, best)| total < best)
208                        {
209                            meeting = Some((hit.node_id, total));
210                        }
211                    }
212                }
213            }
214
215            if meeting.is_some() {
216                break;
217            }
218
219            // Expand the backward frontier one level (following incoming edges).
220            let bwd_level = bwd_q.len();
221            for _ in 0..bwd_level {
222                let Some(node) = bwd_q.pop_front() else { break };
223                let bwd_depth = bwd[&node].0;
224
225                let hits = graph
226                    .neighbors(
227                        node,
228                        NeighborQuery {
229                            direction: Direction::In,
230                            relations: None,
231                            limit: None,
232                            min_weight: None,
233                        },
234                    )
235                    .await?;
236
237                for hit in hits {
238                    if bwd.contains_key(&hit.node_id) {
239                        continue;
240                    }
241                    let new_depth = bwd_depth + 1;
242                    bwd.insert(hit.node_id, (new_depth, Some(node), Some(hit.edge_id)));
243                    bwd_q.push_back(hit.node_id);
244
245                    if let Some(&(fwd_depth, _, _)) = fwd.get(&hit.node_id) {
246                        let total = fwd_depth + new_depth;
247                        if total <= max_depth
248                            && meeting.as_ref().is_none_or(|&(_, best)| total < best)
249                        {
250                            meeting = Some((hit.node_id, total));
251                        }
252                    }
253                }
254            }
255
256            if meeting.is_some() {
257                break;
258            }
259
260            current_depth += 1;
261        }
262
263        let (mid, _) = match meeting {
264            None => return Ok(None),
265            Some(m) => m,
266        };
267
268        // Reconstruct path: walk fwd map back from mid to `from`, then walk bwd map forward to `to`.
269        let mut fwd_chain: Vec<(Uuid, Option<Uuid>)> = Vec::new();
270        {
271            let mut cur = mid;
272            loop {
273                let (_, parent, edge_id) = fwd[&cur];
274                fwd_chain.push((cur, edge_id));
275                match parent {
276                    Some(p) => cur = p,
277                    None => break,
278                }
279            }
280        }
281        fwd_chain.reverse();
282
283        let mut bwd_chain: Vec<(Uuid, Option<Uuid>)> = Vec::new();
284        {
285            let mut cur = mid;
286            // Walk toward `to` using the backward map's child pointers.
287            while let Some(&(_, Some(child), edge_id)) = bwd.get(&cur) {
288                bwd_chain.push((child, edge_id));
289                cur = child;
290            }
291        }
292
293        // Build PathNode slice — fetch edges lazily.
294        let mut path: Vec<PathNode> = Vec::new();
295        for (i, (node_id, edge_id)) in fwd_chain.iter().enumerate() {
296            let via_edge = if i == 0 {
297                None // start node
298            } else if let Some(eid) = edge_id {
299                graph.get_edge(LinkId::from(*eid)).await?.or(None)
300            } else {
301                None
302            };
303            path.push(PathNode {
304                entity_id: *node_id,
305                depth: i,
306                via_edge,
307            });
308        }
309
310        let base = path.len();
311        for (i, (node_id, edge_id)) in bwd_chain.iter().enumerate() {
312            let via_edge = if let Some(eid) = edge_id {
313                graph.get_edge(LinkId::from(*eid)).await?.or(None)
314            } else {
315                None
316            };
317            path.push(PathNode {
318                entity_id: *node_id,
319                depth: base + i,
320                via_edge,
321            });
322        }
323
324        Ok(Some(path))
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use crate::runtime::{KhiveRuntime, NamespaceToken};
332    use khive_storage::EdgeRelation;
333
334    async fn rt() -> KhiveRuntime {
335        KhiveRuntime::memory().expect("memory runtime")
336    }
337
338    #[tokio::test]
339    async fn bfs_max_depth_zero_returns_only_root() {
340        let rt = rt().await;
341        let tok = NamespaceToken::local();
342        let a = rt
343            .create_entity(&tok, "concept", None, "A", None, None, vec![])
344            .await
345            .unwrap();
346        let b = rt
347            .create_entity(&tok, "concept", None, "B", None, None, vec![])
348            .await
349            .unwrap();
350        rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
351            .await
352            .unwrap();
353
354        let opts = TraversalOptions {
355            max_depth: 0,
356            ..Default::default()
357        };
358        let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
359
360        assert_eq!(nodes.len(), 1);
361        assert_eq!(nodes[0].entity_id, a.id);
362        assert_eq!(nodes[0].depth, 0);
363        assert!(nodes[0].via_edge.is_none());
364    }
365
366    #[tokio::test]
367    async fn bfs_depth_one_returns_root_and_neighbors() {
368        let rt = rt().await;
369        let tok = NamespaceToken::local();
370        let a = rt
371            .create_entity(&tok, "concept", None, "A", None, None, vec![])
372            .await
373            .unwrap();
374        let b = rt
375            .create_entity(&tok, "concept", None, "B", None, None, vec![])
376            .await
377            .unwrap();
378        let c = rt
379            .create_entity(&tok, "concept", None, "C", None, None, vec![])
380            .await
381            .unwrap();
382        rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
383            .await
384            .unwrap();
385        rt.link(&tok, a.id, c.id, EdgeRelation::Extends, 1.0, None)
386            .await
387            .unwrap();
388        // Add a node two hops away — it must NOT appear.
389        let d = rt
390            .create_entity(&tok, "concept", None, "D", None, None, vec![])
391            .await
392            .unwrap();
393        rt.link(&tok, b.id, d.id, EdgeRelation::Extends, 1.0, None)
394            .await
395            .unwrap();
396
397        let opts = TraversalOptions {
398            max_depth: 1,
399            ..Default::default()
400        };
401        let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
402
403        let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
404        assert!(ids.contains(&a.id));
405        assert!(ids.contains(&b.id));
406        assert!(ids.contains(&c.id));
407        assert!(!ids.contains(&d.id));
408        // Every non-root node should be at depth 1.
409        for node in &nodes {
410            if node.entity_id != a.id {
411                assert_eq!(node.depth, 1);
412            }
413        }
414    }
415
416    #[tokio::test]
417    async fn bfs_direction_out_only() {
418        let rt = rt().await;
419        let tok = NamespaceToken::local();
420        let a = rt
421            .create_entity(&tok, "concept", None, "A", None, None, vec![])
422            .await
423            .unwrap();
424        let b = rt
425            .create_entity(&tok, "concept", None, "B", None, None, vec![])
426            .await
427            .unwrap();
428        // Edge goes B -> A; traversing Out from A should find nothing.
429        rt.link(&tok, b.id, a.id, EdgeRelation::Extends, 1.0, None)
430            .await
431            .unwrap();
432
433        let opts = TraversalOptions {
434            max_depth: 2,
435            direction: Direction::Out,
436            ..Default::default()
437        };
438        let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
439        assert_eq!(
440            nodes.len(),
441            1,
442            "only root should be returned when traversing Out with no outgoing edges"
443        );
444    }
445
446    #[tokio::test]
447    async fn bfs_direction_in_only() {
448        let rt = rt().await;
449        let tok = NamespaceToken::local();
450        let a = rt
451            .create_entity(&tok, "concept", None, "A", None, None, vec![])
452            .await
453            .unwrap();
454        let b = rt
455            .create_entity(&tok, "concept", None, "B", None, None, vec![])
456            .await
457            .unwrap();
458        // Edge goes B -> A; traversing In from A should find B.
459        rt.link(&tok, b.id, a.id, EdgeRelation::Extends, 1.0, None)
460            .await
461            .unwrap();
462
463        let opts = TraversalOptions {
464            max_depth: 2,
465            direction: Direction::In,
466            ..Default::default()
467        };
468        let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
469        let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
470        assert!(
471            ids.contains(&b.id),
472            "B should be reachable via incoming edge"
473        );
474    }
475
476    #[tokio::test]
477    async fn bfs_relation_filter() {
478        let rt = rt().await;
479        let tok = NamespaceToken::local();
480        let a = rt
481            .create_entity(&tok, "concept", None, "A", None, None, vec![])
482            .await
483            .unwrap();
484        let b = rt
485            .create_entity(&tok, "concept", None, "B", None, None, vec![])
486            .await
487            .unwrap();
488        let c = rt
489            .create_entity(&tok, "concept", None, "C", None, None, vec![])
490            .await
491            .unwrap();
492        rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
493            .await
494            .unwrap();
495        rt.link(&tok, a.id, c.id, EdgeRelation::Enables, 1.0, None)
496            .await
497            .unwrap();
498
499        let opts = TraversalOptions {
500            max_depth: 2,
501            relations: Some(vec![EdgeRelation::Extends]),
502            ..Default::default()
503        };
504        let nodes = rt.bfs_traverse(&tok, a.id, opts).await.unwrap();
505        let ids: HashSet<Uuid> = nodes.iter().map(|n| n.entity_id).collect();
506        assert!(ids.contains(&b.id), "B reachable via 'extends'");
507        assert!(
508            !ids.contains(&c.id),
509            "C not reachable when filtering to 'extends'"
510        );
511    }
512
513    #[tokio::test]
514    async fn shortest_path_connected_nodes() {
515        let rt = rt().await;
516        let tok = NamespaceToken::local();
517        let a = rt
518            .create_entity(&tok, "concept", None, "A", None, None, vec![])
519            .await
520            .unwrap();
521        let b = rt
522            .create_entity(&tok, "concept", None, "B", None, None, vec![])
523            .await
524            .unwrap();
525        let c = rt
526            .create_entity(&tok, "concept", None, "C", None, None, vec![])
527            .await
528            .unwrap();
529        rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
530            .await
531            .unwrap();
532        rt.link(&tok, b.id, c.id, EdgeRelation::Extends, 1.0, None)
533            .await
534            .unwrap();
535
536        let path = rt.shortest_path(&tok, a.id, c.id, 10).await.unwrap();
537        let path = path.expect("path should exist");
538        assert_eq!(path.len(), 3, "A -> B -> C = 3 nodes");
539        assert_eq!(path[0].entity_id, a.id);
540        assert_eq!(path[2].entity_id, c.id);
541    }
542
543    #[tokio::test]
544    async fn shortest_path_unreachable_returns_none() {
545        let rt = rt().await;
546        let tok = NamespaceToken::local();
547        let a = rt
548            .create_entity(&tok, "concept", None, "A", None, None, vec![])
549            .await
550            .unwrap();
551        let b = rt
552            .create_entity(&tok, "concept", None, "B", None, None, vec![])
553            .await
554            .unwrap();
555        // No edges between them.
556
557        let path = rt.shortest_path(&tok, a.id, b.id, 5).await.unwrap();
558        assert!(path.is_none());
559    }
560
561    #[tokio::test]
562    async fn shortest_path_same_node() {
563        let rt = rt().await;
564        let tok = NamespaceToken::local();
565        let a = rt
566            .create_entity(&tok, "concept", None, "A", None, None, vec![])
567            .await
568            .unwrap();
569
570        let path = rt.shortest_path(&tok, a.id, a.id, 5).await.unwrap();
571        let path = path.expect("trivial path should always exist");
572        assert_eq!(path.len(), 1);
573        assert_eq!(path[0].entity_id, a.id);
574        assert!(path[0].via_edge.is_none());
575    }
576
577    #[tokio::test]
578    async fn shortest_path_max_depth_zero_adjacent() {
579        let rt = rt().await;
580        let tok = NamespaceToken::local();
581        let a = rt
582            .create_entity(&tok, "concept", None, "A", None, None, vec![])
583            .await
584            .unwrap();
585        let b = rt
586            .create_entity(&tok, "concept", None, "B", None, None, vec![])
587            .await
588            .unwrap();
589        rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
590            .await
591            .unwrap();
592
593        // max_depth=0 means only the trivial from==to case succeeds.
594        let path = rt.shortest_path(&tok, a.id, b.id, 0).await.unwrap();
595        assert!(
596            path.is_none(),
597            "1-hop path should not be returned at max_depth=0"
598        );
599    }
600
601    #[tokio::test]
602    async fn shortest_path_max_depth_one_two_hop_chain() {
603        let rt = rt().await;
604        let tok = NamespaceToken::local();
605        let a = rt
606            .create_entity(&tok, "concept", None, "A", None, None, vec![])
607            .await
608            .unwrap();
609        let b = rt
610            .create_entity(&tok, "concept", None, "B", None, None, vec![])
611            .await
612            .unwrap();
613        let c = rt
614            .create_entity(&tok, "concept", None, "C", None, None, vec![])
615            .await
616            .unwrap();
617        rt.link(&tok, a.id, b.id, EdgeRelation::Extends, 1.0, None)
618            .await
619            .unwrap();
620        rt.link(&tok, b.id, c.id, EdgeRelation::Extends, 1.0, None)
621            .await
622            .unwrap();
623
624        // max_depth=1 should find A->B but not A->B->C.
625        let one_hop = rt.shortest_path(&tok, a.id, b.id, 1).await.unwrap();
626        assert!(
627            one_hop.is_some(),
628            "1-hop path should be found at max_depth=1"
629        );
630
631        let two_hop = rt.shortest_path(&tok, a.id, c.id, 1).await.unwrap();
632        assert!(
633            two_hop.is_none(),
634            "2-hop path should not be returned at max_depth=1"
635        );
636    }
637}