codemem_graph/
traversal.rs1use crate::GraphEngine;
2use codemem_core::{
3 CodememError, Edge, GraphBackend, GraphNode, GraphStats, NodeKind, RelationshipType,
4};
5use petgraph::graph::NodeIndex;
6use petgraph::visit::Bfs;
7use petgraph::Direction;
8use std::collections::{HashMap, HashSet, VecDeque};
9
10impl Default for GraphEngine {
11 fn default() -> Self {
12 Self::new()
13 }
14}
15
16impl GraphBackend for GraphEngine {
17 fn add_node(&mut self, node: GraphNode) -> Result<(), CodememError> {
18 let id = node.id.clone();
19
20 if !self.id_to_index.contains_key(&id) {
21 let idx = self.graph.add_node(id.clone());
22 self.id_to_index.insert(id.clone(), idx);
23 }
24
25 self.nodes.insert(id, node);
26 Ok(())
27 }
28
29 fn get_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
30 Ok(self.nodes.get(id).cloned())
31 }
32
33 fn remove_node(&mut self, id: &str) -> Result<bool, CodememError> {
34 if let Some(idx) = self.id_to_index.remove(id) {
35 self.graph.remove_node(idx);
36 self.nodes.remove(id);
37
38 let edge_ids: Vec<String> = self
40 .edges
41 .iter()
42 .filter(|(_, e)| e.src == id || e.dst == id)
43 .map(|(eid, _)| eid.clone())
44 .collect();
45 for eid in edge_ids {
46 self.edges.remove(&eid);
47 }
48
49 Ok(true)
50 } else {
51 Ok(false)
52 }
53 }
54
55 fn add_edge(&mut self, edge: Edge) -> Result<(), CodememError> {
56 let src_idx = self
57 .id_to_index
58 .get(&edge.src)
59 .ok_or_else(|| CodememError::NotFound(format!("Source node {}", edge.src)))?;
60 let dst_idx = self
61 .id_to_index
62 .get(&edge.dst)
63 .ok_or_else(|| CodememError::NotFound(format!("Destination node {}", edge.dst)))?;
64
65 self.graph.add_edge(*src_idx, *dst_idx, edge.weight);
66 self.edges.insert(edge.id.clone(), edge);
67 Ok(())
68 }
69
70 fn get_edges(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
71 let edges: Vec<Edge> = self
72 .edges
73 .values()
74 .filter(|e| e.src == node_id || e.dst == node_id)
75 .cloned()
76 .collect();
77 Ok(edges)
78 }
79
80 fn remove_edge(&mut self, id: &str) -> Result<bool, CodememError> {
81 if let Some(edge) = self.edges.remove(id) {
82 if let (Some(&src_idx), Some(&dst_idx)) = (
84 self.id_to_index.get(&edge.src),
85 self.id_to_index.get(&edge.dst),
86 ) {
87 if let Some(edge_idx) = self.graph.find_edge(src_idx, dst_idx) {
88 self.graph.remove_edge(edge_idx);
89 }
90 }
91 Ok(true)
92 } else {
93 Ok(false)
94 }
95 }
96
97 fn bfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
98 let start_idx = self
99 .id_to_index
100 .get(start_id)
101 .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
102
103 let mut visited = HashSet::new();
104 let mut result = Vec::new();
105 let mut bfs = Bfs::new(&self.graph, *start_idx);
106 let mut depth_map: HashMap<NodeIndex, usize> = HashMap::new();
107 depth_map.insert(*start_idx, 0);
108
109 while let Some(node_idx) = bfs.next(&self.graph) {
110 let depth = depth_map.get(&node_idx).copied().unwrap_or(0);
111 if depth > max_depth {
112 continue;
113 }
114
115 if visited.insert(node_idx) {
116 if let Some(node_id) = self.graph.node_weight(node_idx) {
117 if let Some(node) = self.nodes.get(node_id) {
118 result.push(node.clone());
119 }
120 }
121 }
122
123 for neighbor in self.graph.neighbors(node_idx) {
125 depth_map.entry(neighbor).or_insert(depth + 1);
126 }
127 }
128
129 Ok(result)
130 }
131
132 fn dfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
133 let start_idx = self
134 .id_to_index
135 .get(start_id)
136 .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
137
138 let mut visited = HashSet::new();
139 let mut result = Vec::new();
140 let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
141
142 while let Some((node_idx, depth)) = stack.pop() {
143 if depth > max_depth || !visited.insert(node_idx) {
144 continue;
145 }
146
147 if let Some(node_id) = self.graph.node_weight(node_idx) {
148 if let Some(node) = self.nodes.get(node_id) {
149 result.push(node.clone());
150 }
151 }
152
153 for neighbor in self.graph.neighbors(node_idx) {
154 if !visited.contains(&neighbor) {
155 stack.push((neighbor, depth + 1));
156 }
157 }
158 }
159
160 Ok(result)
161 }
162
163 fn bfs_filtered(
164 &self,
165 start_id: &str,
166 max_depth: usize,
167 exclude_kinds: &[NodeKind],
168 include_relationships: Option<&[RelationshipType]>,
169 ) -> Result<Vec<GraphNode>, CodememError> {
170 let start_idx = self
171 .id_to_index
172 .get(start_id)
173 .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
174
175 let mut visited = HashSet::new();
176 let mut result = Vec::new();
177 let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
178 queue.push_back((*start_idx, 0));
179 visited.insert(*start_idx);
180
181 while let Some((node_idx, depth)) = queue.pop_front() {
182 if let Some(node_id) = self.graph.node_weight(node_idx) {
184 if let Some(node) = self.nodes.get(node_id) {
185 if !exclude_kinds.contains(&node.kind) {
186 result.push(node.clone());
187 }
188 }
189 }
190
191 if depth >= max_depth {
192 continue;
193 }
194
195 for neighbor_idx in self.graph.neighbors_directed(node_idx, Direction::Outgoing) {
197 if visited.contains(&neighbor_idx) {
198 continue;
199 }
200
201 if let Some(allowed_rels) = include_relationships {
203 let src_id = self
204 .graph
205 .node_weight(node_idx)
206 .cloned()
207 .unwrap_or_default();
208 let dst_id = self
209 .graph
210 .node_weight(neighbor_idx)
211 .cloned()
212 .unwrap_or_default();
213 let edge_matches = self.edges.values().any(|e| {
214 e.src == src_id && e.dst == dst_id && allowed_rels.contains(&e.relationship)
215 });
216 if !edge_matches {
217 continue;
218 }
219 }
220
221 if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
223 if let Some(neighbor_node) = self.nodes.get(neighbor_id) {
224 if exclude_kinds.contains(&neighbor_node.kind) {
225 continue;
226 }
227 }
228 }
229
230 visited.insert(neighbor_idx);
231 queue.push_back((neighbor_idx, depth + 1));
232 }
233 }
234
235 Ok(result)
236 }
237
238 fn dfs_filtered(
239 &self,
240 start_id: &str,
241 max_depth: usize,
242 exclude_kinds: &[NodeKind],
243 include_relationships: Option<&[RelationshipType]>,
244 ) -> Result<Vec<GraphNode>, CodememError> {
245 let start_idx = self
246 .id_to_index
247 .get(start_id)
248 .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
249
250 let mut visited = HashSet::new();
251 let mut result = Vec::new();
252 let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
253
254 while let Some((node_idx, depth)) = stack.pop() {
255 if !visited.insert(node_idx) {
256 continue;
257 }
258
259 if let Some(node_id) = self.graph.node_weight(node_idx) {
261 if let Some(node) = self.nodes.get(node_id) {
262 if !exclude_kinds.contains(&node.kind) {
263 result.push(node.clone());
264 }
265 }
266 }
267
268 if depth >= max_depth {
269 continue;
270 }
271
272 for neighbor_idx in self.graph.neighbors_directed(node_idx, Direction::Outgoing) {
274 if visited.contains(&neighbor_idx) {
275 continue;
276 }
277
278 if let Some(allowed_rels) = include_relationships {
280 let src_id = self
281 .graph
282 .node_weight(node_idx)
283 .cloned()
284 .unwrap_or_default();
285 let dst_id = self
286 .graph
287 .node_weight(neighbor_idx)
288 .cloned()
289 .unwrap_or_default();
290 let edge_matches = self.edges.values().any(|e| {
291 e.src == src_id && e.dst == dst_id && allowed_rels.contains(&e.relationship)
292 });
293 if !edge_matches {
294 continue;
295 }
296 }
297
298 if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
300 if let Some(neighbor_node) = self.nodes.get(neighbor_id) {
301 if exclude_kinds.contains(&neighbor_node.kind) {
302 continue;
303 }
304 }
305 }
306
307 stack.push((neighbor_idx, depth + 1));
308 }
309 }
310
311 Ok(result)
312 }
313
314 fn shortest_path(&self, from: &str, to: &str) -> Result<Vec<String>, CodememError> {
315 let from_idx = self
316 .id_to_index
317 .get(from)
318 .ok_or_else(|| CodememError::NotFound(format!("Node {from}")))?;
319 let to_idx = self
320 .id_to_index
321 .get(to)
322 .ok_or_else(|| CodememError::NotFound(format!("Node {to}")))?;
323
324 use petgraph::algo::astar;
326 let path = astar(
327 &self.graph,
328 *from_idx,
329 |finish| finish == *to_idx,
330 |_| 1.0f64,
331 |_| 0.0f64,
332 );
333
334 match path {
335 Some((_cost, nodes)) => {
336 let ids: Vec<String> = nodes
337 .iter()
338 .filter_map(|idx| self.graph.node_weight(*idx).cloned())
339 .collect();
340 Ok(ids)
341 }
342 None => Ok(vec![]),
343 }
344 }
345
346 fn stats(&self) -> GraphStats {
347 let mut node_kind_counts = HashMap::new();
348 for node in self.nodes.values() {
349 *node_kind_counts.entry(node.kind.to_string()).or_insert(0) += 1;
350 }
351
352 let mut relationship_type_counts = HashMap::new();
353 for edge in self.edges.values() {
354 *relationship_type_counts
355 .entry(edge.relationship.to_string())
356 .or_insert(0) += 1;
357 }
358
359 GraphStats {
360 node_count: self.nodes.len(),
361 edge_count: self.edges.len(),
362 node_kind_counts,
363 relationship_type_counts,
364 }
365 }
366}
367
368#[cfg(test)]
369#[path = "tests/traversal_tests.rs"]
370mod tests;