1use std::collections::{HashMap, HashSet, VecDeque, hash_map::Entry};
7
8pub use nodedb_types::config::tuning::DEFAULT_MAX_VISITED;
9
10use crate::csr::{CsrIndex, Direction};
11
12impl CsrIndex {
13 pub fn traverse_bfs(
18 &self,
19 start_nodes: &[&str],
20 label_filter: Option<&str>,
21 direction: Direction,
22 max_depth: usize,
23 max_visited: usize,
24 ) -> Vec<String> {
25 let label_id = label_filter.and_then(|l| self.label_id(l));
26 let mut visited: HashSet<u32> = HashSet::new();
27 let mut queue: VecDeque<(u32, usize)> = VecDeque::new();
28
29 for &node in start_nodes {
30 if let Some(id) = self.node_id(node)
31 && visited.insert(id)
32 {
33 queue.push_back((id, 0));
34 }
35 }
36
37 while let Some((node_id, depth)) = queue.pop_front() {
38 if depth >= max_depth || visited.len() >= max_visited {
39 continue;
40 }
41
42 if matches!(direction, Direction::Out | Direction::Both) {
43 for (lid, dst) in self.iter_out_edges(node_id) {
44 if label_id.is_none_or(|f| f == lid)
45 && visited.len() < max_visited
46 && visited.insert(dst)
47 {
48 queue.push_back((dst, depth + 1));
49 }
50 }
51 }
52 if matches!(direction, Direction::In | Direction::Both) {
53 for (lid, src) in self.iter_in_edges(node_id) {
54 if label_id.is_none_or(|f| f == lid)
55 && visited.len() < max_visited
56 && visited.insert(src)
57 {
58 queue.push_back((src, depth + 1));
59 }
60 }
61 }
62 }
63
64 visited
65 .into_iter()
66 .map(|id| self.node_name(id).to_string())
67 .collect()
68 }
69
70 pub fn traverse_bfs_with_depth(
75 &self,
76 start_nodes: &[&str],
77 label_filter: Option<&str>,
78 direction: Direction,
79 max_depth: usize,
80 max_visited: usize,
81 ) -> Vec<(String, u8)> {
82 let filters: Vec<&str> = label_filter.into_iter().collect();
83 self.traverse_bfs_with_depth_multi(start_nodes, &filters, direction, max_depth, max_visited)
84 }
85
86 pub fn traverse_bfs_with_depth_multi(
91 &self,
92 start_nodes: &[&str],
93 label_filters: &[&str],
94 direction: Direction,
95 max_depth: usize,
96 max_visited: usize,
97 ) -> Vec<(String, u8)> {
98 let label_ids: Vec<u16> = label_filters
99 .iter()
100 .filter_map(|l| self.label_id(l))
101 .collect();
102 let match_label = |lid: u16| label_ids.is_empty() || label_ids.contains(&lid);
103 let mut visited: HashMap<u32, u8> = HashMap::new();
104 let mut queue: VecDeque<(u32, u8)> = VecDeque::new();
105
106 for &node in start_nodes {
107 if let Some(id) = self.node_id(node) {
108 visited.insert(id, 0);
109 queue.push_back((id, 0));
110 }
111 }
112
113 while let Some((node_id, depth)) = queue.pop_front() {
114 if depth as usize >= max_depth || visited.len() >= max_visited {
115 continue;
116 }
117
118 let next_depth = depth + 1;
119
120 if matches!(direction, Direction::Out | Direction::Both) {
121 for (lid, dst) in self.iter_out_edges(node_id) {
122 if match_label(lid)
123 && visited.len() < max_visited
124 && !visited.contains_key(&dst)
125 {
126 visited.insert(dst, next_depth);
127 queue.push_back((dst, next_depth));
128 }
129 }
130 }
131 if matches!(direction, Direction::In | Direction::Both) {
132 for (lid, src) in self.iter_in_edges(node_id) {
133 if match_label(lid)
134 && visited.len() < max_visited
135 && !visited.contains_key(&src)
136 {
137 visited.insert(src, next_depth);
138 queue.push_back((src, next_depth));
139 }
140 }
141 }
142 }
143
144 visited
145 .into_iter()
146 .map(|(id, depth)| (self.node_name(id).to_string(), depth))
147 .collect()
148 }
149
150 pub fn shortest_path(
155 &self,
156 src: &str,
157 dst: &str,
158 label_filter: Option<&str>,
159 max_depth: usize,
160 max_visited: usize,
161 ) -> Option<Vec<String>> {
162 let src_id = self.node_id(src)?;
163 let dst_id = self.node_id(dst)?;
164 if src_id == dst_id {
165 return Some(vec![src.to_string()]);
166 }
167
168 let label_id = label_filter.and_then(|l| self.label_id(l));
169 let mut fwd_parent: HashMap<u32, u32> = HashMap::new();
170 let mut bwd_parent: HashMap<u32, u32> = HashMap::new();
171 fwd_parent.insert(src_id, src_id);
172 bwd_parent.insert(dst_id, dst_id);
173
174 let mut fwd_frontier: Vec<u32> = vec![src_id];
175 let mut bwd_frontier: Vec<u32> = vec![dst_id];
176
177 for _depth in 0..max_depth {
178 if fwd_parent.len() + bwd_parent.len() >= max_visited {
179 break;
180 }
181
182 let mut next_fwd = Vec::new();
183 for &node in &fwd_frontier {
184 for (lid, neighbor) in self.iter_out_edges(node) {
185 if label_id.is_none_or(|f| f == lid) {
186 if let Entry::Vacant(e) = fwd_parent.entry(neighbor) {
187 e.insert(node);
188 next_fwd.push(neighbor);
189 }
190 if bwd_parent.contains_key(&neighbor) {
191 return Some(self.reconstruct_path(neighbor, &fwd_parent, &bwd_parent));
192 }
193 }
194 }
195 }
196 fwd_frontier = next_fwd;
197
198 let mut next_bwd = Vec::new();
199 for &node in &bwd_frontier {
200 for (lid, neighbor) in self.iter_in_edges(node) {
201 if label_id.is_none_or(|f| f == lid) {
202 if let Entry::Vacant(e) = bwd_parent.entry(neighbor) {
203 e.insert(node);
204 next_bwd.push(neighbor);
205 }
206 if fwd_parent.contains_key(&neighbor) {
207 return Some(self.reconstruct_path(neighbor, &fwd_parent, &bwd_parent));
208 }
209 }
210 }
211 }
212 bwd_frontier = next_bwd;
213
214 if fwd_frontier.is_empty() && bwd_frontier.is_empty() {
215 break;
216 }
217 }
218 None
219 }
220
221 fn reconstruct_path(
222 &self,
223 meeting: u32,
224 fwd_parent: &HashMap<u32, u32>,
225 bwd_parent: &HashMap<u32, u32>,
226 ) -> Vec<String> {
227 let mut fwd_path = Vec::new();
228 let mut current = meeting;
229 loop {
230 fwd_path.push(current);
231 let parent = fwd_parent[¤t];
232 if parent == current {
233 break;
234 }
235 current = parent;
236 }
237 fwd_path.reverse();
238
239 current = bwd_parent[&meeting];
240 if current != meeting {
241 loop {
242 fwd_path.push(current);
243 let parent = bwd_parent[¤t];
244 if parent == current {
245 break;
246 }
247 current = parent;
248 }
249 }
250
251 fwd_path
252 .into_iter()
253 .map(|id| self.node_name(id).to_string())
254 .collect()
255 }
256
257 pub fn subgraph(
262 &self,
263 start_nodes: &[&str],
264 label_filter: Option<&str>,
265 max_depth: usize,
266 max_visited: usize,
267 ) -> Vec<(String, String, String)> {
268 let label_id = label_filter.and_then(|l| self.label_id(l));
269 let mut visited: HashSet<u32> = HashSet::new();
270 let mut queue: VecDeque<(u32, usize)> = VecDeque::new();
271 let mut edges = Vec::new();
272
273 for &node in start_nodes {
274 if let Some(id) = self.node_id(node)
275 && visited.insert(id)
276 {
277 queue.push_back((id, 0));
278 }
279 }
280
281 while let Some((node_id, depth)) = queue.pop_front() {
282 if depth >= max_depth || visited.len() >= max_visited {
283 continue;
284 }
285 for (lid, dst) in self.iter_out_edges(node_id) {
286 if label_id.is_none_or(|f| f == lid) {
287 edges.push((
288 self.node_name(node_id).to_string(),
289 self.label_name(lid).to_string(),
290 self.node_name(dst).to_string(),
291 ));
292 if visited.len() < max_visited && visited.insert(dst) {
293 queue.push_back((dst, depth + 1));
294 }
295 }
296 }
297 }
298
299 edges
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 fn make_csr() -> CsrIndex {
308 let mut csr = CsrIndex::new();
309 csr.add_edge("a", "KNOWS", "b");
310 csr.add_edge("b", "KNOWS", "c");
311 csr.add_edge("c", "KNOWS", "d");
312 csr.add_edge("a", "WORKS", "e");
313 csr
314 }
315
316 #[test]
317 fn bfs_traversal() {
318 let csr = make_csr();
319 let mut result = csr.traverse_bfs(
320 &["a"],
321 Some("KNOWS"),
322 Direction::Out,
323 2,
324 DEFAULT_MAX_VISITED,
325 );
326 result.sort();
327 assert_eq!(result, vec!["a", "b", "c"]);
328 }
329
330 #[test]
331 fn bfs_all_labels() {
332 let csr = make_csr();
333 let mut result = csr.traverse_bfs(&["a"], None, Direction::Out, 1, DEFAULT_MAX_VISITED);
334 result.sort();
335 assert_eq!(result, vec!["a", "b", "e"]);
336 }
337
338 #[test]
339 fn bfs_cycle() {
340 let mut csr = CsrIndex::new();
341 csr.add_edge("a", "L", "b");
342 csr.add_edge("b", "L", "c");
343 csr.add_edge("c", "L", "a");
344 let mut result = csr.traverse_bfs(&["a"], None, Direction::Out, 10, DEFAULT_MAX_VISITED);
345 result.sort();
346 assert_eq!(result, vec!["a", "b", "c"]);
347 }
348
349 #[test]
350 fn bfs_with_depth() {
351 let csr = make_csr();
352 let result = csr.traverse_bfs_with_depth(
353 &["a"],
354 Some("KNOWS"),
355 Direction::Out,
356 3,
357 DEFAULT_MAX_VISITED,
358 );
359 let map: HashMap<String, u8> = result.into_iter().collect();
360 assert_eq!(map["a"], 0);
361 assert_eq!(map["b"], 1);
362 assert_eq!(map["c"], 2);
363 assert_eq!(map["d"], 3);
364 }
365
366 #[test]
367 fn shortest_path_direct() {
368 let csr = make_csr();
369 let path = csr
370 .shortest_path("a", "c", Some("KNOWS"), 5, DEFAULT_MAX_VISITED)
371 .unwrap();
372 assert_eq!(path, vec!["a", "b", "c"]);
373 }
374
375 #[test]
376 fn shortest_path_same_node() {
377 let csr = make_csr();
378 let path = csr
379 .shortest_path("a", "a", None, 5, DEFAULT_MAX_VISITED)
380 .unwrap();
381 assert_eq!(path, vec!["a"]);
382 }
383
384 #[test]
385 fn shortest_path_unreachable() {
386 let csr = make_csr();
387 let path = csr.shortest_path("d", "a", Some("KNOWS"), 5, DEFAULT_MAX_VISITED);
388 assert!(path.is_none());
389 }
390
391 #[test]
392 fn shortest_path_depth_limit() {
393 let csr = make_csr();
394 let path = csr.shortest_path("a", "d", Some("KNOWS"), 1, DEFAULT_MAX_VISITED);
395 assert!(path.is_none());
396 }
397
398 #[test]
399 fn subgraph_materialization() {
400 let csr = make_csr();
401 let edges = csr.subgraph(&["a"], None, 2, DEFAULT_MAX_VISITED);
402 assert_eq!(edges.len(), 3);
403 assert!(edges.contains(&("a".into(), "KNOWS".into(), "b".into())));
404 assert!(edges.contains(&("a".into(), "WORKS".into(), "e".into())));
405 assert!(edges.contains(&("b".into(), "KNOWS".into(), "c".into())));
406 }
407
408 #[test]
409 fn large_graph_bfs() {
410 let mut csr = CsrIndex::new();
411 for i in 0..999 {
413 csr.add_edge(&format!("n{i}"), "NEXT", &format!("n{}", i + 1));
414 }
415 csr.compact();
416
417 let result = csr.traverse_bfs(
418 &["n0"],
419 Some("NEXT"),
420 Direction::Out,
421 100,
422 DEFAULT_MAX_VISITED,
423 );
424 assert_eq!(result.len(), 101);
426
427 let path = csr
428 .shortest_path("n0", "n50", Some("NEXT"), 100, DEFAULT_MAX_VISITED)
429 .unwrap();
430 assert_eq!(path.len(), 51); }
432}