Skip to main content

codemem_graph/
lib.rs

1//! codemem-graph: Graph engine with petgraph algorithms and SQLite persistence.
2//!
3//! Provides BFS, DFS, shortest path, and connected components over
4//! a knowledge graph with 6 node types and 15 relationship types.
5
6mod algorithms;
7mod traversal;
8
9#[cfg(test)]
10use codemem_core::NodeKind;
11use codemem_core::{CodememError, Edge, GraphBackend, GraphNode};
12use petgraph::graph::{DiGraph, NodeIndex};
13use petgraph::Direction;
14use std::collections::{HashMap, HashSet, VecDeque};
15
16/// In-memory graph backed by petgraph, synced to SQLite via codemem-storage.
17pub struct GraphEngine {
18    pub(crate) graph: DiGraph<String, f64>,
19    /// Map from string node IDs to petgraph NodeIndex.
20    pub(crate) id_to_index: HashMap<String, NodeIndex>,
21    /// Node data by ID.
22    pub(crate) nodes: HashMap<String, GraphNode>,
23    /// Edge data by ID.
24    pub(crate) edges: HashMap<String, Edge>,
25    /// Cached PageRank scores (populated by `recompute_centrality()`).
26    pub(crate) cached_pagerank: HashMap<String, f64>,
27    /// Cached betweenness centrality scores (populated by `recompute_centrality()`).
28    pub(crate) cached_betweenness: HashMap<String, f64>,
29}
30
31impl GraphEngine {
32    /// Create a new empty graph.
33    pub fn new() -> Self {
34        Self {
35            graph: DiGraph::new(),
36            id_to_index: HashMap::new(),
37            nodes: HashMap::new(),
38            edges: HashMap::new(),
39            cached_pagerank: HashMap::new(),
40            cached_betweenness: HashMap::new(),
41        }
42    }
43
44    /// Load graph from storage.
45    pub fn from_storage(storage: &dyn codemem_core::StorageBackend) -> Result<Self, CodememError> {
46        let mut engine = Self::new();
47
48        // Load all nodes
49        let nodes = storage.all_graph_nodes()?;
50        for node in nodes {
51            engine.add_node(node)?;
52        }
53
54        // Load all edges
55        let edges = storage.all_graph_edges()?;
56        for edge in edges {
57            engine.add_edge(edge)?;
58        }
59
60        Ok(engine)
61    }
62
63    /// Get the number of nodes.
64    pub fn node_count(&self) -> usize {
65        self.nodes.len()
66    }
67
68    /// Get the number of edges.
69    pub fn edge_count(&self) -> usize {
70        self.edges.len()
71    }
72
73    /// Multi-hop expansion: given a set of node IDs, expand N hops to find related nodes.
74    pub fn expand(
75        &self,
76        start_ids: &[String],
77        max_hops: usize,
78    ) -> Result<Vec<GraphNode>, CodememError> {
79        let mut visited = std::collections::HashSet::new();
80        let mut result = Vec::new();
81
82        for start_id in start_ids {
83            let nodes = self.bfs(start_id, max_hops)?;
84            for node in nodes {
85                if visited.insert(node.id.clone()) {
86                    result.push(node);
87                }
88            }
89        }
90
91        Ok(result)
92    }
93
94    /// Get neighbors of a node (1-hop).
95    pub fn neighbors(&self, node_id: &str) -> Result<Vec<GraphNode>, CodememError> {
96        let idx = self
97            .id_to_index
98            .get(node_id)
99            .ok_or_else(|| CodememError::NotFound(format!("Node {node_id}")))?;
100
101        let mut result = Vec::new();
102        for neighbor_idx in self.graph.neighbors(*idx) {
103            if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
104                if let Some(node) = self.nodes.get(neighbor_id) {
105                    result.push(node.clone());
106                }
107            }
108        }
109
110        Ok(result)
111    }
112
113    /// Return groups of connected node IDs.
114    ///
115    /// Treats the directed graph as undirected: two nodes are in the same
116    /// component if there is a path between them in either direction.
117    /// Each inner `Vec<String>` is one connected component.
118    pub fn connected_components(&self) -> Vec<Vec<String>> {
119        let mut visited: HashSet<NodeIndex> = HashSet::new();
120        let mut components: Vec<Vec<String>> = Vec::new();
121
122        for &start_idx in self.id_to_index.values() {
123            if visited.contains(&start_idx) {
124                continue;
125            }
126
127            // BFS treating edges as undirected
128            let mut component: Vec<String> = Vec::new();
129            let mut queue: VecDeque<NodeIndex> = VecDeque::new();
130            queue.push_back(start_idx);
131            visited.insert(start_idx);
132
133            while let Some(current) = queue.pop_front() {
134                if let Some(node_id) = self.graph.node_weight(current) {
135                    component.push(node_id.clone());
136                }
137
138                // Follow outgoing edges
139                for neighbor in self.graph.neighbors_directed(current, Direction::Outgoing) {
140                    if visited.insert(neighbor) {
141                        queue.push_back(neighbor);
142                    }
143                }
144
145                // Follow incoming edges (treat as undirected)
146                for neighbor in self.graph.neighbors_directed(current, Direction::Incoming) {
147                    if visited.insert(neighbor) {
148                        queue.push_back(neighbor);
149                    }
150                }
151            }
152
153            component.sort();
154            components.push(component);
155        }
156
157        components.sort();
158        components
159    }
160
161    /// Compute degree centrality for every node and update their `centrality` field.
162    ///
163    /// Degree centrality for node *v* is defined as:
164    ///   `(in_degree(v) + out_degree(v)) / (N - 1)`
165    /// where *N* is the total number of nodes.  When N <= 1, centrality is 0.
166    pub fn compute_centrality(&mut self) {
167        let n = self.nodes.len();
168        if n <= 1 {
169            for node in self.nodes.values_mut() {
170                node.centrality = 0.0;
171            }
172            return;
173        }
174
175        let denominator = (n - 1) as f64;
176
177        // Pre-compute centrality values by node ID.
178        let centrality_map: HashMap<String, f64> = self
179            .id_to_index
180            .iter()
181            .map(|(id, &idx)| {
182                let in_deg = self
183                    .graph
184                    .neighbors_directed(idx, Direction::Incoming)
185                    .count();
186                let out_deg = self
187                    .graph
188                    .neighbors_directed(idx, Direction::Outgoing)
189                    .count();
190                let centrality = (in_deg + out_deg) as f64 / denominator;
191                (id.clone(), centrality)
192            })
193            .collect();
194
195        // Apply centrality values to the stored nodes.
196        for (id, centrality) in &centrality_map {
197            if let Some(node) = self.nodes.get_mut(id) {
198                node.centrality = *centrality;
199            }
200        }
201    }
202
203    /// Return all nodes currently in the graph.
204    pub fn get_all_nodes(&self) -> Vec<GraphNode> {
205        self.nodes.values().cloned().collect()
206    }
207
208    /// Recompute and cache PageRank and betweenness centrality scores.
209    ///
210    /// This should be called after loading the graph (e.g., on server start)
211    /// and periodically when the graph changes significantly.
212    pub fn recompute_centrality(&mut self) {
213        self.cached_pagerank = self.pagerank(0.85, 100, 1e-6);
214        self.cached_betweenness = self.betweenness_centrality();
215    }
216
217    /// Get the cached PageRank score for a node. Returns 0.0 if not found.
218    pub fn get_pagerank(&self, node_id: &str) -> f64 {
219        self.cached_pagerank.get(node_id).copied().unwrap_or(0.0)
220    }
221
222    /// Get the cached betweenness centrality score for a node. Returns 0.0 if not found.
223    pub fn get_betweenness(&self, node_id: &str) -> f64 {
224        self.cached_betweenness.get(node_id).copied().unwrap_or(0.0)
225    }
226
227    /// Get the maximum degree (in + out) across all nodes in the graph.
228    /// Returns 1.0 if the graph has fewer than 2 nodes to avoid division by zero.
229    pub fn max_degree(&self) -> f64 {
230        if self.nodes.len() <= 1 {
231            return 1.0;
232        }
233        self.id_to_index
234            .values()
235            .map(|&idx| {
236                let in_deg = self
237                    .graph
238                    .neighbors_directed(idx, Direction::Incoming)
239                    .count();
240                let out_deg = self
241                    .graph
242                    .neighbors_directed(idx, Direction::Outgoing)
243                    .count();
244                (in_deg + out_deg) as f64
245            })
246            .fold(1.0f64, f64::max)
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use codemem_core::RelationshipType;
254
255    fn file_node(id: &str, label: &str) -> GraphNode {
256        GraphNode {
257            id: id.to_string(),
258            kind: NodeKind::File,
259            label: label.to_string(),
260            payload: HashMap::new(),
261            centrality: 0.0,
262            memory_id: None,
263            namespace: None,
264        }
265    }
266
267    fn test_edge(src: &str, dst: &str) -> Edge {
268        Edge {
269            id: format!("{src}->{dst}"),
270            src: src.to_string(),
271            dst: dst.to_string(),
272            relationship: RelationshipType::Contains,
273            weight: 1.0,
274            properties: HashMap::new(),
275            created_at: chrono::Utc::now(),
276            valid_from: None,
277            valid_to: None,
278        }
279    }
280
281    #[test]
282    fn connected_components_single_component() {
283        let mut graph = GraphEngine::new();
284        graph.add_node(file_node("a", "a.rs")).unwrap();
285        graph.add_node(file_node("b", "b.rs")).unwrap();
286        graph.add_node(file_node("c", "c.rs")).unwrap();
287        graph.add_edge(test_edge("a", "b")).unwrap();
288        graph.add_edge(test_edge("b", "c")).unwrap();
289
290        let components = graph.connected_components();
291        assert_eq!(components.len(), 1);
292        assert_eq!(components[0], vec!["a", "b", "c"]);
293    }
294
295    #[test]
296    fn connected_components_multiple() {
297        let mut graph = GraphEngine::new();
298        graph.add_node(file_node("a", "a.rs")).unwrap();
299        graph.add_node(file_node("b", "b.rs")).unwrap();
300        graph.add_node(file_node("c", "c.rs")).unwrap();
301        graph.add_node(file_node("d", "d.rs")).unwrap();
302        graph.add_edge(test_edge("a", "b")).unwrap();
303        graph.add_edge(test_edge("c", "d")).unwrap();
304
305        let components = graph.connected_components();
306        assert_eq!(components.len(), 2);
307        assert_eq!(components[0], vec!["a", "b"]);
308        assert_eq!(components[1], vec!["c", "d"]);
309    }
310
311    #[test]
312    fn connected_components_isolated_node() {
313        let mut graph = GraphEngine::new();
314        graph.add_node(file_node("a", "a.rs")).unwrap();
315        graph.add_node(file_node("b", "b.rs")).unwrap();
316        graph.add_node(file_node("c", "c.rs")).unwrap();
317        graph.add_edge(test_edge("a", "b")).unwrap();
318        // "c" is isolated
319
320        let components = graph.connected_components();
321        assert_eq!(components.len(), 2);
322        // Sorted: ["a","b"] comes before ["c"]
323        assert_eq!(components[0], vec!["a", "b"]);
324        assert_eq!(components[1], vec!["c"]);
325    }
326
327    #[test]
328    fn connected_components_reverse_edge_connects() {
329        // Directed edge c->a should still put a and c in the same component
330        // when treated as undirected.
331        let mut graph = GraphEngine::new();
332        graph.add_node(file_node("a", "a.rs")).unwrap();
333        graph.add_node(file_node("b", "b.rs")).unwrap();
334        graph.add_node(file_node("c", "c.rs")).unwrap();
335        graph.add_edge(test_edge("a", "b")).unwrap();
336        graph.add_edge(test_edge("c", "a")).unwrap();
337
338        let components = graph.connected_components();
339        assert_eq!(components.len(), 1);
340        assert_eq!(components[0], vec!["a", "b", "c"]);
341    }
342
343    #[test]
344    fn connected_components_empty_graph() {
345        let graph = GraphEngine::new();
346        let components = graph.connected_components();
347        assert!(components.is_empty());
348    }
349
350    #[test]
351    fn compute_centrality_simple() {
352        // Graph: a -> b -> c
353        // Node a: out=1, in=0 => centrality = 1/2 = 0.5
354        // Node b: out=1, in=1 => centrality = 2/2 = 1.0
355        // Node c: out=0, in=1 => centrality = 1/2 = 0.5
356        let mut graph = GraphEngine::new();
357        graph.add_node(file_node("a", "a.rs")).unwrap();
358        graph.add_node(file_node("b", "b.rs")).unwrap();
359        graph.add_node(file_node("c", "c.rs")).unwrap();
360        graph.add_edge(test_edge("a", "b")).unwrap();
361        graph.add_edge(test_edge("b", "c")).unwrap();
362
363        graph.compute_centrality();
364
365        let a = graph.get_node("a").unwrap().unwrap();
366        let b = graph.get_node("b").unwrap().unwrap();
367        let c = graph.get_node("c").unwrap().unwrap();
368
369        assert!((a.centrality - 0.5).abs() < f64::EPSILON);
370        assert!((b.centrality - 1.0).abs() < f64::EPSILON);
371        assert!((c.centrality - 0.5).abs() < f64::EPSILON);
372    }
373
374    #[test]
375    fn compute_centrality_star() {
376        // Graph: a -> b, a -> c, a -> d (star topology)
377        // Node a: out=3, in=0 => centrality = 3/3 = 1.0
378        // Node b: out=0, in=1 => centrality = 1/3
379        // Node c: out=0, in=1 => centrality = 1/3
380        // Node d: out=0, in=1 => centrality = 1/3
381        let mut graph = GraphEngine::new();
382        graph.add_node(file_node("a", "a.rs")).unwrap();
383        graph.add_node(file_node("b", "b.rs")).unwrap();
384        graph.add_node(file_node("c", "c.rs")).unwrap();
385        graph.add_node(file_node("d", "d.rs")).unwrap();
386        graph.add_edge(test_edge("a", "b")).unwrap();
387        graph.add_edge(test_edge("a", "c")).unwrap();
388        graph.add_edge(test_edge("a", "d")).unwrap();
389
390        graph.compute_centrality();
391
392        let a = graph.get_node("a").unwrap().unwrap();
393        let b = graph.get_node("b").unwrap().unwrap();
394
395        assert!((a.centrality - 1.0).abs() < f64::EPSILON);
396        assert!((b.centrality - 1.0 / 3.0).abs() < f64::EPSILON);
397    }
398
399    #[test]
400    fn compute_centrality_single_node() {
401        let mut graph = GraphEngine::new();
402        graph.add_node(file_node("a", "a.rs")).unwrap();
403
404        graph.compute_centrality();
405
406        let a = graph.get_node("a").unwrap().unwrap();
407        assert!((a.centrality - 0.0).abs() < f64::EPSILON);
408    }
409
410    #[test]
411    fn compute_centrality_no_edges() {
412        let mut graph = GraphEngine::new();
413        graph.add_node(file_node("a", "a.rs")).unwrap();
414        graph.add_node(file_node("b", "b.rs")).unwrap();
415
416        graph.compute_centrality();
417
418        let a = graph.get_node("a").unwrap().unwrap();
419        let b = graph.get_node("b").unwrap().unwrap();
420        assert!((a.centrality - 0.0).abs() < f64::EPSILON);
421        assert!((b.centrality - 0.0).abs() < f64::EPSILON);
422    }
423
424    #[test]
425    fn get_all_nodes_returns_all() {
426        let mut graph = GraphEngine::new();
427        graph.add_node(file_node("a", "a.rs")).unwrap();
428        graph.add_node(file_node("b", "b.rs")).unwrap();
429        graph.add_node(file_node("c", "c.rs")).unwrap();
430
431        let mut all = graph.get_all_nodes();
432        all.sort_by(|x, y| x.id.cmp(&y.id));
433        assert_eq!(all.len(), 3);
434        assert_eq!(all[0].id, "a");
435        assert_eq!(all[1].id, "b");
436        assert_eq!(all[2].id, "c");
437    }
438
439    // ── Centrality Caching Tests ────────────────────────────────────────────
440
441    #[test]
442    fn recompute_centrality_caches_pagerank() {
443        let mut graph = GraphEngine::new();
444        graph.add_node(file_node("a", "a.rs")).unwrap();
445        graph.add_node(file_node("b", "b.rs")).unwrap();
446        graph.add_node(file_node("c", "c.rs")).unwrap();
447        graph.add_edge(test_edge("a", "b")).unwrap();
448        graph.add_edge(test_edge("b", "c")).unwrap();
449
450        // Before recompute, cached values should be 0.0
451        assert_eq!(graph.get_pagerank("a"), 0.0);
452        assert_eq!(graph.get_betweenness("a"), 0.0);
453
454        graph.recompute_centrality();
455
456        // After recompute, cached PageRank values should be non-zero
457        assert!(graph.get_pagerank("a") > 0.0);
458        assert!(graph.get_pagerank("b") > 0.0);
459        assert!(graph.get_pagerank("c") > 0.0);
460
461        // c should have highest PageRank (sink node in a -> b -> c)
462        assert!(
463            graph.get_pagerank("c") > graph.get_pagerank("a"),
464            "c ({}) should have higher PageRank than a ({})",
465            graph.get_pagerank("c"),
466            graph.get_pagerank("a")
467        );
468
469        // b should have highest betweenness (middle of chain)
470        assert!(
471            graph.get_betweenness("b") > graph.get_betweenness("a"),
472            "b ({}) should have higher betweenness than a ({})",
473            graph.get_betweenness("b"),
474            graph.get_betweenness("a")
475        );
476    }
477
478    #[test]
479    fn get_pagerank_returns_zero_for_unknown_node() {
480        let graph = GraphEngine::new();
481        assert_eq!(graph.get_pagerank("nonexistent"), 0.0);
482        assert_eq!(graph.get_betweenness("nonexistent"), 0.0);
483    }
484
485    #[test]
486    fn max_degree_returns_correct_value() {
487        let mut graph = GraphEngine::new();
488        graph.add_node(file_node("a", "a.rs")).unwrap();
489        graph.add_node(file_node("b", "b.rs")).unwrap();
490        graph.add_node(file_node("c", "c.rs")).unwrap();
491        graph.add_node(file_node("d", "d.rs")).unwrap();
492        // a -> b, a -> c, a -> d (star: a has degree 3)
493        graph.add_edge(test_edge("a", "b")).unwrap();
494        graph.add_edge(test_edge("a", "c")).unwrap();
495        graph.add_edge(test_edge("a", "d")).unwrap();
496
497        assert!((graph.max_degree() - 3.0).abs() < f64::EPSILON);
498    }
499
500    #[test]
501    fn enhanced_graph_strength_differs_from_simple_edge_count() {
502        // Build a graph where PageRank/betweenness differ from simple edge count.
503        // a -> b -> c -> d (chain)
504        // b is in the middle with betweenness, c gets more PageRank flow
505        let mut graph = GraphEngine::new();
506        for id in &["a", "b", "c", "d"] {
507            graph.add_node(file_node(id, &format!("{id}.rs"))).unwrap();
508        }
509        graph.add_edge(test_edge("a", "b")).unwrap();
510        graph.add_edge(test_edge("b", "c")).unwrap();
511        graph.add_edge(test_edge("c", "d")).unwrap();
512        graph.recompute_centrality();
513
514        // Nodes b and c both have 2 edges (in+out), but different centrality profiles.
515        // b has higher betweenness (on path a->c and a->d).
516        // Simple edge count would give them equal scores.
517        let edges_b = graph.get_edges("b").unwrap().len();
518        let edges_c = graph.get_edges("c").unwrap().len();
519        assert_eq!(edges_b, edges_c, "b and c should have same edge count");
520
521        // But their centrality profiles should differ
522        let pr_b = graph.get_pagerank("b");
523        let pr_c = graph.get_pagerank("c");
524        let bt_b = graph.get_betweenness("b");
525        let bt_c = graph.get_betweenness("c");
526
527        // At least one centrality metric should differ between b and c
528        let centrality_differs = (pr_b - pr_c).abs() > 1e-6 || (bt_b - bt_c).abs() > 1e-6;
529        assert!(
530            centrality_differs,
531            "Centrality should differ: b(pr={pr_b}, bt={bt_b}) vs c(pr={pr_c}, bt={bt_c})"
532        );
533    }
534}