Skip to main content

codemem_storage/graph/
mod.rs

1//! codemem-graph: Graph engine with petgraph algorithms and SQLite persistence.
2//!
3//! Provides BFS, DFS, shortest path, and connected components over
4//! a knowledge graph with 13 node kinds and 24 relationship types.
5
6mod algorithms;
7mod traversal;
8
9use codemem_core::{CodememError, Edge, GraphBackend, GraphNode, NodeKind};
10use petgraph::graph::{DiGraph, NodeIndex};
11use petgraph::Direction;
12use std::collections::{HashMap, HashSet, VecDeque};
13
14/// Raw graph metrics for a memory node, collected from its neighbors.
15///
16/// Returned by `GraphEngine::raw_graph_metrics_for_memory()` so that the
17/// scoring formula can live in the engine crate.
18#[derive(Debug, Clone)]
19pub struct RawGraphMetrics {
20    /// Highest PageRank score among code-graph neighbors.
21    pub max_pagerank: f64,
22    /// Highest betweenness centrality among code-graph neighbors.
23    pub max_betweenness: f64,
24    /// Number of code-graph neighbors (sym:, file:, chunk:, pkg:).
25    pub code_neighbor_count: usize,
26    /// Sum of edge weights connecting this memory to code-graph neighbors.
27    pub total_edge_weight: f64,
28    /// Number of memory-to-memory neighbors (UUID-based IDs).
29    pub memory_neighbor_count: usize,
30    /// Sum of edge weights connecting this memory to other memory nodes.
31    pub memory_edge_weight: f64,
32}
33
34/// In-memory graph engine backed by petgraph, synced to SQLite via codemem-storage.
35///
36/// # Design: intentional in-memory architecture
37///
38/// All graph data (nodes, edges, adjacency) is held entirely in memory using
39/// `HashMap`-based structures. This is deliberate: graph traversals, centrality
40/// algorithms, and multi-hop expansions benefit enormously from avoiding disk
41/// I/O on every edge follow. The trade-off is higher memory usage, which is
42/// acceptable for the typical code-graph sizes this engine targets.
43///
44/// # Memory characteristics
45///
46/// - **`nodes`**: `HashMap<String, GraphNode>` — ~200 bytes per node (ID, kind,
47///   label, namespace, metadata, centrality).
48/// - **`edges`**: `HashMap<String, Edge>` — ~150 bytes per edge (ID, src, dst,
49///   relationship, weight, properties, timestamps).
50/// - **`edge_adj`**: `HashMap<String, Vec<String>>` — adjacency index mapping
51///   node IDs to incident edge IDs for O(degree) lookups.
52/// - **`id_to_index`**: maps string IDs to petgraph `NodeIndex` values.
53/// - **`cached_pagerank` / `cached_betweenness`**: centrality caches populated
54///   by [`recompute_centrality()`](Self::recompute_centrality).
55///
56/// Use [`CodememEngine::graph_memory_estimate()`](../../codemem_engine) for a
57/// byte-level sizing estimate based on current node and edge counts.
58///
59/// # Thread safety
60///
61/// `GraphEngine` is **not** `Sync` — it stores mutable graph state without
62/// internal locking. Callers in codemem-engine wrap it in `Mutex<GraphEngine>`
63/// (via `lock_graph()`) to ensure exclusive access. All public `&mut self`
64/// methods (e.g., `add_node`, `recompute_centrality`) require the caller to
65/// hold the lock.
66pub struct GraphEngine {
67    pub(crate) graph: DiGraph<String, f64>,
68    /// Map from string node IDs to petgraph `NodeIndex`.
69    pub(crate) id_to_index: HashMap<String, NodeIndex>,
70    /// Node data by ID.
71    pub(crate) nodes: HashMap<String, GraphNode>,
72    /// Edge data by ID.
73    pub(crate) edges: HashMap<String, Edge>,
74    /// Edge adjacency index: maps node IDs to the IDs of edges incident on that node.
75    ///
76    /// Maintained alongside `edges` to allow O(degree) edge lookups instead of O(E).
77    /// The string duplication (~40 bytes/edge for source+target node ID copies) is
78    /// intentional: using `Arc<str>` for shared ownership would be too invasive for
79    /// the marginal memory savings, and the adjacency index is critical for
80    /// performance in `get_edges()`, `bfs_filtered()`, and `raw_graph_metrics_for_memory()`.
81    pub(crate) edge_adj: HashMap<String, Vec<String>>,
82    /// Cached PageRank scores (populated by [`recompute_centrality()`](Self::recompute_centrality)).
83    pub(crate) cached_pagerank: HashMap<String, f64>,
84    /// Cached betweenness centrality scores (populated by [`recompute_centrality()`](Self::recompute_centrality)).
85    pub(crate) cached_betweenness: HashMap<String, f64>,
86}
87
88impl GraphEngine {
89    /// Create a new empty graph.
90    pub fn new() -> Self {
91        Self {
92            graph: DiGraph::new(),
93            id_to_index: HashMap::new(),
94            nodes: HashMap::new(),
95            edges: HashMap::new(),
96            edge_adj: HashMap::new(),
97            cached_pagerank: HashMap::new(),
98            cached_betweenness: HashMap::new(),
99        }
100    }
101
102    /// Load graph from storage.
103    pub fn from_storage(storage: &dyn codemem_core::StorageBackend) -> Result<Self, CodememError> {
104        let mut engine = Self::new();
105
106        // Load all nodes
107        let nodes = storage.all_graph_nodes()?;
108        for node in nodes {
109            engine.add_node(node)?;
110        }
111
112        // Load all edges
113        let edges = storage.all_graph_edges()?;
114        for edge in edges {
115            engine.add_edge(edge)?;
116        }
117
118        // Compute degree centrality so subgraph queries can rank nodes
119        engine.compute_centrality();
120
121        Ok(engine)
122    }
123
124    /// Get the number of nodes.
125    pub fn node_count(&self) -> usize {
126        self.nodes.len()
127    }
128
129    /// Get the number of edges.
130    pub fn edge_count(&self) -> usize {
131        self.edges.len()
132    }
133
134    /// Multi-hop expansion: given a set of node IDs, expand N hops to find related nodes.
135    pub fn expand(
136        &self,
137        start_ids: &[String],
138        max_hops: usize,
139    ) -> Result<Vec<GraphNode>, CodememError> {
140        let mut visited = std::collections::HashSet::new();
141        let mut result = Vec::new();
142
143        for start_id in start_ids {
144            let nodes = self.bfs(start_id, max_hops)?;
145            for node in nodes {
146                if visited.insert(node.id.clone()) {
147                    result.push(node);
148                }
149            }
150        }
151
152        Ok(result)
153    }
154
155    /// Get neighbors of a node (1-hop).
156    pub fn neighbors(&self, node_id: &str) -> Result<Vec<GraphNode>, CodememError> {
157        let idx = self
158            .id_to_index
159            .get(node_id)
160            .ok_or_else(|| CodememError::NotFound(format!("Node {node_id}")))?;
161
162        let mut result = Vec::new();
163        for neighbor_idx in self.graph.neighbors(*idx) {
164            if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
165                if let Some(node) = self.nodes.get(neighbor_id) {
166                    result.push(node.clone());
167                }
168            }
169        }
170
171        Ok(result)
172    }
173
174    /// Return groups of connected node IDs.
175    ///
176    /// Treats the directed graph as undirected: two nodes are in the same
177    /// component if there is a path between them in either direction.
178    /// Each inner `Vec<String>` is one connected component.
179    pub fn connected_components(&self) -> Vec<Vec<String>> {
180        let mut visited: HashSet<NodeIndex> = HashSet::new();
181        let mut components: Vec<Vec<String>> = Vec::new();
182
183        for &start_idx in self.id_to_index.values() {
184            if visited.contains(&start_idx) {
185                continue;
186            }
187
188            // BFS treating edges as undirected
189            let mut component: Vec<String> = Vec::new();
190            let mut queue: VecDeque<NodeIndex> = VecDeque::new();
191            queue.push_back(start_idx);
192            visited.insert(start_idx);
193
194            while let Some(current) = queue.pop_front() {
195                if let Some(node_id) = self.graph.node_weight(current) {
196                    component.push(node_id.clone());
197                }
198
199                // Follow outgoing edges
200                for neighbor in self.graph.neighbors_directed(current, Direction::Outgoing) {
201                    if visited.insert(neighbor) {
202                        queue.push_back(neighbor);
203                    }
204                }
205
206                // Follow incoming edges (treat as undirected)
207                for neighbor in self.graph.neighbors_directed(current, Direction::Incoming) {
208                    if visited.insert(neighbor) {
209                        queue.push_back(neighbor);
210                    }
211                }
212            }
213
214            component.sort();
215            components.push(component);
216        }
217
218        components.sort();
219        components
220    }
221
222    /// Compute degree centrality for every node and update their `centrality` field.
223    ///
224    /// Degree centrality for node *v* is defined as:
225    ///   `(in_degree(v) + out_degree(v)) / (N - 1)`
226    /// where *N* is the total number of nodes.  When N <= 1, centrality is 0.
227    pub fn compute_centrality(&mut self) {
228        let n = self.nodes.len();
229        if n <= 1 {
230            for node in self.nodes.values_mut() {
231                node.centrality = 0.0;
232            }
233            return;
234        }
235
236        let denominator = (n - 1) as f64;
237
238        // Pre-compute centrality values by node ID.
239        let centrality_map: HashMap<String, f64> = self
240            .id_to_index
241            .iter()
242            .map(|(id, &idx)| {
243                let in_deg = self
244                    .graph
245                    .neighbors_directed(idx, Direction::Incoming)
246                    .count();
247                let out_deg = self
248                    .graph
249                    .neighbors_directed(idx, Direction::Outgoing)
250                    .count();
251                let centrality = (in_deg + out_deg) as f64 / denominator;
252                (id.clone(), centrality)
253            })
254            .collect();
255
256        // Apply centrality values to the stored nodes.
257        for (id, centrality) in &centrality_map {
258            if let Some(node) = self.nodes.get_mut(id) {
259                node.centrality = *centrality;
260            }
261        }
262    }
263
264    /// Return all nodes currently in the graph.
265    pub fn get_all_nodes(&self) -> Vec<GraphNode> {
266        self.nodes.values().cloned().collect()
267    }
268
269    /// Return a reference to a node without cloning. Returns `None` if not found.
270    pub fn get_node_ref(&self, id: &str) -> Option<&GraphNode> {
271        self.nodes.get(id)
272    }
273
274    /// Return references to edges incident on a node without cloning.
275    ///
276    /// This is the zero-copy variant of [`GraphBackend::get_edges()`] — same
277    /// lookup logic via `edge_adj`, but returns `&Edge` instead of owned `Edge`.
278    pub fn get_edges_ref(&self, node_id: &str) -> Vec<&Edge> {
279        self.edge_adj
280            .get(node_id)
281            .map(|edge_ids| {
282                edge_ids
283                    .iter()
284                    .filter_map(|eid| self.edges.get(eid))
285                    .collect()
286            })
287            .unwrap_or_default()
288    }
289
290    /// Recompute and cache PageRank and betweenness centrality scores.
291    ///
292    /// This should be called after loading the graph (e.g., on server start)
293    /// and periodically when the graph changes significantly.
294    pub fn recompute_centrality(&mut self) {
295        self.recompute_centrality_with_options(true);
296    }
297
298    /// Recompute centrality caches with control over which algorithms run.
299    ///
300    /// PageRank is always computed. Betweenness centrality is only computed
301    /// when `include_betweenness` is true, since it is O(V * E) and can be
302    /// expensive on large graphs.
303    pub fn recompute_centrality_with_options(&mut self, include_betweenness: bool) {
304        self.cached_pagerank = self.pagerank(0.85, 100, 1e-6);
305        if include_betweenness {
306            self.cached_betweenness = self.betweenness_centrality();
307        } else {
308            // L1: Clear stale betweenness cache so ensure_betweenness_computed()
309            // knows it needs to recompute when lazily invoked.
310            self.cached_betweenness.clear();
311        }
312    }
313
314    /// Lazily ensure betweenness centrality has been computed.
315    ///
316    /// If `cached_betweenness` is empty (e.g., after `recompute_centrality_with_options(false)`),
317    /// this method computes and caches betweenness centrality on demand. If the
318    /// cache is already populated, this is a no-op.
319    pub fn ensure_betweenness_computed(&mut self) {
320        if self.cached_betweenness.is_empty() && self.graph.node_count() > 0 {
321            self.cached_betweenness = self.betweenness_centrality();
322        }
323    }
324
325    /// Get the cached PageRank score for a node. Returns 0.0 if not found.
326    pub fn get_pagerank(&self, node_id: &str) -> f64 {
327        self.cached_pagerank.get(node_id).copied().unwrap_or(0.0)
328    }
329
330    /// Get the cached betweenness centrality score for a node. Returns 0.0 if not found.
331    pub fn get_betweenness(&self, node_id: &str) -> f64 {
332        self.cached_betweenness.get(node_id).copied().unwrap_or(0.0)
333    }
334
335    /// Collect raw graph metrics for a memory node from both code-graph and
336    /// memory-graph neighbors.
337    ///
338    /// Code neighbors (`sym:`, `file:`, `chunk:`, `pkg:`) contribute centrality
339    /// data (PageRank, betweenness). Memory neighbors (UUID-based) contribute
340    /// connectivity data (count + edge weights). A function with many linked
341    /// memories is more important; a memory with many linked memories has richer
342    /// context.
343    ///
344    /// Returns `None` if the memory node is not in the graph or has no neighbors.
345    pub fn raw_graph_metrics_for_memory(&self, memory_id: &str) -> Option<RawGraphMetrics> {
346        let idx = *self.id_to_index.get(memory_id)?;
347
348        let mut max_pagerank = 0.0_f64;
349        let mut max_betweenness = 0.0_f64;
350        let mut code_neighbor_count = 0_usize;
351        let mut total_edge_weight = 0.0_f64;
352        let mut memory_neighbor_count = 0_usize;
353        let mut memory_edge_weight = 0.0_f64;
354
355        // Iterate both outgoing and incoming neighbors
356        for direction in &[Direction::Outgoing, Direction::Incoming] {
357            for neighbor_idx in self.graph.neighbors_directed(idx, *direction) {
358                if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
359                    let is_code_node = self
360                        .nodes
361                        .get(neighbor_id.as_str())
362                        .map(|n| n.kind != NodeKind::Memory)
363                        .unwrap_or(false);
364
365                    // Collect edge weight from our edge adjacency index
366                    let mut edge_w = 0.0_f64;
367                    if let Some(edge_ids) = self.edge_adj.get(memory_id) {
368                        for eid in edge_ids {
369                            if let Some(edge) = self.edges.get(eid) {
370                                if (edge.src == memory_id && edge.dst == *neighbor_id)
371                                    || (edge.dst == memory_id && edge.src == *neighbor_id)
372                                {
373                                    edge_w = edge.weight;
374                                    break;
375                                }
376                            }
377                        }
378                    }
379
380                    if is_code_node {
381                        code_neighbor_count += 1;
382                        total_edge_weight += edge_w;
383
384                        let pr = self
385                            .cached_pagerank
386                            .get(neighbor_id)
387                            .copied()
388                            .unwrap_or(0.0);
389                        let bt = self
390                            .cached_betweenness
391                            .get(neighbor_id)
392                            .copied()
393                            .unwrap_or(0.0);
394                        max_pagerank = max_pagerank.max(pr);
395                        max_betweenness = max_betweenness.max(bt);
396                    } else {
397                        memory_neighbor_count += 1;
398                        memory_edge_weight += edge_w;
399                    }
400                }
401            }
402        }
403
404        if code_neighbor_count == 0 && memory_neighbor_count == 0 {
405            return None;
406        }
407
408        Some(RawGraphMetrics {
409            max_pagerank,
410            max_betweenness,
411            code_neighbor_count,
412            total_edge_weight,
413            memory_neighbor_count,
414            memory_edge_weight,
415        })
416    }
417
418    /// Get the maximum degree (in + out) across all nodes in the graph.
419    /// Returns 1.0 if the graph has fewer than 2 nodes to avoid division by zero.
420    pub fn max_degree(&self) -> f64 {
421        if self.nodes.len() <= 1 {
422            return 1.0;
423        }
424        self.id_to_index
425            .values()
426            .map(|&idx| {
427                let in_deg = self
428                    .graph
429                    .neighbors_directed(idx, Direction::Incoming)
430                    .count();
431                let out_deg = self
432                    .graph
433                    .neighbors_directed(idx, Direction::Outgoing)
434                    .count();
435                (in_deg + out_deg) as f64
436            })
437            .fold(1.0f64, f64::max)
438    }
439}
440
441#[cfg(test)]
442#[path = "../tests/graph_tests.rs"]
443mod tests;