Skip to main content

codemem_graph/
lib.rs

1//! codemem-graph: Graph engine with petgraph algorithms and SQLite persistence.
2//!
3//! Provides BFS, DFS, shortest path, and connected components over
4//! a knowledge graph with 6 node types and 15 relationship types.
5
6mod algorithms;
7mod traversal;
8
9#[cfg(test)]
10use codemem_core::NodeKind;
11use codemem_core::{CodememError, Edge, GraphBackend, GraphNode};
12use petgraph::graph::{DiGraph, NodeIndex};
13use petgraph::Direction;
14use std::collections::{HashMap, HashSet, VecDeque};
15
16/// In-memory graph backed by petgraph, synced to SQLite via codemem-storage.
17pub struct GraphEngine {
18    pub(crate) graph: DiGraph<String, f64>,
19    /// Map from string node IDs to petgraph NodeIndex.
20    pub(crate) id_to_index: HashMap<String, NodeIndex>,
21    /// Node data by ID.
22    pub(crate) nodes: HashMap<String, GraphNode>,
23    /// Edge data by ID.
24    pub(crate) edges: HashMap<String, Edge>,
25    /// Cached PageRank scores (populated by `recompute_centrality()`).
26    pub(crate) cached_pagerank: HashMap<String, f64>,
27    /// Cached betweenness centrality scores (populated by `recompute_centrality()`).
28    pub(crate) cached_betweenness: HashMap<String, f64>,
29}
30
31impl GraphEngine {
32    /// Create a new empty graph.
33    pub fn new() -> Self {
34        Self {
35            graph: DiGraph::new(),
36            id_to_index: HashMap::new(),
37            nodes: HashMap::new(),
38            edges: HashMap::new(),
39            cached_pagerank: HashMap::new(),
40            cached_betweenness: HashMap::new(),
41        }
42    }
43
44    /// Load graph from storage.
45    pub fn from_storage(storage: &dyn codemem_core::StorageBackend) -> Result<Self, CodememError> {
46        let mut engine = Self::new();
47
48        // Load all nodes
49        let nodes = storage.all_graph_nodes()?;
50        for node in nodes {
51            engine.add_node(node)?;
52        }
53
54        // Load all edges
55        let edges = storage.all_graph_edges()?;
56        for edge in edges {
57            engine.add_edge(edge)?;
58        }
59
60        Ok(engine)
61    }
62
63    /// Get the number of nodes.
64    pub fn node_count(&self) -> usize {
65        self.nodes.len()
66    }
67
68    /// Get the number of edges.
69    pub fn edge_count(&self) -> usize {
70        self.edges.len()
71    }
72
73    /// Multi-hop expansion: given a set of node IDs, expand N hops to find related nodes.
74    pub fn expand(
75        &self,
76        start_ids: &[String],
77        max_hops: usize,
78    ) -> Result<Vec<GraphNode>, CodememError> {
79        let mut visited = std::collections::HashSet::new();
80        let mut result = Vec::new();
81
82        for start_id in start_ids {
83            let nodes = self.bfs(start_id, max_hops)?;
84            for node in nodes {
85                if visited.insert(node.id.clone()) {
86                    result.push(node);
87                }
88            }
89        }
90
91        Ok(result)
92    }
93
94    /// Get neighbors of a node (1-hop).
95    pub fn neighbors(&self, node_id: &str) -> Result<Vec<GraphNode>, CodememError> {
96        let idx = self
97            .id_to_index
98            .get(node_id)
99            .ok_or_else(|| CodememError::NotFound(format!("Node {node_id}")))?;
100
101        let mut result = Vec::new();
102        for neighbor_idx in self.graph.neighbors(*idx) {
103            if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
104                if let Some(node) = self.nodes.get(neighbor_id) {
105                    result.push(node.clone());
106                }
107            }
108        }
109
110        Ok(result)
111    }
112
113    /// Return groups of connected node IDs.
114    ///
115    /// Treats the directed graph as undirected: two nodes are in the same
116    /// component if there is a path between them in either direction.
117    /// Each inner `Vec<String>` is one connected component.
118    pub fn connected_components(&self) -> Vec<Vec<String>> {
119        let mut visited: HashSet<NodeIndex> = HashSet::new();
120        let mut components: Vec<Vec<String>> = Vec::new();
121
122        for &start_idx in self.id_to_index.values() {
123            if visited.contains(&start_idx) {
124                continue;
125            }
126
127            // BFS treating edges as undirected
128            let mut component: Vec<String> = Vec::new();
129            let mut queue: VecDeque<NodeIndex> = VecDeque::new();
130            queue.push_back(start_idx);
131            visited.insert(start_idx);
132
133            while let Some(current) = queue.pop_front() {
134                if let Some(node_id) = self.graph.node_weight(current) {
135                    component.push(node_id.clone());
136                }
137
138                // Follow outgoing edges
139                for neighbor in self.graph.neighbors_directed(current, Direction::Outgoing) {
140                    if visited.insert(neighbor) {
141                        queue.push_back(neighbor);
142                    }
143                }
144
145                // Follow incoming edges (treat as undirected)
146                for neighbor in self.graph.neighbors_directed(current, Direction::Incoming) {
147                    if visited.insert(neighbor) {
148                        queue.push_back(neighbor);
149                    }
150                }
151            }
152
153            component.sort();
154            components.push(component);
155        }
156
157        components.sort();
158        components
159    }
160
161    /// Compute degree centrality for every node and update their `centrality` field.
162    ///
163    /// Degree centrality for node *v* is defined as:
164    ///   `(in_degree(v) + out_degree(v)) / (N - 1)`
165    /// where *N* is the total number of nodes.  When N <= 1, centrality is 0.
166    pub fn compute_centrality(&mut self) {
167        let n = self.nodes.len();
168        if n <= 1 {
169            for node in self.nodes.values_mut() {
170                node.centrality = 0.0;
171            }
172            return;
173        }
174
175        let denominator = (n - 1) as f64;
176
177        // Pre-compute centrality values by node ID.
178        let centrality_map: HashMap<String, f64> = self
179            .id_to_index
180            .iter()
181            .map(|(id, &idx)| {
182                let in_deg = self
183                    .graph
184                    .neighbors_directed(idx, Direction::Incoming)
185                    .count();
186                let out_deg = self
187                    .graph
188                    .neighbors_directed(idx, Direction::Outgoing)
189                    .count();
190                let centrality = (in_deg + out_deg) as f64 / denominator;
191                (id.clone(), centrality)
192            })
193            .collect();
194
195        // Apply centrality values to the stored nodes.
196        for (id, centrality) in &centrality_map {
197            if let Some(node) = self.nodes.get_mut(id) {
198                node.centrality = *centrality;
199            }
200        }
201    }
202
203    /// Return all nodes currently in the graph.
204    pub fn get_all_nodes(&self) -> Vec<GraphNode> {
205        self.nodes.values().cloned().collect()
206    }
207
208    /// Recompute and cache PageRank and betweenness centrality scores.
209    ///
210    /// This should be called after loading the graph (e.g., on server start)
211    /// and periodically when the graph changes significantly.
212    pub fn recompute_centrality(&mut self) {
213        self.cached_pagerank = self.pagerank(0.85, 100, 1e-6);
214        self.cached_betweenness = self.betweenness_centrality();
215    }
216
217    /// Get the cached PageRank score for a node. Returns 0.0 if not found.
218    pub fn get_pagerank(&self, node_id: &str) -> f64 {
219        self.cached_pagerank.get(node_id).copied().unwrap_or(0.0)
220    }
221
222    /// Get the cached betweenness centrality score for a node. Returns 0.0 if not found.
223    pub fn get_betweenness(&self, node_id: &str) -> f64 {
224        self.cached_betweenness.get(node_id).copied().unwrap_or(0.0)
225    }
226
227    /// Get the maximum degree (in + out) across all nodes in the graph.
228    /// Returns 1.0 if the graph has fewer than 2 nodes to avoid division by zero.
229    pub fn max_degree(&self) -> f64 {
230        if self.nodes.len() <= 1 {
231            return 1.0;
232        }
233        self.id_to_index
234            .values()
235            .map(|&idx| {
236                let in_deg = self
237                    .graph
238                    .neighbors_directed(idx, Direction::Incoming)
239                    .count();
240                let out_deg = self
241                    .graph
242                    .neighbors_directed(idx, Direction::Outgoing)
243                    .count();
244                (in_deg + out_deg) as f64
245            })
246            .fold(1.0f64, f64::max)
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use codemem_core::RelationshipType;
254
255    fn file_node(id: &str, label: &str) -> GraphNode {
256        GraphNode {
257            id: id.to_string(),
258            kind: NodeKind::File,
259            label: label.to_string(),
260            payload: HashMap::new(),
261            centrality: 0.0,
262            memory_id: None,
263            namespace: None,
264        }
265    }
266
267    fn test_edge(src: &str, dst: &str) -> Edge {
268        Edge {
269            id: format!("{src}->{dst}"),
270            src: src.to_string(),
271            dst: dst.to_string(),
272            relationship: RelationshipType::Contains,
273            weight: 1.0,
274            properties: HashMap::new(),
275            created_at: chrono::Utc::now(),
276        }
277    }
278
279    #[test]
280    fn connected_components_single_component() {
281        let mut graph = GraphEngine::new();
282        graph.add_node(file_node("a", "a.rs")).unwrap();
283        graph.add_node(file_node("b", "b.rs")).unwrap();
284        graph.add_node(file_node("c", "c.rs")).unwrap();
285        graph.add_edge(test_edge("a", "b")).unwrap();
286        graph.add_edge(test_edge("b", "c")).unwrap();
287
288        let components = graph.connected_components();
289        assert_eq!(components.len(), 1);
290        assert_eq!(components[0], vec!["a", "b", "c"]);
291    }
292
293    #[test]
294    fn connected_components_multiple() {
295        let mut graph = GraphEngine::new();
296        graph.add_node(file_node("a", "a.rs")).unwrap();
297        graph.add_node(file_node("b", "b.rs")).unwrap();
298        graph.add_node(file_node("c", "c.rs")).unwrap();
299        graph.add_node(file_node("d", "d.rs")).unwrap();
300        graph.add_edge(test_edge("a", "b")).unwrap();
301        graph.add_edge(test_edge("c", "d")).unwrap();
302
303        let components = graph.connected_components();
304        assert_eq!(components.len(), 2);
305        assert_eq!(components[0], vec!["a", "b"]);
306        assert_eq!(components[1], vec!["c", "d"]);
307    }
308
309    #[test]
310    fn connected_components_isolated_node() {
311        let mut graph = GraphEngine::new();
312        graph.add_node(file_node("a", "a.rs")).unwrap();
313        graph.add_node(file_node("b", "b.rs")).unwrap();
314        graph.add_node(file_node("c", "c.rs")).unwrap();
315        graph.add_edge(test_edge("a", "b")).unwrap();
316        // "c" is isolated
317
318        let components = graph.connected_components();
319        assert_eq!(components.len(), 2);
320        // Sorted: ["a","b"] comes before ["c"]
321        assert_eq!(components[0], vec!["a", "b"]);
322        assert_eq!(components[1], vec!["c"]);
323    }
324
325    #[test]
326    fn connected_components_reverse_edge_connects() {
327        // Directed edge c->a should still put a and c in the same component
328        // when treated as undirected.
329        let mut graph = GraphEngine::new();
330        graph.add_node(file_node("a", "a.rs")).unwrap();
331        graph.add_node(file_node("b", "b.rs")).unwrap();
332        graph.add_node(file_node("c", "c.rs")).unwrap();
333        graph.add_edge(test_edge("a", "b")).unwrap();
334        graph.add_edge(test_edge("c", "a")).unwrap();
335
336        let components = graph.connected_components();
337        assert_eq!(components.len(), 1);
338        assert_eq!(components[0], vec!["a", "b", "c"]);
339    }
340
341    #[test]
342    fn connected_components_empty_graph() {
343        let graph = GraphEngine::new();
344        let components = graph.connected_components();
345        assert!(components.is_empty());
346    }
347
348    #[test]
349    fn compute_centrality_simple() {
350        // Graph: a -> b -> c
351        // Node a: out=1, in=0 => centrality = 1/2 = 0.5
352        // Node b: out=1, in=1 => centrality = 2/2 = 1.0
353        // Node c: out=0, in=1 => centrality = 1/2 = 0.5
354        let mut graph = GraphEngine::new();
355        graph.add_node(file_node("a", "a.rs")).unwrap();
356        graph.add_node(file_node("b", "b.rs")).unwrap();
357        graph.add_node(file_node("c", "c.rs")).unwrap();
358        graph.add_edge(test_edge("a", "b")).unwrap();
359        graph.add_edge(test_edge("b", "c")).unwrap();
360
361        graph.compute_centrality();
362
363        let a = graph.get_node("a").unwrap().unwrap();
364        let b = graph.get_node("b").unwrap().unwrap();
365        let c = graph.get_node("c").unwrap().unwrap();
366
367        assert!((a.centrality - 0.5).abs() < f64::EPSILON);
368        assert!((b.centrality - 1.0).abs() < f64::EPSILON);
369        assert!((c.centrality - 0.5).abs() < f64::EPSILON);
370    }
371
372    #[test]
373    fn compute_centrality_star() {
374        // Graph: a -> b, a -> c, a -> d (star topology)
375        // Node a: out=3, in=0 => centrality = 3/3 = 1.0
376        // Node b: out=0, in=1 => centrality = 1/3
377        // Node c: out=0, in=1 => centrality = 1/3
378        // Node d: out=0, in=1 => centrality = 1/3
379        let mut graph = GraphEngine::new();
380        graph.add_node(file_node("a", "a.rs")).unwrap();
381        graph.add_node(file_node("b", "b.rs")).unwrap();
382        graph.add_node(file_node("c", "c.rs")).unwrap();
383        graph.add_node(file_node("d", "d.rs")).unwrap();
384        graph.add_edge(test_edge("a", "b")).unwrap();
385        graph.add_edge(test_edge("a", "c")).unwrap();
386        graph.add_edge(test_edge("a", "d")).unwrap();
387
388        graph.compute_centrality();
389
390        let a = graph.get_node("a").unwrap().unwrap();
391        let b = graph.get_node("b").unwrap().unwrap();
392
393        assert!((a.centrality - 1.0).abs() < f64::EPSILON);
394        assert!((b.centrality - 1.0 / 3.0).abs() < f64::EPSILON);
395    }
396
397    #[test]
398    fn compute_centrality_single_node() {
399        let mut graph = GraphEngine::new();
400        graph.add_node(file_node("a", "a.rs")).unwrap();
401
402        graph.compute_centrality();
403
404        let a = graph.get_node("a").unwrap().unwrap();
405        assert!((a.centrality - 0.0).abs() < f64::EPSILON);
406    }
407
408    #[test]
409    fn compute_centrality_no_edges() {
410        let mut graph = GraphEngine::new();
411        graph.add_node(file_node("a", "a.rs")).unwrap();
412        graph.add_node(file_node("b", "b.rs")).unwrap();
413
414        graph.compute_centrality();
415
416        let a = graph.get_node("a").unwrap().unwrap();
417        let b = graph.get_node("b").unwrap().unwrap();
418        assert!((a.centrality - 0.0).abs() < f64::EPSILON);
419        assert!((b.centrality - 0.0).abs() < f64::EPSILON);
420    }
421
422    #[test]
423    fn get_all_nodes_returns_all() {
424        let mut graph = GraphEngine::new();
425        graph.add_node(file_node("a", "a.rs")).unwrap();
426        graph.add_node(file_node("b", "b.rs")).unwrap();
427        graph.add_node(file_node("c", "c.rs")).unwrap();
428
429        let mut all = graph.get_all_nodes();
430        all.sort_by(|x, y| x.id.cmp(&y.id));
431        assert_eq!(all.len(), 3);
432        assert_eq!(all[0].id, "a");
433        assert_eq!(all[1].id, "b");
434        assert_eq!(all[2].id, "c");
435    }
436
437    // ── Centrality Caching Tests ────────────────────────────────────────────
438
439    #[test]
440    fn recompute_centrality_caches_pagerank() {
441        let mut graph = GraphEngine::new();
442        graph.add_node(file_node("a", "a.rs")).unwrap();
443        graph.add_node(file_node("b", "b.rs")).unwrap();
444        graph.add_node(file_node("c", "c.rs")).unwrap();
445        graph.add_edge(test_edge("a", "b")).unwrap();
446        graph.add_edge(test_edge("b", "c")).unwrap();
447
448        // Before recompute, cached values should be 0.0
449        assert_eq!(graph.get_pagerank("a"), 0.0);
450        assert_eq!(graph.get_betweenness("a"), 0.0);
451
452        graph.recompute_centrality();
453
454        // After recompute, cached PageRank values should be non-zero
455        assert!(graph.get_pagerank("a") > 0.0);
456        assert!(graph.get_pagerank("b") > 0.0);
457        assert!(graph.get_pagerank("c") > 0.0);
458
459        // c should have highest PageRank (sink node in a -> b -> c)
460        assert!(
461            graph.get_pagerank("c") > graph.get_pagerank("a"),
462            "c ({}) should have higher PageRank than a ({})",
463            graph.get_pagerank("c"),
464            graph.get_pagerank("a")
465        );
466
467        // b should have highest betweenness (middle of chain)
468        assert!(
469            graph.get_betweenness("b") > graph.get_betweenness("a"),
470            "b ({}) should have higher betweenness than a ({})",
471            graph.get_betweenness("b"),
472            graph.get_betweenness("a")
473        );
474    }
475
476    #[test]
477    fn get_pagerank_returns_zero_for_unknown_node() {
478        let graph = GraphEngine::new();
479        assert_eq!(graph.get_pagerank("nonexistent"), 0.0);
480        assert_eq!(graph.get_betweenness("nonexistent"), 0.0);
481    }
482
483    #[test]
484    fn max_degree_returns_correct_value() {
485        let mut graph = GraphEngine::new();
486        graph.add_node(file_node("a", "a.rs")).unwrap();
487        graph.add_node(file_node("b", "b.rs")).unwrap();
488        graph.add_node(file_node("c", "c.rs")).unwrap();
489        graph.add_node(file_node("d", "d.rs")).unwrap();
490        // a -> b, a -> c, a -> d (star: a has degree 3)
491        graph.add_edge(test_edge("a", "b")).unwrap();
492        graph.add_edge(test_edge("a", "c")).unwrap();
493        graph.add_edge(test_edge("a", "d")).unwrap();
494
495        assert!((graph.max_degree() - 3.0).abs() < f64::EPSILON);
496    }
497
498    #[test]
499    fn enhanced_graph_strength_differs_from_simple_edge_count() {
500        // Build a graph where PageRank/betweenness differ from simple edge count.
501        // a -> b -> c -> d (chain)
502        // b is in the middle with betweenness, c gets more PageRank flow
503        let mut graph = GraphEngine::new();
504        for id in &["a", "b", "c", "d"] {
505            graph.add_node(file_node(id, &format!("{id}.rs"))).unwrap();
506        }
507        graph.add_edge(test_edge("a", "b")).unwrap();
508        graph.add_edge(test_edge("b", "c")).unwrap();
509        graph.add_edge(test_edge("c", "d")).unwrap();
510        graph.recompute_centrality();
511
512        // Nodes b and c both have 2 edges (in+out), but different centrality profiles.
513        // b has higher betweenness (on path a->c and a->d).
514        // Simple edge count would give them equal scores.
515        let edges_b = graph.get_edges("b").unwrap().len();
516        let edges_c = graph.get_edges("c").unwrap().len();
517        assert_eq!(edges_b, edges_c, "b and c should have same edge count");
518
519        // But their centrality profiles should differ
520        let pr_b = graph.get_pagerank("b");
521        let pr_c = graph.get_pagerank("c");
522        let bt_b = graph.get_betweenness("b");
523        let bt_c = graph.get_betweenness("c");
524
525        // At least one centrality metric should differ between b and c
526        let centrality_differs = (pr_b - pr_c).abs() > 1e-6 || (bt_b - bt_c).abs() > 1e-6;
527        assert!(
528            centrality_differs,
529            "Centrality should differ: b(pr={pr_b}, bt={bt_b}) vs c(pr={pr_c}, bt={bt_c})"
530        );
531    }
532}