Skip to main content

nodedb_graph/
traversal.rs

1//! Graph traversal algorithms on the CSR index.
2//!
3//! BFS, bidirectional shortest path, and subgraph materialization.
4//! Max-visited cap prevents supernode fan-out explosion.
5
6use std::collections::{HashMap, HashSet, VecDeque, hash_map::Entry};
7
8pub use nodedb_types::config::tuning::DEFAULT_MAX_VISITED;
9
10use crate::csr::{CsrIndex, Direction};
11
12impl CsrIndex {
13    /// BFS traversal. Returns all reachable node IDs within max_depth hops.
14    ///
15    /// `max_visited` caps the number of nodes visited to prevent supernode fan-out
16    /// explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
17    pub fn traverse_bfs(
18        &self,
19        start_nodes: &[&str],
20        label_filter: Option<&str>,
21        direction: Direction,
22        max_depth: usize,
23        max_visited: usize,
24    ) -> Vec<String> {
25        let label_id = label_filter.and_then(|l| self.label_id(l));
26        let mut visited: HashSet<u32> = HashSet::new();
27        let mut queue: VecDeque<(u32, usize)> = VecDeque::new();
28
29        for &node in start_nodes {
30            if let Some(id) = self.node_id(node)
31                && visited.insert(id)
32            {
33                queue.push_back((id, 0));
34            }
35        }
36
37        while let Some((node_id, depth)) = queue.pop_front() {
38            if depth >= max_depth || visited.len() >= max_visited {
39                continue;
40            }
41
42            if matches!(direction, Direction::Out | Direction::Both) {
43                for (lid, dst) in self.iter_out_edges(node_id) {
44                    if label_id.is_none_or(|f| f == lid)
45                        && visited.len() < max_visited
46                        && visited.insert(dst)
47                    {
48                        queue.push_back((dst, depth + 1));
49                    }
50                }
51            }
52            if matches!(direction, Direction::In | Direction::Both) {
53                for (lid, src) in self.iter_in_edges(node_id) {
54                    if label_id.is_none_or(|f| f == lid)
55                        && visited.len() < max_visited
56                        && visited.insert(src)
57                    {
58                        queue.push_back((src, depth + 1));
59                    }
60                }
61            }
62        }
63
64        visited
65            .into_iter()
66            .map(|id| self.node_name(id).to_string())
67            .collect()
68    }
69
70    /// BFS traversal returning nodes with depth information.
71    ///
72    /// `max_visited` caps the number of nodes visited to prevent supernode fan-out
73    /// explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
74    pub fn traverse_bfs_with_depth(
75        &self,
76        start_nodes: &[&str],
77        label_filter: Option<&str>,
78        direction: Direction,
79        max_depth: usize,
80        max_visited: usize,
81    ) -> Vec<(String, u8)> {
82        let filters: Vec<&str> = label_filter.into_iter().collect();
83        self.traverse_bfs_with_depth_multi(start_nodes, &filters, direction, max_depth, max_visited)
84    }
85
86    /// BFS traversal with multi-label filter. Empty labels = all edges.
87    ///
88    /// `max_visited` caps the number of nodes visited to prevent supernode fan-out
89    /// explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
90    pub fn traverse_bfs_with_depth_multi(
91        &self,
92        start_nodes: &[&str],
93        label_filters: &[&str],
94        direction: Direction,
95        max_depth: usize,
96        max_visited: usize,
97    ) -> Vec<(String, u8)> {
98        let label_ids: Vec<u16> = label_filters
99            .iter()
100            .filter_map(|l| self.label_id(l))
101            .collect();
102        let match_label = |lid: u16| label_ids.is_empty() || label_ids.contains(&lid);
103        let mut visited: HashMap<u32, u8> = HashMap::new();
104        let mut queue: VecDeque<(u32, u8)> = VecDeque::new();
105
106        for &node in start_nodes {
107            if let Some(id) = self.node_id(node) {
108                visited.insert(id, 0);
109                queue.push_back((id, 0));
110            }
111        }
112
113        while let Some((node_id, depth)) = queue.pop_front() {
114            if depth as usize >= max_depth || visited.len() >= max_visited {
115                continue;
116            }
117
118            let next_depth = depth + 1;
119
120            if matches!(direction, Direction::Out | Direction::Both) {
121                for (lid, dst) in self.iter_out_edges(node_id) {
122                    if match_label(lid)
123                        && visited.len() < max_visited
124                        && !visited.contains_key(&dst)
125                    {
126                        visited.insert(dst, next_depth);
127                        queue.push_back((dst, next_depth));
128                    }
129                }
130            }
131            if matches!(direction, Direction::In | Direction::Both) {
132                for (lid, src) in self.iter_in_edges(node_id) {
133                    if match_label(lid)
134                        && visited.len() < max_visited
135                        && !visited.contains_key(&src)
136                    {
137                        visited.insert(src, next_depth);
138                        queue.push_back((src, next_depth));
139                    }
140                }
141            }
142        }
143
144        visited
145            .into_iter()
146            .map(|(id, depth)| (self.node_name(id).to_string(), depth))
147            .collect()
148    }
149
150    /// Shortest path via bidirectional BFS.
151    ///
152    /// `max_visited` caps the combined forward+backward visited set to prevent
153    /// supernode fan-out explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
154    pub fn shortest_path(
155        &self,
156        src: &str,
157        dst: &str,
158        label_filter: Option<&str>,
159        max_depth: usize,
160        max_visited: usize,
161    ) -> Option<Vec<String>> {
162        let src_id = self.node_id(src)?;
163        let dst_id = self.node_id(dst)?;
164        if src_id == dst_id {
165            return Some(vec![src.to_string()]);
166        }
167
168        let label_id = label_filter.and_then(|l| self.label_id(l));
169        let mut fwd_parent: HashMap<u32, u32> = HashMap::new();
170        let mut bwd_parent: HashMap<u32, u32> = HashMap::new();
171        fwd_parent.insert(src_id, src_id);
172        bwd_parent.insert(dst_id, dst_id);
173
174        let mut fwd_frontier: Vec<u32> = vec![src_id];
175        let mut bwd_frontier: Vec<u32> = vec![dst_id];
176
177        for _depth in 0..max_depth {
178            if fwd_parent.len() + bwd_parent.len() >= max_visited {
179                break;
180            }
181
182            let mut next_fwd = Vec::new();
183            for &node in &fwd_frontier {
184                for (lid, neighbor) in self.iter_out_edges(node) {
185                    if label_id.is_none_or(|f| f == lid) {
186                        if let Entry::Vacant(e) = fwd_parent.entry(neighbor) {
187                            e.insert(node);
188                            next_fwd.push(neighbor);
189                        }
190                        if bwd_parent.contains_key(&neighbor) {
191                            return Some(self.reconstruct_path(neighbor, &fwd_parent, &bwd_parent));
192                        }
193                    }
194                }
195            }
196            fwd_frontier = next_fwd;
197
198            let mut next_bwd = Vec::new();
199            for &node in &bwd_frontier {
200                for (lid, neighbor) in self.iter_in_edges(node) {
201                    if label_id.is_none_or(|f| f == lid) {
202                        if let Entry::Vacant(e) = bwd_parent.entry(neighbor) {
203                            e.insert(node);
204                            next_bwd.push(neighbor);
205                        }
206                        if fwd_parent.contains_key(&neighbor) {
207                            return Some(self.reconstruct_path(neighbor, &fwd_parent, &bwd_parent));
208                        }
209                    }
210                }
211            }
212            bwd_frontier = next_bwd;
213
214            if fwd_frontier.is_empty() && bwd_frontier.is_empty() {
215                break;
216            }
217        }
218        None
219    }
220
221    fn reconstruct_path(
222        &self,
223        meeting: u32,
224        fwd_parent: &HashMap<u32, u32>,
225        bwd_parent: &HashMap<u32, u32>,
226    ) -> Vec<String> {
227        let mut fwd_path = Vec::new();
228        let mut current = meeting;
229        loop {
230            fwd_path.push(current);
231            let parent = fwd_parent[&current];
232            if parent == current {
233                break;
234            }
235            current = parent;
236        }
237        fwd_path.reverse();
238
239        current = bwd_parent[&meeting];
240        if current != meeting {
241            loop {
242                fwd_path.push(current);
243                let parent = bwd_parent[&current];
244                if parent == current {
245                    break;
246                }
247                current = parent;
248            }
249        }
250
251        fwd_path
252            .into_iter()
253            .map(|id| self.node_name(id).to_string())
254            .collect()
255    }
256
257    /// Materialize a subgraph as edge tuples within max_depth.
258    ///
259    /// `max_visited` caps the number of nodes visited to prevent supernode fan-out
260    /// explosion. Pass [`DEFAULT_MAX_VISITED`] for the standard limit.
261    pub fn subgraph(
262        &self,
263        start_nodes: &[&str],
264        label_filter: Option<&str>,
265        max_depth: usize,
266        max_visited: usize,
267    ) -> Vec<(String, String, String)> {
268        let label_id = label_filter.and_then(|l| self.label_id(l));
269        let mut visited: HashSet<u32> = HashSet::new();
270        let mut queue: VecDeque<(u32, usize)> = VecDeque::new();
271        let mut edges = Vec::new();
272
273        for &node in start_nodes {
274            if let Some(id) = self.node_id(node)
275                && visited.insert(id)
276            {
277                queue.push_back((id, 0));
278            }
279        }
280
281        while let Some((node_id, depth)) = queue.pop_front() {
282            if depth >= max_depth || visited.len() >= max_visited {
283                continue;
284            }
285            for (lid, dst) in self.iter_out_edges(node_id) {
286                if label_id.is_none_or(|f| f == lid) {
287                    edges.push((
288                        self.node_name(node_id).to_string(),
289                        self.label_name(lid).to_string(),
290                        self.node_name(dst).to_string(),
291                    ));
292                    if visited.len() < max_visited && visited.insert(dst) {
293                        queue.push_back((dst, depth + 1));
294                    }
295                }
296            }
297        }
298
299        edges
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    fn make_csr() -> CsrIndex {
308        let mut csr = CsrIndex::new();
309        csr.add_edge("a", "KNOWS", "b");
310        csr.add_edge("b", "KNOWS", "c");
311        csr.add_edge("c", "KNOWS", "d");
312        csr.add_edge("a", "WORKS", "e");
313        csr
314    }
315
316    #[test]
317    fn bfs_traversal() {
318        let csr = make_csr();
319        let mut result = csr.traverse_bfs(
320            &["a"],
321            Some("KNOWS"),
322            Direction::Out,
323            2,
324            DEFAULT_MAX_VISITED,
325        );
326        result.sort();
327        assert_eq!(result, vec!["a", "b", "c"]);
328    }
329
330    #[test]
331    fn bfs_all_labels() {
332        let csr = make_csr();
333        let mut result = csr.traverse_bfs(&["a"], None, Direction::Out, 1, DEFAULT_MAX_VISITED);
334        result.sort();
335        assert_eq!(result, vec!["a", "b", "e"]);
336    }
337
338    #[test]
339    fn bfs_cycle() {
340        let mut csr = CsrIndex::new();
341        csr.add_edge("a", "L", "b");
342        csr.add_edge("b", "L", "c");
343        csr.add_edge("c", "L", "a");
344        let mut result = csr.traverse_bfs(&["a"], None, Direction::Out, 10, DEFAULT_MAX_VISITED);
345        result.sort();
346        assert_eq!(result, vec!["a", "b", "c"]);
347    }
348
349    #[test]
350    fn bfs_with_depth() {
351        let csr = make_csr();
352        let result = csr.traverse_bfs_with_depth(
353            &["a"],
354            Some("KNOWS"),
355            Direction::Out,
356            3,
357            DEFAULT_MAX_VISITED,
358        );
359        let map: HashMap<String, u8> = result.into_iter().collect();
360        assert_eq!(map["a"], 0);
361        assert_eq!(map["b"], 1);
362        assert_eq!(map["c"], 2);
363        assert_eq!(map["d"], 3);
364    }
365
366    #[test]
367    fn shortest_path_direct() {
368        let csr = make_csr();
369        let path = csr
370            .shortest_path("a", "c", Some("KNOWS"), 5, DEFAULT_MAX_VISITED)
371            .unwrap();
372        assert_eq!(path, vec!["a", "b", "c"]);
373    }
374
375    #[test]
376    fn shortest_path_same_node() {
377        let csr = make_csr();
378        let path = csr
379            .shortest_path("a", "a", None, 5, DEFAULT_MAX_VISITED)
380            .unwrap();
381        assert_eq!(path, vec!["a"]);
382    }
383
384    #[test]
385    fn shortest_path_unreachable() {
386        let csr = make_csr();
387        let path = csr.shortest_path("d", "a", Some("KNOWS"), 5, DEFAULT_MAX_VISITED);
388        assert!(path.is_none());
389    }
390
391    #[test]
392    fn shortest_path_depth_limit() {
393        let csr = make_csr();
394        let path = csr.shortest_path("a", "d", Some("KNOWS"), 1, DEFAULT_MAX_VISITED);
395        assert!(path.is_none());
396    }
397
398    #[test]
399    fn subgraph_materialization() {
400        let csr = make_csr();
401        let edges = csr.subgraph(&["a"], None, 2, DEFAULT_MAX_VISITED);
402        assert_eq!(edges.len(), 3);
403        assert!(edges.contains(&("a".into(), "KNOWS".into(), "b".into())));
404        assert!(edges.contains(&("a".into(), "WORKS".into(), "e".into())));
405        assert!(edges.contains(&("b".into(), "KNOWS".into(), "c".into())));
406    }
407
408    #[test]
409    fn large_graph_bfs() {
410        let mut csr = CsrIndex::new();
411        // Chain of 1000 nodes.
412        for i in 0..999 {
413            csr.add_edge(&format!("n{i}"), "NEXT", &format!("n{}", i + 1));
414        }
415        csr.compact();
416
417        let result = csr.traverse_bfs(
418            &["n0"],
419            Some("NEXT"),
420            Direction::Out,
421            100,
422            DEFAULT_MAX_VISITED,
423        );
424        // Should find n0..n100 (101 nodes within 100 hops).
425        assert_eq!(result.len(), 101);
426
427        let path = csr
428            .shortest_path("n0", "n50", Some("NEXT"), 100, DEFAULT_MAX_VISITED)
429            .unwrap();
430        assert_eq!(path.len(), 51); // n0, n1, ..., n50
431    }
432}