Skip to main content

codemem_storage/graph/
traversal.rs

1use super::GraphEngine;
2use codemem_core::{
3    CodememError, Edge, GraphBackend, GraphNode, GraphStats, NodeKind, RelationshipType,
4};
5use petgraph::graph::NodeIndex;
6use petgraph::visit::EdgeRef;
7use petgraph::Direction;
8use std::collections::{HashMap, HashSet, VecDeque};
9
10impl Default for GraphEngine {
11    fn default() -> Self {
12        Self::new()
13    }
14}
15
16impl GraphBackend for GraphEngine {
17    fn add_node(&mut self, node: GraphNode) -> Result<(), CodememError> {
18        let id = node.id.clone();
19
20        if !self.id_to_index.contains_key(&id) {
21            let idx = self.graph.add_node(id.clone());
22            self.id_to_index.insert(id.clone(), idx);
23        }
24
25        self.nodes.insert(id, node);
26        Ok(())
27    }
28
29    fn get_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
30        Ok(self.nodes.get(id).cloned())
31    }
32
33    fn remove_node(&mut self, id: &str) -> Result<bool, CodememError> {
34        if let Some(idx) = self.id_to_index.remove(id) {
35            // petgraph::DiGraph::remove_node swaps the last node into the removed
36            // slot, invalidating the last node's NodeIndex. We must fix id_to_index.
37            let last_idx = NodeIndex::new(self.graph.node_count() - 1);
38            self.graph.remove_node(idx);
39            // After removal, the node that was at `last_idx` is now at `idx`
40            // (unless we removed the last node itself).
41            if idx != last_idx {
42                if let Some(swapped_id) = self.graph.node_weight(idx) {
43                    self.id_to_index.insert(swapped_id.clone(), idx);
44                }
45            }
46            self.nodes.remove(id);
47
48            // Remove associated edges using edge adjacency index
49            if let Some(edge_ids) = self.edge_adj.remove(id) {
50                for eid in &edge_ids {
51                    if let Some(edge) = self.edges.remove(eid) {
52                        // Also remove from the other endpoint's adjacency list
53                        let other = if edge.src == id { &edge.dst } else { &edge.src };
54                        if let Some(other_edges) = self.edge_adj.get_mut(other) {
55                            other_edges.retain(|e| e != eid);
56                        }
57                    }
58                }
59            }
60
61            // Clean up centrality caches
62            self.cached_pagerank.remove(id);
63            self.cached_betweenness.remove(id);
64
65            Ok(true)
66        } else {
67            Ok(false)
68        }
69    }
70
71    fn add_edge(&mut self, edge: Edge) -> Result<(), CodememError> {
72        let src_idx = self
73            .id_to_index
74            .get(&edge.src)
75            .ok_or_else(|| CodememError::NotFound(format!("Source node {}", edge.src)))?;
76        let dst_idx = self
77            .id_to_index
78            .get(&edge.dst)
79            .ok_or_else(|| CodememError::NotFound(format!("Destination node {}", edge.dst)))?;
80
81        self.graph.add_edge(*src_idx, *dst_idx, edge.weight);
82        // Maintain edge adjacency index
83        self.edge_adj
84            .entry(edge.src.clone())
85            .or_default()
86            .push(edge.id.clone());
87        self.edge_adj
88            .entry(edge.dst.clone())
89            .or_default()
90            .push(edge.id.clone());
91        self.edges.insert(edge.id.clone(), edge);
92        Ok(())
93    }
94
95    fn get_edges(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
96        let edges: Vec<Edge> = self
97            .edge_adj
98            .get(node_id)
99            .map(|edge_ids| {
100                edge_ids
101                    .iter()
102                    .filter_map(|eid| self.edges.get(eid).cloned())
103                    .collect()
104            })
105            .unwrap_or_default();
106        Ok(edges)
107    }
108
109    fn remove_edge(&mut self, id: &str) -> Result<bool, CodememError> {
110        if let Some(edge) = self.edges.remove(id) {
111            // Remove from petgraph — match by weight to handle parallel edges
112            if let (Some(&src_idx), Some(&dst_idx)) = (
113                self.id_to_index.get(&edge.src),
114                self.id_to_index.get(&edge.dst),
115            ) {
116                // Iterate edges_connecting to find the correct one by weight
117                let target_weight = edge.weight;
118                let petgraph_edge_idx = self
119                    .graph
120                    .edges_connecting(src_idx, dst_idx)
121                    .find(|e| (*e.weight() - target_weight).abs() < f64::EPSILON)
122                    .map(|e| e.id());
123                if let Some(eidx) = petgraph_edge_idx {
124                    self.graph.remove_edge(eidx);
125                }
126            }
127            // Maintain edge adjacency index
128            if let Some(src_edges) = self.edge_adj.get_mut(&edge.src) {
129                src_edges.retain(|e| e != id);
130            }
131            if let Some(dst_edges) = self.edge_adj.get_mut(&edge.dst) {
132                dst_edges.retain(|e| e != id);
133            }
134            Ok(true)
135        } else {
136            Ok(false)
137        }
138    }
139
140    fn bfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
141        let start_idx = self
142            .id_to_index
143            .get(start_id)
144            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
145
146        let mut visited = HashSet::new();
147        let mut result = Vec::new();
148        let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
149        queue.push_back((*start_idx, 0));
150        visited.insert(*start_idx);
151
152        while let Some((node_idx, depth)) = queue.pop_front() {
153            if let Some(node_id) = self.graph.node_weight(node_idx) {
154                if let Some(node) = self.nodes.get(node_id) {
155                    result.push(node.clone());
156                }
157            }
158
159            if depth >= max_depth {
160                continue;
161            }
162
163            // Traverse edges in both directions so we find parents and children
164            for neighbor in self.graph.neighbors_undirected(node_idx) {
165                if visited.insert(neighbor) {
166                    queue.push_back((neighbor, depth + 1));
167                }
168            }
169        }
170
171        Ok(result)
172    }
173
174    fn dfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
175        let start_idx = self
176            .id_to_index
177            .get(start_id)
178            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
179
180        let mut visited = HashSet::new();
181        let mut result = Vec::new();
182        let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
183
184        while let Some((node_idx, depth)) = stack.pop() {
185            if depth > max_depth || !visited.insert(node_idx) {
186                continue;
187            }
188
189            if let Some(node_id) = self.graph.node_weight(node_idx) {
190                if let Some(node) = self.nodes.get(node_id) {
191                    result.push(node.clone());
192                }
193            }
194
195            for neighbor in self.graph.neighbors_undirected(node_idx) {
196                if !visited.contains(&neighbor) {
197                    stack.push((neighbor, depth + 1));
198                }
199            }
200        }
201
202        Ok(result)
203    }
204
205    fn bfs_filtered(
206        &self,
207        start_id: &str,
208        max_depth: usize,
209        exclude_kinds: &[NodeKind],
210        include_relationships: Option<&[RelationshipType]>,
211    ) -> Result<Vec<GraphNode>, CodememError> {
212        let start_idx = self
213            .id_to_index
214            .get(start_id)
215            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
216
217        let mut visited = HashSet::new();
218        let mut result = Vec::new();
219        let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
220        queue.push_back((*start_idx, 0));
221        visited.insert(*start_idx);
222
223        while let Some((node_idx, depth)) = queue.pop_front() {
224            // Add current node to results if not excluded
225            if let Some(node_id) = self.graph.node_weight(node_idx) {
226                if let Some(node) = self.nodes.get(node_id) {
227                    if !exclude_kinds.contains(&node.kind) {
228                        result.push(node.clone());
229                    }
230                }
231            }
232
233            if depth >= max_depth {
234                continue;
235            }
236
237            // Explore outgoing edges with filtering
238            for neighbor_idx in self.graph.neighbors_directed(node_idx, Direction::Outgoing) {
239                if visited.contains(&neighbor_idx) {
240                    continue;
241                }
242
243                // Check relationship filter if set, using edge adjacency index
244                if let Some(allowed_rels) = include_relationships {
245                    let src_id = self
246                        .graph
247                        .node_weight(node_idx)
248                        .cloned()
249                        .unwrap_or_default();
250                    let dst_id = self
251                        .graph
252                        .node_weight(neighbor_idx)
253                        .cloned()
254                        .unwrap_or_default();
255                    let edge_matches = self
256                        .edge_adj
257                        .get(&src_id)
258                        .map(|edge_ids| {
259                            edge_ids.iter().any(|eid| {
260                                self.edges.get(eid).is_some_and(|e| {
261                                    e.src == src_id
262                                        && e.dst == dst_id
263                                        && allowed_rels.contains(&e.relationship)
264                                })
265                            })
266                        })
267                        .unwrap_or(false);
268                    if !edge_matches {
269                        continue;
270                    }
271                }
272
273                // Always traverse through excluded-kind nodes but don't include
274                // them in results (handled above when popped from queue).
275                visited.insert(neighbor_idx);
276                queue.push_back((neighbor_idx, depth + 1));
277            }
278        }
279
280        Ok(result)
281    }
282
283    fn dfs_filtered(
284        &self,
285        start_id: &str,
286        max_depth: usize,
287        exclude_kinds: &[NodeKind],
288        include_relationships: Option<&[RelationshipType]>,
289    ) -> Result<Vec<GraphNode>, CodememError> {
290        let start_idx = self
291            .id_to_index
292            .get(start_id)
293            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
294
295        let mut visited = HashSet::new();
296        let mut result = Vec::new();
297        let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
298
299        while let Some((node_idx, depth)) = stack.pop() {
300            if !visited.insert(node_idx) {
301                continue;
302            }
303
304            // Add current node to results if not excluded
305            if let Some(node_id) = self.graph.node_weight(node_idx) {
306                if let Some(node) = self.nodes.get(node_id) {
307                    if !exclude_kinds.contains(&node.kind) {
308                        result.push(node.clone());
309                    }
310                }
311            }
312
313            if depth >= max_depth {
314                continue;
315            }
316
317            // Explore outgoing edges with filtering
318            for neighbor_idx in self.graph.neighbors_directed(node_idx, Direction::Outgoing) {
319                if visited.contains(&neighbor_idx) {
320                    continue;
321                }
322
323                // Check relationship filter if set, using edge adjacency index
324                if let Some(allowed_rels) = include_relationships {
325                    let src_id = self
326                        .graph
327                        .node_weight(node_idx)
328                        .cloned()
329                        .unwrap_or_default();
330                    let dst_id = self
331                        .graph
332                        .node_weight(neighbor_idx)
333                        .cloned()
334                        .unwrap_or_default();
335                    let edge_matches = self
336                        .edge_adj
337                        .get(&src_id)
338                        .map(|edge_ids| {
339                            edge_ids.iter().any(|eid| {
340                                self.edges.get(eid).is_some_and(|e| {
341                                    e.src == src_id
342                                        && e.dst == dst_id
343                                        && allowed_rels.contains(&e.relationship)
344                                })
345                            })
346                        })
347                        .unwrap_or(false);
348                    if !edge_matches {
349                        continue;
350                    }
351                }
352
353                // Always traverse through excluded-kind nodes but don't include
354                // them in results (handled above when popped from stack).
355                stack.push((neighbor_idx, depth + 1));
356            }
357        }
358
359        Ok(result)
360    }
361
362    fn shortest_path(&self, from: &str, to: &str) -> Result<Vec<String>, CodememError> {
363        let from_idx = self
364            .id_to_index
365            .get(from)
366            .ok_or_else(|| CodememError::NotFound(format!("Node {from}")))?;
367        let to_idx = self
368            .id_to_index
369            .get(to)
370            .ok_or_else(|| CodememError::NotFound(format!("Node {to}")))?;
371
372        // BFS shortest path (unweighted)
373        use petgraph::algo::astar;
374        let path = astar(
375            &self.graph,
376            *from_idx,
377            |finish| finish == *to_idx,
378            |_| 1.0f64,
379            |_| 0.0f64,
380        );
381
382        match path {
383            Some((_cost, nodes)) => {
384                let ids: Vec<String> = nodes
385                    .iter()
386                    .filter_map(|idx| self.graph.node_weight(*idx).cloned())
387                    .collect();
388                Ok(ids)
389            }
390            None => Ok(vec![]),
391        }
392    }
393
394    // Note: O(n+e) per call. Could be cached if this becomes a hot path.
395    fn stats(&self) -> GraphStats {
396        let mut node_kind_counts = HashMap::new();
397        for node in self.nodes.values() {
398            *node_kind_counts.entry(node.kind.to_string()).or_insert(0) += 1;
399        }
400
401        let mut relationship_type_counts = HashMap::new();
402        for edge in self.edges.values() {
403            *relationship_type_counts
404                .entry(edge.relationship.to_string())
405                .or_insert(0) += 1;
406        }
407
408        GraphStats {
409            node_count: self.nodes.len(),
410            edge_count: self.edges.len(),
411            node_kind_counts,
412            relationship_type_counts,
413        }
414    }
415}
416
417#[cfg(test)]
418#[path = "../tests/graph_traversal_tests.rs"]
419mod tests;