1use std::collections::VecDeque;
4
5use ahash::{HashMap, HashSet};
6use mentedb_core::edge::{EdgeType, MemoryEdge};
7use mentedb_core::types::MemoryId;
8
9use crate::csr::CsrGraph;
10
11pub fn bfs(graph: &CsrGraph, start: MemoryId, max_depth: usize) -> Vec<(MemoryId, usize)> {
13 let Some(_) = graph.get_idx(start) else {
14 return Vec::new();
15 };
16
17 let mut visited = HashSet::default();
18 let mut queue = VecDeque::new();
19 let mut result = Vec::new();
20
21 visited.insert(start);
22 queue.push_back((start, 0usize));
23
24 while let Some((node, depth)) = queue.pop_front() {
25 result.push((node, depth));
26 if depth >= max_depth {
27 continue;
28 }
29 for (neighbor, _edge) in graph.outgoing(node) {
30 if visited.insert(neighbor) {
31 queue.push_back((neighbor, depth + 1));
32 }
33 }
34 }
35
36 result
37}
38
39pub fn dfs(graph: &CsrGraph, start: MemoryId, max_depth: usize) -> Vec<(MemoryId, usize)> {
41 let Some(_) = graph.get_idx(start) else {
42 return Vec::new();
43 };
44
45 let mut visited = HashSet::default();
46 let mut stack = vec![(start, 0usize)];
47 let mut result = Vec::new();
48
49 while let Some((node, depth)) = stack.pop() {
50 if !visited.insert(node) {
51 continue;
52 }
53 result.push((node, depth));
54 if depth >= max_depth {
55 continue;
56 }
57 for (neighbor, _edge) in graph.outgoing(node) {
58 if !visited.contains(&neighbor) {
59 stack.push((neighbor, depth + 1));
60 }
61 }
62 }
63
64 result
65}
66
67pub fn bfs_filtered(
69 graph: &CsrGraph,
70 start: MemoryId,
71 max_depth: usize,
72 edge_filter: &[EdgeType],
73) -> Vec<(MemoryId, usize)> {
74 let Some(_) = graph.get_idx(start) else {
75 return Vec::new();
76 };
77
78 let filter_set: HashSet<EdgeType> = edge_filter.iter().copied().collect();
79 let mut visited = HashSet::default();
80 let mut queue = VecDeque::new();
81 let mut result = Vec::new();
82
83 visited.insert(start);
84 queue.push_back((start, 0usize));
85
86 while let Some((node, depth)) = queue.pop_front() {
87 result.push((node, depth));
88 if depth >= max_depth {
89 continue;
90 }
91 for (neighbor, edge) in graph.outgoing(node) {
92 if filter_set.contains(&edge.edge_type) && visited.insert(neighbor) {
93 queue.push_back((neighbor, depth + 1));
94 }
95 }
96 }
97
98 result
99}
100
101pub fn extract_subgraph(
103 graph: &CsrGraph,
104 center: MemoryId,
105 radius: usize,
106) -> (Vec<MemoryId>, Vec<MemoryEdge>) {
107 let nodes_with_depth = bfs(graph, center, radius);
108 let node_set: HashSet<MemoryId> = nodes_with_depth.iter().map(|&(id, _)| id).collect();
109
110 let nodes: Vec<MemoryId> = nodes_with_depth.into_iter().map(|(id, _)| id).collect();
111 let mut edges = Vec::new();
112
113 for &node in &nodes {
114 for (neighbor, stored) in graph.outgoing(node) {
115 if node_set.contains(&neighbor) {
116 edges.push(MemoryEdge {
117 source: node,
118 target: neighbor,
119 edge_type: stored.edge_type,
120 weight: stored.weight,
121 created_at: stored.created_at,
122 valid_from: stored.valid_from,
123 valid_until: stored.valid_until,
124 });
125 }
126 }
127 }
128
129 (nodes, edges)
130}
131
132pub fn shortest_path(graph: &CsrGraph, from: MemoryId, to: MemoryId) -> Option<Vec<MemoryId>> {
134 if from == to {
135 return Some(vec![from]);
136 }
137
138 let _ = graph.get_idx(from)?;
139 let _ = graph.get_idx(to)?;
140
141 let mut visited = HashSet::default();
142 let mut parent: HashMap<MemoryId, MemoryId> = HashMap::default();
143 let mut queue = VecDeque::new();
144
145 visited.insert(from);
146 queue.push_back(from);
147
148 while let Some(node) = queue.pop_front() {
149 for (neighbor, _) in graph.outgoing(node) {
150 if visited.insert(neighbor) {
151 parent.insert(neighbor, node);
152 if neighbor == to {
153 let mut path = vec![to];
155 let mut cur = to;
156 while let Some(&prev) = parent.get(&cur) {
157 path.push(prev);
158 cur = prev;
159 if cur == from {
160 break;
161 }
162 }
163 path.reverse();
164 return Some(path);
165 }
166 queue.push_back(neighbor);
167 }
168 }
169 }
170
171 None
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
179 MemoryEdge {
180 source: src,
181 target: tgt,
182 edge_type: etype,
183 weight: 1.0,
184 created_at: 1000,
185 valid_from: None,
186 valid_until: None,
187 }
188 }
189
190 fn build_chain() -> (CsrGraph, Vec<MemoryId>) {
191 let mut g = CsrGraph::new();
193 let ids: Vec<MemoryId> = (0..4).map(|_| MemoryId::new()).collect();
194 g.add_edge(&make_edge(ids[0], ids[1], EdgeType::Caused));
195 g.add_edge(&make_edge(ids[1], ids[2], EdgeType::Caused));
196 g.add_edge(&make_edge(ids[2], ids[3], EdgeType::Related));
197 (g, ids)
198 }
199
200 #[test]
201 fn test_bfs_chain() {
202 let (g, ids) = build_chain();
203 let result = bfs(&g, ids[0], 10);
204 assert_eq!(result.len(), 4);
205 assert_eq!(result[0], (ids[0], 0));
206 assert_eq!(result[1], (ids[1], 1));
207 }
208
209 #[test]
210 fn test_bfs_max_depth() {
211 let (g, ids) = build_chain();
212 let result = bfs(&g, ids[0], 1);
213 assert_eq!(result.len(), 2);
214 }
215
216 #[test]
217 fn test_dfs_chain() {
218 let (g, ids) = build_chain();
219 let result = dfs(&g, ids[0], 10);
220 assert_eq!(result.len(), 4);
221 assert_eq!(result[0].0, ids[0]);
222 }
223
224 #[test]
225 fn test_bfs_filtered() {
226 let (g, ids) = build_chain();
227 let result = bfs_filtered(&g, ids[0], 10, &[EdgeType::Caused]);
229 assert_eq!(result.len(), 3); }
231
232 #[test]
233 fn test_shortest_path() {
234 let (g, ids) = build_chain();
235 let path = shortest_path(&g, ids[0], ids[3]);
236 assert!(path.is_some());
237 let path = path.unwrap();
238 assert_eq!(path.len(), 4);
239 assert_eq!(path[0], ids[0]);
240 assert_eq!(path[3], ids[3]);
241 }
242
243 #[test]
244 fn test_shortest_path_no_path() {
245 let (g, ids) = build_chain();
246 let path = shortest_path(&g, ids[3], ids[0]);
248 assert!(path.is_none());
249 }
250
251 #[test]
252 fn test_extract_subgraph() {
253 let (g, ids) = build_chain();
254 let (nodes, edges) = extract_subgraph(&g, ids[0], 2);
255 assert_eq!(nodes.len(), 3); assert_eq!(edges.len(), 2); }
258}