Skip to main content

mentedb_graph/
csr.rs

1//! Compressed Sparse Row/Column graph storage with delta log for incremental updates.
2
3use ahash::HashMap;
4use mentedb_core::edge::{EdgeType, MemoryEdge};
5use mentedb_core::error::{MenteError, MenteResult};
6use mentedb_core::types::{MemoryId, Timestamp};
7use serde::{Deserialize, Serialize};
8
9/// Compact edge data stored in CSR/CSC arrays.
10#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
11pub struct StoredEdge {
12    pub edge_type: EdgeType,
13    pub weight: f32,
14    pub created_at: Timestamp,
15}
16
17impl StoredEdge {
18    pub fn from_memory_edge(edge: &MemoryEdge) -> Self {
19        Self {
20            edge_type: edge.edge_type,
21            weight: edge.weight,
22            created_at: edge.created_at,
23        }
24    }
25}
26
27/// A pending edge in the delta log before compaction into CSR.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29struct DeltaEdge {
30    source_idx: u32,
31    target_idx: u32,
32    data: StoredEdge,
33}
34
35/// Compressed Sparse Row storage for one direction (outgoing or incoming).
36#[derive(Debug, Clone, Default, Serialize, Deserialize)]
37struct CompressedStorage {
38    /// Length = num_nodes + 1. row_offsets[i]..row_offsets[i+1] gives the range in col_indices/edge_data.
39    row_offsets: Vec<u32>,
40    /// Column indices (target node for CSR, source node for CSC).
41    col_indices: Vec<u32>,
42    /// Edge metadata parallel to col_indices.
43    edge_data: Vec<StoredEdge>,
44}
45
46impl CompressedStorage {
47    #[allow(dead_code)]
48    fn new(num_nodes: usize) -> Self {
49        Self {
50            row_offsets: vec![0; num_nodes + 1],
51            col_indices: Vec::new(),
52            edge_data: Vec::new(),
53        }
54    }
55
56    /// Get neighbors and edge data for a given row index.
57    fn neighbors(&self, row: u32) -> &[u32] {
58        let row = row as usize;
59        if row + 1 >= self.row_offsets.len() {
60            return &[];
61        }
62        let start = self.row_offsets[row] as usize;
63        let end = self.row_offsets[row + 1] as usize;
64        &self.col_indices[start..end]
65    }
66
67    fn edge_data_for(&self, row: u32) -> &[StoredEdge] {
68        let row = row as usize;
69        if row + 1 >= self.row_offsets.len() {
70            return &[];
71        }
72        let start = self.row_offsets[row] as usize;
73        let end = self.row_offsets[row + 1] as usize;
74        &self.edge_data[start..end]
75    }
76}
77
78/// Bidirectional graph with CSR (outgoing) and CSC (incoming) plus a delta log.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct CsrGraph {
81    /// Maps MemoryId -> internal u32 index.
82    id_to_idx: HashMap<MemoryId, u32>,
83    /// Maps internal u32 index -> MemoryId.
84    idx_to_id: Vec<MemoryId>,
85
86    /// CSR for outgoing edges.
87    csr: CompressedStorage,
88    /// CSC for incoming edges.
89    csc: CompressedStorage,
90
91    /// Recent edges not yet merged into the compressed storage.
92    delta_edges: Vec<DeltaEdge>,
93    /// Edges marked for removal (source_idx, target_idx).
94    removed_edges: Vec<(u32, u32)>,
95}
96
97impl CsrGraph {
98    pub fn new() -> Self {
99        Self {
100            id_to_idx: HashMap::default(),
101            idx_to_id: Vec::new(),
102            csr: CompressedStorage::default(),
103            csc: CompressedStorage::default(),
104            delta_edges: Vec::new(),
105            removed_edges: Vec::new(),
106        }
107    }
108
109    /// Register a node. Returns its internal index.
110    pub fn add_node(&mut self, id: MemoryId) -> u32 {
111        if let Some(&idx) = self.id_to_idx.get(&id) {
112            return idx;
113        }
114        let idx = self.idx_to_id.len() as u32;
115        self.id_to_idx.insert(id, idx);
116        self.idx_to_id.push(id);
117        idx
118    }
119
120    /// Remove a node and all its edges.
121    pub fn remove_node(&mut self, id: MemoryId) {
122        let Some(&idx) = self.id_to_idx.get(&id) else {
123            return;
124        };
125        // Mark all outgoing and incoming edges for removal
126        for &neighbor in self.csr.neighbors(idx) {
127            self.removed_edges.push((idx, neighbor));
128        }
129        for &neighbor in self.csc.neighbors(idx) {
130            self.removed_edges.push((neighbor, idx));
131        }
132        // Also remove from delta
133        self.delta_edges
134            .retain(|e| e.source_idx != idx && e.target_idx != idx);
135    }
136
137    /// Add an edge to the delta log.
138    pub fn add_edge(&mut self, edge: &MemoryEdge) {
139        let source_idx = self.add_node(edge.source);
140        let target_idx = self.add_node(edge.target);
141        self.delta_edges.push(DeltaEdge {
142            source_idx,
143            target_idx,
144            data: StoredEdge::from_memory_edge(edge),
145        });
146    }
147
148    /// Mark an edge for removal.
149    pub fn remove_edge(&mut self, source: MemoryId, target: MemoryId) {
150        let (Some(&src_idx), Some(&tgt_idx)) =
151            (self.id_to_idx.get(&source), self.id_to_idx.get(&target))
152        else {
153            return;
154        };
155        self.removed_edges.push((src_idx, tgt_idx));
156        self.delta_edges
157            .retain(|e| !(e.source_idx == src_idx && e.target_idx == tgt_idx));
158    }
159
160    /// Get all outgoing edges from a node (CSR + delta, minus removed).
161    pub fn outgoing(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
162        let Some(&idx) = self.id_to_idx.get(&id) else {
163            return Vec::new();
164        };
165        self.outgoing_by_idx(idx)
166    }
167
168    pub(crate) fn outgoing_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
169        let mut results = Vec::new();
170
171        // From compressed storage
172        let neighbors = self.csr.neighbors(idx);
173        let edges = self.csr.edge_data_for(idx);
174        for (i, &neighbor) in neighbors.iter().enumerate() {
175            if !self.is_removed(idx, neighbor)
176                && let Some(&id) = self.idx_to_id.get(neighbor as usize)
177            {
178                results.push((id, edges[i]));
179            }
180        }
181
182        // From delta
183        for delta in &self.delta_edges {
184            if delta.source_idx == idx
185                && let Some(&id) = self.idx_to_id.get(delta.target_idx as usize)
186            {
187                results.push((id, delta.data));
188            }
189        }
190
191        results
192    }
193
194    /// Get all incoming edges to a node (CSC + delta, minus removed).
195    pub fn incoming(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
196        let Some(&idx) = self.id_to_idx.get(&id) else {
197            return Vec::new();
198        };
199        self.incoming_by_idx(idx)
200    }
201
202    pub(crate) fn incoming_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
203        let mut results = Vec::new();
204
205        // From compressed storage (CSC)
206        let neighbors = self.csc.neighbors(idx);
207        let edges = self.csc.edge_data_for(idx);
208        for (i, &neighbor) in neighbors.iter().enumerate() {
209            if !self.is_removed(neighbor, idx)
210                && let Some(&id) = self.idx_to_id.get(neighbor as usize)
211            {
212                results.push((id, edges[i]));
213            }
214        }
215
216        // From delta
217        for delta in &self.delta_edges {
218            if delta.target_idx == idx
219                && let Some(&id) = self.idx_to_id.get(delta.source_idx as usize)
220            {
221                results.push((id, delta.data));
222            }
223        }
224
225        results
226    }
227
228    /// Check if a node exists in the graph.
229    pub fn contains_node(&self, id: MemoryId) -> bool {
230        self.id_to_idx.contains_key(&id)
231    }
232
233    /// Number of registered nodes.
234    pub fn node_count(&self) -> usize {
235        self.idx_to_id.len()
236    }
237
238    /// Resolve a MemoryId to its internal index.
239    pub(crate) fn get_idx(&self, id: MemoryId) -> Option<u32> {
240        self.id_to_idx.get(&id).copied()
241    }
242
243    /// Resolve an internal index to its MemoryId.
244    #[allow(dead_code)]
245    pub(crate) fn get_id(&self, idx: u32) -> Option<MemoryId> {
246        self.idx_to_id.get(idx as usize).copied()
247    }
248
249    /// All registered node IDs.
250    pub fn node_ids(&self) -> &[MemoryId] {
251        &self.idx_to_id
252    }
253
254    fn is_removed(&self, source: u32, target: u32) -> bool {
255        self.removed_edges
256            .iter()
257            .any(|&(s, t)| s == source && t == target)
258    }
259
260    /// Merge all delta edges and removals into the compressed CSR/CSC storage.
261    pub fn compact(&mut self) {
262        let num_nodes = self.idx_to_id.len();
263
264        // Collect all edges: existing (minus removed) + delta
265        let mut all_edges: Vec<(u32, u32, StoredEdge)> = Vec::new();
266
267        // Existing CSR edges
268        for row in 0..num_nodes {
269            let row = row as u32;
270            let neighbors = self.csr.neighbors(row);
271            let edges = self.csr.edge_data_for(row);
272            for (i, &col) in neighbors.iter().enumerate() {
273                if !self.is_removed(row, col) {
274                    all_edges.push((row, col, edges[i]));
275                }
276            }
277        }
278
279        // Delta edges
280        for delta in &self.delta_edges {
281            all_edges.push((delta.source_idx, delta.target_idx, delta.data));
282        }
283
284        // Build CSR (sorted by source)
285        self.csr = Self::build_compressed(&all_edges, num_nodes, false);
286
287        // Build CSC (sorted by target)
288        self.csc = Self::build_compressed(&all_edges, num_nodes, true);
289
290        self.delta_edges.clear();
291        self.removed_edges.clear();
292    }
293
294    fn build_compressed(
295        edges: &[(u32, u32, StoredEdge)],
296        num_nodes: usize,
297        transpose: bool,
298    ) -> CompressedStorage {
299        // Count edges per row
300        let mut counts = vec![0u32; num_nodes];
301        for &(src, tgt, _) in edges {
302            let row = if transpose { tgt } else { src };
303            if (row as usize) < num_nodes {
304                counts[row as usize] += 1;
305            }
306        }
307
308        // Build offsets via prefix sum
309        let mut row_offsets = vec![0u32; num_nodes + 1];
310        for i in 0..num_nodes {
311            row_offsets[i + 1] = row_offsets[i] + counts[i];
312        }
313
314        let total = row_offsets[num_nodes] as usize;
315        let mut col_indices = vec![0u32; total];
316        let mut edge_data = vec![
317            StoredEdge {
318                edge_type: EdgeType::Related,
319                weight: 0.0,
320                created_at: 0,
321            };
322            total
323        ];
324
325        // Fill using write cursors
326        let mut cursors = row_offsets[..num_nodes].to_vec();
327        for &(src, tgt, data) in edges {
328            let (row, col) = if transpose { (tgt, src) } else { (src, tgt) };
329            if (row as usize) < num_nodes {
330                let pos = cursors[row as usize] as usize;
331                col_indices[pos] = col;
332                edge_data[pos] = data;
333                cursors[row as usize] += 1;
334            }
335        }
336
337        CompressedStorage {
338            row_offsets,
339            col_indices,
340            edge_data,
341        }
342    }
343    /// Save the graph to a JSON file.
344    pub fn save(&self, path: &std::path::Path) -> MenteResult<()> {
345        let data =
346            serde_json::to_vec(self).map_err(|e| MenteError::Serialization(e.to_string()))?;
347        std::fs::write(path, data)?;
348        Ok(())
349    }
350
351    /// Load the graph from a JSON file.
352    pub fn load(path: &std::path::Path) -> MenteResult<Self> {
353        let data = std::fs::read(path)?;
354        let graph: Self =
355            serde_json::from_slice(&data).map_err(|e| MenteError::Serialization(e.to_string()))?;
356        Ok(graph)
357    }
358}
359
360impl Default for CsrGraph {
361    fn default() -> Self {
362        Self::new()
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use uuid::Uuid;
370
371    fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
372        MemoryEdge {
373            source: src,
374            target: tgt,
375            edge_type: etype,
376            weight: 0.8,
377            created_at: 1000,
378        }
379    }
380
381    #[test]
382    fn test_add_node_idempotent() {
383        let mut g = CsrGraph::new();
384        let id = Uuid::new_v4();
385        let idx1 = g.add_node(id);
386        let idx2 = g.add_node(id);
387        assert_eq!(idx1, idx2);
388        assert_eq!(g.node_count(), 1);
389    }
390
391    #[test]
392    fn test_add_and_query_edges() {
393        let mut g = CsrGraph::new();
394        let a = Uuid::new_v4();
395        let b = Uuid::new_v4();
396        let c = Uuid::new_v4();
397
398        g.add_edge(&make_edge(a, b, EdgeType::Caused));
399        g.add_edge(&make_edge(a, c, EdgeType::Related));
400
401        let out = g.outgoing(a);
402        assert_eq!(out.len(), 2);
403
404        let inc_b = g.incoming(b);
405        assert_eq!(inc_b.len(), 1);
406        assert_eq!(inc_b[0].0, a);
407    }
408
409    #[test]
410    fn test_remove_edge() {
411        let mut g = CsrGraph::new();
412        let a = Uuid::new_v4();
413        let b = Uuid::new_v4();
414
415        g.add_edge(&make_edge(a, b, EdgeType::Caused));
416        assert_eq!(g.outgoing(a).len(), 1);
417
418        g.remove_edge(a, b);
419        assert_eq!(g.outgoing(a).len(), 0);
420    }
421
422    #[test]
423    fn test_compact() {
424        let mut g = CsrGraph::new();
425        let a = Uuid::new_v4();
426        let b = Uuid::new_v4();
427        let c = Uuid::new_v4();
428
429        g.add_edge(&make_edge(a, b, EdgeType::Caused));
430        g.add_edge(&make_edge(b, c, EdgeType::Before));
431        g.compact();
432
433        let out_a = g.outgoing(a);
434        assert_eq!(out_a.len(), 1);
435        assert_eq!(out_a[0].0, b);
436
437        let inc_c = g.incoming(c);
438        assert_eq!(inc_c.len(), 1);
439        assert_eq!(inc_c[0].0, b);
440    }
441
442    #[test]
443    fn test_compact_with_removals() {
444        let mut g = CsrGraph::new();
445        let a = Uuid::new_v4();
446        let b = Uuid::new_v4();
447        let c = Uuid::new_v4();
448
449        g.add_edge(&make_edge(a, b, EdgeType::Caused));
450        g.add_edge(&make_edge(a, c, EdgeType::Related));
451        g.compact();
452
453        g.remove_edge(a, b);
454        g.compact();
455
456        let out = g.outgoing(a);
457        assert_eq!(out.len(), 1);
458        assert_eq!(out[0].0, c);
459    }
460}