Skip to main content

codemem_graph/
traversal.rs

1use crate::GraphEngine;
2use codemem_core::{CodememError, Edge, GraphBackend, GraphNode, GraphStats};
3use petgraph::graph::NodeIndex;
4use petgraph::visit::Bfs;
5use std::collections::{HashMap, HashSet};
6
7impl Default for GraphEngine {
8    fn default() -> Self {
9        Self::new()
10    }
11}
12
13impl GraphBackend for GraphEngine {
14    fn add_node(&mut self, node: GraphNode) -> Result<(), CodememError> {
15        let id = node.id.clone();
16
17        if !self.id_to_index.contains_key(&id) {
18            let idx = self.graph.add_node(id.clone());
19            self.id_to_index.insert(id.clone(), idx);
20        }
21
22        self.nodes.insert(id, node);
23        Ok(())
24    }
25
26    fn get_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
27        Ok(self.nodes.get(id).cloned())
28    }
29
30    fn remove_node(&mut self, id: &str) -> Result<bool, CodememError> {
31        if let Some(idx) = self.id_to_index.remove(id) {
32            self.graph.remove_node(idx);
33            self.nodes.remove(id);
34
35            // Remove associated edges
36            let edge_ids: Vec<String> = self
37                .edges
38                .iter()
39                .filter(|(_, e)| e.src == id || e.dst == id)
40                .map(|(eid, _)| eid.clone())
41                .collect();
42            for eid in edge_ids {
43                self.edges.remove(&eid);
44            }
45
46            Ok(true)
47        } else {
48            Ok(false)
49        }
50    }
51
52    fn add_edge(&mut self, edge: Edge) -> Result<(), CodememError> {
53        let src_idx = self
54            .id_to_index
55            .get(&edge.src)
56            .ok_or_else(|| CodememError::NotFound(format!("Source node {}", edge.src)))?;
57        let dst_idx = self
58            .id_to_index
59            .get(&edge.dst)
60            .ok_or_else(|| CodememError::NotFound(format!("Destination node {}", edge.dst)))?;
61
62        self.graph.add_edge(*src_idx, *dst_idx, edge.weight);
63        self.edges.insert(edge.id.clone(), edge);
64        Ok(())
65    }
66
67    fn get_edges(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
68        let edges: Vec<Edge> = self
69            .edges
70            .values()
71            .filter(|e| e.src == node_id || e.dst == node_id)
72            .cloned()
73            .collect();
74        Ok(edges)
75    }
76
77    fn remove_edge(&mut self, id: &str) -> Result<bool, CodememError> {
78        if let Some(edge) = self.edges.remove(id) {
79            // Also remove from petgraph
80            if let (Some(&src_idx), Some(&dst_idx)) = (
81                self.id_to_index.get(&edge.src),
82                self.id_to_index.get(&edge.dst),
83            ) {
84                if let Some(edge_idx) = self.graph.find_edge(src_idx, dst_idx) {
85                    self.graph.remove_edge(edge_idx);
86                }
87            }
88            Ok(true)
89        } else {
90            Ok(false)
91        }
92    }
93
94    fn bfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
95        let start_idx = self
96            .id_to_index
97            .get(start_id)
98            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
99
100        let mut visited = HashSet::new();
101        let mut result = Vec::new();
102        let mut bfs = Bfs::new(&self.graph, *start_idx);
103        let mut depth_map: HashMap<NodeIndex, usize> = HashMap::new();
104        depth_map.insert(*start_idx, 0);
105
106        while let Some(node_idx) = bfs.next(&self.graph) {
107            let depth = depth_map.get(&node_idx).copied().unwrap_or(0);
108            if depth > max_depth {
109                continue;
110            }
111
112            if visited.insert(node_idx) {
113                if let Some(node_id) = self.graph.node_weight(node_idx) {
114                    if let Some(node) = self.nodes.get(node_id) {
115                        result.push(node.clone());
116                    }
117                }
118            }
119
120            // Set depth for neighbors
121            for neighbor in self.graph.neighbors(node_idx) {
122                depth_map.entry(neighbor).or_insert(depth + 1);
123            }
124        }
125
126        Ok(result)
127    }
128
129    fn dfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
130        let start_idx = self
131            .id_to_index
132            .get(start_id)
133            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
134
135        let mut visited = HashSet::new();
136        let mut result = Vec::new();
137        let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
138
139        while let Some((node_idx, depth)) = stack.pop() {
140            if depth > max_depth || !visited.insert(node_idx) {
141                continue;
142            }
143
144            if let Some(node_id) = self.graph.node_weight(node_idx) {
145                if let Some(node) = self.nodes.get(node_id) {
146                    result.push(node.clone());
147                }
148            }
149
150            for neighbor in self.graph.neighbors(node_idx) {
151                if !visited.contains(&neighbor) {
152                    stack.push((neighbor, depth + 1));
153                }
154            }
155        }
156
157        Ok(result)
158    }
159
160    fn shortest_path(&self, from: &str, to: &str) -> Result<Vec<String>, CodememError> {
161        let from_idx = self
162            .id_to_index
163            .get(from)
164            .ok_or_else(|| CodememError::NotFound(format!("Node {from}")))?;
165        let to_idx = self
166            .id_to_index
167            .get(to)
168            .ok_or_else(|| CodememError::NotFound(format!("Node {to}")))?;
169
170        // BFS shortest path (unweighted)
171        use petgraph::algo::astar;
172        let path = astar(
173            &self.graph,
174            *from_idx,
175            |finish| finish == *to_idx,
176            |_| 1.0f64,
177            |_| 0.0f64,
178        );
179
180        match path {
181            Some((_cost, nodes)) => {
182                let ids: Vec<String> = nodes
183                    .iter()
184                    .filter_map(|idx| self.graph.node_weight(*idx).cloned())
185                    .collect();
186                Ok(ids)
187            }
188            None => Ok(vec![]),
189        }
190    }
191
192    fn stats(&self) -> GraphStats {
193        let mut node_kind_counts = HashMap::new();
194        for node in self.nodes.values() {
195            *node_kind_counts.entry(node.kind.to_string()).or_insert(0) += 1;
196        }
197
198        let mut relationship_type_counts = HashMap::new();
199        for edge in self.edges.values() {
200            *relationship_type_counts
201                .entry(edge.relationship.to_string())
202                .or_insert(0) += 1;
203        }
204
205        GraphStats {
206            node_count: self.nodes.len(),
207            edge_count: self.edges.len(),
208            node_kind_counts,
209            relationship_type_counts,
210        }
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use crate::GraphEngine;
217    use codemem_core::{Edge, GraphBackend, GraphNode, NodeKind, RelationshipType};
218    use std::collections::HashMap;
219
220    fn file_node(id: &str, label: &str) -> GraphNode {
221        GraphNode {
222            id: id.to_string(),
223            kind: NodeKind::File,
224            label: label.to_string(),
225            payload: HashMap::new(),
226            centrality: 0.0,
227            memory_id: None,
228            namespace: None,
229        }
230    }
231
232    fn test_edge(src: &str, dst: &str) -> Edge {
233        Edge {
234            id: format!("{src}->{dst}"),
235            src: src.to_string(),
236            dst: dst.to_string(),
237            relationship: RelationshipType::Contains,
238            weight: 1.0,
239            properties: HashMap::new(),
240            created_at: chrono::Utc::now(),
241        }
242    }
243
244    #[test]
245    fn add_nodes_and_edges() {
246        let mut graph = GraphEngine::new();
247        graph.add_node(file_node("a", "a.rs")).unwrap();
248        graph.add_node(file_node("b", "b.rs")).unwrap();
249        graph.add_edge(test_edge("a", "b")).unwrap();
250
251        assert_eq!(graph.node_count(), 2);
252        assert_eq!(graph.edge_count(), 1);
253    }
254
255    #[test]
256    fn bfs_traversal() {
257        let mut graph = GraphEngine::new();
258        graph.add_node(file_node("a", "a.rs")).unwrap();
259        graph.add_node(file_node("b", "b.rs")).unwrap();
260        graph.add_node(file_node("c", "c.rs")).unwrap();
261        graph.add_edge(test_edge("a", "b")).unwrap();
262        graph.add_edge(test_edge("b", "c")).unwrap();
263
264        let nodes = graph.bfs("a", 1).unwrap();
265        assert_eq!(nodes.len(), 2); // a and b (c is at depth 2)
266    }
267
268    #[test]
269    fn shortest_path() {
270        let mut graph = GraphEngine::new();
271        graph.add_node(file_node("a", "a.rs")).unwrap();
272        graph.add_node(file_node("b", "b.rs")).unwrap();
273        graph.add_node(file_node("c", "c.rs")).unwrap();
274        graph.add_edge(test_edge("a", "b")).unwrap();
275        graph.add_edge(test_edge("b", "c")).unwrap();
276
277        let path = graph.shortest_path("a", "c").unwrap();
278        assert_eq!(path, vec!["a", "b", "c"]);
279    }
280}