Skip to main content

nodedb_graph/
traversal.rs

1//! Graph traversal algorithms on the CSR index.
2//!
3//! BFS, bidirectional shortest path, and subgraph materialization.
4//! All algorithms respect a max-visited cap to prevent supernode fan-out
5//! explosion from consuming unbounded memory.
6//!
7//! Access tracking and prefetch hints are integrated: each traversal records
8//! node access for hot/cold partition decisions, and prefetches frontier
9//! neighbors for cache efficiency.
10
11use std::collections::{HashMap, HashSet, VecDeque, hash_map::Entry};
12
13pub use nodedb_types::config::tuning::DEFAULT_MAX_VISITED;
14
15use crate::csr::{CsrIndex, Direction};
16
17impl CsrIndex {
18    /// BFS traversal. Returns all reachable node IDs within max_depth hops.
19    ///
20    /// `max_visited` caps the number of nodes visited to prevent supernode fan-out
21    /// explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
22    pub fn traverse_bfs(
23        &self,
24        start_nodes: &[&str],
25        label_filter: Option<&str>,
26        direction: Direction,
27        max_depth: usize,
28        max_visited: usize,
29    ) -> Vec<String> {
30        let label_id = label_filter.and_then(|l| self.label_id(l));
31        let mut visited: HashSet<u32> = HashSet::new();
32        let mut queue: VecDeque<(u32, usize)> = VecDeque::new();
33
34        for &node in start_nodes {
35            if let Some(id) = self.node_id(node)
36                && visited.insert(id)
37            {
38                queue.push_back((id, 0));
39            }
40        }
41
42        while let Some((node_id, depth)) = queue.pop_front() {
43            if depth >= max_depth || visited.len() >= max_visited {
44                continue;
45            }
46
47            // Track access for hot/cold partition decisions.
48            self.record_access(node_id);
49
50            if matches!(direction, Direction::Out | Direction::Both) {
51                for (lid, dst) in self.iter_out_edges(node_id) {
52                    if label_id.is_none_or(|f| f == lid)
53                        && visited.len() < max_visited
54                        && visited.insert(dst)
55                    {
56                        self.prefetch_node(dst);
57                        queue.push_back((dst, depth + 1));
58                    }
59                }
60            }
61            if matches!(direction, Direction::In | Direction::Both) {
62                for (lid, src) in self.iter_in_edges(node_id) {
63                    if label_id.is_none_or(|f| f == lid)
64                        && visited.len() < max_visited
65                        && visited.insert(src)
66                    {
67                        self.prefetch_node(src);
68                        queue.push_back((src, depth + 1));
69                    }
70                }
71            }
72        }
73
74        visited
75            .into_iter()
76            .map(|id| self.node_name(id).to_string())
77            .collect()
78    }
79
80    /// BFS traversal returning nodes with depth information.
81    ///
82    /// `max_visited` caps the number of nodes visited to prevent supernode fan-out
83    /// explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
84    pub fn traverse_bfs_with_depth(
85        &self,
86        start_nodes: &[&str],
87        label_filter: Option<&str>,
88        direction: Direction,
89        max_depth: usize,
90        max_visited: usize,
91    ) -> Vec<(String, u8)> {
92        let filters: Vec<&str> = label_filter.into_iter().collect();
93        self.traverse_bfs_with_depth_multi(start_nodes, &filters, direction, max_depth, max_visited)
94    }
95
96    /// BFS traversal with multi-label filter. Empty labels = all edges.
97    ///
98    /// `max_visited` caps the number of nodes visited to prevent supernode fan-out
99    /// explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
100    pub fn traverse_bfs_with_depth_multi(
101        &self,
102        start_nodes: &[&str],
103        label_filters: &[&str],
104        direction: Direction,
105        max_depth: usize,
106        max_visited: usize,
107    ) -> Vec<(String, u8)> {
108        let label_ids: Vec<u32> = label_filters
109            .iter()
110            .filter_map(|l| self.label_id(l))
111            .collect();
112        let match_label = |lid: u32| label_ids.is_empty() || label_ids.contains(&lid);
113        let mut visited: HashMap<u32, u8> = HashMap::new();
114        let mut queue: VecDeque<(u32, u8)> = VecDeque::new();
115
116        for &node in start_nodes {
117            if let Some(id) = self.node_id(node) {
118                visited.insert(id, 0);
119                queue.push_back((id, 0));
120            }
121        }
122
123        while let Some((node_id, depth)) = queue.pop_front() {
124            if depth as usize >= max_depth || visited.len() >= max_visited {
125                continue;
126            }
127
128            let next_depth = depth + 1;
129
130            if matches!(direction, Direction::Out | Direction::Both) {
131                for (lid, dst) in self.iter_out_edges(node_id) {
132                    if match_label(lid)
133                        && visited.len() < max_visited
134                        && !visited.contains_key(&dst)
135                    {
136                        visited.insert(dst, next_depth);
137                        queue.push_back((dst, next_depth));
138                    }
139                }
140            }
141            if matches!(direction, Direction::In | Direction::Both) {
142                for (lid, src) in self.iter_in_edges(node_id) {
143                    if match_label(lid)
144                        && visited.len() < max_visited
145                        && !visited.contains_key(&src)
146                    {
147                        visited.insert(src, next_depth);
148                        queue.push_back((src, next_depth));
149                    }
150                }
151            }
152        }
153
154        visited
155            .into_iter()
156            .map(|(id, depth)| (self.node_name(id).to_string(), depth))
157            .collect()
158    }
159
160    /// Shortest path via bidirectional BFS.
161    ///
162    /// `max_visited` caps the combined forward+backward visited set to prevent
163    /// supernode fan-out explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
164    pub fn shortest_path(
165        &self,
166        src: &str,
167        dst: &str,
168        label_filter: Option<&str>,
169        max_depth: usize,
170        max_visited: usize,
171    ) -> Option<Vec<String>> {
172        let src_id = self.node_id(src)?;
173        let dst_id = self.node_id(dst)?;
174        if src_id == dst_id {
175            return Some(vec![src.to_string()]);
176        }
177
178        let label_id = label_filter.and_then(|l| self.label_id(l));
179        let mut fwd_parent: HashMap<u32, u32> = HashMap::new();
180        let mut bwd_parent: HashMap<u32, u32> = HashMap::new();
181        fwd_parent.insert(src_id, src_id);
182        bwd_parent.insert(dst_id, dst_id);
183
184        let mut fwd_frontier: Vec<u32> = vec![src_id];
185        let mut bwd_frontier: Vec<u32> = vec![dst_id];
186
187        for _depth in 0..max_depth {
188            if fwd_parent.len() + bwd_parent.len() >= max_visited {
189                break;
190            }
191
192            let mut next_fwd = Vec::new();
193            for &node in &fwd_frontier {
194                self.record_access(node);
195                for (lid, neighbor) in self.iter_out_edges(node) {
196                    if label_id.is_none_or(|f| f == lid) {
197                        if let Entry::Vacant(e) = fwd_parent.entry(neighbor) {
198                            e.insert(node);
199                            next_fwd.push(neighbor);
200                        }
201                        if bwd_parent.contains_key(&neighbor) {
202                            return Some(self.reconstruct_path(neighbor, &fwd_parent, &bwd_parent));
203                        }
204                    }
205                }
206            }
207            fwd_frontier = next_fwd;
208
209            let mut next_bwd = Vec::new();
210            for &node in &bwd_frontier {
211                self.record_access(node);
212                for (lid, neighbor) in self.iter_in_edges(node) {
213                    if label_id.is_none_or(|f| f == lid) {
214                        if let Entry::Vacant(e) = bwd_parent.entry(neighbor) {
215                            e.insert(node);
216                            next_bwd.push(neighbor);
217                        }
218                        if fwd_parent.contains_key(&neighbor) {
219                            return Some(self.reconstruct_path(neighbor, &fwd_parent, &bwd_parent));
220                        }
221                    }
222                }
223            }
224            bwd_frontier = next_bwd;
225
226            if fwd_frontier.is_empty() && bwd_frontier.is_empty() {
227                break;
228            }
229        }
230        None
231    }
232
233    fn reconstruct_path(
234        &self,
235        meeting: u32,
236        fwd_parent: &HashMap<u32, u32>,
237        bwd_parent: &HashMap<u32, u32>,
238    ) -> Vec<String> {
239        let mut fwd_path = Vec::new();
240        let mut current = meeting;
241        loop {
242            fwd_path.push(current);
243            let parent = fwd_parent[&current];
244            if parent == current {
245                break;
246            }
247            current = parent;
248        }
249        fwd_path.reverse();
250
251        current = bwd_parent[&meeting];
252        if current != meeting {
253            loop {
254                fwd_path.push(current);
255                let parent = bwd_parent[&current];
256                if parent == current {
257                    break;
258                }
259                current = parent;
260            }
261        }
262
263        fwd_path
264            .into_iter()
265            .map(|id| self.node_name(id).to_string())
266            .collect()
267    }
268
269    /// Materialize a subgraph as edge tuples within max_depth.
270    ///
271    /// `max_visited` caps the number of nodes visited to prevent supernode fan-out
272    /// explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
273    pub fn subgraph(
274        &self,
275        start_nodes: &[&str],
276        label_filter: Option<&str>,
277        max_depth: usize,
278        max_visited: usize,
279    ) -> Vec<(String, String, String)> {
280        let label_id = label_filter.and_then(|l| self.label_id(l));
281        let mut visited: HashSet<u32> = HashSet::new();
282        let mut queue: VecDeque<(u32, usize)> = VecDeque::new();
283        let mut edges = Vec::new();
284
285        for &node in start_nodes {
286            if let Some(id) = self.node_id(node)
287                && visited.insert(id)
288            {
289                queue.push_back((id, 0));
290            }
291        }
292
293        while let Some((node_id, depth)) = queue.pop_front() {
294            if depth >= max_depth || visited.len() >= max_visited {
295                continue;
296            }
297            self.record_access(node_id);
298            for (lid, dst) in self.iter_out_edges(node_id) {
299                if label_id.is_none_or(|f| f == lid) {
300                    edges.push((
301                        self.node_name(node_id).to_string(),
302                        self.label_name(lid).to_string(),
303                        self.node_name(dst).to_string(),
304                    ));
305                    if visited.len() < max_visited && visited.insert(dst) {
306                        queue.push_back((dst, depth + 1));
307                    }
308                }
309            }
310        }
311
312        edges
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    fn make_csr() -> CsrIndex {
321        let mut csr = CsrIndex::new();
322        csr.add_edge("a", "KNOWS", "b").unwrap();
323        csr.add_edge("b", "KNOWS", "c").unwrap();
324        csr.add_edge("c", "KNOWS", "d").unwrap();
325        csr.add_edge("a", "WORKS", "e").unwrap();
326        csr
327    }
328
329    #[test]
330    fn bfs_traversal() {
331        let csr = make_csr();
332        let mut result = csr.traverse_bfs(
333            &["a"],
334            Some("KNOWS"),
335            Direction::Out,
336            2,
337            DEFAULT_MAX_VISITED,
338        );
339        result.sort();
340        assert_eq!(result, vec!["a", "b", "c"]);
341    }
342
343    #[test]
344    fn bfs_all_labels() {
345        let csr = make_csr();
346        let mut result = csr.traverse_bfs(&["a"], None, Direction::Out, 1, DEFAULT_MAX_VISITED);
347        result.sort();
348        assert_eq!(result, vec!["a", "b", "e"]);
349    }
350
351    #[test]
352    fn bfs_cycle() {
353        let mut csr = CsrIndex::new();
354        csr.add_edge("a", "L", "b").unwrap();
355        csr.add_edge("b", "L", "c").unwrap();
356        csr.add_edge("c", "L", "a").unwrap();
357        let mut result = csr.traverse_bfs(&["a"], None, Direction::Out, 10, DEFAULT_MAX_VISITED);
358        result.sort();
359        assert_eq!(result, vec!["a", "b", "c"]);
360    }
361
362    #[test]
363    fn bfs_with_depth() {
364        let csr = make_csr();
365        let result = csr.traverse_bfs_with_depth(
366            &["a"],
367            Some("KNOWS"),
368            Direction::Out,
369            3,
370            DEFAULT_MAX_VISITED,
371        );
372        let map: HashMap<String, u8> = result.into_iter().collect();
373        assert_eq!(map["a"], 0);
374        assert_eq!(map["b"], 1);
375        assert_eq!(map["c"], 2);
376        assert_eq!(map["d"], 3);
377    }
378
379    #[test]
380    fn shortest_path_direct() {
381        let csr = make_csr();
382        let path = csr
383            .shortest_path("a", "c", Some("KNOWS"), 5, DEFAULT_MAX_VISITED)
384            .unwrap();
385        assert_eq!(path, vec!["a", "b", "c"]);
386    }
387
388    #[test]
389    fn shortest_path_same_node() {
390        let csr = make_csr();
391        let path = csr
392            .shortest_path("a", "a", None, 5, DEFAULT_MAX_VISITED)
393            .unwrap();
394        assert_eq!(path, vec!["a"]);
395    }
396
397    #[test]
398    fn shortest_path_unreachable() {
399        let csr = make_csr();
400        let path = csr.shortest_path("d", "a", Some("KNOWS"), 5, DEFAULT_MAX_VISITED);
401        assert!(path.is_none());
402    }
403
404    #[test]
405    fn shortest_path_depth_limit() {
406        let csr = make_csr();
407        let path = csr.shortest_path("a", "d", Some("KNOWS"), 1, DEFAULT_MAX_VISITED);
408        assert!(path.is_none());
409    }
410
411    #[test]
412    fn subgraph_materialization() {
413        let csr = make_csr();
414        let edges = csr.subgraph(&["a"], None, 2, DEFAULT_MAX_VISITED);
415        assert_eq!(edges.len(), 3);
416        assert!(edges.contains(&("a".into(), "KNOWS".into(), "b".into())));
417        assert!(edges.contains(&("a".into(), "WORKS".into(), "e".into())));
418        assert!(edges.contains(&("b".into(), "KNOWS".into(), "c".into())));
419    }
420
421    #[test]
422    fn large_graph_bfs() {
423        let mut csr = CsrIndex::new();
424        for i in 0..999 {
425            csr.add_edge(&format!("n{i}"), "NEXT", &format!("n{}", i + 1))
426                .unwrap();
427        }
428        csr.compact();
429
430        let result = csr.traverse_bfs(
431            &["n0"],
432            Some("NEXT"),
433            Direction::Out,
434            100,
435            DEFAULT_MAX_VISITED,
436        );
437        assert_eq!(result.len(), 101);
438
439        let path = csr
440            .shortest_path("n0", "n50", Some("NEXT"), 100, DEFAULT_MAX_VISITED)
441            .unwrap();
442        assert_eq!(path.len(), 51);
443    }
444}