codemem_storage/graph/
traversal.rs1use super::GraphEngine;
2use codemem_core::{
3 CodememError, Edge, GraphBackend, GraphNode, GraphStats, NodeKind, RelationshipType,
4};
5use petgraph::graph::NodeIndex;
6use petgraph::visit::EdgeRef;
7use petgraph::Direction;
8use std::collections::{HashMap, HashSet, VecDeque};
9
10impl Default for GraphEngine {
11 fn default() -> Self {
12 Self::new()
13 }
14}
15
16impl GraphBackend for GraphEngine {
17 fn add_node(&mut self, node: GraphNode) -> Result<(), CodememError> {
18 let id = node.id.clone();
19
20 if !self.id_to_index.contains_key(&id) {
21 let idx = self.graph.add_node(id.clone());
22 self.id_to_index.insert(id.clone(), idx);
23 }
24
25 self.nodes.insert(id, node);
26 Ok(())
27 }
28
29 fn get_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
30 Ok(self.nodes.get(id).cloned())
31 }
32
33 fn remove_node(&mut self, id: &str) -> Result<bool, CodememError> {
34 if let Some(idx) = self.id_to_index.remove(id) {
35 let last_idx = NodeIndex::new(self.graph.node_count() - 1);
38 self.graph.remove_node(idx);
39 if idx != last_idx {
42 if let Some(swapped_id) = self.graph.node_weight(idx) {
43 self.id_to_index.insert(swapped_id.clone(), idx);
44 }
45 }
46 self.nodes.remove(id);
47
48 if let Some(edge_ids) = self.edge_adj.remove(id) {
50 for eid in &edge_ids {
51 if let Some(edge) = self.edges.remove(eid) {
52 let other = if edge.src == id { &edge.dst } else { &edge.src };
54 if let Some(other_edges) = self.edge_adj.get_mut(other) {
55 other_edges.retain(|e| e != eid);
56 }
57 }
58 }
59 }
60
61 self.cached_pagerank.remove(id);
63 self.cached_betweenness.remove(id);
64
65 Ok(true)
66 } else {
67 Ok(false)
68 }
69 }
70
71 fn add_edge(&mut self, edge: Edge) -> Result<(), CodememError> {
72 let src_idx = self
73 .id_to_index
74 .get(&edge.src)
75 .ok_or_else(|| CodememError::NotFound(format!("Source node {}", edge.src)))?;
76 let dst_idx = self
77 .id_to_index
78 .get(&edge.dst)
79 .ok_or_else(|| CodememError::NotFound(format!("Destination node {}", edge.dst)))?;
80
81 self.graph.add_edge(*src_idx, *dst_idx, edge.weight);
82 self.edge_adj
84 .entry(edge.src.clone())
85 .or_default()
86 .push(edge.id.clone());
87 self.edge_adj
88 .entry(edge.dst.clone())
89 .or_default()
90 .push(edge.id.clone());
91 self.edges.insert(edge.id.clone(), edge);
92 Ok(())
93 }
94
95 fn get_edges(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
96 let edges: Vec<Edge> = self
97 .edge_adj
98 .get(node_id)
99 .map(|edge_ids| {
100 edge_ids
101 .iter()
102 .filter_map(|eid| self.edges.get(eid).cloned())
103 .collect()
104 })
105 .unwrap_or_default();
106 Ok(edges)
107 }
108
109 fn remove_edge(&mut self, id: &str) -> Result<bool, CodememError> {
110 if let Some(edge) = self.edges.remove(id) {
111 if let (Some(&src_idx), Some(&dst_idx)) = (
113 self.id_to_index.get(&edge.src),
114 self.id_to_index.get(&edge.dst),
115 ) {
116 let target_weight = edge.weight;
118 let petgraph_edge_idx = self
119 .graph
120 .edges_connecting(src_idx, dst_idx)
121 .find(|e| (*e.weight() - target_weight).abs() < f64::EPSILON)
122 .map(|e| e.id());
123 if let Some(eidx) = petgraph_edge_idx {
124 self.graph.remove_edge(eidx);
125 }
126 }
127 if let Some(src_edges) = self.edge_adj.get_mut(&edge.src) {
129 src_edges.retain(|e| e != id);
130 }
131 if let Some(dst_edges) = self.edge_adj.get_mut(&edge.dst) {
132 dst_edges.retain(|e| e != id);
133 }
134 Ok(true)
135 } else {
136 Ok(false)
137 }
138 }
139
140 fn bfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
141 let start_idx = self
142 .id_to_index
143 .get(start_id)
144 .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
145
146 let mut visited = HashSet::new();
147 let mut result = Vec::new();
148 let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
149 queue.push_back((*start_idx, 0));
150 visited.insert(*start_idx);
151
152 while let Some((node_idx, depth)) = queue.pop_front() {
153 if let Some(node_id) = self.graph.node_weight(node_idx) {
154 if let Some(node) = self.nodes.get(node_id) {
155 result.push(node.clone());
156 }
157 }
158
159 if depth >= max_depth {
160 continue;
161 }
162
163 for neighbor in self.graph.neighbors_undirected(node_idx) {
165 if visited.insert(neighbor) {
166 queue.push_back((neighbor, depth + 1));
167 }
168 }
169 }
170
171 Ok(result)
172 }
173
174 fn dfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
175 let start_idx = self
176 .id_to_index
177 .get(start_id)
178 .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
179
180 let mut visited = HashSet::new();
181 let mut result = Vec::new();
182 let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
183
184 while let Some((node_idx, depth)) = stack.pop() {
185 if depth > max_depth || !visited.insert(node_idx) {
186 continue;
187 }
188
189 if let Some(node_id) = self.graph.node_weight(node_idx) {
190 if let Some(node) = self.nodes.get(node_id) {
191 result.push(node.clone());
192 }
193 }
194
195 for neighbor in self.graph.neighbors_undirected(node_idx) {
196 if !visited.contains(&neighbor) {
197 stack.push((neighbor, depth + 1));
198 }
199 }
200 }
201
202 Ok(result)
203 }
204
205 fn bfs_filtered(
206 &self,
207 start_id: &str,
208 max_depth: usize,
209 exclude_kinds: &[NodeKind],
210 include_relationships: Option<&[RelationshipType]>,
211 ) -> Result<Vec<GraphNode>, CodememError> {
212 let start_idx = self
213 .id_to_index
214 .get(start_id)
215 .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
216
217 let mut visited = HashSet::new();
218 let mut result = Vec::new();
219 let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
220 queue.push_back((*start_idx, 0));
221 visited.insert(*start_idx);
222
223 while let Some((node_idx, depth)) = queue.pop_front() {
224 if let Some(node_id) = self.graph.node_weight(node_idx) {
226 if let Some(node) = self.nodes.get(node_id) {
227 if !exclude_kinds.contains(&node.kind) {
228 result.push(node.clone());
229 }
230 }
231 }
232
233 if depth >= max_depth {
234 continue;
235 }
236
237 for neighbor_idx in self.graph.neighbors_directed(node_idx, Direction::Outgoing) {
239 if visited.contains(&neighbor_idx) {
240 continue;
241 }
242
243 if let Some(allowed_rels) = include_relationships {
245 let src_id = self
246 .graph
247 .node_weight(node_idx)
248 .cloned()
249 .unwrap_or_default();
250 let dst_id = self
251 .graph
252 .node_weight(neighbor_idx)
253 .cloned()
254 .unwrap_or_default();
255 let edge_matches = self
256 .edge_adj
257 .get(&src_id)
258 .map(|edge_ids| {
259 edge_ids.iter().any(|eid| {
260 self.edges.get(eid).is_some_and(|e| {
261 e.src == src_id
262 && e.dst == dst_id
263 && allowed_rels.contains(&e.relationship)
264 })
265 })
266 })
267 .unwrap_or(false);
268 if !edge_matches {
269 continue;
270 }
271 }
272
273 visited.insert(neighbor_idx);
276 queue.push_back((neighbor_idx, depth + 1));
277 }
278 }
279
280 Ok(result)
281 }
282
283 fn dfs_filtered(
284 &self,
285 start_id: &str,
286 max_depth: usize,
287 exclude_kinds: &[NodeKind],
288 include_relationships: Option<&[RelationshipType]>,
289 ) -> Result<Vec<GraphNode>, CodememError> {
290 let start_idx = self
291 .id_to_index
292 .get(start_id)
293 .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
294
295 let mut visited = HashSet::new();
296 let mut result = Vec::new();
297 let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
298
299 while let Some((node_idx, depth)) = stack.pop() {
300 if !visited.insert(node_idx) {
301 continue;
302 }
303
304 if let Some(node_id) = self.graph.node_weight(node_idx) {
306 if let Some(node) = self.nodes.get(node_id) {
307 if !exclude_kinds.contains(&node.kind) {
308 result.push(node.clone());
309 }
310 }
311 }
312
313 if depth >= max_depth {
314 continue;
315 }
316
317 for neighbor_idx in self.graph.neighbors_directed(node_idx, Direction::Outgoing) {
319 if visited.contains(&neighbor_idx) {
320 continue;
321 }
322
323 if let Some(allowed_rels) = include_relationships {
325 let src_id = self
326 .graph
327 .node_weight(node_idx)
328 .cloned()
329 .unwrap_or_default();
330 let dst_id = self
331 .graph
332 .node_weight(neighbor_idx)
333 .cloned()
334 .unwrap_or_default();
335 let edge_matches = self
336 .edge_adj
337 .get(&src_id)
338 .map(|edge_ids| {
339 edge_ids.iter().any(|eid| {
340 self.edges.get(eid).is_some_and(|e| {
341 e.src == src_id
342 && e.dst == dst_id
343 && allowed_rels.contains(&e.relationship)
344 })
345 })
346 })
347 .unwrap_or(false);
348 if !edge_matches {
349 continue;
350 }
351 }
352
353 stack.push((neighbor_idx, depth + 1));
356 }
357 }
358
359 Ok(result)
360 }
361
362 fn shortest_path(&self, from: &str, to: &str) -> Result<Vec<String>, CodememError> {
363 let from_idx = self
364 .id_to_index
365 .get(from)
366 .ok_or_else(|| CodememError::NotFound(format!("Node {from}")))?;
367 let to_idx = self
368 .id_to_index
369 .get(to)
370 .ok_or_else(|| CodememError::NotFound(format!("Node {to}")))?;
371
372 use petgraph::algo::astar;
374 let path = astar(
375 &self.graph,
376 *from_idx,
377 |finish| finish == *to_idx,
378 |_| 1.0f64,
379 |_| 0.0f64,
380 );
381
382 match path {
383 Some((_cost, nodes)) => {
384 let ids: Vec<String> = nodes
385 .iter()
386 .filter_map(|idx| self.graph.node_weight(*idx).cloned())
387 .collect();
388 Ok(ids)
389 }
390 None => Ok(vec![]),
391 }
392 }
393
394 fn stats(&self) -> GraphStats {
396 let mut node_kind_counts = HashMap::new();
397 for node in self.nodes.values() {
398 *node_kind_counts.entry(node.kind.to_string()).or_insert(0) += 1;
399 }
400
401 let mut relationship_type_counts = HashMap::new();
402 for edge in self.edges.values() {
403 *relationship_type_counts
404 .entry(edge.relationship.to_string())
405 .or_insert(0) += 1;
406 }
407
408 GraphStats {
409 node_count: self.nodes.len(),
410 edge_count: self.edges.len(),
411 node_kind_counts,
412 relationship_type_counts,
413 }
414 }
415}
416
417#[cfg(test)]
418#[path = "../tests/graph_traversal_tests.rs"]
419mod tests;