Skip to main content

codemem_storage/graph/
traversal.rs

1use super::GraphEngine;
2use codemem_core::{
3    CodememError, Edge, GraphBackend, GraphNode, GraphStats, NodeKind, RelationshipType,
4};
5use petgraph::graph::NodeIndex;
6use petgraph::visit::{Bfs, EdgeRef};
7use petgraph::Direction;
8use std::collections::{HashMap, HashSet, VecDeque};
9
10impl Default for GraphEngine {
11    fn default() -> Self {
12        Self::new()
13    }
14}
15
16impl GraphBackend for GraphEngine {
17    fn add_node(&mut self, node: GraphNode) -> Result<(), CodememError> {
18        let id = node.id.clone();
19
20        if !self.id_to_index.contains_key(&id) {
21            let idx = self.graph.add_node(id.clone());
22            self.id_to_index.insert(id.clone(), idx);
23        }
24
25        self.nodes.insert(id, node);
26        Ok(())
27    }
28
29    fn get_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
30        Ok(self.nodes.get(id).cloned())
31    }
32
33    fn remove_node(&mut self, id: &str) -> Result<bool, CodememError> {
34        if let Some(idx) = self.id_to_index.remove(id) {
35            // petgraph::DiGraph::remove_node swaps the last node into the removed
36            // slot, invalidating the last node's NodeIndex. We must fix id_to_index.
37            let last_idx = NodeIndex::new(self.graph.node_count() - 1);
38            self.graph.remove_node(idx);
39            // After removal, the node that was at `last_idx` is now at `idx`
40            // (unless we removed the last node itself).
41            if idx != last_idx {
42                if let Some(swapped_id) = self.graph.node_weight(idx) {
43                    self.id_to_index.insert(swapped_id.clone(), idx);
44                }
45            }
46            self.nodes.remove(id);
47
48            // Remove associated edges using edge adjacency index
49            if let Some(edge_ids) = self.edge_adj.remove(id) {
50                for eid in &edge_ids {
51                    if let Some(edge) = self.edges.remove(eid) {
52                        // Also remove from the other endpoint's adjacency list
53                        let other = if edge.src == id { &edge.dst } else { &edge.src };
54                        if let Some(other_edges) = self.edge_adj.get_mut(other) {
55                            other_edges.retain(|e| e != eid);
56                        }
57                    }
58                }
59            }
60
61            // Clean up centrality caches
62            self.cached_pagerank.remove(id);
63            self.cached_betweenness.remove(id);
64
65            Ok(true)
66        } else {
67            Ok(false)
68        }
69    }
70
71    fn add_edge(&mut self, edge: Edge) -> Result<(), CodememError> {
72        let src_idx = self
73            .id_to_index
74            .get(&edge.src)
75            .ok_or_else(|| CodememError::NotFound(format!("Source node {}", edge.src)))?;
76        let dst_idx = self
77            .id_to_index
78            .get(&edge.dst)
79            .ok_or_else(|| CodememError::NotFound(format!("Destination node {}", edge.dst)))?;
80
81        self.graph.add_edge(*src_idx, *dst_idx, edge.weight);
82        // Maintain edge adjacency index
83        self.edge_adj
84            .entry(edge.src.clone())
85            .or_default()
86            .push(edge.id.clone());
87        self.edge_adj
88            .entry(edge.dst.clone())
89            .or_default()
90            .push(edge.id.clone());
91        self.edges.insert(edge.id.clone(), edge);
92        Ok(())
93    }
94
95    fn get_edges(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
96        let edges: Vec<Edge> = self
97            .edge_adj
98            .get(node_id)
99            .map(|edge_ids| {
100                edge_ids
101                    .iter()
102                    .filter_map(|eid| self.edges.get(eid).cloned())
103                    .collect()
104            })
105            .unwrap_or_default();
106        Ok(edges)
107    }
108
109    fn remove_edge(&mut self, id: &str) -> Result<bool, CodememError> {
110        if let Some(edge) = self.edges.remove(id) {
111            // Remove from petgraph — match by weight to handle parallel edges
112            if let (Some(&src_idx), Some(&dst_idx)) = (
113                self.id_to_index.get(&edge.src),
114                self.id_to_index.get(&edge.dst),
115            ) {
116                // Iterate edges_connecting to find the correct one by weight
117                let target_weight = edge.weight;
118                let petgraph_edge_idx = self
119                    .graph
120                    .edges_connecting(src_idx, dst_idx)
121                    .find(|e| (*e.weight() - target_weight).abs() < f64::EPSILON)
122                    .map(|e| e.id());
123                if let Some(eidx) = petgraph_edge_idx {
124                    self.graph.remove_edge(eidx);
125                }
126            }
127            // Maintain edge adjacency index
128            if let Some(src_edges) = self.edge_adj.get_mut(&edge.src) {
129                src_edges.retain(|e| e != id);
130            }
131            if let Some(dst_edges) = self.edge_adj.get_mut(&edge.dst) {
132                dst_edges.retain(|e| e != id);
133            }
134            Ok(true)
135        } else {
136            Ok(false)
137        }
138    }
139
140    fn bfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
141        let start_idx = self
142            .id_to_index
143            .get(start_id)
144            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
145
146        let mut visited = HashSet::new();
147        let mut result = Vec::new();
148        let mut bfs = Bfs::new(&self.graph, *start_idx);
149        let mut depth_map: HashMap<NodeIndex, usize> = HashMap::new();
150        depth_map.insert(*start_idx, 0);
151
152        while let Some(node_idx) = bfs.next(&self.graph) {
153            let depth = depth_map.get(&node_idx).copied().unwrap_or(0);
154            if depth > max_depth {
155                continue;
156            }
157
158            if visited.insert(node_idx) {
159                if let Some(node_id) = self.graph.node_weight(node_idx) {
160                    if let Some(node) = self.nodes.get(node_id) {
161                        result.push(node.clone());
162                    }
163                }
164            }
165
166            // Set depth for neighbors
167            for neighbor in self.graph.neighbors(node_idx) {
168                depth_map.entry(neighbor).or_insert(depth + 1);
169            }
170        }
171
172        Ok(result)
173    }
174
175    fn dfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
176        let start_idx = self
177            .id_to_index
178            .get(start_id)
179            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
180
181        let mut visited = HashSet::new();
182        let mut result = Vec::new();
183        let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
184
185        while let Some((node_idx, depth)) = stack.pop() {
186            if depth > max_depth || !visited.insert(node_idx) {
187                continue;
188            }
189
190            if let Some(node_id) = self.graph.node_weight(node_idx) {
191                if let Some(node) = self.nodes.get(node_id) {
192                    result.push(node.clone());
193                }
194            }
195
196            for neighbor in self.graph.neighbors(node_idx) {
197                if !visited.contains(&neighbor) {
198                    stack.push((neighbor, depth + 1));
199                }
200            }
201        }
202
203        Ok(result)
204    }
205
206    fn bfs_filtered(
207        &self,
208        start_id: &str,
209        max_depth: usize,
210        exclude_kinds: &[NodeKind],
211        include_relationships: Option<&[RelationshipType]>,
212    ) -> Result<Vec<GraphNode>, CodememError> {
213        let start_idx = self
214            .id_to_index
215            .get(start_id)
216            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
217
218        let mut visited = HashSet::new();
219        let mut result = Vec::new();
220        let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
221        queue.push_back((*start_idx, 0));
222        visited.insert(*start_idx);
223
224        while let Some((node_idx, depth)) = queue.pop_front() {
225            // Add current node to results if not excluded
226            if let Some(node_id) = self.graph.node_weight(node_idx) {
227                if let Some(node) = self.nodes.get(node_id) {
228                    if !exclude_kinds.contains(&node.kind) {
229                        result.push(node.clone());
230                    }
231                }
232            }
233
234            if depth >= max_depth {
235                continue;
236            }
237
238            // Explore outgoing edges with filtering
239            for neighbor_idx in self.graph.neighbors_directed(node_idx, Direction::Outgoing) {
240                if visited.contains(&neighbor_idx) {
241                    continue;
242                }
243
244                // Check relationship filter if set, using edge adjacency index
245                if let Some(allowed_rels) = include_relationships {
246                    let src_id = self
247                        .graph
248                        .node_weight(node_idx)
249                        .cloned()
250                        .unwrap_or_default();
251                    let dst_id = self
252                        .graph
253                        .node_weight(neighbor_idx)
254                        .cloned()
255                        .unwrap_or_default();
256                    let edge_matches = self
257                        .edge_adj
258                        .get(&src_id)
259                        .map(|edge_ids| {
260                            edge_ids.iter().any(|eid| {
261                                self.edges.get(eid).is_some_and(|e| {
262                                    e.src == src_id
263                                        && e.dst == dst_id
264                                        && allowed_rels.contains(&e.relationship)
265                                })
266                            })
267                        })
268                        .unwrap_or(false);
269                    if !edge_matches {
270                        continue;
271                    }
272                }
273
274                // Always traverse through excluded-kind nodes but don't include
275                // them in results (handled above when popped from queue).
276                visited.insert(neighbor_idx);
277                queue.push_back((neighbor_idx, depth + 1));
278            }
279        }
280
281        Ok(result)
282    }
283
284    fn dfs_filtered(
285        &self,
286        start_id: &str,
287        max_depth: usize,
288        exclude_kinds: &[NodeKind],
289        include_relationships: Option<&[RelationshipType]>,
290    ) -> Result<Vec<GraphNode>, CodememError> {
291        let start_idx = self
292            .id_to_index
293            .get(start_id)
294            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
295
296        let mut visited = HashSet::new();
297        let mut result = Vec::new();
298        let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
299
300        while let Some((node_idx, depth)) = stack.pop() {
301            if !visited.insert(node_idx) {
302                continue;
303            }
304
305            // Add current node to results if not excluded
306            if let Some(node_id) = self.graph.node_weight(node_idx) {
307                if let Some(node) = self.nodes.get(node_id) {
308                    if !exclude_kinds.contains(&node.kind) {
309                        result.push(node.clone());
310                    }
311                }
312            }
313
314            if depth >= max_depth {
315                continue;
316            }
317
318            // Explore outgoing edges with filtering
319            for neighbor_idx in self.graph.neighbors_directed(node_idx, Direction::Outgoing) {
320                if visited.contains(&neighbor_idx) {
321                    continue;
322                }
323
324                // Check relationship filter if set, using edge adjacency index
325                if let Some(allowed_rels) = include_relationships {
326                    let src_id = self
327                        .graph
328                        .node_weight(node_idx)
329                        .cloned()
330                        .unwrap_or_default();
331                    let dst_id = self
332                        .graph
333                        .node_weight(neighbor_idx)
334                        .cloned()
335                        .unwrap_or_default();
336                    let edge_matches = self
337                        .edge_adj
338                        .get(&src_id)
339                        .map(|edge_ids| {
340                            edge_ids.iter().any(|eid| {
341                                self.edges.get(eid).is_some_and(|e| {
342                                    e.src == src_id
343                                        && e.dst == dst_id
344                                        && allowed_rels.contains(&e.relationship)
345                                })
346                            })
347                        })
348                        .unwrap_or(false);
349                    if !edge_matches {
350                        continue;
351                    }
352                }
353
354                // Always traverse through excluded-kind nodes but don't include
355                // them in results (handled above when popped from stack).
356                stack.push((neighbor_idx, depth + 1));
357            }
358        }
359
360        Ok(result)
361    }
362
363    fn shortest_path(&self, from: &str, to: &str) -> Result<Vec<String>, CodememError> {
364        let from_idx = self
365            .id_to_index
366            .get(from)
367            .ok_or_else(|| CodememError::NotFound(format!("Node {from}")))?;
368        let to_idx = self
369            .id_to_index
370            .get(to)
371            .ok_or_else(|| CodememError::NotFound(format!("Node {to}")))?;
372
373        // BFS shortest path (unweighted)
374        use petgraph::algo::astar;
375        let path = astar(
376            &self.graph,
377            *from_idx,
378            |finish| finish == *to_idx,
379            |_| 1.0f64,
380            |_| 0.0f64,
381        );
382
383        match path {
384            Some((_cost, nodes)) => {
385                let ids: Vec<String> = nodes
386                    .iter()
387                    .filter_map(|idx| self.graph.node_weight(*idx).cloned())
388                    .collect();
389                Ok(ids)
390            }
391            None => Ok(vec![]),
392        }
393    }
394
395    // Note: O(n+e) per call. Could be cached if this becomes a hot path.
396    fn stats(&self) -> GraphStats {
397        let mut node_kind_counts = HashMap::new();
398        for node in self.nodes.values() {
399            *node_kind_counts.entry(node.kind.to_string()).or_insert(0) += 1;
400        }
401
402        let mut relationship_type_counts = HashMap::new();
403        for edge in self.edges.values() {
404            *relationship_type_counts
405                .entry(edge.relationship.to_string())
406                .or_insert(0) += 1;
407        }
408
409        GraphStats {
410            node_count: self.nodes.len(),
411            edge_count: self.edges.len(),
412            node_kind_counts,
413            relationship_type_counts,
414        }
415    }
416}
417
418#[cfg(test)]
419#[path = "../tests/graph_traversal_tests.rs"]
420mod tests;