1use std::collections::VecDeque;
4
5use ahash::{HashMap, HashSet};
6use mentedb_core::edge::{EdgeType, MemoryEdge};
7use mentedb_core::types::MemoryId;
8
9use crate::csr::CsrGraph;
10
11pub fn bfs(graph: &CsrGraph, start: MemoryId, max_depth: usize) -> Vec<(MemoryId, usize)> {
13 let Some(_) = graph.get_idx(start) else {
14 return Vec::new();
15 };
16
17 let mut visited = HashSet::default();
18 let mut queue = VecDeque::new();
19 let mut result = Vec::new();
20
21 visited.insert(start);
22 queue.push_back((start, 0usize));
23
24 while let Some((node, depth)) = queue.pop_front() {
25 result.push((node, depth));
26 if depth >= max_depth {
27 continue;
28 }
29 for (neighbor, _edge) in graph.outgoing(node) {
30 if visited.insert(neighbor) {
31 queue.push_back((neighbor, depth + 1));
32 }
33 }
34 }
35
36 result
37}
38
39pub fn dfs(graph: &CsrGraph, start: MemoryId, max_depth: usize) -> Vec<(MemoryId, usize)> {
41 let Some(_) = graph.get_idx(start) else {
42 return Vec::new();
43 };
44
45 let mut visited = HashSet::default();
46 let mut stack = vec![(start, 0usize)];
47 let mut result = Vec::new();
48
49 while let Some((node, depth)) = stack.pop() {
50 if !visited.insert(node) {
51 continue;
52 }
53 result.push((node, depth));
54 if depth >= max_depth {
55 continue;
56 }
57 for (neighbor, _edge) in graph.outgoing(node) {
58 if !visited.contains(&neighbor) {
59 stack.push((neighbor, depth + 1));
60 }
61 }
62 }
63
64 result
65}
66
67pub fn bfs_filtered(
69 graph: &CsrGraph,
70 start: MemoryId,
71 max_depth: usize,
72 edge_filter: &[EdgeType],
73) -> Vec<(MemoryId, usize)> {
74 let Some(_) = graph.get_idx(start) else {
75 return Vec::new();
76 };
77
78 let filter_set: HashSet<EdgeType> = edge_filter.iter().copied().collect();
79 let mut visited = HashSet::default();
80 let mut queue = VecDeque::new();
81 let mut result = Vec::new();
82
83 visited.insert(start);
84 queue.push_back((start, 0usize));
85
86 while let Some((node, depth)) = queue.pop_front() {
87 result.push((node, depth));
88 if depth >= max_depth {
89 continue;
90 }
91 for (neighbor, edge) in graph.outgoing(node) {
92 if filter_set.contains(&edge.edge_type) && visited.insert(neighbor) {
93 queue.push_back((neighbor, depth + 1));
94 }
95 }
96 }
97
98 result
99}
100
101pub fn extract_subgraph(
103 graph: &CsrGraph,
104 center: MemoryId,
105 radius: usize,
106) -> (Vec<MemoryId>, Vec<MemoryEdge>) {
107 let nodes_with_depth = bfs(graph, center, radius);
108 let node_set: HashSet<MemoryId> = nodes_with_depth.iter().map(|&(id, _)| id).collect();
109
110 let nodes: Vec<MemoryId> = nodes_with_depth.into_iter().map(|(id, _)| id).collect();
111 let mut edges = Vec::new();
112
113 for &node in &nodes {
114 for (neighbor, stored) in graph.outgoing(node) {
115 if node_set.contains(&neighbor) {
116 edges.push(MemoryEdge {
117 source: node,
118 target: neighbor,
119 edge_type: stored.edge_type,
120 weight: stored.weight,
121 created_at: stored.created_at,
122 });
123 }
124 }
125 }
126
127 (nodes, edges)
128}
129
130pub fn shortest_path(graph: &CsrGraph, from: MemoryId, to: MemoryId) -> Option<Vec<MemoryId>> {
132 if from == to {
133 return Some(vec![from]);
134 }
135
136 let _ = graph.get_idx(from)?;
137 let _ = graph.get_idx(to)?;
138
139 let mut visited = HashSet::default();
140 let mut parent: HashMap<MemoryId, MemoryId> = HashMap::default();
141 let mut queue = VecDeque::new();
142
143 visited.insert(from);
144 queue.push_back(from);
145
146 while let Some(node) = queue.pop_front() {
147 for (neighbor, _) in graph.outgoing(node) {
148 if visited.insert(neighbor) {
149 parent.insert(neighbor, node);
150 if neighbor == to {
151 let mut path = vec![to];
153 let mut cur = to;
154 while let Some(&prev) = parent.get(&cur) {
155 path.push(prev);
156 cur = prev;
157 if cur == from {
158 break;
159 }
160 }
161 path.reverse();
162 return Some(path);
163 }
164 queue.push_back(neighbor);
165 }
166 }
167 }
168
169 None
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
177 MemoryEdge {
178 source: src,
179 target: tgt,
180 edge_type: etype,
181 weight: 1.0,
182 created_at: 1000,
183 }
184 }
185
186 fn build_chain() -> (CsrGraph, Vec<MemoryId>) {
187 let mut g = CsrGraph::new();
189 let ids: Vec<MemoryId> = (0..4).map(|_| MemoryId::new()).collect();
190 g.add_edge(&make_edge(ids[0], ids[1], EdgeType::Caused));
191 g.add_edge(&make_edge(ids[1], ids[2], EdgeType::Caused));
192 g.add_edge(&make_edge(ids[2], ids[3], EdgeType::Related));
193 (g, ids)
194 }
195
196 #[test]
197 fn test_bfs_chain() {
198 let (g, ids) = build_chain();
199 let result = bfs(&g, ids[0], 10);
200 assert_eq!(result.len(), 4);
201 assert_eq!(result[0], (ids[0], 0));
202 assert_eq!(result[1], (ids[1], 1));
203 }
204
205 #[test]
206 fn test_bfs_max_depth() {
207 let (g, ids) = build_chain();
208 let result = bfs(&g, ids[0], 1);
209 assert_eq!(result.len(), 2);
210 }
211
212 #[test]
213 fn test_dfs_chain() {
214 let (g, ids) = build_chain();
215 let result = dfs(&g, ids[0], 10);
216 assert_eq!(result.len(), 4);
217 assert_eq!(result[0].0, ids[0]);
218 }
219
220 #[test]
221 fn test_bfs_filtered() {
222 let (g, ids) = build_chain();
223 let result = bfs_filtered(&g, ids[0], 10, &[EdgeType::Caused]);
225 assert_eq!(result.len(), 3); }
227
228 #[test]
229 fn test_shortest_path() {
230 let (g, ids) = build_chain();
231 let path = shortest_path(&g, ids[0], ids[3]);
232 assert!(path.is_some());
233 let path = path.unwrap();
234 assert_eq!(path.len(), 4);
235 assert_eq!(path[0], ids[0]);
236 assert_eq!(path[3], ids[3]);
237 }
238
239 #[test]
240 fn test_shortest_path_no_path() {
241 let (g, ids) = build_chain();
242 let path = shortest_path(&g, ids[3], ids[0]);
244 assert!(path.is_none());
245 }
246
247 #[test]
248 fn test_extract_subgraph() {
249 let (g, ids) = build_chain();
250 let (nodes, edges) = extract_subgraph(&g, ids[0], 2);
251 assert_eq!(nodes.len(), 3); assert_eq!(edges.len(), 2); }
254}