1use std::collections::{HashMap, HashSet, VecDeque, hash_map::Entry};
12
13pub use nodedb_types::config::tuning::DEFAULT_MAX_VISITED;
14
15use crate::csr::{CsrIndex, Direction};
16
17impl CsrIndex {
18 pub fn traverse_bfs(
23 &self,
24 start_nodes: &[&str],
25 label_filter: Option<&str>,
26 direction: Direction,
27 max_depth: usize,
28 max_visited: usize,
29 ) -> Vec<String> {
30 let label_id = label_filter.and_then(|l| self.label_id(l));
31 let mut visited: HashSet<u32> = HashSet::new();
32 let mut queue: VecDeque<(u32, usize)> = VecDeque::new();
33
34 for &node in start_nodes {
35 if let Some(id) = self.node_id(node)
36 && visited.insert(id)
37 {
38 queue.push_back((id, 0));
39 }
40 }
41
42 while let Some((node_id, depth)) = queue.pop_front() {
43 if depth >= max_depth || visited.len() >= max_visited {
44 continue;
45 }
46
47 self.record_access(node_id);
49
50 if matches!(direction, Direction::Out | Direction::Both) {
51 for (lid, dst) in self.iter_out_edges(node_id) {
52 if label_id.is_none_or(|f| f == lid)
53 && visited.len() < max_visited
54 && visited.insert(dst)
55 {
56 self.prefetch_node(dst);
57 queue.push_back((dst, depth + 1));
58 }
59 }
60 }
61 if matches!(direction, Direction::In | Direction::Both) {
62 for (lid, src) in self.iter_in_edges(node_id) {
63 if label_id.is_none_or(|f| f == lid)
64 && visited.len() < max_visited
65 && visited.insert(src)
66 {
67 self.prefetch_node(src);
68 queue.push_back((src, depth + 1));
69 }
70 }
71 }
72 }
73
74 visited
75 .into_iter()
76 .map(|id| self.node_name(id).to_string())
77 .collect()
78 }
79
80 pub fn traverse_bfs_with_depth(
85 &self,
86 start_nodes: &[&str],
87 label_filter: Option<&str>,
88 direction: Direction,
89 max_depth: usize,
90 max_visited: usize,
91 ) -> Vec<(String, u8)> {
92 let filters: Vec<&str> = label_filter.into_iter().collect();
93 self.traverse_bfs_with_depth_multi(start_nodes, &filters, direction, max_depth, max_visited)
94 }
95
96 pub fn traverse_bfs_with_depth_multi(
101 &self,
102 start_nodes: &[&str],
103 label_filters: &[&str],
104 direction: Direction,
105 max_depth: usize,
106 max_visited: usize,
107 ) -> Vec<(String, u8)> {
108 let label_ids: Vec<u32> = label_filters
109 .iter()
110 .filter_map(|l| self.label_id(l))
111 .collect();
112 let match_label = |lid: u32| label_ids.is_empty() || label_ids.contains(&lid);
113 let mut visited: HashMap<u32, u8> = HashMap::new();
114 let mut queue: VecDeque<(u32, u8)> = VecDeque::new();
115
116 for &node in start_nodes {
117 if let Some(id) = self.node_id(node) {
118 visited.insert(id, 0);
119 queue.push_back((id, 0));
120 }
121 }
122
123 while let Some((node_id, depth)) = queue.pop_front() {
124 if depth as usize >= max_depth || visited.len() >= max_visited {
125 continue;
126 }
127
128 let next_depth = depth + 1;
129
130 if matches!(direction, Direction::Out | Direction::Both) {
131 for (lid, dst) in self.iter_out_edges(node_id) {
132 if match_label(lid)
133 && visited.len() < max_visited
134 && !visited.contains_key(&dst)
135 {
136 visited.insert(dst, next_depth);
137 queue.push_back((dst, next_depth));
138 }
139 }
140 }
141 if matches!(direction, Direction::In | Direction::Both) {
142 for (lid, src) in self.iter_in_edges(node_id) {
143 if match_label(lid)
144 && visited.len() < max_visited
145 && !visited.contains_key(&src)
146 {
147 visited.insert(src, next_depth);
148 queue.push_back((src, next_depth));
149 }
150 }
151 }
152 }
153
154 visited
155 .into_iter()
156 .map(|(id, depth)| (self.node_name(id).to_string(), depth))
157 .collect()
158 }
159
160 pub fn shortest_path(
165 &self,
166 src: &str,
167 dst: &str,
168 label_filter: Option<&str>,
169 max_depth: usize,
170 max_visited: usize,
171 ) -> Option<Vec<String>> {
172 let src_id = self.node_id(src)?;
173 let dst_id = self.node_id(dst)?;
174 if src_id == dst_id {
175 return Some(vec![src.to_string()]);
176 }
177
178 let label_id = label_filter.and_then(|l| self.label_id(l));
179 let mut fwd_parent: HashMap<u32, u32> = HashMap::new();
180 let mut bwd_parent: HashMap<u32, u32> = HashMap::new();
181 fwd_parent.insert(src_id, src_id);
182 bwd_parent.insert(dst_id, dst_id);
183
184 let mut fwd_frontier: Vec<u32> = vec![src_id];
185 let mut bwd_frontier: Vec<u32> = vec![dst_id];
186
187 for _depth in 0..max_depth {
188 if fwd_parent.len() + bwd_parent.len() >= max_visited {
189 break;
190 }
191
192 let mut next_fwd = Vec::new();
193 for &node in &fwd_frontier {
194 self.record_access(node);
195 for (lid, neighbor) in self.iter_out_edges(node) {
196 if label_id.is_none_or(|f| f == lid) {
197 if let Entry::Vacant(e) = fwd_parent.entry(neighbor) {
198 e.insert(node);
199 next_fwd.push(neighbor);
200 }
201 if bwd_parent.contains_key(&neighbor) {
202 return Some(self.reconstruct_path(neighbor, &fwd_parent, &bwd_parent));
203 }
204 }
205 }
206 }
207 fwd_frontier = next_fwd;
208
209 let mut next_bwd = Vec::new();
210 for &node in &bwd_frontier {
211 self.record_access(node);
212 for (lid, neighbor) in self.iter_in_edges(node) {
213 if label_id.is_none_or(|f| f == lid) {
214 if let Entry::Vacant(e) = bwd_parent.entry(neighbor) {
215 e.insert(node);
216 next_bwd.push(neighbor);
217 }
218 if fwd_parent.contains_key(&neighbor) {
219 return Some(self.reconstruct_path(neighbor, &fwd_parent, &bwd_parent));
220 }
221 }
222 }
223 }
224 bwd_frontier = next_bwd;
225
226 if fwd_frontier.is_empty() && bwd_frontier.is_empty() {
227 break;
228 }
229 }
230 None
231 }
232
233 fn reconstruct_path(
234 &self,
235 meeting: u32,
236 fwd_parent: &HashMap<u32, u32>,
237 bwd_parent: &HashMap<u32, u32>,
238 ) -> Vec<String> {
239 let mut fwd_path = Vec::new();
240 let mut current = meeting;
241 loop {
242 fwd_path.push(current);
243 let parent = fwd_parent[¤t];
244 if parent == current {
245 break;
246 }
247 current = parent;
248 }
249 fwd_path.reverse();
250
251 current = bwd_parent[&meeting];
252 if current != meeting {
253 loop {
254 fwd_path.push(current);
255 let parent = bwd_parent[¤t];
256 if parent == current {
257 break;
258 }
259 current = parent;
260 }
261 }
262
263 fwd_path
264 .into_iter()
265 .map(|id| self.node_name(id).to_string())
266 .collect()
267 }
268
269 pub fn subgraph(
274 &self,
275 start_nodes: &[&str],
276 label_filter: Option<&str>,
277 max_depth: usize,
278 max_visited: usize,
279 ) -> Vec<(String, String, String)> {
280 let label_id = label_filter.and_then(|l| self.label_id(l));
281 let mut visited: HashSet<u32> = HashSet::new();
282 let mut queue: VecDeque<(u32, usize)> = VecDeque::new();
283 let mut edges = Vec::new();
284
285 for &node in start_nodes {
286 if let Some(id) = self.node_id(node)
287 && visited.insert(id)
288 {
289 queue.push_back((id, 0));
290 }
291 }
292
293 while let Some((node_id, depth)) = queue.pop_front() {
294 if depth >= max_depth || visited.len() >= max_visited {
295 continue;
296 }
297 self.record_access(node_id);
298 for (lid, dst) in self.iter_out_edges(node_id) {
299 if label_id.is_none_or(|f| f == lid) {
300 edges.push((
301 self.node_name(node_id).to_string(),
302 self.label_name(lid).to_string(),
303 self.node_name(dst).to_string(),
304 ));
305 if visited.len() < max_visited && visited.insert(dst) {
306 queue.push_back((dst, depth + 1));
307 }
308 }
309 }
310 }
311
312 edges
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 fn make_csr() -> CsrIndex {
321 let mut csr = CsrIndex::new();
322 csr.add_edge("a", "KNOWS", "b").unwrap();
323 csr.add_edge("b", "KNOWS", "c").unwrap();
324 csr.add_edge("c", "KNOWS", "d").unwrap();
325 csr.add_edge("a", "WORKS", "e").unwrap();
326 csr
327 }
328
329 #[test]
330 fn bfs_traversal() {
331 let csr = make_csr();
332 let mut result = csr.traverse_bfs(
333 &["a"],
334 Some("KNOWS"),
335 Direction::Out,
336 2,
337 DEFAULT_MAX_VISITED,
338 );
339 result.sort();
340 assert_eq!(result, vec!["a", "b", "c"]);
341 }
342
343 #[test]
344 fn bfs_all_labels() {
345 let csr = make_csr();
346 let mut result = csr.traverse_bfs(&["a"], None, Direction::Out, 1, DEFAULT_MAX_VISITED);
347 result.sort();
348 assert_eq!(result, vec!["a", "b", "e"]);
349 }
350
351 #[test]
352 fn bfs_cycle() {
353 let mut csr = CsrIndex::new();
354 csr.add_edge("a", "L", "b").unwrap();
355 csr.add_edge("b", "L", "c").unwrap();
356 csr.add_edge("c", "L", "a").unwrap();
357 let mut result = csr.traverse_bfs(&["a"], None, Direction::Out, 10, DEFAULT_MAX_VISITED);
358 result.sort();
359 assert_eq!(result, vec!["a", "b", "c"]);
360 }
361
362 #[test]
363 fn bfs_with_depth() {
364 let csr = make_csr();
365 let result = csr.traverse_bfs_with_depth(
366 &["a"],
367 Some("KNOWS"),
368 Direction::Out,
369 3,
370 DEFAULT_MAX_VISITED,
371 );
372 let map: HashMap<String, u8> = result.into_iter().collect();
373 assert_eq!(map["a"], 0);
374 assert_eq!(map["b"], 1);
375 assert_eq!(map["c"], 2);
376 assert_eq!(map["d"], 3);
377 }
378
379 #[test]
380 fn shortest_path_direct() {
381 let csr = make_csr();
382 let path = csr
383 .shortest_path("a", "c", Some("KNOWS"), 5, DEFAULT_MAX_VISITED)
384 .unwrap();
385 assert_eq!(path, vec!["a", "b", "c"]);
386 }
387
388 #[test]
389 fn shortest_path_same_node() {
390 let csr = make_csr();
391 let path = csr
392 .shortest_path("a", "a", None, 5, DEFAULT_MAX_VISITED)
393 .unwrap();
394 assert_eq!(path, vec!["a"]);
395 }
396
397 #[test]
398 fn shortest_path_unreachable() {
399 let csr = make_csr();
400 let path = csr.shortest_path("d", "a", Some("KNOWS"), 5, DEFAULT_MAX_VISITED);
401 assert!(path.is_none());
402 }
403
404 #[test]
405 fn shortest_path_depth_limit() {
406 let csr = make_csr();
407 let path = csr.shortest_path("a", "d", Some("KNOWS"), 1, DEFAULT_MAX_VISITED);
408 assert!(path.is_none());
409 }
410
411 #[test]
412 fn subgraph_materialization() {
413 let csr = make_csr();
414 let edges = csr.subgraph(&["a"], None, 2, DEFAULT_MAX_VISITED);
415 assert_eq!(edges.len(), 3);
416 assert!(edges.contains(&("a".into(), "KNOWS".into(), "b".into())));
417 assert!(edges.contains(&("a".into(), "WORKS".into(), "e".into())));
418 assert!(edges.contains(&("b".into(), "KNOWS".into(), "c".into())));
419 }
420
421 #[test]
422 fn large_graph_bfs() {
423 let mut csr = CsrIndex::new();
424 for i in 0..999 {
425 csr.add_edge(&format!("n{i}"), "NEXT", &format!("n{}", i + 1))
426 .unwrap();
427 }
428 csr.compact();
429
430 let result = csr.traverse_bfs(
431 &["n0"],
432 Some("NEXT"),
433 Direction::Out,
434 100,
435 DEFAULT_MAX_VISITED,
436 );
437 assert_eq!(result.len(), 101);
438
439 let path = csr
440 .shortest_path("n0", "n50", Some("NEXT"), 100, DEFAULT_MAX_VISITED)
441 .unwrap();
442 assert_eq!(path.len(), 51);
443 }
444}