Skip to main content

codemem_storage/graph/
mod.rs

1//! codemem-graph: Graph engine with petgraph algorithms and SQLite persistence.
2//!
3//! Provides BFS, DFS, shortest path, and connected components over
4//! a knowledge graph with 13 node kinds and 24 relationship types.
5
6mod algorithms;
7mod traversal;
8
9use codemem_core::{CodememError, Edge, GraphBackend, GraphNode, NodeKind, RawGraphMetrics};
10use petgraph::graph::{DiGraph, NodeIndex};
11use petgraph::Direction;
12use std::collections::{HashMap, HashSet, VecDeque};
13
14/// In-memory graph engine backed by petgraph, synced to SQLite via codemem-storage.
15///
16/// # Design: intentional in-memory architecture
17///
18/// All graph data (nodes, edges, adjacency) is held entirely in memory using
19/// `HashMap`-based structures. This is deliberate: graph traversals, centrality
20/// algorithms, and multi-hop expansions benefit enormously from avoiding disk
21/// I/O on every edge follow. The trade-off is higher memory usage, which is
22/// acceptable for the typical code-graph sizes this engine targets.
23///
24/// # Memory characteristics
25///
26/// - **`nodes`**: `HashMap<String, GraphNode>` — ~200 bytes per node (ID, kind,
27///   label, namespace, metadata, centrality).
28/// - **`edges`**: `HashMap<String, Edge>` — ~150 bytes per edge (ID, src, dst,
29///   relationship, weight, properties, timestamps).
30/// - **`edge_adj`**: `HashMap<String, Vec<String>>` — adjacency index mapping
31///   node IDs to incident edge IDs for O(degree) lookups.
32/// - **`id_to_index`**: maps string IDs to petgraph `NodeIndex` values.
33/// - **`cached_pagerank` / `cached_betweenness`**: centrality caches populated
34///   by [`recompute_centrality()`](Self::recompute_centrality).
35///
36/// Use [`CodememEngine::graph_memory_estimate()`](../../codemem_engine) for a
37/// byte-level sizing estimate based on current node and edge counts.
38///
39/// # Thread safety
40///
41/// `GraphEngine` is **not** `Sync` — it stores mutable graph state without
42/// internal locking. Callers in codemem-engine wrap it in `Mutex<GraphEngine>`
43/// (via `lock_graph()`) to ensure exclusive access. All public `&mut self`
44/// methods (e.g., `add_node`, `recompute_centrality`) require the caller to
45/// hold the lock.
46pub struct GraphEngine {
47    pub(crate) graph: DiGraph<String, f64>,
48    /// Map from string node IDs to petgraph `NodeIndex`.
49    pub(crate) id_to_index: HashMap<String, NodeIndex>,
50    /// Node data by ID.
51    pub(crate) nodes: HashMap<String, GraphNode>,
52    /// Edge data by ID.
53    pub(crate) edges: HashMap<String, Edge>,
54    /// Edge adjacency index: maps node IDs to the IDs of edges incident on that node.
55    ///
56    /// Maintained alongside `edges` to allow O(degree) edge lookups instead of O(E).
57    /// The string duplication (~40 bytes/edge for source+target node ID copies) is
58    /// intentional: using `Arc<str>` for shared ownership would be too invasive for
59    /// the marginal memory savings, and the adjacency index is critical for
60    /// performance in `get_edges()`, `bfs_filtered()`, and `raw_graph_metrics_for_memory()`.
61    pub(crate) edge_adj: HashMap<String, Vec<String>>,
62    /// Cached PageRank scores (populated by [`recompute_centrality()`](Self::recompute_centrality)).
63    pub(crate) cached_pagerank: HashMap<String, f64>,
64    /// Cached betweenness centrality scores (populated by [`recompute_centrality()`](Self::recompute_centrality)).
65    pub(crate) cached_betweenness: HashMap<String, f64>,
66}
67
68impl GraphEngine {
69    /// Create a new empty graph.
70    pub fn new() -> Self {
71        Self {
72            graph: DiGraph::new(),
73            id_to_index: HashMap::new(),
74            nodes: HashMap::new(),
75            edges: HashMap::new(),
76            edge_adj: HashMap::new(),
77            cached_pagerank: HashMap::new(),
78            cached_betweenness: HashMap::new(),
79        }
80    }
81
82    /// Load graph from storage.
83    pub fn from_storage(storage: &dyn codemem_core::StorageBackend) -> Result<Self, CodememError> {
84        let mut engine = Self::new();
85
86        // Load all nodes
87        let nodes = storage.all_graph_nodes()?;
88        for node in nodes {
89            engine.add_node(node)?;
90        }
91
92        // Load all edges
93        let edges = storage.all_graph_edges()?;
94        for edge in edges {
95            engine.add_edge(edge)?;
96        }
97
98        // Compute degree centrality so subgraph queries can rank nodes
99        engine.compute_centrality();
100
101        Ok(engine)
102    }
103
104    /// Get the number of nodes.
105    pub fn node_count(&self) -> usize {
106        self.nodes.len()
107    }
108
109    /// Get the number of edges.
110    pub fn edge_count(&self) -> usize {
111        self.edges.len()
112    }
113
114    /// Multi-hop expansion: given a set of node IDs, expand N hops to find related nodes.
115    pub fn expand(
116        &self,
117        start_ids: &[String],
118        max_hops: usize,
119    ) -> Result<Vec<GraphNode>, CodememError> {
120        let mut visited = std::collections::HashSet::new();
121        let mut result = Vec::new();
122
123        for start_id in start_ids {
124            let nodes = self.bfs(start_id, max_hops)?;
125            for node in nodes {
126                if visited.insert(node.id.clone()) {
127                    result.push(node);
128                }
129            }
130        }
131
132        Ok(result)
133    }
134
135    /// Get neighbors of a node (1-hop).
136    pub fn neighbors(&self, node_id: &str) -> Result<Vec<GraphNode>, CodememError> {
137        let idx = self
138            .id_to_index
139            .get(node_id)
140            .ok_or_else(|| CodememError::NotFound(format!("Node {node_id}")))?;
141
142        let mut result = Vec::new();
143        for neighbor_idx in self.graph.neighbors(*idx) {
144            if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
145                if let Some(node) = self.nodes.get(neighbor_id) {
146                    result.push(node.clone());
147                }
148            }
149        }
150
151        Ok(result)
152    }
153
154    /// Return groups of connected node IDs.
155    ///
156    /// Treats the directed graph as undirected: two nodes are in the same
157    /// component if there is a path between them in either direction.
158    /// Each inner `Vec<String>` is one connected component.
159    pub fn connected_components(&self) -> Vec<Vec<String>> {
160        let mut visited: HashSet<NodeIndex> = HashSet::new();
161        let mut components: Vec<Vec<String>> = Vec::new();
162
163        for &start_idx in self.id_to_index.values() {
164            if visited.contains(&start_idx) {
165                continue;
166            }
167
168            // BFS treating edges as undirected
169            let mut component: Vec<String> = Vec::new();
170            let mut queue: VecDeque<NodeIndex> = VecDeque::new();
171            queue.push_back(start_idx);
172            visited.insert(start_idx);
173
174            while let Some(current) = queue.pop_front() {
175                if let Some(node_id) = self.graph.node_weight(current) {
176                    component.push(node_id.clone());
177                }
178
179                // Follow outgoing edges
180                for neighbor in self.graph.neighbors_directed(current, Direction::Outgoing) {
181                    if visited.insert(neighbor) {
182                        queue.push_back(neighbor);
183                    }
184                }
185
186                // Follow incoming edges (treat as undirected)
187                for neighbor in self.graph.neighbors_directed(current, Direction::Incoming) {
188                    if visited.insert(neighbor) {
189                        queue.push_back(neighbor);
190                    }
191                }
192            }
193
194            component.sort();
195            components.push(component);
196        }
197
198        components.sort();
199        components
200    }
201
202    /// Compute degree centrality for every node and update their `centrality` field.
203    ///
204    /// Degree centrality for node *v* is defined as:
205    ///   `(in_degree(v) + out_degree(v)) / (N - 1)`
206    /// where *N* is the total number of nodes.  When N <= 1, centrality is 0.
207    pub fn compute_centrality(&mut self) {
208        let n = self.nodes.len();
209        if n <= 1 {
210            for node in self.nodes.values_mut() {
211                node.centrality = 0.0;
212            }
213            return;
214        }
215
216        let denominator = (n - 1) as f64;
217
218        // Pre-compute centrality values by node ID.
219        let centrality_map: HashMap<String, f64> = self
220            .id_to_index
221            .iter()
222            .map(|(id, &idx)| {
223                let in_deg = self
224                    .graph
225                    .neighbors_directed(idx, Direction::Incoming)
226                    .count();
227                let out_deg = self
228                    .graph
229                    .neighbors_directed(idx, Direction::Outgoing)
230                    .count();
231                let centrality = (in_deg + out_deg) as f64 / denominator;
232                (id.clone(), centrality)
233            })
234            .collect();
235
236        // Apply centrality values to the stored nodes.
237        for (id, centrality) in &centrality_map {
238            if let Some(node) = self.nodes.get_mut(id) {
239                node.centrality = *centrality;
240            }
241        }
242    }
243
244    /// Return all nodes currently in the graph.
245    pub fn get_all_nodes(&self) -> Vec<GraphNode> {
246        self.nodes.values().cloned().collect()
247    }
248
249    /// Return a reference to a node without cloning. Returns `None` if not found.
250    pub fn get_node_ref(&self, id: &str) -> Option<&GraphNode> {
251        self.nodes.get(id)
252    }
253
254    /// Return references to edges incident on a node without cloning.
255    ///
256    /// This is the zero-copy variant of [`GraphBackend::get_edges()`] — same
257    /// lookup logic via `edge_adj`, but returns `&Edge` instead of owned `Edge`.
258    pub fn get_edges_ref(&self, node_id: &str) -> Vec<&Edge> {
259        self.edge_adj
260            .get(node_id)
261            .map(|edge_ids| {
262                edge_ids
263                    .iter()
264                    .filter_map(|eid| self.edges.get(eid))
265                    .collect()
266            })
267            .unwrap_or_default()
268    }
269
270    /// Recompute and cache PageRank and betweenness centrality scores.
271    ///
272    /// This should be called after loading the graph (e.g., on server start)
273    /// and periodically when the graph changes significantly.
274    pub fn recompute_centrality(&mut self) {
275        self.recompute_centrality_with_options(true);
276    }
277
278    /// Recompute centrality caches with control over which algorithms run.
279    ///
280    /// PageRank is always computed. Betweenness centrality is only computed
281    /// when `include_betweenness` is true, since it is O(V * E) and can be
282    /// expensive on large graphs.
283    pub fn recompute_centrality_with_options(&mut self, include_betweenness: bool) {
284        self.cached_pagerank = self.pagerank(0.85, 100, 1e-6);
285        if include_betweenness {
286            self.cached_betweenness = self.betweenness_centrality();
287        } else {
288            // L1: Clear stale betweenness cache so ensure_betweenness_computed()
289            // knows it needs to recompute when lazily invoked.
290            self.cached_betweenness.clear();
291        }
292    }
293
294    /// Lazily ensure betweenness centrality has been computed.
295    ///
296    /// If `cached_betweenness` is empty (e.g., after `recompute_centrality_with_options(false)`),
297    /// this method computes and caches betweenness centrality on demand. If the
298    /// cache is already populated, this is a no-op.
299    pub fn ensure_betweenness_computed(&mut self) {
300        if self.cached_betweenness.is_empty() && self.graph.node_count() > 0 {
301            self.cached_betweenness = self.betweenness_centrality();
302        }
303    }
304
305    /// Get the cached PageRank score for a node. Returns 0.0 if not found.
306    pub fn get_pagerank(&self, node_id: &str) -> f64 {
307        self.cached_pagerank.get(node_id).copied().unwrap_or(0.0)
308    }
309
310    /// Get the cached betweenness centrality score for a node. Returns 0.0 if not found.
311    pub fn get_betweenness(&self, node_id: &str) -> f64 {
312        self.cached_betweenness.get(node_id).copied().unwrap_or(0.0)
313    }
314
315    /// Collect raw graph metrics for a memory node from both code-graph and
316    /// memory-graph neighbors.
317    ///
318    /// Code neighbors (`sym:`, `file:`, `chunk:`, `pkg:`) contribute centrality
319    /// data (PageRank, betweenness). Memory neighbors (UUID-based) contribute
320    /// connectivity data (count + edge weights). A function with many linked
321    /// memories is more important; a memory with many linked memories has richer
322    /// context.
323    ///
324    /// Returns `None` if the memory node is not in the graph or has no neighbors.
325    pub fn raw_graph_metrics_for_memory(&self, memory_id: &str) -> Option<RawGraphMetrics> {
326        let idx = *self.id_to_index.get(memory_id)?;
327
328        let mut max_pagerank = 0.0_f64;
329        let mut max_betweenness = 0.0_f64;
330        let mut code_neighbor_count = 0_usize;
331        let mut total_edge_weight = 0.0_f64;
332        let mut memory_neighbor_count = 0_usize;
333        let mut memory_edge_weight = 0.0_f64;
334
335        // Iterate both outgoing and incoming neighbors
336        for direction in &[Direction::Outgoing, Direction::Incoming] {
337            for neighbor_idx in self.graph.neighbors_directed(idx, *direction) {
338                if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
339                    let is_code_node = self
340                        .nodes
341                        .get(neighbor_id.as_str())
342                        .map(|n| n.kind != NodeKind::Memory)
343                        .unwrap_or(false);
344
345                    // Collect edge weight from our edge adjacency index
346                    let mut edge_w = 0.0_f64;
347                    if let Some(edge_ids) = self.edge_adj.get(memory_id) {
348                        for eid in edge_ids {
349                            if let Some(edge) = self.edges.get(eid) {
350                                if (edge.src == memory_id && edge.dst == *neighbor_id)
351                                    || (edge.dst == memory_id && edge.src == *neighbor_id)
352                                {
353                                    edge_w = edge.weight;
354                                    break;
355                                }
356                            }
357                        }
358                    }
359
360                    if is_code_node {
361                        code_neighbor_count += 1;
362                        total_edge_weight += edge_w;
363
364                        let pr = self
365                            .cached_pagerank
366                            .get(neighbor_id)
367                            .copied()
368                            .unwrap_or(0.0);
369                        let bt = self
370                            .cached_betweenness
371                            .get(neighbor_id)
372                            .copied()
373                            .unwrap_or(0.0);
374                        max_pagerank = max_pagerank.max(pr);
375                        max_betweenness = max_betweenness.max(bt);
376                    } else {
377                        memory_neighbor_count += 1;
378                        memory_edge_weight += edge_w;
379                    }
380                }
381            }
382        }
383
384        if code_neighbor_count == 0 && memory_neighbor_count == 0 {
385            return None;
386        }
387
388        Some(RawGraphMetrics {
389            max_pagerank,
390            max_betweenness,
391            code_neighbor_count,
392            total_edge_weight,
393            memory_neighbor_count,
394            memory_edge_weight,
395        })
396    }
397
398    /// Get the maximum degree (in + out) across all nodes in the graph.
399    /// Returns 1.0 if the graph has fewer than 2 nodes to avoid division by zero.
400    #[cfg(test)]
401    pub fn max_degree(&self) -> f64 {
402        if self.nodes.len() <= 1 {
403            return 1.0;
404        }
405        self.id_to_index
406            .values()
407            .map(|&idx| {
408                let in_deg = self
409                    .graph
410                    .neighbors_directed(idx, Direction::Incoming)
411                    .count();
412                let out_deg = self
413                    .graph
414                    .neighbors_directed(idx, Direction::Outgoing)
415                    .count();
416                (in_deg + out_deg) as f64
417            })
418            .fold(1.0f64, f64::max)
419    }
420}
421
422#[cfg(test)]
423#[path = "../tests/graph_tests.rs"]
424mod tests;