Skip to main content

codemem_graph/
traversal.rs

1use crate::GraphEngine;
2use codemem_core::{
3    CodememError, Edge, GraphBackend, GraphNode, GraphStats, NodeKind, RelationshipType,
4};
5use petgraph::graph::NodeIndex;
6use petgraph::visit::Bfs;
7use petgraph::Direction;
8use std::collections::{HashMap, HashSet, VecDeque};
9
10impl Default for GraphEngine {
11    fn default() -> Self {
12        Self::new()
13    }
14}
15
16impl GraphBackend for GraphEngine {
17    fn add_node(&mut self, node: GraphNode) -> Result<(), CodememError> {
18        let id = node.id.clone();
19
20        if !self.id_to_index.contains_key(&id) {
21            let idx = self.graph.add_node(id.clone());
22            self.id_to_index.insert(id.clone(), idx);
23        }
24
25        self.nodes.insert(id, node);
26        Ok(())
27    }
28
29    fn get_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
30        Ok(self.nodes.get(id).cloned())
31    }
32
33    fn remove_node(&mut self, id: &str) -> Result<bool, CodememError> {
34        if let Some(idx) = self.id_to_index.remove(id) {
35            self.graph.remove_node(idx);
36            self.nodes.remove(id);
37
38            // Remove associated edges
39            let edge_ids: Vec<String> = self
40                .edges
41                .iter()
42                .filter(|(_, e)| e.src == id || e.dst == id)
43                .map(|(eid, _)| eid.clone())
44                .collect();
45            for eid in edge_ids {
46                self.edges.remove(&eid);
47            }
48
49            Ok(true)
50        } else {
51            Ok(false)
52        }
53    }
54
55    fn add_edge(&mut self, edge: Edge) -> Result<(), CodememError> {
56        let src_idx = self
57            .id_to_index
58            .get(&edge.src)
59            .ok_or_else(|| CodememError::NotFound(format!("Source node {}", edge.src)))?;
60        let dst_idx = self
61            .id_to_index
62            .get(&edge.dst)
63            .ok_or_else(|| CodememError::NotFound(format!("Destination node {}", edge.dst)))?;
64
65        self.graph.add_edge(*src_idx, *dst_idx, edge.weight);
66        self.edges.insert(edge.id.clone(), edge);
67        Ok(())
68    }
69
70    fn get_edges(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
71        let edges: Vec<Edge> = self
72            .edges
73            .values()
74            .filter(|e| e.src == node_id || e.dst == node_id)
75            .cloned()
76            .collect();
77        Ok(edges)
78    }
79
80    fn remove_edge(&mut self, id: &str) -> Result<bool, CodememError> {
81        if let Some(edge) = self.edges.remove(id) {
82            // Also remove from petgraph
83            if let (Some(&src_idx), Some(&dst_idx)) = (
84                self.id_to_index.get(&edge.src),
85                self.id_to_index.get(&edge.dst),
86            ) {
87                if let Some(edge_idx) = self.graph.find_edge(src_idx, dst_idx) {
88                    self.graph.remove_edge(edge_idx);
89                }
90            }
91            Ok(true)
92        } else {
93            Ok(false)
94        }
95    }
96
97    fn bfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
98        let start_idx = self
99            .id_to_index
100            .get(start_id)
101            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
102
103        let mut visited = HashSet::new();
104        let mut result = Vec::new();
105        let mut bfs = Bfs::new(&self.graph, *start_idx);
106        let mut depth_map: HashMap<NodeIndex, usize> = HashMap::new();
107        depth_map.insert(*start_idx, 0);
108
109        while let Some(node_idx) = bfs.next(&self.graph) {
110            let depth = depth_map.get(&node_idx).copied().unwrap_or(0);
111            if depth > max_depth {
112                continue;
113            }
114
115            if visited.insert(node_idx) {
116                if let Some(node_id) = self.graph.node_weight(node_idx) {
117                    if let Some(node) = self.nodes.get(node_id) {
118                        result.push(node.clone());
119                    }
120                }
121            }
122
123            // Set depth for neighbors
124            for neighbor in self.graph.neighbors(node_idx) {
125                depth_map.entry(neighbor).or_insert(depth + 1);
126            }
127        }
128
129        Ok(result)
130    }
131
132    fn dfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
133        let start_idx = self
134            .id_to_index
135            .get(start_id)
136            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
137
138        let mut visited = HashSet::new();
139        let mut result = Vec::new();
140        let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
141
142        while let Some((node_idx, depth)) = stack.pop() {
143            if depth > max_depth || !visited.insert(node_idx) {
144                continue;
145            }
146
147            if let Some(node_id) = self.graph.node_weight(node_idx) {
148                if let Some(node) = self.nodes.get(node_id) {
149                    result.push(node.clone());
150                }
151            }
152
153            for neighbor in self.graph.neighbors(node_idx) {
154                if !visited.contains(&neighbor) {
155                    stack.push((neighbor, depth + 1));
156                }
157            }
158        }
159
160        Ok(result)
161    }
162
163    fn bfs_filtered(
164        &self,
165        start_id: &str,
166        max_depth: usize,
167        exclude_kinds: &[NodeKind],
168        include_relationships: Option<&[RelationshipType]>,
169    ) -> Result<Vec<GraphNode>, CodememError> {
170        let start_idx = self
171            .id_to_index
172            .get(start_id)
173            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
174
175        let mut visited = HashSet::new();
176        let mut result = Vec::new();
177        let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
178        queue.push_back((*start_idx, 0));
179        visited.insert(*start_idx);
180
181        while let Some((node_idx, depth)) = queue.pop_front() {
182            // Add current node to results if not excluded
183            if let Some(node_id) = self.graph.node_weight(node_idx) {
184                if let Some(node) = self.nodes.get(node_id) {
185                    if !exclude_kinds.contains(&node.kind) {
186                        result.push(node.clone());
187                    }
188                }
189            }
190
191            if depth >= max_depth {
192                continue;
193            }
194
195            // Explore outgoing edges with filtering
196            for neighbor_idx in self.graph.neighbors_directed(node_idx, Direction::Outgoing) {
197                if visited.contains(&neighbor_idx) {
198                    continue;
199                }
200
201                // Check relationship filter if set
202                if let Some(allowed_rels) = include_relationships {
203                    let src_id = self
204                        .graph
205                        .node_weight(node_idx)
206                        .cloned()
207                        .unwrap_or_default();
208                    let dst_id = self
209                        .graph
210                        .node_weight(neighbor_idx)
211                        .cloned()
212                        .unwrap_or_default();
213                    let edge_matches = self.edges.values().any(|e| {
214                        e.src == src_id && e.dst == dst_id && allowed_rels.contains(&e.relationship)
215                    });
216                    if !edge_matches {
217                        continue;
218                    }
219                }
220
221                // Check if neighbor's kind is excluded
222                if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
223                    if let Some(neighbor_node) = self.nodes.get(neighbor_id) {
224                        if exclude_kinds.contains(&neighbor_node.kind) {
225                            continue;
226                        }
227                    }
228                }
229
230                visited.insert(neighbor_idx);
231                queue.push_back((neighbor_idx, depth + 1));
232            }
233        }
234
235        Ok(result)
236    }
237
238    fn dfs_filtered(
239        &self,
240        start_id: &str,
241        max_depth: usize,
242        exclude_kinds: &[NodeKind],
243        include_relationships: Option<&[RelationshipType]>,
244    ) -> Result<Vec<GraphNode>, CodememError> {
245        let start_idx = self
246            .id_to_index
247            .get(start_id)
248            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
249
250        let mut visited = HashSet::new();
251        let mut result = Vec::new();
252        let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
253
254        while let Some((node_idx, depth)) = stack.pop() {
255            if !visited.insert(node_idx) {
256                continue;
257            }
258
259            // Add current node to results if not excluded
260            if let Some(node_id) = self.graph.node_weight(node_idx) {
261                if let Some(node) = self.nodes.get(node_id) {
262                    if !exclude_kinds.contains(&node.kind) {
263                        result.push(node.clone());
264                    }
265                }
266            }
267
268            if depth >= max_depth {
269                continue;
270            }
271
272            // Explore outgoing edges with filtering
273            for neighbor_idx in self.graph.neighbors_directed(node_idx, Direction::Outgoing) {
274                if visited.contains(&neighbor_idx) {
275                    continue;
276                }
277
278                // Check relationship filter if set
279                if let Some(allowed_rels) = include_relationships {
280                    let src_id = self
281                        .graph
282                        .node_weight(node_idx)
283                        .cloned()
284                        .unwrap_or_default();
285                    let dst_id = self
286                        .graph
287                        .node_weight(neighbor_idx)
288                        .cloned()
289                        .unwrap_or_default();
290                    let edge_matches = self.edges.values().any(|e| {
291                        e.src == src_id && e.dst == dst_id && allowed_rels.contains(&e.relationship)
292                    });
293                    if !edge_matches {
294                        continue;
295                    }
296                }
297
298                // Check if neighbor's kind is excluded
299                if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
300                    if let Some(neighbor_node) = self.nodes.get(neighbor_id) {
301                        if exclude_kinds.contains(&neighbor_node.kind) {
302                            continue;
303                        }
304                    }
305                }
306
307                stack.push((neighbor_idx, depth + 1));
308            }
309        }
310
311        Ok(result)
312    }
313
314    fn shortest_path(&self, from: &str, to: &str) -> Result<Vec<String>, CodememError> {
315        let from_idx = self
316            .id_to_index
317            .get(from)
318            .ok_or_else(|| CodememError::NotFound(format!("Node {from}")))?;
319        let to_idx = self
320            .id_to_index
321            .get(to)
322            .ok_or_else(|| CodememError::NotFound(format!("Node {to}")))?;
323
324        // BFS shortest path (unweighted)
325        use petgraph::algo::astar;
326        let path = astar(
327            &self.graph,
328            *from_idx,
329            |finish| finish == *to_idx,
330            |_| 1.0f64,
331            |_| 0.0f64,
332        );
333
334        match path {
335            Some((_cost, nodes)) => {
336                let ids: Vec<String> = nodes
337                    .iter()
338                    .filter_map(|idx| self.graph.node_weight(*idx).cloned())
339                    .collect();
340                Ok(ids)
341            }
342            None => Ok(vec![]),
343        }
344    }
345
346    fn stats(&self) -> GraphStats {
347        let mut node_kind_counts = HashMap::new();
348        for node in self.nodes.values() {
349            *node_kind_counts.entry(node.kind.to_string()).or_insert(0) += 1;
350        }
351
352        let mut relationship_type_counts = HashMap::new();
353        for edge in self.edges.values() {
354            *relationship_type_counts
355                .entry(edge.relationship.to_string())
356                .or_insert(0) += 1;
357        }
358
359        GraphStats {
360            node_count: self.nodes.len(),
361            edge_count: self.edges.len(),
362            node_kind_counts,
363            relationship_type_counts,
364        }
365    }
366}
367
368#[cfg(test)]
369#[path = "tests/traversal_tests.rs"]
370mod tests;