1use std::collections::VecDeque;
4
5use ahash::{HashMap, HashSet};
6use mentedb_core::edge::{EdgeType, MemoryEdge};
7use mentedb_core::types::MemoryId;
8
9use crate::csr::CsrGraph;
10
11pub fn bfs(graph: &CsrGraph, start: MemoryId, max_depth: usize) -> Vec<(MemoryId, usize)> {
13 let Some(_) = graph.get_idx(start) else {
14 return Vec::new();
15 };
16
17 let mut visited = HashSet::default();
18 let mut queue = VecDeque::new();
19 let mut result = Vec::new();
20
21 visited.insert(start);
22 queue.push_back((start, 0usize));
23
24 while let Some((node, depth)) = queue.pop_front() {
25 result.push((node, depth));
26 if depth >= max_depth {
27 continue;
28 }
29 for (neighbor, _edge) in graph.outgoing(node) {
30 if visited.insert(neighbor) {
31 queue.push_back((neighbor, depth + 1));
32 }
33 }
34 }
35
36 result
37}
38
39pub fn dfs(graph: &CsrGraph, start: MemoryId, max_depth: usize) -> Vec<(MemoryId, usize)> {
41 let Some(_) = graph.get_idx(start) else {
42 return Vec::new();
43 };
44
45 let mut visited = HashSet::default();
46 let mut stack = vec![(start, 0usize)];
47 let mut result = Vec::new();
48
49 while let Some((node, depth)) = stack.pop() {
50 if !visited.insert(node) {
51 continue;
52 }
53 result.push((node, depth));
54 if depth >= max_depth {
55 continue;
56 }
57 for (neighbor, _edge) in graph.outgoing(node) {
58 if !visited.contains(&neighbor) {
59 stack.push((neighbor, depth + 1));
60 }
61 }
62 }
63
64 result
65}
66
67pub fn bfs_filtered(
69 graph: &CsrGraph,
70 start: MemoryId,
71 max_depth: usize,
72 edge_filter: &[EdgeType],
73) -> Vec<(MemoryId, usize)> {
74 let Some(_) = graph.get_idx(start) else {
75 return Vec::new();
76 };
77
78 let filter_set: HashSet<EdgeType> = edge_filter.iter().copied().collect();
79 let mut visited = HashSet::default();
80 let mut queue = VecDeque::new();
81 let mut result = Vec::new();
82
83 visited.insert(start);
84 queue.push_back((start, 0usize));
85
86 while let Some((node, depth)) = queue.pop_front() {
87 result.push((node, depth));
88 if depth >= max_depth {
89 continue;
90 }
91 for (neighbor, edge) in graph.outgoing(node) {
92 if filter_set.contains(&edge.edge_type) && visited.insert(neighbor) {
93 queue.push_back((neighbor, depth + 1));
94 }
95 }
96 }
97
98 result
99}
100
101pub fn extract_subgraph(
103 graph: &CsrGraph,
104 center: MemoryId,
105 radius: usize,
106) -> (Vec<MemoryId>, Vec<MemoryEdge>) {
107 let nodes_with_depth = bfs(graph, center, radius);
108 let node_set: HashSet<MemoryId> = nodes_with_depth.iter().map(|&(id, _)| id).collect();
109
110 let nodes: Vec<MemoryId> = nodes_with_depth.into_iter().map(|(id, _)| id).collect();
111 let mut edges = Vec::new();
112
113 for &node in &nodes {
114 for (neighbor, stored) in graph.outgoing(node) {
115 if node_set.contains(&neighbor) {
116 edges.push(MemoryEdge {
117 source: node,
118 target: neighbor,
119 edge_type: stored.edge_type,
120 weight: stored.weight,
121 created_at: stored.created_at,
122 valid_from: stored.valid_from,
123 valid_until: stored.valid_until,
124 label: stored.label.clone(),
125 });
126 }
127 }
128 }
129
130 (nodes, edges)
131}
132
133pub fn shortest_path(graph: &CsrGraph, from: MemoryId, to: MemoryId) -> Option<Vec<MemoryId>> {
135 if from == to {
136 return Some(vec![from]);
137 }
138
139 let _ = graph.get_idx(from)?;
140 let _ = graph.get_idx(to)?;
141
142 let mut visited = HashSet::default();
143 let mut parent: HashMap<MemoryId, MemoryId> = HashMap::default();
144 let mut queue = VecDeque::new();
145
146 visited.insert(from);
147 queue.push_back(from);
148
149 while let Some(node) = queue.pop_front() {
150 for (neighbor, _) in graph.outgoing(node) {
151 if visited.insert(neighbor) {
152 parent.insert(neighbor, node);
153 if neighbor == to {
154 let mut path = vec![to];
156 let mut cur = to;
157 while let Some(&prev) = parent.get(&cur) {
158 path.push(prev);
159 cur = prev;
160 if cur == from {
161 break;
162 }
163 }
164 path.reverse();
165 return Some(path);
166 }
167 queue.push_back(neighbor);
168 }
169 }
170 }
171
172 None
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
180 MemoryEdge {
181 source: src,
182 target: tgt,
183 edge_type: etype,
184 weight: 1.0,
185 created_at: 1000,
186 valid_from: None,
187 valid_until: None,
188 label: None,
189 }
190 }
191
192 fn build_chain() -> (CsrGraph, Vec<MemoryId>) {
193 let mut g = CsrGraph::new();
195 let ids: Vec<MemoryId> = (0..4).map(|_| MemoryId::new()).collect();
196 g.add_edge(&make_edge(ids[0], ids[1], EdgeType::Caused));
197 g.add_edge(&make_edge(ids[1], ids[2], EdgeType::Caused));
198 g.add_edge(&make_edge(ids[2], ids[3], EdgeType::Related));
199 (g, ids)
200 }
201
202 #[test]
203 fn test_bfs_chain() {
204 let (g, ids) = build_chain();
205 let result = bfs(&g, ids[0], 10);
206 assert_eq!(result.len(), 4);
207 assert_eq!(result[0], (ids[0], 0));
208 assert_eq!(result[1], (ids[1], 1));
209 }
210
211 #[test]
212 fn test_bfs_max_depth() {
213 let (g, ids) = build_chain();
214 let result = bfs(&g, ids[0], 1);
215 assert_eq!(result.len(), 2);
216 }
217
218 #[test]
219 fn test_dfs_chain() {
220 let (g, ids) = build_chain();
221 let result = dfs(&g, ids[0], 10);
222 assert_eq!(result.len(), 4);
223 assert_eq!(result[0].0, ids[0]);
224 }
225
226 #[test]
227 fn test_bfs_filtered() {
228 let (g, ids) = build_chain();
229 let result = bfs_filtered(&g, ids[0], 10, &[EdgeType::Caused]);
231 assert_eq!(result.len(), 3); }
233
234 #[test]
235 fn test_shortest_path() {
236 let (g, ids) = build_chain();
237 let path = shortest_path(&g, ids[0], ids[3]);
238 assert!(path.is_some());
239 let path = path.unwrap();
240 assert_eq!(path.len(), 4);
241 assert_eq!(path[0], ids[0]);
242 assert_eq!(path[3], ids[3]);
243 }
244
245 #[test]
246 fn test_shortest_path_no_path() {
247 let (g, ids) = build_chain();
248 let path = shortest_path(&g, ids[3], ids[0]);
250 assert!(path.is_none());
251 }
252
253 #[test]
254 fn test_extract_subgraph() {
255 let (g, ids) = build_chain();
256 let (nodes, edges) = extract_subgraph(&g, ids[0], 2);
257 assert_eq!(nodes.len(), 3); assert_eq!(edges.len(), 2); }
260}