1use crate::GraphEngine;
2use codemem_core::{CodememError, Edge, GraphBackend, GraphNode, GraphStats};
3use petgraph::graph::NodeIndex;
4use petgraph::visit::Bfs;
5use std::collections::{HashMap, HashSet};
6
7impl Default for GraphEngine {
8 fn default() -> Self {
9 Self::new()
10 }
11}
12
13impl GraphBackend for GraphEngine {
14 fn add_node(&mut self, node: GraphNode) -> Result<(), CodememError> {
15 let id = node.id.clone();
16
17 if !self.id_to_index.contains_key(&id) {
18 let idx = self.graph.add_node(id.clone());
19 self.id_to_index.insert(id.clone(), idx);
20 }
21
22 self.nodes.insert(id, node);
23 Ok(())
24 }
25
26 fn get_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
27 Ok(self.nodes.get(id).cloned())
28 }
29
30 fn remove_node(&mut self, id: &str) -> Result<bool, CodememError> {
31 if let Some(idx) = self.id_to_index.remove(id) {
32 self.graph.remove_node(idx);
33 self.nodes.remove(id);
34
35 let edge_ids: Vec<String> = self
37 .edges
38 .iter()
39 .filter(|(_, e)| e.src == id || e.dst == id)
40 .map(|(eid, _)| eid.clone())
41 .collect();
42 for eid in edge_ids {
43 self.edges.remove(&eid);
44 }
45
46 Ok(true)
47 } else {
48 Ok(false)
49 }
50 }
51
52 fn add_edge(&mut self, edge: Edge) -> Result<(), CodememError> {
53 let src_idx = self
54 .id_to_index
55 .get(&edge.src)
56 .ok_or_else(|| CodememError::NotFound(format!("Source node {}", edge.src)))?;
57 let dst_idx = self
58 .id_to_index
59 .get(&edge.dst)
60 .ok_or_else(|| CodememError::NotFound(format!("Destination node {}", edge.dst)))?;
61
62 self.graph.add_edge(*src_idx, *dst_idx, edge.weight);
63 self.edges.insert(edge.id.clone(), edge);
64 Ok(())
65 }
66
67 fn get_edges(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
68 let edges: Vec<Edge> = self
69 .edges
70 .values()
71 .filter(|e| e.src == node_id || e.dst == node_id)
72 .cloned()
73 .collect();
74 Ok(edges)
75 }
76
77 fn remove_edge(&mut self, id: &str) -> Result<bool, CodememError> {
78 if let Some(edge) = self.edges.remove(id) {
79 if let (Some(&src_idx), Some(&dst_idx)) = (
81 self.id_to_index.get(&edge.src),
82 self.id_to_index.get(&edge.dst),
83 ) {
84 if let Some(edge_idx) = self.graph.find_edge(src_idx, dst_idx) {
85 self.graph.remove_edge(edge_idx);
86 }
87 }
88 Ok(true)
89 } else {
90 Ok(false)
91 }
92 }
93
94 fn bfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
95 let start_idx = self
96 .id_to_index
97 .get(start_id)
98 .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
99
100 let mut visited = HashSet::new();
101 let mut result = Vec::new();
102 let mut bfs = Bfs::new(&self.graph, *start_idx);
103 let mut depth_map: HashMap<NodeIndex, usize> = HashMap::new();
104 depth_map.insert(*start_idx, 0);
105
106 while let Some(node_idx) = bfs.next(&self.graph) {
107 let depth = depth_map.get(&node_idx).copied().unwrap_or(0);
108 if depth > max_depth {
109 continue;
110 }
111
112 if visited.insert(node_idx) {
113 if let Some(node_id) = self.graph.node_weight(node_idx) {
114 if let Some(node) = self.nodes.get(node_id) {
115 result.push(node.clone());
116 }
117 }
118 }
119
120 for neighbor in self.graph.neighbors(node_idx) {
122 depth_map.entry(neighbor).or_insert(depth + 1);
123 }
124 }
125
126 Ok(result)
127 }
128
129 fn dfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
130 let start_idx = self
131 .id_to_index
132 .get(start_id)
133 .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
134
135 let mut visited = HashSet::new();
136 let mut result = Vec::new();
137 let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
138
139 while let Some((node_idx, depth)) = stack.pop() {
140 if depth > max_depth || !visited.insert(node_idx) {
141 continue;
142 }
143
144 if let Some(node_id) = self.graph.node_weight(node_idx) {
145 if let Some(node) = self.nodes.get(node_id) {
146 result.push(node.clone());
147 }
148 }
149
150 for neighbor in self.graph.neighbors(node_idx) {
151 if !visited.contains(&neighbor) {
152 stack.push((neighbor, depth + 1));
153 }
154 }
155 }
156
157 Ok(result)
158 }
159
160 fn shortest_path(&self, from: &str, to: &str) -> Result<Vec<String>, CodememError> {
161 let from_idx = self
162 .id_to_index
163 .get(from)
164 .ok_or_else(|| CodememError::NotFound(format!("Node {from}")))?;
165 let to_idx = self
166 .id_to_index
167 .get(to)
168 .ok_or_else(|| CodememError::NotFound(format!("Node {to}")))?;
169
170 use petgraph::algo::astar;
172 let path = astar(
173 &self.graph,
174 *from_idx,
175 |finish| finish == *to_idx,
176 |_| 1.0f64,
177 |_| 0.0f64,
178 );
179
180 match path {
181 Some((_cost, nodes)) => {
182 let ids: Vec<String> = nodes
183 .iter()
184 .filter_map(|idx| self.graph.node_weight(*idx).cloned())
185 .collect();
186 Ok(ids)
187 }
188 None => Ok(vec![]),
189 }
190 }
191
192 fn stats(&self) -> GraphStats {
193 let mut node_kind_counts = HashMap::new();
194 for node in self.nodes.values() {
195 *node_kind_counts.entry(node.kind.to_string()).or_insert(0) += 1;
196 }
197
198 let mut relationship_type_counts = HashMap::new();
199 for edge in self.edges.values() {
200 *relationship_type_counts
201 .entry(edge.relationship.to_string())
202 .or_insert(0) += 1;
203 }
204
205 GraphStats {
206 node_count: self.nodes.len(),
207 edge_count: self.edges.len(),
208 node_kind_counts,
209 relationship_type_counts,
210 }
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use crate::GraphEngine;
217 use codemem_core::{Edge, GraphBackend, GraphNode, NodeKind, RelationshipType};
218 use std::collections::HashMap;
219
220 fn file_node(id: &str, label: &str) -> GraphNode {
221 GraphNode {
222 id: id.to_string(),
223 kind: NodeKind::File,
224 label: label.to_string(),
225 payload: HashMap::new(),
226 centrality: 0.0,
227 memory_id: None,
228 namespace: None,
229 }
230 }
231
232 fn test_edge(src: &str, dst: &str) -> Edge {
233 Edge {
234 id: format!("{src}->{dst}"),
235 src: src.to_string(),
236 dst: dst.to_string(),
237 relationship: RelationshipType::Contains,
238 weight: 1.0,
239 properties: HashMap::new(),
240 created_at: chrono::Utc::now(),
241 }
242 }
243
244 #[test]
245 fn add_nodes_and_edges() {
246 let mut graph = GraphEngine::new();
247 graph.add_node(file_node("a", "a.rs")).unwrap();
248 graph.add_node(file_node("b", "b.rs")).unwrap();
249 graph.add_edge(test_edge("a", "b")).unwrap();
250
251 assert_eq!(graph.node_count(), 2);
252 assert_eq!(graph.edge_count(), 1);
253 }
254
255 #[test]
256 fn bfs_traversal() {
257 let mut graph = GraphEngine::new();
258 graph.add_node(file_node("a", "a.rs")).unwrap();
259 graph.add_node(file_node("b", "b.rs")).unwrap();
260 graph.add_node(file_node("c", "c.rs")).unwrap();
261 graph.add_edge(test_edge("a", "b")).unwrap();
262 graph.add_edge(test_edge("b", "c")).unwrap();
263
264 let nodes = graph.bfs("a", 1).unwrap();
265 assert_eq!(nodes.len(), 2); }
267
268 #[test]
269 fn shortest_path() {
270 let mut graph = GraphEngine::new();
271 graph.add_node(file_node("a", "a.rs")).unwrap();
272 graph.add_node(file_node("b", "b.rs")).unwrap();
273 graph.add_node(file_node("c", "c.rs")).unwrap();
274 graph.add_edge(test_edge("a", "b")).unwrap();
275 graph.add_edge(test_edge("b", "c")).unwrap();
276
277 let path = graph.shortest_path("a", "c").unwrap();
278 assert_eq!(path, vec!["a", "b", "c"]);
279 }
280}