Skip to main content

codemem_storage/graph/
mod.rs

1//! codemem-graph: Graph engine with petgraph algorithms and SQLite persistence.
2//!
3//! Provides BFS, DFS, shortest path, and connected components over
4//! a knowledge graph with 13 node kinds and 24 relationship types.
5
6mod algorithms;
7mod traversal;
8
9#[cfg(test)]
10use codemem_core::NodeKind;
11use codemem_core::{CodememError, Edge, GraphBackend, GraphNode};
12use petgraph::graph::{DiGraph, NodeIndex};
13use petgraph::Direction;
14use std::collections::{HashMap, HashSet, VecDeque};
15
16/// Raw graph metrics for a memory node, collected from its code-graph neighbors.
17///
18/// Returned by `GraphEngine::raw_graph_metrics_for_memory()` so that the
19/// scoring formula can live in the engine crate.
20#[derive(Debug, Clone)]
21pub struct RawGraphMetrics {
22    /// Highest PageRank score among code-graph neighbors.
23    pub max_pagerank: f64,
24    /// Highest betweenness centrality among code-graph neighbors.
25    pub max_betweenness: f64,
26    /// Number of code-graph neighbors (sym:, file:, chunk:, pkg:).
27    pub code_neighbor_count: usize,
28    /// Sum of edge weights connecting this memory to code-graph neighbors.
29    pub total_edge_weight: f64,
30}
31
32/// In-memory graph engine backed by petgraph, synced to SQLite via codemem-storage.
33///
34/// # Design: intentional in-memory architecture
35///
36/// All graph data (nodes, edges, adjacency) is held entirely in memory using
37/// `HashMap`-based structures. This is deliberate: graph traversals, centrality
38/// algorithms, and multi-hop expansions benefit enormously from avoiding disk
39/// I/O on every edge follow. The trade-off is higher memory usage, which is
40/// acceptable for the typical code-graph sizes this engine targets.
41///
42/// # Memory characteristics
43///
44/// - **`nodes`**: `HashMap<String, GraphNode>` — ~200 bytes per node (ID, kind,
45///   label, namespace, metadata, centrality).
46/// - **`edges`**: `HashMap<String, Edge>` — ~150 bytes per edge (ID, src, dst,
47///   relationship, weight, properties, timestamps).
48/// - **`edge_adj`**: `HashMap<String, Vec<String>>` — adjacency index mapping
49///   node IDs to incident edge IDs for O(degree) lookups.
50/// - **`id_to_index`**: maps string IDs to petgraph `NodeIndex` values.
51/// - **`cached_pagerank` / `cached_betweenness`**: centrality caches populated
52///   by [`recompute_centrality()`](Self::recompute_centrality).
53///
54/// Use [`CodememEngine::graph_memory_estimate()`](../../codemem_engine) for a
55/// byte-level sizing estimate based on current node and edge counts.
56///
57/// # Thread safety
58///
59/// `GraphEngine` is **not** `Sync` — it stores mutable graph state without
60/// internal locking. Callers in codemem-engine wrap it in `Mutex<GraphEngine>`
61/// (via `lock_graph()`) to ensure exclusive access. All public `&mut self`
62/// methods (e.g., `add_node`, `recompute_centrality`) require the caller to
63/// hold the lock.
64pub struct GraphEngine {
65    pub(crate) graph: DiGraph<String, f64>,
66    /// Map from string node IDs to petgraph `NodeIndex`.
67    pub(crate) id_to_index: HashMap<String, NodeIndex>,
68    /// Node data by ID.
69    pub(crate) nodes: HashMap<String, GraphNode>,
70    /// Edge data by ID.
71    pub(crate) edges: HashMap<String, Edge>,
72    /// Edge adjacency index: maps node IDs to the IDs of edges incident on that node.
73    ///
74    /// Maintained alongside `edges` to allow O(degree) edge lookups instead of O(E).
75    /// The string duplication (~40 bytes/edge for source+target node ID copies) is
76    /// intentional: using `Arc<str>` for shared ownership would be too invasive for
77    /// the marginal memory savings, and the adjacency index is critical for
78    /// performance in `get_edges()`, `bfs_filtered()`, and `raw_graph_metrics_for_memory()`.
79    pub(crate) edge_adj: HashMap<String, Vec<String>>,
80    /// Cached PageRank scores (populated by [`recompute_centrality()`](Self::recompute_centrality)).
81    pub(crate) cached_pagerank: HashMap<String, f64>,
82    /// Cached betweenness centrality scores (populated by [`recompute_centrality()`](Self::recompute_centrality)).
83    pub(crate) cached_betweenness: HashMap<String, f64>,
84}
85
86impl GraphEngine {
87    /// Create a new empty graph.
88    pub fn new() -> Self {
89        Self {
90            graph: DiGraph::new(),
91            id_to_index: HashMap::new(),
92            nodes: HashMap::new(),
93            edges: HashMap::new(),
94            edge_adj: HashMap::new(),
95            cached_pagerank: HashMap::new(),
96            cached_betweenness: HashMap::new(),
97        }
98    }
99
100    /// Load graph from storage.
101    pub fn from_storage(storage: &dyn codemem_core::StorageBackend) -> Result<Self, CodememError> {
102        let mut engine = Self::new();
103
104        // Load all nodes
105        let nodes = storage.all_graph_nodes()?;
106        for node in nodes {
107            engine.add_node(node)?;
108        }
109
110        // Load all edges
111        let edges = storage.all_graph_edges()?;
112        for edge in edges {
113            engine.add_edge(edge)?;
114        }
115
116        // Compute degree centrality so subgraph queries can rank nodes
117        engine.compute_centrality();
118
119        Ok(engine)
120    }
121
122    /// Get the number of nodes.
123    pub fn node_count(&self) -> usize {
124        self.nodes.len()
125    }
126
127    /// Get the number of edges.
128    pub fn edge_count(&self) -> usize {
129        self.edges.len()
130    }
131
132    /// Multi-hop expansion: given a set of node IDs, expand N hops to find related nodes.
133    pub fn expand(
134        &self,
135        start_ids: &[String],
136        max_hops: usize,
137    ) -> Result<Vec<GraphNode>, CodememError> {
138        let mut visited = std::collections::HashSet::new();
139        let mut result = Vec::new();
140
141        for start_id in start_ids {
142            let nodes = self.bfs(start_id, max_hops)?;
143            for node in nodes {
144                if visited.insert(node.id.clone()) {
145                    result.push(node);
146                }
147            }
148        }
149
150        Ok(result)
151    }
152
153    /// Get neighbors of a node (1-hop).
154    pub fn neighbors(&self, node_id: &str) -> Result<Vec<GraphNode>, CodememError> {
155        let idx = self
156            .id_to_index
157            .get(node_id)
158            .ok_or_else(|| CodememError::NotFound(format!("Node {node_id}")))?;
159
160        let mut result = Vec::new();
161        for neighbor_idx in self.graph.neighbors(*idx) {
162            if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
163                if let Some(node) = self.nodes.get(neighbor_id) {
164                    result.push(node.clone());
165                }
166            }
167        }
168
169        Ok(result)
170    }
171
172    /// Return groups of connected node IDs.
173    ///
174    /// Treats the directed graph as undirected: two nodes are in the same
175    /// component if there is a path between them in either direction.
176    /// Each inner `Vec<String>` is one connected component.
177    pub fn connected_components(&self) -> Vec<Vec<String>> {
178        let mut visited: HashSet<NodeIndex> = HashSet::new();
179        let mut components: Vec<Vec<String>> = Vec::new();
180
181        for &start_idx in self.id_to_index.values() {
182            if visited.contains(&start_idx) {
183                continue;
184            }
185
186            // BFS treating edges as undirected
187            let mut component: Vec<String> = Vec::new();
188            let mut queue: VecDeque<NodeIndex> = VecDeque::new();
189            queue.push_back(start_idx);
190            visited.insert(start_idx);
191
192            while let Some(current) = queue.pop_front() {
193                if let Some(node_id) = self.graph.node_weight(current) {
194                    component.push(node_id.clone());
195                }
196
197                // Follow outgoing edges
198                for neighbor in self.graph.neighbors_directed(current, Direction::Outgoing) {
199                    if visited.insert(neighbor) {
200                        queue.push_back(neighbor);
201                    }
202                }
203
204                // Follow incoming edges (treat as undirected)
205                for neighbor in self.graph.neighbors_directed(current, Direction::Incoming) {
206                    if visited.insert(neighbor) {
207                        queue.push_back(neighbor);
208                    }
209                }
210            }
211
212            component.sort();
213            components.push(component);
214        }
215
216        components.sort();
217        components
218    }
219
220    /// Compute degree centrality for every node and update their `centrality` field.
221    ///
222    /// Degree centrality for node *v* is defined as:
223    ///   `(in_degree(v) + out_degree(v)) / (N - 1)`
224    /// where *N* is the total number of nodes.  When N <= 1, centrality is 0.
225    pub fn compute_centrality(&mut self) {
226        let n = self.nodes.len();
227        if n <= 1 {
228            for node in self.nodes.values_mut() {
229                node.centrality = 0.0;
230            }
231            return;
232        }
233
234        let denominator = (n - 1) as f64;
235
236        // Pre-compute centrality values by node ID.
237        let centrality_map: HashMap<String, f64> = self
238            .id_to_index
239            .iter()
240            .map(|(id, &idx)| {
241                let in_deg = self
242                    .graph
243                    .neighbors_directed(idx, Direction::Incoming)
244                    .count();
245                let out_deg = self
246                    .graph
247                    .neighbors_directed(idx, Direction::Outgoing)
248                    .count();
249                let centrality = (in_deg + out_deg) as f64 / denominator;
250                (id.clone(), centrality)
251            })
252            .collect();
253
254        // Apply centrality values to the stored nodes.
255        for (id, centrality) in &centrality_map {
256            if let Some(node) = self.nodes.get_mut(id) {
257                node.centrality = *centrality;
258            }
259        }
260    }
261
262    /// Return all nodes currently in the graph.
263    pub fn get_all_nodes(&self) -> Vec<GraphNode> {
264        self.nodes.values().cloned().collect()
265    }
266
267    /// Return references to all node IDs without cloning.
268    pub fn get_all_node_ids(&self) -> Vec<&str> {
269        self.nodes.keys().map(|s| s.as_str()).collect()
270    }
271
272    /// Return a reference to a node without cloning. Returns `None` if not found.
273    pub fn get_node_ref(&self, id: &str) -> Option<&GraphNode> {
274        self.nodes.get(id)
275    }
276
277    /// Return references to edges incident on a node without cloning.
278    ///
279    /// This is the zero-copy variant of [`GraphBackend::get_edges()`] — same
280    /// lookup logic via `edge_adj`, but returns `&Edge` instead of owned `Edge`.
281    pub fn get_edges_ref(&self, node_id: &str) -> Vec<&Edge> {
282        self.edge_adj
283            .get(node_id)
284            .map(|edge_ids| {
285                edge_ids
286                    .iter()
287                    .filter_map(|eid| self.edges.get(eid))
288                    .collect()
289            })
290            .unwrap_or_default()
291    }
292
293    /// Recompute and cache PageRank and betweenness centrality scores.
294    ///
295    /// This should be called after loading the graph (e.g., on server start)
296    /// and periodically when the graph changes significantly.
297    pub fn recompute_centrality(&mut self) {
298        self.recompute_centrality_with_options(true);
299    }
300
301    /// Recompute centrality caches with control over which algorithms run.
302    ///
303    /// PageRank is always computed. Betweenness centrality is only computed
304    /// when `include_betweenness` is true, since it is O(V * E) and can be
305    /// expensive on large graphs.
306    pub fn recompute_centrality_with_options(&mut self, include_betweenness: bool) {
307        self.cached_pagerank = self.pagerank(0.85, 100, 1e-6);
308        if include_betweenness {
309            self.cached_betweenness = self.betweenness_centrality();
310        } else {
311            // L1: Clear stale betweenness cache so ensure_betweenness_computed()
312            // knows it needs to recompute when lazily invoked.
313            self.cached_betweenness.clear();
314        }
315    }
316
317    /// Lazily ensure betweenness centrality has been computed.
318    ///
319    /// If `cached_betweenness` is empty (e.g., after `recompute_centrality_with_options(false)`),
320    /// this method computes and caches betweenness centrality on demand. If the
321    /// cache is already populated, this is a no-op.
322    pub fn ensure_betweenness_computed(&mut self) {
323        if self.cached_betweenness.is_empty() && self.graph.node_count() > 0 {
324            self.cached_betweenness = self.betweenness_centrality();
325        }
326    }
327
328    /// Get the cached PageRank score for a node. Returns 0.0 if not found.
329    pub fn get_pagerank(&self, node_id: &str) -> f64 {
330        self.cached_pagerank.get(node_id).copied().unwrap_or(0.0)
331    }
332
333    /// Get the cached betweenness centrality score for a node. Returns 0.0 if not found.
334    pub fn get_betweenness(&self, node_id: &str) -> f64 {
335        self.cached_betweenness.get(node_id).copied().unwrap_or(0.0)
336    }
337
338    /// Collect raw graph metrics for a memory node by bridging to code-graph neighbors.
339    ///
340    /// Memory nodes (UUIDs) and code nodes (`sym:`, `file:`) exist in disconnected
341    /// ID spaces. This method looks up a memory node's neighbors and collects
342    /// centrality data from any connected code-graph nodes.
343    ///
344    /// Returns `None` if the memory node is not in the graph or has no code neighbors.
345    pub fn raw_graph_metrics_for_memory(&self, memory_id: &str) -> Option<RawGraphMetrics> {
346        let idx = *self.id_to_index.get(memory_id)?;
347
348        let mut max_pagerank = 0.0_f64;
349        let mut max_betweenness = 0.0_f64;
350        let mut code_neighbor_count = 0_usize;
351        let mut total_edge_weight = 0.0_f64;
352
353        // Iterate both outgoing and incoming neighbors
354        for direction in &[Direction::Outgoing, Direction::Incoming] {
355            for neighbor_idx in self.graph.neighbors_directed(idx, *direction) {
356                if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
357                    // Only consider code-graph nodes (sym:, file:, chunk:, pkg:)
358                    if neighbor_id.starts_with("sym:")
359                        || neighbor_id.starts_with("file:")
360                        || neighbor_id.starts_with("chunk:")
361                        || neighbor_id.starts_with("pkg:")
362                    {
363                        code_neighbor_count += 1;
364                        let pr = self
365                            .cached_pagerank
366                            .get(neighbor_id)
367                            .copied()
368                            .unwrap_or(0.0);
369                        let bt = self
370                            .cached_betweenness
371                            .get(neighbor_id)
372                            .copied()
373                            .unwrap_or(0.0);
374                        max_pagerank = max_pagerank.max(pr);
375                        max_betweenness = max_betweenness.max(bt);
376
377                        // Collect edge weight from our edge adjacency index
378                        if let Some(edge_ids) = self.edge_adj.get(memory_id) {
379                            for eid in edge_ids {
380                                if let Some(edge) = self.edges.get(eid) {
381                                    if (edge.src == memory_id && edge.dst == *neighbor_id)
382                                        || (edge.dst == memory_id && edge.src == *neighbor_id)
383                                    {
384                                        total_edge_weight += edge.weight;
385                                        break;
386                                    }
387                                }
388                            }
389                        }
390                    }
391                }
392            }
393        }
394
395        if code_neighbor_count == 0 {
396            return None;
397        }
398
399        Some(RawGraphMetrics {
400            max_pagerank,
401            max_betweenness,
402            code_neighbor_count,
403            total_edge_weight,
404        })
405    }
406
407    /// Get the maximum degree (in + out) across all nodes in the graph.
408    /// Returns 1.0 if the graph has fewer than 2 nodes to avoid division by zero.
409    pub fn max_degree(&self) -> f64 {
410        if self.nodes.len() <= 1 {
411            return 1.0;
412        }
413        self.id_to_index
414            .values()
415            .map(|&idx| {
416                let in_deg = self
417                    .graph
418                    .neighbors_directed(idx, Direction::Incoming)
419                    .count();
420                let out_deg = self
421                    .graph
422                    .neighbors_directed(idx, Direction::Outgoing)
423                    .count();
424                (in_deg + out_deg) as f64
425            })
426            .fold(1.0f64, f64::max)
427    }
428}
429
430#[cfg(test)]
431#[path = "../tests/graph_tests.rs"]
432mod tests;