Skip to main content

codemem_graph/
lib.rs

1//! codemem-graph: Graph engine with petgraph algorithms and SQLite persistence.
2//!
3//! Provides BFS, DFS, shortest path, and connected components over
4//! a knowledge graph with 6 node types and 15 relationship types.
5
6mod algorithms;
7mod traversal;
8
9#[cfg(test)]
10use codemem_core::NodeKind;
11use codemem_core::{CodememError, Edge, GraphBackend, GraphNode};
12use petgraph::graph::{DiGraph, NodeIndex};
13use petgraph::Direction;
14use std::collections::{HashMap, HashSet, VecDeque};
15
16/// In-memory graph backed by petgraph, synced to SQLite via codemem-storage.
17pub struct GraphEngine {
18    pub(crate) graph: DiGraph<String, f64>,
19    /// Map from string node IDs to petgraph NodeIndex.
20    pub(crate) id_to_index: HashMap<String, NodeIndex>,
21    /// Node data by ID.
22    pub(crate) nodes: HashMap<String, GraphNode>,
23    /// Edge data by ID.
24    pub(crate) edges: HashMap<String, Edge>,
25    /// Cached PageRank scores (populated by `recompute_centrality()`).
26    pub(crate) cached_pagerank: HashMap<String, f64>,
27    /// Cached betweenness centrality scores (populated by `recompute_centrality()`).
28    pub(crate) cached_betweenness: HashMap<String, f64>,
29}
30
31impl GraphEngine {
32    /// Create a new empty graph.
33    pub fn new() -> Self {
34        Self {
35            graph: DiGraph::new(),
36            id_to_index: HashMap::new(),
37            nodes: HashMap::new(),
38            edges: HashMap::new(),
39            cached_pagerank: HashMap::new(),
40            cached_betweenness: HashMap::new(),
41        }
42    }
43
44    /// Load graph from storage.
45    pub fn from_storage(storage: &dyn codemem_core::StorageBackend) -> Result<Self, CodememError> {
46        let mut engine = Self::new();
47
48        // Load all nodes
49        let nodes = storage.all_graph_nodes()?;
50        for node in nodes {
51            engine.add_node(node)?;
52        }
53
54        // Load all edges
55        let edges = storage.all_graph_edges()?;
56        for edge in edges {
57            engine.add_edge(edge)?;
58        }
59
60        // Compute degree centrality so subgraph queries can rank nodes
61        engine.compute_centrality();
62
63        Ok(engine)
64    }
65
66    /// Get the number of nodes.
67    pub fn node_count(&self) -> usize {
68        self.nodes.len()
69    }
70
71    /// Get the number of edges.
72    pub fn edge_count(&self) -> usize {
73        self.edges.len()
74    }
75
76    /// Multi-hop expansion: given a set of node IDs, expand N hops to find related nodes.
77    pub fn expand(
78        &self,
79        start_ids: &[String],
80        max_hops: usize,
81    ) -> Result<Vec<GraphNode>, CodememError> {
82        let mut visited = std::collections::HashSet::new();
83        let mut result = Vec::new();
84
85        for start_id in start_ids {
86            let nodes = self.bfs(start_id, max_hops)?;
87            for node in nodes {
88                if visited.insert(node.id.clone()) {
89                    result.push(node);
90                }
91            }
92        }
93
94        Ok(result)
95    }
96
97    /// Get neighbors of a node (1-hop).
98    pub fn neighbors(&self, node_id: &str) -> Result<Vec<GraphNode>, CodememError> {
99        let idx = self
100            .id_to_index
101            .get(node_id)
102            .ok_or_else(|| CodememError::NotFound(format!("Node {node_id}")))?;
103
104        let mut result = Vec::new();
105        for neighbor_idx in self.graph.neighbors(*idx) {
106            if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
107                if let Some(node) = self.nodes.get(neighbor_id) {
108                    result.push(node.clone());
109                }
110            }
111        }
112
113        Ok(result)
114    }
115
116    /// Return groups of connected node IDs.
117    ///
118    /// Treats the directed graph as undirected: two nodes are in the same
119    /// component if there is a path between them in either direction.
120    /// Each inner `Vec<String>` is one connected component.
121    pub fn connected_components(&self) -> Vec<Vec<String>> {
122        let mut visited: HashSet<NodeIndex> = HashSet::new();
123        let mut components: Vec<Vec<String>> = Vec::new();
124
125        for &start_idx in self.id_to_index.values() {
126            if visited.contains(&start_idx) {
127                continue;
128            }
129
130            // BFS treating edges as undirected
131            let mut component: Vec<String> = Vec::new();
132            let mut queue: VecDeque<NodeIndex> = VecDeque::new();
133            queue.push_back(start_idx);
134            visited.insert(start_idx);
135
136            while let Some(current) = queue.pop_front() {
137                if let Some(node_id) = self.graph.node_weight(current) {
138                    component.push(node_id.clone());
139                }
140
141                // Follow outgoing edges
142                for neighbor in self.graph.neighbors_directed(current, Direction::Outgoing) {
143                    if visited.insert(neighbor) {
144                        queue.push_back(neighbor);
145                    }
146                }
147
148                // Follow incoming edges (treat as undirected)
149                for neighbor in self.graph.neighbors_directed(current, Direction::Incoming) {
150                    if visited.insert(neighbor) {
151                        queue.push_back(neighbor);
152                    }
153                }
154            }
155
156            component.sort();
157            components.push(component);
158        }
159
160        components.sort();
161        components
162    }
163
164    /// Compute degree centrality for every node and update their `centrality` field.
165    ///
166    /// Degree centrality for node *v* is defined as:
167    ///   `(in_degree(v) + out_degree(v)) / (N - 1)`
168    /// where *N* is the total number of nodes.  When N <= 1, centrality is 0.
169    pub fn compute_centrality(&mut self) {
170        let n = self.nodes.len();
171        if n <= 1 {
172            for node in self.nodes.values_mut() {
173                node.centrality = 0.0;
174            }
175            return;
176        }
177
178        let denominator = (n - 1) as f64;
179
180        // Pre-compute centrality values by node ID.
181        let centrality_map: HashMap<String, f64> = self
182            .id_to_index
183            .iter()
184            .map(|(id, &idx)| {
185                let in_deg = self
186                    .graph
187                    .neighbors_directed(idx, Direction::Incoming)
188                    .count();
189                let out_deg = self
190                    .graph
191                    .neighbors_directed(idx, Direction::Outgoing)
192                    .count();
193                let centrality = (in_deg + out_deg) as f64 / denominator;
194                (id.clone(), centrality)
195            })
196            .collect();
197
198        // Apply centrality values to the stored nodes.
199        for (id, centrality) in &centrality_map {
200            if let Some(node) = self.nodes.get_mut(id) {
201                node.centrality = *centrality;
202            }
203        }
204    }
205
206    /// Return all nodes currently in the graph.
207    pub fn get_all_nodes(&self) -> Vec<GraphNode> {
208        self.nodes.values().cloned().collect()
209    }
210
211    /// Recompute and cache PageRank and betweenness centrality scores.
212    ///
213    /// This should be called after loading the graph (e.g., on server start)
214    /// and periodically when the graph changes significantly.
215    pub fn recompute_centrality(&mut self) {
216        self.cached_pagerank = self.pagerank(0.85, 100, 1e-6);
217        self.cached_betweenness = self.betweenness_centrality();
218    }
219
220    /// Get the cached PageRank score for a node. Returns 0.0 if not found.
221    pub fn get_pagerank(&self, node_id: &str) -> f64 {
222        self.cached_pagerank.get(node_id).copied().unwrap_or(0.0)
223    }
224
225    /// Get the cached betweenness centrality score for a node. Returns 0.0 if not found.
226    pub fn get_betweenness(&self, node_id: &str) -> f64 {
227        self.cached_betweenness.get(node_id).copied().unwrap_or(0.0)
228    }
229
230    /// Compute graph strength for a memory node by bridging to code-graph centrality.
231    ///
232    /// Memory nodes (UUIDs) and code nodes (`sym:`, `file:`) exist in disconnected
233    /// ID spaces. This method looks up a memory node's neighbors and collects
234    /// centrality data from any connected code-graph nodes to produce a meaningful
235    /// graph_strength score.
236    pub fn graph_strength_for_memory(&self, memory_id: &str) -> f64 {
237        let idx = match self.id_to_index.get(memory_id) {
238            Some(idx) => *idx,
239            None => return 0.0,
240        };
241
242        let mut max_pagerank = 0.0_f64;
243        let mut max_betweenness = 0.0_f64;
244        let mut code_neighbor_count = 0_usize;
245        let mut total_edge_weight = 0.0_f64;
246
247        // Iterate both outgoing and incoming neighbors
248        for direction in &[Direction::Outgoing, Direction::Incoming] {
249            for neighbor_idx in self.graph.neighbors_directed(idx, *direction) {
250                if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
251                    // Only consider code-graph nodes (sym:, file:, chunk:, pkg:)
252                    if neighbor_id.starts_with("sym:")
253                        || neighbor_id.starts_with("file:")
254                        || neighbor_id.starts_with("chunk:")
255                        || neighbor_id.starts_with("pkg:")
256                    {
257                        code_neighbor_count += 1;
258                        let pr = self
259                            .cached_pagerank
260                            .get(neighbor_id)
261                            .copied()
262                            .unwrap_or(0.0);
263                        let bt = self
264                            .cached_betweenness
265                            .get(neighbor_id)
266                            .copied()
267                            .unwrap_or(0.0);
268                        max_pagerank = max_pagerank.max(pr);
269                        max_betweenness = max_betweenness.max(bt);
270
271                        // Collect edge weight from our edge metadata
272                        for edge in self.edges.values() {
273                            if (edge.src == memory_id && edge.dst == *neighbor_id)
274                                || (edge.dst == memory_id && edge.src == *neighbor_id)
275                            {
276                                total_edge_weight += edge.weight;
277                                break;
278                            }
279                        }
280                    }
281                }
282            }
283        }
284
285        if code_neighbor_count == 0 {
286            return 0.0;
287        }
288
289        let connectivity_bonus = (code_neighbor_count as f64 / 5.0).min(1.0);
290        let edge_weight_bonus = (total_edge_weight / code_neighbor_count as f64).min(1.0);
291
292        (0.4 * max_pagerank
293            + 0.3 * max_betweenness
294            + 0.2 * connectivity_bonus
295            + 0.1 * edge_weight_bonus)
296            .min(1.0)
297    }
298
299    /// Get the maximum degree (in + out) across all nodes in the graph.
300    /// Returns 1.0 if the graph has fewer than 2 nodes to avoid division by zero.
301    pub fn max_degree(&self) -> f64 {
302        if self.nodes.len() <= 1 {
303            return 1.0;
304        }
305        self.id_to_index
306            .values()
307            .map(|&idx| {
308                let in_deg = self
309                    .graph
310                    .neighbors_directed(idx, Direction::Incoming)
311                    .count();
312                let out_deg = self
313                    .graph
314                    .neighbors_directed(idx, Direction::Outgoing)
315                    .count();
316                (in_deg + out_deg) as f64
317            })
318            .fold(1.0f64, f64::max)
319    }
320}
321
322#[cfg(test)]
323#[path = "tests/lib_tests.rs"]
324mod tests;