Skip to main content

mentedb_graph/
csr.rs

1//! Compressed Sparse Row/Column graph storage with delta log for incremental updates.
2
3use ahash::HashMap;
4use mentedb_core::edge::{EdgeType, MemoryEdge};
5use mentedb_core::error::{MenteError, MenteResult};
6use mentedb_core::types::{MemoryId, Timestamp};
7use serde::{Deserialize, Serialize};
8
9/// Compact edge data stored in CSR/CSC arrays.
10#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
11pub struct StoredEdge {
12    /// The relationship type.
13    pub edge_type: EdgeType,
14    /// Edge weight (0.0 to 1.0).
15    pub weight: f32,
16    /// When this edge was created.
17    pub created_at: Timestamp,
18    /// When this relationship became valid. None = since creation.
19    #[serde(default, skip_serializing_if = "Option::is_none")]
20    pub valid_from: Option<Timestamp>,
21    /// When this relationship stopped being valid. None = still valid.
22    #[serde(default, skip_serializing_if = "Option::is_none")]
23    pub valid_until: Option<Timestamp>,
24}
25
26impl StoredEdge {
27    /// Converts a [`MemoryEdge`] into a compact stored representation.
28    pub fn from_memory_edge(edge: &MemoryEdge) -> Self {
29        Self {
30            edge_type: edge.edge_type,
31            weight: edge.weight,
32            created_at: edge.created_at,
33            valid_from: edge.valid_from,
34            valid_until: edge.valid_until,
35        }
36    }
37
38    /// Returns true if this edge is temporally valid at the given timestamp.
39    pub fn is_valid_at(&self, at: Timestamp) -> bool {
40        let from = self.valid_from.unwrap_or(0);
41        match self.valid_until {
42            Some(until) => at >= from && at < until,
43            None => at >= from,
44        }
45    }
46
47    /// Returns true if this edge has been invalidated.
48    pub fn is_invalidated(&self) -> bool {
49        self.valid_until.is_some()
50    }
51}
52
53/// A pending edge in the delta log before compaction into CSR.
54#[derive(Debug, Clone, Serialize, Deserialize)]
55struct DeltaEdge {
56    source_idx: u32,
57    target_idx: u32,
58    data: StoredEdge,
59}
60
61/// Compressed Sparse Row storage for one direction (outgoing or incoming).
62#[derive(Debug, Clone, Default, Serialize, Deserialize)]
63struct CompressedStorage {
64    /// Length = num_nodes + 1. row_offsets[i]..row_offsets[i+1] gives the range in col_indices/edge_data.
65    row_offsets: Vec<u32>,
66    /// Column indices (target node for CSR, source node for CSC).
67    col_indices: Vec<u32>,
68    /// Edge metadata parallel to col_indices.
69    edge_data: Vec<StoredEdge>,
70}
71
72impl CompressedStorage {
73    #[allow(dead_code)]
74    fn new(num_nodes: usize) -> Self {
75        Self {
76            row_offsets: vec![0; num_nodes + 1],
77            col_indices: Vec::new(),
78            edge_data: Vec::new(),
79        }
80    }
81
82    /// Get neighbors and edge data for a given row index.
83    fn neighbors(&self, row: u32) -> &[u32] {
84        let row = row as usize;
85        if row + 1 >= self.row_offsets.len() {
86            return &[];
87        }
88        let start = self.row_offsets[row] as usize;
89        let end = self.row_offsets[row + 1] as usize;
90        &self.col_indices[start..end]
91    }
92
93    fn edge_data_for(&self, row: u32) -> &[StoredEdge] {
94        let row = row as usize;
95        if row + 1 >= self.row_offsets.len() {
96            return &[];
97        }
98        let start = self.row_offsets[row] as usize;
99        let end = self.row_offsets[row + 1] as usize;
100        &self.edge_data[start..end]
101    }
102}
103
104/// Bidirectional graph with CSR (outgoing) and CSC (incoming) plus a delta log.
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct CsrGraph {
107    /// Maps MemoryId -> internal u32 index.
108    id_to_idx: HashMap<MemoryId, u32>,
109    /// Maps internal u32 index -> MemoryId.
110    idx_to_id: Vec<MemoryId>,
111
112    /// CSR for outgoing edges.
113    csr: CompressedStorage,
114    /// CSC for incoming edges.
115    csc: CompressedStorage,
116
117    /// Recent edges not yet merged into the compressed storage.
118    delta_edges: Vec<DeltaEdge>,
119    /// Edges marked for removal (source_idx, target_idx).
120    removed_edges: Vec<(u32, u32)>,
121}
122
123impl CsrGraph {
124    /// Creates a new empty CSR graph.
125    pub fn new() -> Self {
126        Self {
127            id_to_idx: HashMap::default(),
128            idx_to_id: Vec::new(),
129            csr: CompressedStorage::default(),
130            csc: CompressedStorage::default(),
131            delta_edges: Vec::new(),
132            removed_edges: Vec::new(),
133        }
134    }
135
136    /// Register a node. Returns its internal index.
137    pub fn add_node(&mut self, id: MemoryId) -> u32 {
138        if let Some(&idx) = self.id_to_idx.get(&id) {
139            return idx;
140        }
141        let idx = self.idx_to_id.len() as u32;
142        self.id_to_idx.insert(id, idx);
143        self.idx_to_id.push(id);
144        idx
145    }
146
147    /// Remove a node and all its edges.
148    pub fn remove_node(&mut self, id: MemoryId) {
149        let Some(&idx) = self.id_to_idx.get(&id) else {
150            return;
151        };
152        // Mark all outgoing and incoming edges for removal
153        for &neighbor in self.csr.neighbors(idx) {
154            self.removed_edges.push((idx, neighbor));
155        }
156        for &neighbor in self.csc.neighbors(idx) {
157            self.removed_edges.push((neighbor, idx));
158        }
159        // Also remove from delta
160        self.delta_edges
161            .retain(|e| e.source_idx != idx && e.target_idx != idx);
162    }
163
164    /// Add an edge to the delta log.
165    pub fn add_edge(&mut self, edge: &MemoryEdge) {
166        let source_idx = self.add_node(edge.source);
167        let target_idx = self.add_node(edge.target);
168        self.delta_edges.push(DeltaEdge {
169            source_idx,
170            target_idx,
171            data: StoredEdge::from_memory_edge(edge),
172        });
173    }
174
175    /// Mark an edge for removal.
176    pub fn remove_edge(&mut self, source: MemoryId, target: MemoryId) {
177        let (Some(&src_idx), Some(&tgt_idx)) =
178            (self.id_to_idx.get(&source), self.id_to_idx.get(&target))
179        else {
180            return;
181        };
182        self.removed_edges.push((src_idx, tgt_idx));
183        self.delta_edges
184            .retain(|e| !(e.source_idx == src_idx && e.target_idx == tgt_idx));
185    }
186
187    /// Get all outgoing edges from a node (CSR + delta, minus removed).
188    pub fn outgoing(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
189        let Some(&idx) = self.id_to_idx.get(&id) else {
190            return Vec::new();
191        };
192        self.outgoing_by_idx(idx)
193    }
194
195    /// Get outgoing edges that are temporally valid at the given timestamp.
196    pub fn outgoing_valid_at(&self, id: MemoryId, at: Timestamp) -> Vec<(MemoryId, StoredEdge)> {
197        self.outgoing(id)
198            .into_iter()
199            .filter(|(_, e)| e.is_valid_at(at))
200            .collect()
201    }
202
203    pub(crate) fn outgoing_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
204        let mut results = Vec::new();
205
206        // From compressed storage
207        let neighbors = self.csr.neighbors(idx);
208        let edges = self.csr.edge_data_for(idx);
209        for (i, &neighbor) in neighbors.iter().enumerate() {
210            if !self.is_removed(idx, neighbor)
211                && let Some(&id) = self.idx_to_id.get(neighbor as usize)
212            {
213                results.push((id, edges[i]));
214            }
215        }
216
217        // From delta
218        for delta in &self.delta_edges {
219            if delta.source_idx == idx
220                && let Some(&id) = self.idx_to_id.get(delta.target_idx as usize)
221            {
222                results.push((id, delta.data));
223            }
224        }
225
226        results
227    }
228
229    /// Get all incoming edges to a node (CSC + delta, minus removed).
230    pub fn incoming(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
231        let Some(&idx) = self.id_to_idx.get(&id) else {
232            return Vec::new();
233        };
234        self.incoming_by_idx(idx)
235    }
236
237    /// Get incoming edges that are temporally valid at the given timestamp.
238    pub fn incoming_valid_at(&self, id: MemoryId, at: Timestamp) -> Vec<(MemoryId, StoredEdge)> {
239        self.incoming(id)
240            .into_iter()
241            .filter(|(_, e)| e.is_valid_at(at))
242            .collect()
243    }
244
245    pub(crate) fn incoming_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
246        let mut results = Vec::new();
247
248        // From compressed storage (CSC)
249        let neighbors = self.csc.neighbors(idx);
250        let edges = self.csc.edge_data_for(idx);
251        for (i, &neighbor) in neighbors.iter().enumerate() {
252            if !self.is_removed(neighbor, idx)
253                && let Some(&id) = self.idx_to_id.get(neighbor as usize)
254            {
255                results.push((id, edges[i]));
256            }
257        }
258
259        // From delta
260        for delta in &self.delta_edges {
261            if delta.target_idx == idx
262                && let Some(&id) = self.idx_to_id.get(delta.source_idx as usize)
263            {
264                results.push((id, delta.data));
265            }
266        }
267
268        results
269    }
270
271    /// Check if a node exists in the graph.
272    pub fn contains_node(&self, id: MemoryId) -> bool {
273        self.id_to_idx.contains_key(&id)
274    }
275
276    /// Number of registered nodes.
277    pub fn node_count(&self) -> usize {
278        self.idx_to_id.len()
279    }
280
281    /// Resolve a MemoryId to its internal index.
282    pub(crate) fn get_idx(&self, id: MemoryId) -> Option<u32> {
283        self.id_to_idx.get(&id).copied()
284    }
285
286    /// Resolve an internal index to its MemoryId.
287    #[allow(dead_code)]
288    pub(crate) fn get_id(&self, idx: u32) -> Option<MemoryId> {
289        self.idx_to_id.get(idx as usize).copied()
290    }
291
292    /// All registered node IDs.
293    pub fn node_ids(&self) -> &[MemoryId] {
294        &self.idx_to_id
295    }
296
297    fn is_removed(&self, source: u32, target: u32) -> bool {
298        self.removed_edges
299            .iter()
300            .any(|&(s, t)| s == source && t == target)
301    }
302
303    /// Merge all delta edges and removals into the compressed CSR/CSC storage.
304    pub fn compact(&mut self) {
305        let num_nodes = self.idx_to_id.len();
306
307        // Collect all edges: existing (minus removed) + delta
308        let mut all_edges: Vec<(u32, u32, StoredEdge)> = Vec::new();
309
310        // Existing CSR edges
311        for row in 0..num_nodes {
312            let row = row as u32;
313            let neighbors = self.csr.neighbors(row);
314            let edges = self.csr.edge_data_for(row);
315            for (i, &col) in neighbors.iter().enumerate() {
316                if !self.is_removed(row, col) {
317                    all_edges.push((row, col, edges[i]));
318                }
319            }
320        }
321
322        // Delta edges
323        for delta in &self.delta_edges {
324            all_edges.push((delta.source_idx, delta.target_idx, delta.data));
325        }
326
327        // Build CSR (sorted by source)
328        self.csr = Self::build_compressed(&all_edges, num_nodes, false);
329
330        // Build CSC (sorted by target)
331        self.csc = Self::build_compressed(&all_edges, num_nodes, true);
332
333        self.delta_edges.clear();
334        self.removed_edges.clear();
335    }
336
337    fn build_compressed(
338        edges: &[(u32, u32, StoredEdge)],
339        num_nodes: usize,
340        transpose: bool,
341    ) -> CompressedStorage {
342        // Count edges per row
343        let mut counts = vec![0u32; num_nodes];
344        for &(src, tgt, _) in edges {
345            let row = if transpose { tgt } else { src };
346            if (row as usize) < num_nodes {
347                counts[row as usize] += 1;
348            }
349        }
350
351        // Build offsets via prefix sum
352        let mut row_offsets = vec![0u32; num_nodes + 1];
353        for i in 0..num_nodes {
354            row_offsets[i + 1] = row_offsets[i] + counts[i];
355        }
356
357        let total = row_offsets[num_nodes] as usize;
358        let mut col_indices = vec![0u32; total];
359        let mut edge_data = vec![
360            StoredEdge {
361                edge_type: EdgeType::Related,
362                weight: 0.0,
363                created_at: 0,
364                valid_from: None,
365                valid_until: None,
366            };
367            total
368        ];
369
370        // Fill using write cursors
371        let mut cursors = row_offsets[..num_nodes].to_vec();
372        for &(src, tgt, data) in edges {
373            let (row, col) = if transpose { (tgt, src) } else { (src, tgt) };
374            if (row as usize) < num_nodes {
375                let pos = cursors[row as usize] as usize;
376                col_indices[pos] = col;
377                edge_data[pos] = data;
378                cursors[row as usize] += 1;
379            }
380        }
381
382        CompressedStorage {
383            row_offsets,
384            col_indices,
385            edge_data,
386        }
387    }
388    /// Save the graph to a JSON file.
389    pub fn save(&self, path: &std::path::Path) -> MenteResult<()> {
390        let data =
391            serde_json::to_vec(self).map_err(|e| MenteError::Serialization(e.to_string()))?;
392        std::fs::write(path, data)?;
393        Ok(())
394    }
395
396    /// Load the graph from a JSON file.
397    pub fn load(path: &std::path::Path) -> MenteResult<Self> {
398        let data = std::fs::read(path)?;
399        let graph: Self =
400            serde_json::from_slice(&data).map_err(|e| MenteError::Serialization(e.to_string()))?;
401        Ok(graph)
402    }
403}
404
405impl Default for CsrGraph {
406    fn default() -> Self {
407        Self::new()
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
416        MemoryEdge {
417            source: src,
418            target: tgt,
419            edge_type: etype,
420            weight: 0.8,
421            created_at: 1000,
422            valid_from: None,
423            valid_until: None,
424        }
425    }
426
427    #[test]
428    fn test_add_node_idempotent() {
429        let mut g = CsrGraph::new();
430        let id = MemoryId::new();
431        let idx1 = g.add_node(id);
432        let idx2 = g.add_node(id);
433        assert_eq!(idx1, idx2);
434        assert_eq!(g.node_count(), 1);
435    }
436
437    #[test]
438    fn test_add_and_query_edges() {
439        let mut g = CsrGraph::new();
440        let a = MemoryId::new();
441        let b = MemoryId::new();
442        let c = MemoryId::new();
443
444        g.add_edge(&make_edge(a, b, EdgeType::Caused));
445        g.add_edge(&make_edge(a, c, EdgeType::Related));
446
447        let out = g.outgoing(a);
448        assert_eq!(out.len(), 2);
449
450        let inc_b = g.incoming(b);
451        assert_eq!(inc_b.len(), 1);
452        assert_eq!(inc_b[0].0, a);
453    }
454
455    #[test]
456    fn test_remove_edge() {
457        let mut g = CsrGraph::new();
458        let a = MemoryId::new();
459        let b = MemoryId::new();
460
461        g.add_edge(&make_edge(a, b, EdgeType::Caused));
462        assert_eq!(g.outgoing(a).len(), 1);
463
464        g.remove_edge(a, b);
465        assert_eq!(g.outgoing(a).len(), 0);
466    }
467
468    #[test]
469    fn test_compact() {
470        let mut g = CsrGraph::new();
471        let a = MemoryId::new();
472        let b = MemoryId::new();
473        let c = MemoryId::new();
474
475        g.add_edge(&make_edge(a, b, EdgeType::Caused));
476        g.add_edge(&make_edge(b, c, EdgeType::Before));
477        g.compact();
478
479        let out_a = g.outgoing(a);
480        assert_eq!(out_a.len(), 1);
481        assert_eq!(out_a[0].0, b);
482
483        let inc_c = g.incoming(c);
484        assert_eq!(inc_c.len(), 1);
485        assert_eq!(inc_c[0].0, b);
486    }
487
488    #[test]
489    fn test_compact_with_removals() {
490        let mut g = CsrGraph::new();
491        let a = MemoryId::new();
492        let b = MemoryId::new();
493        let c = MemoryId::new();
494
495        g.add_edge(&make_edge(a, b, EdgeType::Caused));
496        g.add_edge(&make_edge(a, c, EdgeType::Related));
497        g.compact();
498
499        g.remove_edge(a, b);
500        g.compact();
501
502        let out = g.outgoing(a);
503        assert_eq!(out.len(), 1);
504        assert_eq!(out[0].0, c);
505    }
506}