1mod algorithms;
7mod traversal;
8
9#[cfg(test)]
10use codemem_core::NodeKind;
11use codemem_core::{CodememError, Edge, GraphBackend, GraphNode};
12use petgraph::graph::{DiGraph, NodeIndex};
13use petgraph::Direction;
14use std::collections::{HashMap, HashSet, VecDeque};
15
16pub struct GraphEngine {
18 pub(crate) graph: DiGraph<String, f64>,
19 pub(crate) id_to_index: HashMap<String, NodeIndex>,
21 pub(crate) nodes: HashMap<String, GraphNode>,
23 pub(crate) edges: HashMap<String, Edge>,
25 pub(crate) cached_pagerank: HashMap<String, f64>,
27 pub(crate) cached_betweenness: HashMap<String, f64>,
29}
30
31impl GraphEngine {
32 pub fn new() -> Self {
34 Self {
35 graph: DiGraph::new(),
36 id_to_index: HashMap::new(),
37 nodes: HashMap::new(),
38 edges: HashMap::new(),
39 cached_pagerank: HashMap::new(),
40 cached_betweenness: HashMap::new(),
41 }
42 }
43
44 pub fn from_storage(storage: &dyn codemem_core::StorageBackend) -> Result<Self, CodememError> {
46 let mut engine = Self::new();
47
48 let nodes = storage.all_graph_nodes()?;
50 for node in nodes {
51 engine.add_node(node)?;
52 }
53
54 let edges = storage.all_graph_edges()?;
56 for edge in edges {
57 engine.add_edge(edge)?;
58 }
59
60 Ok(engine)
61 }
62
63 pub fn node_count(&self) -> usize {
65 self.nodes.len()
66 }
67
68 pub fn edge_count(&self) -> usize {
70 self.edges.len()
71 }
72
73 pub fn expand(
75 &self,
76 start_ids: &[String],
77 max_hops: usize,
78 ) -> Result<Vec<GraphNode>, CodememError> {
79 let mut visited = std::collections::HashSet::new();
80 let mut result = Vec::new();
81
82 for start_id in start_ids {
83 let nodes = self.bfs(start_id, max_hops)?;
84 for node in nodes {
85 if visited.insert(node.id.clone()) {
86 result.push(node);
87 }
88 }
89 }
90
91 Ok(result)
92 }
93
94 pub fn neighbors(&self, node_id: &str) -> Result<Vec<GraphNode>, CodememError> {
96 let idx = self
97 .id_to_index
98 .get(node_id)
99 .ok_or_else(|| CodememError::NotFound(format!("Node {node_id}")))?;
100
101 let mut result = Vec::new();
102 for neighbor_idx in self.graph.neighbors(*idx) {
103 if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
104 if let Some(node) = self.nodes.get(neighbor_id) {
105 result.push(node.clone());
106 }
107 }
108 }
109
110 Ok(result)
111 }
112
113 pub fn connected_components(&self) -> Vec<Vec<String>> {
119 let mut visited: HashSet<NodeIndex> = HashSet::new();
120 let mut components: Vec<Vec<String>> = Vec::new();
121
122 for &start_idx in self.id_to_index.values() {
123 if visited.contains(&start_idx) {
124 continue;
125 }
126
127 let mut component: Vec<String> = Vec::new();
129 let mut queue: VecDeque<NodeIndex> = VecDeque::new();
130 queue.push_back(start_idx);
131 visited.insert(start_idx);
132
133 while let Some(current) = queue.pop_front() {
134 if let Some(node_id) = self.graph.node_weight(current) {
135 component.push(node_id.clone());
136 }
137
138 for neighbor in self.graph.neighbors_directed(current, Direction::Outgoing) {
140 if visited.insert(neighbor) {
141 queue.push_back(neighbor);
142 }
143 }
144
145 for neighbor in self.graph.neighbors_directed(current, Direction::Incoming) {
147 if visited.insert(neighbor) {
148 queue.push_back(neighbor);
149 }
150 }
151 }
152
153 component.sort();
154 components.push(component);
155 }
156
157 components.sort();
158 components
159 }
160
161 pub fn compute_centrality(&mut self) {
167 let n = self.nodes.len();
168 if n <= 1 {
169 for node in self.nodes.values_mut() {
170 node.centrality = 0.0;
171 }
172 return;
173 }
174
175 let denominator = (n - 1) as f64;
176
177 let centrality_map: HashMap<String, f64> = self
179 .id_to_index
180 .iter()
181 .map(|(id, &idx)| {
182 let in_deg = self
183 .graph
184 .neighbors_directed(idx, Direction::Incoming)
185 .count();
186 let out_deg = self
187 .graph
188 .neighbors_directed(idx, Direction::Outgoing)
189 .count();
190 let centrality = (in_deg + out_deg) as f64 / denominator;
191 (id.clone(), centrality)
192 })
193 .collect();
194
195 for (id, centrality) in ¢rality_map {
197 if let Some(node) = self.nodes.get_mut(id) {
198 node.centrality = *centrality;
199 }
200 }
201 }
202
203 pub fn get_all_nodes(&self) -> Vec<GraphNode> {
205 self.nodes.values().cloned().collect()
206 }
207
208 pub fn recompute_centrality(&mut self) {
213 self.cached_pagerank = self.pagerank(0.85, 100, 1e-6);
214 self.cached_betweenness = self.betweenness_centrality();
215 }
216
217 pub fn get_pagerank(&self, node_id: &str) -> f64 {
219 self.cached_pagerank.get(node_id).copied().unwrap_or(0.0)
220 }
221
222 pub fn get_betweenness(&self, node_id: &str) -> f64 {
224 self.cached_betweenness.get(node_id).copied().unwrap_or(0.0)
225 }
226
227 pub fn max_degree(&self) -> f64 {
230 if self.nodes.len() <= 1 {
231 return 1.0;
232 }
233 self.id_to_index
234 .values()
235 .map(|&idx| {
236 let in_deg = self
237 .graph
238 .neighbors_directed(idx, Direction::Incoming)
239 .count();
240 let out_deg = self
241 .graph
242 .neighbors_directed(idx, Direction::Outgoing)
243 .count();
244 (in_deg + out_deg) as f64
245 })
246 .fold(1.0f64, f64::max)
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use codemem_core::RelationshipType;
254
255 fn file_node(id: &str, label: &str) -> GraphNode {
256 GraphNode {
257 id: id.to_string(),
258 kind: NodeKind::File,
259 label: label.to_string(),
260 payload: HashMap::new(),
261 centrality: 0.0,
262 memory_id: None,
263 namespace: None,
264 }
265 }
266
267 fn test_edge(src: &str, dst: &str) -> Edge {
268 Edge {
269 id: format!("{src}->{dst}"),
270 src: src.to_string(),
271 dst: dst.to_string(),
272 relationship: RelationshipType::Contains,
273 weight: 1.0,
274 properties: HashMap::new(),
275 created_at: chrono::Utc::now(),
276 }
277 }
278
279 #[test]
280 fn connected_components_single_component() {
281 let mut graph = GraphEngine::new();
282 graph.add_node(file_node("a", "a.rs")).unwrap();
283 graph.add_node(file_node("b", "b.rs")).unwrap();
284 graph.add_node(file_node("c", "c.rs")).unwrap();
285 graph.add_edge(test_edge("a", "b")).unwrap();
286 graph.add_edge(test_edge("b", "c")).unwrap();
287
288 let components = graph.connected_components();
289 assert_eq!(components.len(), 1);
290 assert_eq!(components[0], vec!["a", "b", "c"]);
291 }
292
293 #[test]
294 fn connected_components_multiple() {
295 let mut graph = GraphEngine::new();
296 graph.add_node(file_node("a", "a.rs")).unwrap();
297 graph.add_node(file_node("b", "b.rs")).unwrap();
298 graph.add_node(file_node("c", "c.rs")).unwrap();
299 graph.add_node(file_node("d", "d.rs")).unwrap();
300 graph.add_edge(test_edge("a", "b")).unwrap();
301 graph.add_edge(test_edge("c", "d")).unwrap();
302
303 let components = graph.connected_components();
304 assert_eq!(components.len(), 2);
305 assert_eq!(components[0], vec!["a", "b"]);
306 assert_eq!(components[1], vec!["c", "d"]);
307 }
308
309 #[test]
310 fn connected_components_isolated_node() {
311 let mut graph = GraphEngine::new();
312 graph.add_node(file_node("a", "a.rs")).unwrap();
313 graph.add_node(file_node("b", "b.rs")).unwrap();
314 graph.add_node(file_node("c", "c.rs")).unwrap();
315 graph.add_edge(test_edge("a", "b")).unwrap();
316 let components = graph.connected_components();
319 assert_eq!(components.len(), 2);
320 assert_eq!(components[0], vec!["a", "b"]);
322 assert_eq!(components[1], vec!["c"]);
323 }
324
325 #[test]
326 fn connected_components_reverse_edge_connects() {
327 let mut graph = GraphEngine::new();
330 graph.add_node(file_node("a", "a.rs")).unwrap();
331 graph.add_node(file_node("b", "b.rs")).unwrap();
332 graph.add_node(file_node("c", "c.rs")).unwrap();
333 graph.add_edge(test_edge("a", "b")).unwrap();
334 graph.add_edge(test_edge("c", "a")).unwrap();
335
336 let components = graph.connected_components();
337 assert_eq!(components.len(), 1);
338 assert_eq!(components[0], vec!["a", "b", "c"]);
339 }
340
341 #[test]
342 fn connected_components_empty_graph() {
343 let graph = GraphEngine::new();
344 let components = graph.connected_components();
345 assert!(components.is_empty());
346 }
347
348 #[test]
349 fn compute_centrality_simple() {
350 let mut graph = GraphEngine::new();
355 graph.add_node(file_node("a", "a.rs")).unwrap();
356 graph.add_node(file_node("b", "b.rs")).unwrap();
357 graph.add_node(file_node("c", "c.rs")).unwrap();
358 graph.add_edge(test_edge("a", "b")).unwrap();
359 graph.add_edge(test_edge("b", "c")).unwrap();
360
361 graph.compute_centrality();
362
363 let a = graph.get_node("a").unwrap().unwrap();
364 let b = graph.get_node("b").unwrap().unwrap();
365 let c = graph.get_node("c").unwrap().unwrap();
366
367 assert!((a.centrality - 0.5).abs() < f64::EPSILON);
368 assert!((b.centrality - 1.0).abs() < f64::EPSILON);
369 assert!((c.centrality - 0.5).abs() < f64::EPSILON);
370 }
371
372 #[test]
373 fn compute_centrality_star() {
374 let mut graph = GraphEngine::new();
380 graph.add_node(file_node("a", "a.rs")).unwrap();
381 graph.add_node(file_node("b", "b.rs")).unwrap();
382 graph.add_node(file_node("c", "c.rs")).unwrap();
383 graph.add_node(file_node("d", "d.rs")).unwrap();
384 graph.add_edge(test_edge("a", "b")).unwrap();
385 graph.add_edge(test_edge("a", "c")).unwrap();
386 graph.add_edge(test_edge("a", "d")).unwrap();
387
388 graph.compute_centrality();
389
390 let a = graph.get_node("a").unwrap().unwrap();
391 let b = graph.get_node("b").unwrap().unwrap();
392
393 assert!((a.centrality - 1.0).abs() < f64::EPSILON);
394 assert!((b.centrality - 1.0 / 3.0).abs() < f64::EPSILON);
395 }
396
397 #[test]
398 fn compute_centrality_single_node() {
399 let mut graph = GraphEngine::new();
400 graph.add_node(file_node("a", "a.rs")).unwrap();
401
402 graph.compute_centrality();
403
404 let a = graph.get_node("a").unwrap().unwrap();
405 assert!((a.centrality - 0.0).abs() < f64::EPSILON);
406 }
407
408 #[test]
409 fn compute_centrality_no_edges() {
410 let mut graph = GraphEngine::new();
411 graph.add_node(file_node("a", "a.rs")).unwrap();
412 graph.add_node(file_node("b", "b.rs")).unwrap();
413
414 graph.compute_centrality();
415
416 let a = graph.get_node("a").unwrap().unwrap();
417 let b = graph.get_node("b").unwrap().unwrap();
418 assert!((a.centrality - 0.0).abs() < f64::EPSILON);
419 assert!((b.centrality - 0.0).abs() < f64::EPSILON);
420 }
421
422 #[test]
423 fn get_all_nodes_returns_all() {
424 let mut graph = GraphEngine::new();
425 graph.add_node(file_node("a", "a.rs")).unwrap();
426 graph.add_node(file_node("b", "b.rs")).unwrap();
427 graph.add_node(file_node("c", "c.rs")).unwrap();
428
429 let mut all = graph.get_all_nodes();
430 all.sort_by(|x, y| x.id.cmp(&y.id));
431 assert_eq!(all.len(), 3);
432 assert_eq!(all[0].id, "a");
433 assert_eq!(all[1].id, "b");
434 assert_eq!(all[2].id, "c");
435 }
436
437 #[test]
440 fn recompute_centrality_caches_pagerank() {
441 let mut graph = GraphEngine::new();
442 graph.add_node(file_node("a", "a.rs")).unwrap();
443 graph.add_node(file_node("b", "b.rs")).unwrap();
444 graph.add_node(file_node("c", "c.rs")).unwrap();
445 graph.add_edge(test_edge("a", "b")).unwrap();
446 graph.add_edge(test_edge("b", "c")).unwrap();
447
448 assert_eq!(graph.get_pagerank("a"), 0.0);
450 assert_eq!(graph.get_betweenness("a"), 0.0);
451
452 graph.recompute_centrality();
453
454 assert!(graph.get_pagerank("a") > 0.0);
456 assert!(graph.get_pagerank("b") > 0.0);
457 assert!(graph.get_pagerank("c") > 0.0);
458
459 assert!(
461 graph.get_pagerank("c") > graph.get_pagerank("a"),
462 "c ({}) should have higher PageRank than a ({})",
463 graph.get_pagerank("c"),
464 graph.get_pagerank("a")
465 );
466
467 assert!(
469 graph.get_betweenness("b") > graph.get_betweenness("a"),
470 "b ({}) should have higher betweenness than a ({})",
471 graph.get_betweenness("b"),
472 graph.get_betweenness("a")
473 );
474 }
475
476 #[test]
477 fn get_pagerank_returns_zero_for_unknown_node() {
478 let graph = GraphEngine::new();
479 assert_eq!(graph.get_pagerank("nonexistent"), 0.0);
480 assert_eq!(graph.get_betweenness("nonexistent"), 0.0);
481 }
482
483 #[test]
484 fn max_degree_returns_correct_value() {
485 let mut graph = GraphEngine::new();
486 graph.add_node(file_node("a", "a.rs")).unwrap();
487 graph.add_node(file_node("b", "b.rs")).unwrap();
488 graph.add_node(file_node("c", "c.rs")).unwrap();
489 graph.add_node(file_node("d", "d.rs")).unwrap();
490 graph.add_edge(test_edge("a", "b")).unwrap();
492 graph.add_edge(test_edge("a", "c")).unwrap();
493 graph.add_edge(test_edge("a", "d")).unwrap();
494
495 assert!((graph.max_degree() - 3.0).abs() < f64::EPSILON);
496 }
497
498 #[test]
499 fn enhanced_graph_strength_differs_from_simple_edge_count() {
500 let mut graph = GraphEngine::new();
504 for id in &["a", "b", "c", "d"] {
505 graph.add_node(file_node(id, &format!("{id}.rs"))).unwrap();
506 }
507 graph.add_edge(test_edge("a", "b")).unwrap();
508 graph.add_edge(test_edge("b", "c")).unwrap();
509 graph.add_edge(test_edge("c", "d")).unwrap();
510 graph.recompute_centrality();
511
512 let edges_b = graph.get_edges("b").unwrap().len();
516 let edges_c = graph.get_edges("c").unwrap().len();
517 assert_eq!(edges_b, edges_c, "b and c should have same edge count");
518
519 let pr_b = graph.get_pagerank("b");
521 let pr_c = graph.get_pagerank("c");
522 let bt_b = graph.get_betweenness("b");
523 let bt_c = graph.get_betweenness("c");
524
525 let centrality_differs = (pr_b - pr_c).abs() > 1e-6 || (bt_b - bt_c).abs() > 1e-6;
527 assert!(
528 centrality_differs,
529 "Centrality should differ: b(pr={pr_b}, bt={bt_b}) vs c(pr={pr_c}, bt={bt_c})"
530 );
531 }
532}