codemem_storage/graph/
traversal.rs1use super::GraphEngine;
2use codemem_core::{
3 CodememError, Edge, GraphBackend, GraphNode, GraphStats, NodeKind, RelationshipType,
4};
5use petgraph::graph::NodeIndex;
6use petgraph::visit::{Bfs, EdgeRef};
7use petgraph::Direction;
8use std::collections::{HashMap, HashSet, VecDeque};
9
10impl Default for GraphEngine {
11 fn default() -> Self {
12 Self::new()
13 }
14}
15
16impl GraphBackend for GraphEngine {
17 fn add_node(&mut self, node: GraphNode) -> Result<(), CodememError> {
18 let id = node.id.clone();
19
20 if !self.id_to_index.contains_key(&id) {
21 let idx = self.graph.add_node(id.clone());
22 self.id_to_index.insert(id.clone(), idx);
23 }
24
25 self.nodes.insert(id, node);
26 Ok(())
27 }
28
29 fn get_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
30 Ok(self.nodes.get(id).cloned())
31 }
32
33 fn remove_node(&mut self, id: &str) -> Result<bool, CodememError> {
34 if let Some(idx) = self.id_to_index.remove(id) {
35 let last_idx = NodeIndex::new(self.graph.node_count() - 1);
38 self.graph.remove_node(idx);
39 if idx != last_idx {
42 if let Some(swapped_id) = self.graph.node_weight(idx) {
43 self.id_to_index.insert(swapped_id.clone(), idx);
44 }
45 }
46 self.nodes.remove(id);
47
48 if let Some(edge_ids) = self.edge_adj.remove(id) {
50 for eid in &edge_ids {
51 if let Some(edge) = self.edges.remove(eid) {
52 let other = if edge.src == id { &edge.dst } else { &edge.src };
54 if let Some(other_edges) = self.edge_adj.get_mut(other) {
55 other_edges.retain(|e| e != eid);
56 }
57 }
58 }
59 }
60
61 self.cached_pagerank.remove(id);
63 self.cached_betweenness.remove(id);
64
65 Ok(true)
66 } else {
67 Ok(false)
68 }
69 }
70
71 fn add_edge(&mut self, edge: Edge) -> Result<(), CodememError> {
72 let src_idx = self
73 .id_to_index
74 .get(&edge.src)
75 .ok_or_else(|| CodememError::NotFound(format!("Source node {}", edge.src)))?;
76 let dst_idx = self
77 .id_to_index
78 .get(&edge.dst)
79 .ok_or_else(|| CodememError::NotFound(format!("Destination node {}", edge.dst)))?;
80
81 self.graph.add_edge(*src_idx, *dst_idx, edge.weight);
82 self.edge_adj
84 .entry(edge.src.clone())
85 .or_default()
86 .push(edge.id.clone());
87 self.edge_adj
88 .entry(edge.dst.clone())
89 .or_default()
90 .push(edge.id.clone());
91 self.edges.insert(edge.id.clone(), edge);
92 Ok(())
93 }
94
95 fn get_edges(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
96 let edges: Vec<Edge> = self
97 .edge_adj
98 .get(node_id)
99 .map(|edge_ids| {
100 edge_ids
101 .iter()
102 .filter_map(|eid| self.edges.get(eid).cloned())
103 .collect()
104 })
105 .unwrap_or_default();
106 Ok(edges)
107 }
108
109 fn remove_edge(&mut self, id: &str) -> Result<bool, CodememError> {
110 if let Some(edge) = self.edges.remove(id) {
111 if let (Some(&src_idx), Some(&dst_idx)) = (
113 self.id_to_index.get(&edge.src),
114 self.id_to_index.get(&edge.dst),
115 ) {
116 let target_weight = edge.weight;
118 let petgraph_edge_idx = self
119 .graph
120 .edges_connecting(src_idx, dst_idx)
121 .find(|e| (*e.weight() - target_weight).abs() < f64::EPSILON)
122 .map(|e| e.id());
123 if let Some(eidx) = petgraph_edge_idx {
124 self.graph.remove_edge(eidx);
125 }
126 }
127 if let Some(src_edges) = self.edge_adj.get_mut(&edge.src) {
129 src_edges.retain(|e| e != id);
130 }
131 if let Some(dst_edges) = self.edge_adj.get_mut(&edge.dst) {
132 dst_edges.retain(|e| e != id);
133 }
134 Ok(true)
135 } else {
136 Ok(false)
137 }
138 }
139
140 fn bfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
141 let start_idx = self
142 .id_to_index
143 .get(start_id)
144 .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
145
146 let mut visited = HashSet::new();
147 let mut result = Vec::new();
148 let mut bfs = Bfs::new(&self.graph, *start_idx);
149 let mut depth_map: HashMap<NodeIndex, usize> = HashMap::new();
150 depth_map.insert(*start_idx, 0);
151
152 while let Some(node_idx) = bfs.next(&self.graph) {
153 let depth = depth_map.get(&node_idx).copied().unwrap_or(0);
154 if depth > max_depth {
155 continue;
156 }
157
158 if visited.insert(node_idx) {
159 if let Some(node_id) = self.graph.node_weight(node_idx) {
160 if let Some(node) = self.nodes.get(node_id) {
161 result.push(node.clone());
162 }
163 }
164 }
165
166 for neighbor in self.graph.neighbors(node_idx) {
168 depth_map.entry(neighbor).or_insert(depth + 1);
169 }
170 }
171
172 Ok(result)
173 }
174
175 fn dfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
176 let start_idx = self
177 .id_to_index
178 .get(start_id)
179 .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
180
181 let mut visited = HashSet::new();
182 let mut result = Vec::new();
183 let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
184
185 while let Some((node_idx, depth)) = stack.pop() {
186 if depth > max_depth || !visited.insert(node_idx) {
187 continue;
188 }
189
190 if let Some(node_id) = self.graph.node_weight(node_idx) {
191 if let Some(node) = self.nodes.get(node_id) {
192 result.push(node.clone());
193 }
194 }
195
196 for neighbor in self.graph.neighbors(node_idx) {
197 if !visited.contains(&neighbor) {
198 stack.push((neighbor, depth + 1));
199 }
200 }
201 }
202
203 Ok(result)
204 }
205
206 fn bfs_filtered(
207 &self,
208 start_id: &str,
209 max_depth: usize,
210 exclude_kinds: &[NodeKind],
211 include_relationships: Option<&[RelationshipType]>,
212 ) -> Result<Vec<GraphNode>, CodememError> {
213 let start_idx = self
214 .id_to_index
215 .get(start_id)
216 .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
217
218 let mut visited = HashSet::new();
219 let mut result = Vec::new();
220 let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
221 queue.push_back((*start_idx, 0));
222 visited.insert(*start_idx);
223
224 while let Some((node_idx, depth)) = queue.pop_front() {
225 if let Some(node_id) = self.graph.node_weight(node_idx) {
227 if let Some(node) = self.nodes.get(node_id) {
228 if !exclude_kinds.contains(&node.kind) {
229 result.push(node.clone());
230 }
231 }
232 }
233
234 if depth >= max_depth {
235 continue;
236 }
237
238 for neighbor_idx in self.graph.neighbors_directed(node_idx, Direction::Outgoing) {
240 if visited.contains(&neighbor_idx) {
241 continue;
242 }
243
244 if let Some(allowed_rels) = include_relationships {
246 let src_id = self
247 .graph
248 .node_weight(node_idx)
249 .cloned()
250 .unwrap_or_default();
251 let dst_id = self
252 .graph
253 .node_weight(neighbor_idx)
254 .cloned()
255 .unwrap_or_default();
256 let edge_matches = self
257 .edge_adj
258 .get(&src_id)
259 .map(|edge_ids| {
260 edge_ids.iter().any(|eid| {
261 self.edges.get(eid).is_some_and(|e| {
262 e.src == src_id
263 && e.dst == dst_id
264 && allowed_rels.contains(&e.relationship)
265 })
266 })
267 })
268 .unwrap_or(false);
269 if !edge_matches {
270 continue;
271 }
272 }
273
274 visited.insert(neighbor_idx);
277 queue.push_back((neighbor_idx, depth + 1));
278 }
279 }
280
281 Ok(result)
282 }
283
284 fn dfs_filtered(
285 &self,
286 start_id: &str,
287 max_depth: usize,
288 exclude_kinds: &[NodeKind],
289 include_relationships: Option<&[RelationshipType]>,
290 ) -> Result<Vec<GraphNode>, CodememError> {
291 let start_idx = self
292 .id_to_index
293 .get(start_id)
294 .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
295
296 let mut visited = HashSet::new();
297 let mut result = Vec::new();
298 let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
299
300 while let Some((node_idx, depth)) = stack.pop() {
301 if !visited.insert(node_idx) {
302 continue;
303 }
304
305 if let Some(node_id) = self.graph.node_weight(node_idx) {
307 if let Some(node) = self.nodes.get(node_id) {
308 if !exclude_kinds.contains(&node.kind) {
309 result.push(node.clone());
310 }
311 }
312 }
313
314 if depth >= max_depth {
315 continue;
316 }
317
318 for neighbor_idx in self.graph.neighbors_directed(node_idx, Direction::Outgoing) {
320 if visited.contains(&neighbor_idx) {
321 continue;
322 }
323
324 if let Some(allowed_rels) = include_relationships {
326 let src_id = self
327 .graph
328 .node_weight(node_idx)
329 .cloned()
330 .unwrap_or_default();
331 let dst_id = self
332 .graph
333 .node_weight(neighbor_idx)
334 .cloned()
335 .unwrap_or_default();
336 let edge_matches = self
337 .edge_adj
338 .get(&src_id)
339 .map(|edge_ids| {
340 edge_ids.iter().any(|eid| {
341 self.edges.get(eid).is_some_and(|e| {
342 e.src == src_id
343 && e.dst == dst_id
344 && allowed_rels.contains(&e.relationship)
345 })
346 })
347 })
348 .unwrap_or(false);
349 if !edge_matches {
350 continue;
351 }
352 }
353
354 stack.push((neighbor_idx, depth + 1));
357 }
358 }
359
360 Ok(result)
361 }
362
363 fn shortest_path(&self, from: &str, to: &str) -> Result<Vec<String>, CodememError> {
364 let from_idx = self
365 .id_to_index
366 .get(from)
367 .ok_or_else(|| CodememError::NotFound(format!("Node {from}")))?;
368 let to_idx = self
369 .id_to_index
370 .get(to)
371 .ok_or_else(|| CodememError::NotFound(format!("Node {to}")))?;
372
373 use petgraph::algo::astar;
375 let path = astar(
376 &self.graph,
377 *from_idx,
378 |finish| finish == *to_idx,
379 |_| 1.0f64,
380 |_| 0.0f64,
381 );
382
383 match path {
384 Some((_cost, nodes)) => {
385 let ids: Vec<String> = nodes
386 .iter()
387 .filter_map(|idx| self.graph.node_weight(*idx).cloned())
388 .collect();
389 Ok(ids)
390 }
391 None => Ok(vec![]),
392 }
393 }
394
395 fn stats(&self) -> GraphStats {
397 let mut node_kind_counts = HashMap::new();
398 for node in self.nodes.values() {
399 *node_kind_counts.entry(node.kind.to_string()).or_insert(0) += 1;
400 }
401
402 let mut relationship_type_counts = HashMap::new();
403 for edge in self.edges.values() {
404 *relationship_type_counts
405 .entry(edge.relationship.to_string())
406 .or_insert(0) += 1;
407 }
408
409 GraphStats {
410 node_count: self.nodes.len(),
411 edge_count: self.edges.len(),
412 node_kind_counts,
413 relationship_type_counts,
414 }
415 }
416}
417
418#[cfg(test)]
419#[path = "../tests/graph_traversal_tests.rs"]
420mod tests;