Skip to main content

mentedb_graph/
csr.rs

1//! Compressed Sparse Row/Column graph storage with delta log for incremental updates.
2
3use ahash::HashMap;
4use mentedb_core::edge::{EdgeType, MemoryEdge};
5use mentedb_core::error::{MenteError, MenteResult};
6use mentedb_core::types::{MemoryId, Timestamp};
7use serde::{Deserialize, Serialize};
8
9/// Compact edge data stored in CSR/CSC arrays.
10#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
11pub struct StoredEdge {
12    /// The relationship type.
13    pub edge_type: EdgeType,
14    /// Edge weight (0.0 to 1.0).
15    pub weight: f32,
16    /// When this edge was created.
17    pub created_at: Timestamp,
18}
19
20impl StoredEdge {
21    /// Converts a [`MemoryEdge`] into a compact stored representation.
22    pub fn from_memory_edge(edge: &MemoryEdge) -> Self {
23        Self {
24            edge_type: edge.edge_type,
25            weight: edge.weight,
26            created_at: edge.created_at,
27        }
28    }
29}
30
31/// A pending edge in the delta log before compaction into CSR.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33struct DeltaEdge {
34    source_idx: u32,
35    target_idx: u32,
36    data: StoredEdge,
37}
38
39/// Compressed Sparse Row storage for one direction (outgoing or incoming).
40#[derive(Debug, Clone, Default, Serialize, Deserialize)]
41struct CompressedStorage {
42    /// Length = num_nodes + 1. row_offsets[i]..row_offsets[i+1] gives the range in col_indices/edge_data.
43    row_offsets: Vec<u32>,
44    /// Column indices (target node for CSR, source node for CSC).
45    col_indices: Vec<u32>,
46    /// Edge metadata parallel to col_indices.
47    edge_data: Vec<StoredEdge>,
48}
49
50impl CompressedStorage {
51    #[allow(dead_code)]
52    fn new(num_nodes: usize) -> Self {
53        Self {
54            row_offsets: vec![0; num_nodes + 1],
55            col_indices: Vec::new(),
56            edge_data: Vec::new(),
57        }
58    }
59
60    /// Get neighbors and edge data for a given row index.
61    fn neighbors(&self, row: u32) -> &[u32] {
62        let row = row as usize;
63        if row + 1 >= self.row_offsets.len() {
64            return &[];
65        }
66        let start = self.row_offsets[row] as usize;
67        let end = self.row_offsets[row + 1] as usize;
68        &self.col_indices[start..end]
69    }
70
71    fn edge_data_for(&self, row: u32) -> &[StoredEdge] {
72        let row = row as usize;
73        if row + 1 >= self.row_offsets.len() {
74            return &[];
75        }
76        let start = self.row_offsets[row] as usize;
77        let end = self.row_offsets[row + 1] as usize;
78        &self.edge_data[start..end]
79    }
80}
81
82/// Bidirectional graph with CSR (outgoing) and CSC (incoming) plus a delta log.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct CsrGraph {
85    /// Maps MemoryId -> internal u32 index.
86    id_to_idx: HashMap<MemoryId, u32>,
87    /// Maps internal u32 index -> MemoryId.
88    idx_to_id: Vec<MemoryId>,
89
90    /// CSR for outgoing edges.
91    csr: CompressedStorage,
92    /// CSC for incoming edges.
93    csc: CompressedStorage,
94
95    /// Recent edges not yet merged into the compressed storage.
96    delta_edges: Vec<DeltaEdge>,
97    /// Edges marked for removal (source_idx, target_idx).
98    removed_edges: Vec<(u32, u32)>,
99}
100
101impl CsrGraph {
102    /// Creates a new empty CSR graph.
103    pub fn new() -> Self {
104        Self {
105            id_to_idx: HashMap::default(),
106            idx_to_id: Vec::new(),
107            csr: CompressedStorage::default(),
108            csc: CompressedStorage::default(),
109            delta_edges: Vec::new(),
110            removed_edges: Vec::new(),
111        }
112    }
113
114    /// Register a node. Returns its internal index.
115    pub fn add_node(&mut self, id: MemoryId) -> u32 {
116        if let Some(&idx) = self.id_to_idx.get(&id) {
117            return idx;
118        }
119        let idx = self.idx_to_id.len() as u32;
120        self.id_to_idx.insert(id, idx);
121        self.idx_to_id.push(id);
122        idx
123    }
124
125    /// Remove a node and all its edges.
126    pub fn remove_node(&mut self, id: MemoryId) {
127        let Some(&idx) = self.id_to_idx.get(&id) else {
128            return;
129        };
130        // Mark all outgoing and incoming edges for removal
131        for &neighbor in self.csr.neighbors(idx) {
132            self.removed_edges.push((idx, neighbor));
133        }
134        for &neighbor in self.csc.neighbors(idx) {
135            self.removed_edges.push((neighbor, idx));
136        }
137        // Also remove from delta
138        self.delta_edges
139            .retain(|e| e.source_idx != idx && e.target_idx != idx);
140    }
141
142    /// Add an edge to the delta log.
143    pub fn add_edge(&mut self, edge: &MemoryEdge) {
144        let source_idx = self.add_node(edge.source);
145        let target_idx = self.add_node(edge.target);
146        self.delta_edges.push(DeltaEdge {
147            source_idx,
148            target_idx,
149            data: StoredEdge::from_memory_edge(edge),
150        });
151    }
152
153    /// Mark an edge for removal.
154    pub fn remove_edge(&mut self, source: MemoryId, target: MemoryId) {
155        let (Some(&src_idx), Some(&tgt_idx)) =
156            (self.id_to_idx.get(&source), self.id_to_idx.get(&target))
157        else {
158            return;
159        };
160        self.removed_edges.push((src_idx, tgt_idx));
161        self.delta_edges
162            .retain(|e| !(e.source_idx == src_idx && e.target_idx == tgt_idx));
163    }
164
165    /// Get all outgoing edges from a node (CSR + delta, minus removed).
166    pub fn outgoing(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
167        let Some(&idx) = self.id_to_idx.get(&id) else {
168            return Vec::new();
169        };
170        self.outgoing_by_idx(idx)
171    }
172
173    pub(crate) fn outgoing_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
174        let mut results = Vec::new();
175
176        // From compressed storage
177        let neighbors = self.csr.neighbors(idx);
178        let edges = self.csr.edge_data_for(idx);
179        for (i, &neighbor) in neighbors.iter().enumerate() {
180            if !self.is_removed(idx, neighbor)
181                && let Some(&id) = self.idx_to_id.get(neighbor as usize)
182            {
183                results.push((id, edges[i]));
184            }
185        }
186
187        // From delta
188        for delta in &self.delta_edges {
189            if delta.source_idx == idx
190                && let Some(&id) = self.idx_to_id.get(delta.target_idx as usize)
191            {
192                results.push((id, delta.data));
193            }
194        }
195
196        results
197    }
198
199    /// Get all incoming edges to a node (CSC + delta, minus removed).
200    pub fn incoming(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
201        let Some(&idx) = self.id_to_idx.get(&id) else {
202            return Vec::new();
203        };
204        self.incoming_by_idx(idx)
205    }
206
207    pub(crate) fn incoming_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
208        let mut results = Vec::new();
209
210        // From compressed storage (CSC)
211        let neighbors = self.csc.neighbors(idx);
212        let edges = self.csc.edge_data_for(idx);
213        for (i, &neighbor) in neighbors.iter().enumerate() {
214            if !self.is_removed(neighbor, idx)
215                && let Some(&id) = self.idx_to_id.get(neighbor as usize)
216            {
217                results.push((id, edges[i]));
218            }
219        }
220
221        // From delta
222        for delta in &self.delta_edges {
223            if delta.target_idx == idx
224                && let Some(&id) = self.idx_to_id.get(delta.source_idx as usize)
225            {
226                results.push((id, delta.data));
227            }
228        }
229
230        results
231    }
232
233    /// Check if a node exists in the graph.
234    pub fn contains_node(&self, id: MemoryId) -> bool {
235        self.id_to_idx.contains_key(&id)
236    }
237
238    /// Number of registered nodes.
239    pub fn node_count(&self) -> usize {
240        self.idx_to_id.len()
241    }
242
243    /// Resolve a MemoryId to its internal index.
244    pub(crate) fn get_idx(&self, id: MemoryId) -> Option<u32> {
245        self.id_to_idx.get(&id).copied()
246    }
247
248    /// Resolve an internal index to its MemoryId.
249    #[allow(dead_code)]
250    pub(crate) fn get_id(&self, idx: u32) -> Option<MemoryId> {
251        self.idx_to_id.get(idx as usize).copied()
252    }
253
254    /// All registered node IDs.
255    pub fn node_ids(&self) -> &[MemoryId] {
256        &self.idx_to_id
257    }
258
259    fn is_removed(&self, source: u32, target: u32) -> bool {
260        self.removed_edges
261            .iter()
262            .any(|&(s, t)| s == source && t == target)
263    }
264
265    /// Merge all delta edges and removals into the compressed CSR/CSC storage.
266    pub fn compact(&mut self) {
267        let num_nodes = self.idx_to_id.len();
268
269        // Collect all edges: existing (minus removed) + delta
270        let mut all_edges: Vec<(u32, u32, StoredEdge)> = Vec::new();
271
272        // Existing CSR edges
273        for row in 0..num_nodes {
274            let row = row as u32;
275            let neighbors = self.csr.neighbors(row);
276            let edges = self.csr.edge_data_for(row);
277            for (i, &col) in neighbors.iter().enumerate() {
278                if !self.is_removed(row, col) {
279                    all_edges.push((row, col, edges[i]));
280                }
281            }
282        }
283
284        // Delta edges
285        for delta in &self.delta_edges {
286            all_edges.push((delta.source_idx, delta.target_idx, delta.data));
287        }
288
289        // Build CSR (sorted by source)
290        self.csr = Self::build_compressed(&all_edges, num_nodes, false);
291
292        // Build CSC (sorted by target)
293        self.csc = Self::build_compressed(&all_edges, num_nodes, true);
294
295        self.delta_edges.clear();
296        self.removed_edges.clear();
297    }
298
299    fn build_compressed(
300        edges: &[(u32, u32, StoredEdge)],
301        num_nodes: usize,
302        transpose: bool,
303    ) -> CompressedStorage {
304        // Count edges per row
305        let mut counts = vec![0u32; num_nodes];
306        for &(src, tgt, _) in edges {
307            let row = if transpose { tgt } else { src };
308            if (row as usize) < num_nodes {
309                counts[row as usize] += 1;
310            }
311        }
312
313        // Build offsets via prefix sum
314        let mut row_offsets = vec![0u32; num_nodes + 1];
315        for i in 0..num_nodes {
316            row_offsets[i + 1] = row_offsets[i] + counts[i];
317        }
318
319        let total = row_offsets[num_nodes] as usize;
320        let mut col_indices = vec![0u32; total];
321        let mut edge_data = vec![
322            StoredEdge {
323                edge_type: EdgeType::Related,
324                weight: 0.0,
325                created_at: 0,
326            };
327            total
328        ];
329
330        // Fill using write cursors
331        let mut cursors = row_offsets[..num_nodes].to_vec();
332        for &(src, tgt, data) in edges {
333            let (row, col) = if transpose { (tgt, src) } else { (src, tgt) };
334            if (row as usize) < num_nodes {
335                let pos = cursors[row as usize] as usize;
336                col_indices[pos] = col;
337                edge_data[pos] = data;
338                cursors[row as usize] += 1;
339            }
340        }
341
342        CompressedStorage {
343            row_offsets,
344            col_indices,
345            edge_data,
346        }
347    }
348    /// Save the graph to a JSON file.
349    pub fn save(&self, path: &std::path::Path) -> MenteResult<()> {
350        let data =
351            serde_json::to_vec(self).map_err(|e| MenteError::Serialization(e.to_string()))?;
352        std::fs::write(path, data)?;
353        Ok(())
354    }
355
356    /// Load the graph from a JSON file.
357    pub fn load(path: &std::path::Path) -> MenteResult<Self> {
358        let data = std::fs::read(path)?;
359        let graph: Self =
360            serde_json::from_slice(&data).map_err(|e| MenteError::Serialization(e.to_string()))?;
361        Ok(graph)
362    }
363}
364
365impl Default for CsrGraph {
366    fn default() -> Self {
367        Self::new()
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
376        MemoryEdge {
377            source: src,
378            target: tgt,
379            edge_type: etype,
380            weight: 0.8,
381            created_at: 1000,
382        }
383    }
384
385    #[test]
386    fn test_add_node_idempotent() {
387        let mut g = CsrGraph::new();
388        let id = MemoryId::new();
389        let idx1 = g.add_node(id);
390        let idx2 = g.add_node(id);
391        assert_eq!(idx1, idx2);
392        assert_eq!(g.node_count(), 1);
393    }
394
395    #[test]
396    fn test_add_and_query_edges() {
397        let mut g = CsrGraph::new();
398        let a = MemoryId::new();
399        let b = MemoryId::new();
400        let c = MemoryId::new();
401
402        g.add_edge(&make_edge(a, b, EdgeType::Caused));
403        g.add_edge(&make_edge(a, c, EdgeType::Related));
404
405        let out = g.outgoing(a);
406        assert_eq!(out.len(), 2);
407
408        let inc_b = g.incoming(b);
409        assert_eq!(inc_b.len(), 1);
410        assert_eq!(inc_b[0].0, a);
411    }
412
413    #[test]
414    fn test_remove_edge() {
415        let mut g = CsrGraph::new();
416        let a = MemoryId::new();
417        let b = MemoryId::new();
418
419        g.add_edge(&make_edge(a, b, EdgeType::Caused));
420        assert_eq!(g.outgoing(a).len(), 1);
421
422        g.remove_edge(a, b);
423        assert_eq!(g.outgoing(a).len(), 0);
424    }
425
426    #[test]
427    fn test_compact() {
428        let mut g = CsrGraph::new();
429        let a = MemoryId::new();
430        let b = MemoryId::new();
431        let c = MemoryId::new();
432
433        g.add_edge(&make_edge(a, b, EdgeType::Caused));
434        g.add_edge(&make_edge(b, c, EdgeType::Before));
435        g.compact();
436
437        let out_a = g.outgoing(a);
438        assert_eq!(out_a.len(), 1);
439        assert_eq!(out_a[0].0, b);
440
441        let inc_c = g.incoming(c);
442        assert_eq!(inc_c.len(), 1);
443        assert_eq!(inc_c[0].0, b);
444    }
445
446    #[test]
447    fn test_compact_with_removals() {
448        let mut g = CsrGraph::new();
449        let a = MemoryId::new();
450        let b = MemoryId::new();
451        let c = MemoryId::new();
452
453        g.add_edge(&make_edge(a, b, EdgeType::Caused));
454        g.add_edge(&make_edge(a, c, EdgeType::Related));
455        g.compact();
456
457        g.remove_edge(a, b);
458        g.compact();
459
460        let out = g.outgoing(a);
461        assert_eq!(out.len(), 1);
462        assert_eq!(out[0].0, c);
463    }
464}