Skip to main content

nodedb_graph/
traversal.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Graph traversal algorithms on the CSR index.
4//!
5//! BFS, bidirectional shortest path, and subgraph materialization.
6//! All algorithms respect a max-visited cap to prevent supernode fan-out
7//! explosion from consuming unbounded memory.
8//!
9//! Access tracking and prefetch hints are integrated: each traversal records
10//! node access for hot/cold partition decisions, and prefetches frontier
11//! neighbors for cache efficiency.
12
13use std::collections::{HashMap, HashSet, VecDeque, hash_map::Entry};
14
15pub use nodedb_types::config::tuning::DEFAULT_MAX_VISITED;
16
17use crate::csr::{CsrIndex, Direction};
18
19impl CsrIndex {
20    /// BFS traversal. Returns all reachable node IDs within max_depth hops.
21    ///
22    /// `max_visited` caps the number of nodes visited to prevent supernode fan-out
23    /// explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
24    ///
25    /// `frontier_bitmap`: when `Some`, only nodes whose surrogate is present in the
26    /// bitmap are eligible as traversal targets. Start nodes are not gated — only
27    /// newly discovered frontier nodes are checked.
28    pub fn traverse_bfs(
29        &self,
30        start_nodes: &[&str],
31        label_filter: Option<&str>,
32        direction: Direction,
33        max_depth: usize,
34        max_visited: usize,
35        frontier_bitmap: Option<&nodedb_types::SurrogateBitmap>,
36    ) -> Vec<String> {
37        let label_id = label_filter.and_then(|l| self.label_id(l));
38        let mut visited: HashSet<u32> = HashSet::new();
39        let mut queue: VecDeque<(u32, usize)> = VecDeque::new();
40
41        for &node in start_nodes {
42            if let Some(&id) = self.node_to_id.get(node)
43                && visited.insert(id)
44            {
45                queue.push_back((id, 0));
46            }
47        }
48
49        while let Some((node_id, depth)) = queue.pop_front() {
50            if depth >= max_depth || visited.len() >= max_visited {
51                continue;
52            }
53
54            // Track access for hot/cold partition decisions.
55            self.record_access(node_id);
56
57            if matches!(direction, Direction::Out | Direction::Both) {
58                for (lid, dst) in self.dense_iter_out(node_id) {
59                    if label_id.is_none_or(|f| f == lid)
60                        && visited.len() < max_visited
61                        && frontier_bitmap.is_none_or(|bm| {
62                            bm.contains(nodedb_types::Surrogate::new(self.node_surrogate_raw(dst)))
63                        })
64                        && visited.insert(dst)
65                    {
66                        self.prefetch_node(dst);
67                        queue.push_back((dst, depth + 1));
68                    }
69                }
70            }
71            if matches!(direction, Direction::In | Direction::Both) {
72                for (lid, src) in self.dense_iter_in(node_id) {
73                    if label_id.is_none_or(|f| f == lid)
74                        && visited.len() < max_visited
75                        && frontier_bitmap.is_none_or(|bm| {
76                            bm.contains(nodedb_types::Surrogate::new(self.node_surrogate_raw(src)))
77                        })
78                        && visited.insert(src)
79                    {
80                        self.prefetch_node(src);
81                        queue.push_back((src, depth + 1));
82                    }
83                }
84            }
85        }
86
87        visited
88            .into_iter()
89            .map(|id| self.id_to_node[id as usize].clone())
90            .collect()
91    }
92
93    /// BFS traversal returning nodes with depth information.
94    ///
95    /// `max_visited` caps the number of nodes visited to prevent supernode fan-out
96    /// explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
97    pub fn traverse_bfs_with_depth(
98        &self,
99        start_nodes: &[&str],
100        label_filter: Option<&str>,
101        direction: Direction,
102        max_depth: usize,
103        max_visited: usize,
104    ) -> Vec<(String, u8)> {
105        let filters: Vec<&str> = label_filter.into_iter().collect();
106        self.traverse_bfs_with_depth_multi(start_nodes, &filters, direction, max_depth, max_visited)
107    }
108
109    /// BFS traversal with multi-label filter. Empty labels = all edges.
110    ///
111    /// `max_visited` caps the number of nodes visited to prevent supernode fan-out
112    /// explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
113    pub fn traverse_bfs_with_depth_multi(
114        &self,
115        start_nodes: &[&str],
116        label_filters: &[&str],
117        direction: Direction,
118        max_depth: usize,
119        max_visited: usize,
120    ) -> Vec<(String, u8)> {
121        let label_ids: Vec<u32> = label_filters
122            .iter()
123            .filter_map(|l| self.label_id(l))
124            .collect();
125        let match_label = |lid: u32| label_ids.is_empty() || label_ids.contains(&lid);
126        let mut visited: HashMap<u32, u8> = HashMap::new();
127        let mut queue: VecDeque<(u32, u8)> = VecDeque::new();
128
129        for &node in start_nodes {
130            if let Some(&id) = self.node_to_id.get(node) {
131                visited.insert(id, 0);
132                queue.push_back((id, 0));
133            }
134        }
135
136        while let Some((node_id, depth)) = queue.pop_front() {
137            if depth as usize >= max_depth || visited.len() >= max_visited {
138                continue;
139            }
140
141            let next_depth = depth + 1;
142
143            if matches!(direction, Direction::Out | Direction::Both) {
144                for (lid, dst) in self.dense_iter_out(node_id) {
145                    if match_label(lid)
146                        && visited.len() < max_visited
147                        && !visited.contains_key(&dst)
148                    {
149                        visited.insert(dst, next_depth);
150                        queue.push_back((dst, next_depth));
151                    }
152                }
153            }
154            if matches!(direction, Direction::In | Direction::Both) {
155                for (lid, src) in self.dense_iter_in(node_id) {
156                    if match_label(lid)
157                        && visited.len() < max_visited
158                        && !visited.contains_key(&src)
159                    {
160                        visited.insert(src, next_depth);
161                        queue.push_back((src, next_depth));
162                    }
163                }
164            }
165        }
166
167        visited
168            .into_iter()
169            .map(|(id, depth)| (self.id_to_node[id as usize].clone(), depth))
170            .collect()
171    }
172
173    /// Shortest path via bidirectional BFS.
174    ///
175    /// `max_visited` caps the combined forward+backward visited set to prevent
176    /// supernode fan-out explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
177    ///
178    /// `frontier_bitmap`: when `Some`, only nodes whose surrogate is present in the
179    /// bitmap are eligible for expansion. Start and end nodes are not gated.
180    pub fn shortest_path(
181        &self,
182        src: &str,
183        dst: &str,
184        label_filter: Option<&str>,
185        max_depth: usize,
186        max_visited: usize,
187        frontier_bitmap: Option<&nodedb_types::SurrogateBitmap>,
188    ) -> Option<Vec<String>> {
189        let src_id = *self.node_to_id.get(src)?;
190        let dst_id = *self.node_to_id.get(dst)?;
191        if src_id == dst_id {
192            return Some(vec![src.to_string()]);
193        }
194
195        let label_id = label_filter.and_then(|l| self.label_id(l));
196        let mut fwd_parent: HashMap<u32, u32> = HashMap::new();
197        let mut bwd_parent: HashMap<u32, u32> = HashMap::new();
198        fwd_parent.insert(src_id, src_id);
199        bwd_parent.insert(dst_id, dst_id);
200
201        let mut fwd_frontier: Vec<u32> = vec![src_id];
202        let mut bwd_frontier: Vec<u32> = vec![dst_id];
203
204        for _depth in 0..max_depth {
205            if fwd_parent.len() + bwd_parent.len() >= max_visited {
206                break;
207            }
208
209            let mut next_fwd = Vec::new();
210            for &node in &fwd_frontier {
211                self.record_access(node);
212                for (lid, neighbor) in self.dense_iter_out(node) {
213                    if label_id.is_none_or(|f| f == lid)
214                        && frontier_bitmap.is_none_or(|bm| {
215                            bm.contains(nodedb_types::Surrogate::new(
216                                self.node_surrogate_raw(neighbor),
217                            ))
218                        })
219                    {
220                        if let Entry::Vacant(e) = fwd_parent.entry(neighbor) {
221                            e.insert(node);
222                            next_fwd.push(neighbor);
223                        }
224                        if bwd_parent.contains_key(&neighbor) {
225                            return Some(self.reconstruct_path(neighbor, &fwd_parent, &bwd_parent));
226                        }
227                    }
228                }
229            }
230            fwd_frontier = next_fwd;
231
232            let mut next_bwd = Vec::new();
233            for &node in &bwd_frontier {
234                self.record_access(node);
235                for (lid, neighbor) in self.dense_iter_in(node) {
236                    if label_id.is_none_or(|f| f == lid)
237                        && frontier_bitmap.is_none_or(|bm| {
238                            bm.contains(nodedb_types::Surrogate::new(
239                                self.node_surrogate_raw(neighbor),
240                            ))
241                        })
242                    {
243                        if let Entry::Vacant(e) = bwd_parent.entry(neighbor) {
244                            e.insert(node);
245                            next_bwd.push(neighbor);
246                        }
247                        if fwd_parent.contains_key(&neighbor) {
248                            return Some(self.reconstruct_path(neighbor, &fwd_parent, &bwd_parent));
249                        }
250                    }
251                }
252            }
253            bwd_frontier = next_bwd;
254
255            if fwd_frontier.is_empty() && bwd_frontier.is_empty() {
256                break;
257            }
258        }
259        None
260    }
261
262    fn reconstruct_path(
263        &self,
264        meeting: u32,
265        fwd_parent: &HashMap<u32, u32>,
266        bwd_parent: &HashMap<u32, u32>,
267    ) -> Vec<String> {
268        let mut fwd_path = Vec::new();
269        let mut current = meeting;
270        loop {
271            fwd_path.push(current);
272            let parent = fwd_parent[&current];
273            if parent == current {
274                break;
275            }
276            current = parent;
277        }
278        fwd_path.reverse();
279
280        current = bwd_parent[&meeting];
281        if current != meeting {
282            loop {
283                fwd_path.push(current);
284                let parent = bwd_parent[&current];
285                if parent == current {
286                    break;
287                }
288                current = parent;
289            }
290        }
291
292        fwd_path
293            .into_iter()
294            .map(|id| self.id_to_node[id as usize].clone())
295            .collect()
296    }
297
298    /// Materialize a subgraph as edge tuples within max_depth.
299    ///
300    /// `max_visited` caps the number of nodes visited to prevent supernode fan-out
301    /// explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
302    pub fn subgraph(
303        &self,
304        start_nodes: &[&str],
305        label_filter: Option<&str>,
306        max_depth: usize,
307        max_visited: usize,
308    ) -> Vec<(String, String, String)> {
309        let label_id = label_filter.and_then(|l| self.label_id(l));
310        let mut visited: HashSet<u32> = HashSet::new();
311        let mut queue: VecDeque<(u32, usize)> = VecDeque::new();
312        let mut edges = Vec::new();
313
314        for &node in start_nodes {
315            if let Some(&id) = self.node_to_id.get(node)
316                && visited.insert(id)
317            {
318                queue.push_back((id, 0));
319            }
320        }
321
322        while let Some((node_id, depth)) = queue.pop_front() {
323            if depth >= max_depth || visited.len() >= max_visited {
324                continue;
325            }
326            self.record_access(node_id);
327            for (lid, dst) in self.dense_iter_out(node_id) {
328                if label_id.is_none_or(|f| f == lid) {
329                    edges.push((
330                        self.id_to_node[node_id as usize].clone(),
331                        self.label_name(lid).to_string(),
332                        self.id_to_node[dst as usize].clone(),
333                    ));
334                    if visited.len() < max_visited && visited.insert(dst) {
335                        queue.push_back((dst, depth + 1));
336                    }
337                }
338            }
339        }
340
341        edges
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    fn make_csr() -> CsrIndex {
350        let mut csr = CsrIndex::new();
351        csr.add_edge("a", "KNOWS", "b").unwrap();
352        csr.add_edge("b", "KNOWS", "c").unwrap();
353        csr.add_edge("c", "KNOWS", "d").unwrap();
354        csr.add_edge("a", "WORKS", "e").unwrap();
355        csr
356    }
357
358    #[test]
359    fn bfs_traversal() {
360        let csr = make_csr();
361        let mut result = csr.traverse_bfs(
362            &["a"],
363            Some("KNOWS"),
364            Direction::Out,
365            2,
366            DEFAULT_MAX_VISITED,
367            None,
368        );
369        result.sort();
370        assert_eq!(result, vec!["a", "b", "c"]);
371    }
372
373    #[test]
374    fn bfs_all_labels() {
375        let csr = make_csr();
376        let mut result =
377            csr.traverse_bfs(&["a"], None, Direction::Out, 1, DEFAULT_MAX_VISITED, None);
378        result.sort();
379        assert_eq!(result, vec!["a", "b", "e"]);
380    }
381
382    #[test]
383    fn bfs_cycle() {
384        let mut csr = CsrIndex::new();
385        csr.add_edge("a", "L", "b").unwrap();
386        csr.add_edge("b", "L", "c").unwrap();
387        csr.add_edge("c", "L", "a").unwrap();
388        let mut result =
389            csr.traverse_bfs(&["a"], None, Direction::Out, 10, DEFAULT_MAX_VISITED, None);
390        result.sort();
391        assert_eq!(result, vec!["a", "b", "c"]);
392    }
393
394    #[test]
395    fn bfs_with_depth() {
396        let csr = make_csr();
397        let result = csr.traverse_bfs_with_depth(
398            &["a"],
399            Some("KNOWS"),
400            Direction::Out,
401            3,
402            DEFAULT_MAX_VISITED,
403        );
404        let map: HashMap<String, u8> = result.into_iter().collect();
405        assert_eq!(map["a"], 0);
406        assert_eq!(map["b"], 1);
407        assert_eq!(map["c"], 2);
408        assert_eq!(map["d"], 3);
409    }
410
411    #[test]
412    fn shortest_path_direct() {
413        let csr = make_csr();
414        let path = csr
415            .shortest_path("a", "c", Some("KNOWS"), 5, DEFAULT_MAX_VISITED, None)
416            .unwrap();
417        assert_eq!(path, vec!["a", "b", "c"]);
418    }
419
420    #[test]
421    fn shortest_path_same_node() {
422        let csr = make_csr();
423        let path = csr
424            .shortest_path("a", "a", None, 5, DEFAULT_MAX_VISITED, None)
425            .unwrap();
426        assert_eq!(path, vec!["a"]);
427    }
428
429    #[test]
430    fn shortest_path_unreachable() {
431        let csr = make_csr();
432        let path = csr.shortest_path("d", "a", Some("KNOWS"), 5, DEFAULT_MAX_VISITED, None);
433        assert!(path.is_none());
434    }
435
436    #[test]
437    fn shortest_path_depth_limit() {
438        let csr = make_csr();
439        let path = csr.shortest_path("a", "d", Some("KNOWS"), 1, DEFAULT_MAX_VISITED, None);
440        assert!(path.is_none());
441    }
442
443    #[test]
444    fn subgraph_materialization() {
445        let csr = make_csr();
446        let edges = csr.subgraph(&["a"], None, 2, DEFAULT_MAX_VISITED);
447        assert_eq!(edges.len(), 3);
448        assert!(edges.contains(&("a".into(), "KNOWS".into(), "b".into())));
449        assert!(edges.contains(&("a".into(), "WORKS".into(), "e".into())));
450        assert!(edges.contains(&("b".into(), "KNOWS".into(), "c".into())));
451    }
452
453    #[test]
454    fn large_graph_bfs() {
455        let mut csr = CsrIndex::new();
456        for i in 0..999 {
457            csr.add_edge(&format!("n{i}"), "NEXT", &format!("n{}", i + 1))
458                .unwrap();
459        }
460        csr.compact().expect("no governor, cannot fail");
461
462        let result = csr.traverse_bfs(
463            &["n0"],
464            Some("NEXT"),
465            Direction::Out,
466            100,
467            DEFAULT_MAX_VISITED,
468            None,
469        );
470        assert_eq!(result.len(), 101);
471
472        let path = csr
473            .shortest_path("n0", "n50", Some("NEXT"), 100, DEFAULT_MAX_VISITED, None)
474            .unwrap();
475        assert_eq!(path.len(), 51);
476    }
477
478    /// BFS with a frontier bitmap that includes only "b". Starting from "a",
479    /// "b" is reachable but "c" is blocked (its surrogate is not in the bitmap).
480    #[test]
481    fn bfs_frontier_bitmap_excludes_non_members() {
482        use nodedb_types::{Surrogate, SurrogateBitmap};
483
484        let mut csr = make_csr();
485        // Assign surrogates: b=10, c=20, d=30. "a" and "e" get no surrogate.
486        csr.set_node_surrogate("b", Surrogate::new(10));
487        csr.set_node_surrogate("c", Surrogate::new(20));
488        csr.set_node_surrogate("d", Surrogate::new(30));
489
490        // Bitmap contains only "b" (surrogate 10).
491        let mut bm = SurrogateBitmap::new();
492        bm.insert(Surrogate::new(10));
493
494        let mut result = csr.traverse_bfs(
495            &["a"],
496            Some("KNOWS"),
497            Direction::Out,
498            10,
499            DEFAULT_MAX_VISITED,
500            Some(&bm),
501        );
502        result.sort();
503        // "a" is the start node (not gated). "b" passes the bitmap. "c" is
504        // excluded (surrogate 20 not in bitmap) so traversal stops there.
505        assert_eq!(result, vec!["a", "b"]);
506    }
507
508    /// shortest_path with a bitmap that excludes the only intermediate node.
509    /// "b" is the only path from "a" to "c" via KNOWS edges; if "b" is blocked
510    /// then no path exists.
511    #[test]
512    fn shortest_path_frontier_bitmap_blocks_intermediate() {
513        use nodedb_types::{Surrogate, SurrogateBitmap};
514
515        let mut csr = make_csr();
516        csr.set_node_surrogate("b", Surrogate::new(10));
517        csr.set_node_surrogate("c", Surrogate::new(20));
518
519        // Bitmap that does NOT contain "b".
520        let mut bm = SurrogateBitmap::new();
521        bm.insert(Surrogate::new(20)); // only "c" is in the bitmap
522
523        let path = csr.shortest_path("a", "c", Some("KNOWS"), 5, DEFAULT_MAX_VISITED, Some(&bm));
524        // "b" (surrogate 10) is not in the bitmap so expansion through it is
525        // blocked, making the path from "a" to "c" unreachable.
526        assert!(path.is_none());
527    }
528}