Skip to main content

mentedb_graph/
csr.rs

1//! Compressed Sparse Row/Column graph storage with delta log for incremental updates.
2
3use ahash::HashMap;
4use mentedb_core::edge::{EdgeType, MemoryEdge};
5use mentedb_core::error::{MenteError, MenteResult};
6use mentedb_core::types::{MemoryId, Timestamp};
7use serde::{Deserialize, Serialize};
8
9/// Compact edge data stored in CSR/CSC arrays.
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11pub struct StoredEdge {
12    /// The relationship type.
13    pub edge_type: EdgeType,
14    /// Edge weight (0.0 to 1.0).
15    pub weight: f32,
16    /// When this edge was created.
17    pub created_at: Timestamp,
18    /// When this relationship became valid. None = since creation.
19    #[serde(default, skip_serializing_if = "Option::is_none")]
20    pub valid_from: Option<Timestamp>,
21    /// When this relationship stopped being valid. None = still valid.
22    #[serde(default, skip_serializing_if = "Option::is_none")]
23    pub valid_until: Option<Timestamp>,
24    /// Semantic label for the relationship (e.g. "owns", "attends").
25    #[serde(default, skip_serializing_if = "Option::is_none")]
26    pub label: Option<String>,
27}
28
29impl StoredEdge {
30    /// Converts a [`MemoryEdge`] into a compact stored representation.
31    pub fn from_memory_edge(edge: &MemoryEdge) -> Self {
32        Self {
33            edge_type: edge.edge_type,
34            weight: edge.weight,
35            created_at: edge.created_at,
36            valid_from: edge.valid_from,
37            valid_until: edge.valid_until,
38            label: edge.label.clone(),
39        }
40    }
41
42    /// Returns true if this edge is temporally valid at the given timestamp.
43    pub fn is_valid_at(&self, at: Timestamp) -> bool {
44        let from = self.valid_from.unwrap_or(0);
45        match self.valid_until {
46            Some(until) => at >= from && at < until,
47            None => at >= from,
48        }
49    }
50
51    /// Returns true if this edge has been invalidated.
52    pub fn is_invalidated(&self) -> bool {
53        self.valid_until.is_some()
54    }
55}
56
57/// A pending edge in the delta log before compaction into CSR.
58#[derive(Debug, Clone, Serialize, Deserialize)]
59struct DeltaEdge {
60    source_idx: u32,
61    target_idx: u32,
62    data: StoredEdge,
63}
64
65/// Compressed Sparse Row storage for one direction (outgoing or incoming).
66#[derive(Debug, Clone, Default, Serialize, Deserialize)]
67struct CompressedStorage {
68    /// Length = num_nodes + 1. row_offsets[i]..row_offsets[i+1] gives the range in col_indices/edge_data.
69    row_offsets: Vec<u32>,
70    /// Column indices (target node for CSR, source node for CSC).
71    col_indices: Vec<u32>,
72    /// Edge metadata parallel to col_indices.
73    edge_data: Vec<StoredEdge>,
74}
75
76impl CompressedStorage {
77    #[allow(dead_code)]
78    fn new(num_nodes: usize) -> Self {
79        Self {
80            row_offsets: vec![0; num_nodes + 1],
81            col_indices: Vec::new(),
82            edge_data: Vec::new(),
83        }
84    }
85
86    /// Get neighbors and edge data for a given row index.
87    fn neighbors(&self, row: u32) -> &[u32] {
88        let row = row as usize;
89        if row + 1 >= self.row_offsets.len() {
90            return &[];
91        }
92        let start = self.row_offsets[row] as usize;
93        let end = self.row_offsets[row + 1] as usize;
94        &self.col_indices[start..end]
95    }
96
97    fn edge_data_for(&self, row: u32) -> &[StoredEdge] {
98        let row = row as usize;
99        if row + 1 >= self.row_offsets.len() {
100            return &[];
101        }
102        let start = self.row_offsets[row] as usize;
103        let end = self.row_offsets[row + 1] as usize;
104        &self.edge_data[start..end]
105    }
106}
107
108/// Bidirectional graph with CSR (outgoing) and CSC (incoming) plus a delta log.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct CsrGraph {
111    /// Maps MemoryId -> internal u32 index.
112    id_to_idx: HashMap<MemoryId, u32>,
113    /// Maps internal u32 index -> MemoryId.
114    idx_to_id: Vec<MemoryId>,
115
116    /// CSR for outgoing edges.
117    csr: CompressedStorage,
118    /// CSC for incoming edges.
119    csc: CompressedStorage,
120
121    /// Recent edges not yet merged into the compressed storage.
122    delta_edges: Vec<DeltaEdge>,
123    /// Edges marked for removal (source_idx, target_idx).
124    removed_edges: Vec<(u32, u32)>,
125}
126
127impl CsrGraph {
128    /// Creates a new empty CSR graph.
129    pub fn new() -> Self {
130        Self {
131            id_to_idx: HashMap::default(),
132            idx_to_id: Vec::new(),
133            csr: CompressedStorage::default(),
134            csc: CompressedStorage::default(),
135            delta_edges: Vec::new(),
136            removed_edges: Vec::new(),
137        }
138    }
139
140    /// Register a node. Returns its internal index.
141    pub fn add_node(&mut self, id: MemoryId) -> u32 {
142        if let Some(&idx) = self.id_to_idx.get(&id) {
143            return idx;
144        }
145        let idx = self.idx_to_id.len() as u32;
146        self.id_to_idx.insert(id, idx);
147        self.idx_to_id.push(id);
148        idx
149    }
150
151    /// Remove a node and all its edges.
152    pub fn remove_node(&mut self, id: MemoryId) {
153        let Some(&idx) = self.id_to_idx.get(&id) else {
154            return;
155        };
156        // Mark all outgoing and incoming edges for removal
157        for &neighbor in self.csr.neighbors(idx) {
158            self.removed_edges.push((idx, neighbor));
159        }
160        for &neighbor in self.csc.neighbors(idx) {
161            self.removed_edges.push((neighbor, idx));
162        }
163        // Also remove from delta
164        self.delta_edges
165            .retain(|e| e.source_idx != idx && e.target_idx != idx);
166        self.id_to_idx.remove(&id);
167    }
168
169    /// Add an edge to the delta log.
170    pub fn add_edge(&mut self, edge: &MemoryEdge) {
171        let source_idx = self.add_node(edge.source);
172        let target_idx = self.add_node(edge.target);
173        self.delta_edges.push(DeltaEdge {
174            source_idx,
175            target_idx,
176            data: StoredEdge::from_memory_edge(edge),
177        });
178    }
179
180    /// Strengthen an edge by incrementing its weight (Hebbian learning).
181    /// Adds a delta edge with the new weight; compaction will merge it.
182    pub fn strengthen_edge(&mut self, source: MemoryId, target: MemoryId, delta: f32) {
183        // Find the existing edge to get its current data
184        if let Some(existing) = self
185            .outgoing(source)
186            .into_iter()
187            .find(|(id, _)| *id == target)
188        {
189            let (_, stored) = existing;
190            let new_weight = (stored.weight + delta).min(1.0);
191            let source_idx = self.add_node(source);
192            let target_idx = self.add_node(target);
193            self.delta_edges.push(DeltaEdge {
194                source_idx,
195                target_idx,
196                data: StoredEdge {
197                    weight: new_weight,
198                    ..stored
199                },
200            });
201        }
202    }
203
204    /// Mark an edge for removal.
205    pub fn remove_edge(&mut self, source: MemoryId, target: MemoryId) {
206        let (Some(&src_idx), Some(&tgt_idx)) =
207            (self.id_to_idx.get(&source), self.id_to_idx.get(&target))
208        else {
209            return;
210        };
211        self.removed_edges.push((src_idx, tgt_idx));
212        self.delta_edges
213            .retain(|e| !(e.source_idx == src_idx && e.target_idx == tgt_idx));
214    }
215
216    /// Get all outgoing edges from a node (CSR + delta, minus removed).
217    pub fn outgoing(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
218        let Some(&idx) = self.id_to_idx.get(&id) else {
219            return Vec::new();
220        };
221        self.outgoing_by_idx(idx)
222    }
223
224    /// Get outgoing edges that are temporally valid at the given timestamp.
225    pub fn outgoing_valid_at(&self, id: MemoryId, at: Timestamp) -> Vec<(MemoryId, StoredEdge)> {
226        self.outgoing(id)
227            .into_iter()
228            .filter(|(_, e)| e.is_valid_at(at))
229            .collect()
230    }
231
232    pub(crate) fn outgoing_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
233        let mut results = Vec::new();
234
235        // From compressed storage
236        let neighbors = self.csr.neighbors(idx);
237        let edges = self.csr.edge_data_for(idx);
238        for (i, &neighbor) in neighbors.iter().enumerate() {
239            if !self.is_removed(idx, neighbor)
240                && let Some(&id) = self.idx_to_id.get(neighbor as usize)
241            {
242                results.push((id, edges[i].clone()));
243            }
244        }
245
246        // From delta
247        for delta in &self.delta_edges {
248            if delta.source_idx == idx
249                && let Some(&id) = self.idx_to_id.get(delta.target_idx as usize)
250            {
251                results.push((id, delta.data.clone()));
252            }
253        }
254
255        results
256    }
257
258    /// Get all incoming edges to a node (CSC + delta, minus removed).
259    pub fn incoming(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
260        let Some(&idx) = self.id_to_idx.get(&id) else {
261            return Vec::new();
262        };
263        self.incoming_by_idx(idx)
264    }
265
266    /// Get incoming edges that are temporally valid at the given timestamp.
267    pub fn incoming_valid_at(&self, id: MemoryId, at: Timestamp) -> Vec<(MemoryId, StoredEdge)> {
268        self.incoming(id)
269            .into_iter()
270            .filter(|(_, e)| e.is_valid_at(at))
271            .collect()
272    }
273
274    pub(crate) fn incoming_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
275        let mut results = Vec::new();
276
277        // From compressed storage (CSC)
278        let neighbors = self.csc.neighbors(idx);
279        let edges = self.csc.edge_data_for(idx);
280        for (i, &neighbor) in neighbors.iter().enumerate() {
281            if !self.is_removed(neighbor, idx)
282                && let Some(&id) = self.idx_to_id.get(neighbor as usize)
283            {
284                results.push((id, edges[i].clone()));
285            }
286        }
287
288        // From delta
289        for delta in &self.delta_edges {
290            if delta.target_idx == idx
291                && let Some(&id) = self.idx_to_id.get(delta.source_idx as usize)
292            {
293                results.push((id, delta.data.clone()));
294            }
295        }
296
297        results
298    }
299
300    /// Check if a node exists in the graph.
301    pub fn contains_node(&self, id: MemoryId) -> bool {
302        self.id_to_idx.contains_key(&id)
303    }
304
305    /// Number of registered nodes.
306    pub fn node_count(&self) -> usize {
307        self.idx_to_id.len()
308    }
309
310    /// Resolve a MemoryId to its internal index.
311    pub(crate) fn get_idx(&self, id: MemoryId) -> Option<u32> {
312        self.id_to_idx.get(&id).copied()
313    }
314
315    /// Resolve an internal index to its MemoryId.
316    #[allow(dead_code)]
317    pub(crate) fn get_id(&self, idx: u32) -> Option<MemoryId> {
318        self.idx_to_id.get(idx as usize).copied()
319    }
320
321    /// All registered node IDs.
322    pub fn node_ids(&self) -> &[MemoryId] {
323        &self.idx_to_id
324    }
325
326    fn is_removed(&self, source: u32, target: u32) -> bool {
327        self.removed_edges
328            .iter()
329            .any(|&(s, t)| s == source && t == target)
330    }
331
332    /// Merge all delta edges and removals into the compressed CSR/CSC storage.
333    pub fn compact(&mut self) {
334        let num_nodes = self.idx_to_id.len();
335
336        // Collect all edges: existing (minus removed) + delta
337        let mut all_edges: Vec<(u32, u32, StoredEdge)> = Vec::new();
338
339        // Existing CSR edges
340        for row in 0..num_nodes {
341            let row = row as u32;
342            let neighbors = self.csr.neighbors(row);
343            let edges = self.csr.edge_data_for(row);
344            for (i, &col) in neighbors.iter().enumerate() {
345                if !self.is_removed(row, col) {
346                    all_edges.push((row, col, edges[i].clone()));
347                }
348            }
349        }
350
351        // Delta edges
352        for delta in &self.delta_edges {
353            all_edges.push((delta.source_idx, delta.target_idx, delta.data.clone()));
354        }
355
356        // Build CSR (sorted by source)
357        self.csr = Self::build_compressed(&all_edges, num_nodes, false);
358
359        // Build CSC (sorted by target)
360        self.csc = Self::build_compressed(&all_edges, num_nodes, true);
361
362        self.delta_edges.clear();
363        self.removed_edges.clear();
364    }
365
366    fn build_compressed(
367        edges: &[(u32, u32, StoredEdge)],
368        num_nodes: usize,
369        transpose: bool,
370    ) -> CompressedStorage {
371        // Count edges per row
372        let mut counts = vec![0u32; num_nodes];
373        for &(src, tgt, ref _data) in edges {
374            let row = if transpose { tgt } else { src };
375            if (row as usize) < num_nodes {
376                counts[row as usize] += 1;
377            }
378        }
379
380        // Build offsets via prefix sum
381        let mut row_offsets = vec![0u32; num_nodes + 1];
382        for i in 0..num_nodes {
383            row_offsets[i + 1] = row_offsets[i] + counts[i];
384        }
385
386        let total = row_offsets[num_nodes] as usize;
387        let mut col_indices = vec![0u32; total];
388        let mut edge_data = vec![
389            StoredEdge {
390                edge_type: EdgeType::Related,
391                weight: 0.0,
392                created_at: 0,
393                valid_from: None,
394                valid_until: None,
395                label: None,
396            };
397            total
398        ];
399
400        // Fill using write cursors
401        let mut cursors = row_offsets[..num_nodes].to_vec();
402        for &(src, tgt, ref data) in edges {
403            let (row, col) = if transpose { (tgt, src) } else { (src, tgt) };
404            if (row as usize) < num_nodes {
405                let pos = cursors[row as usize] as usize;
406                col_indices[pos] = col;
407                edge_data[pos] = data.clone();
408                cursors[row as usize] += 1;
409            }
410        }
411
412        CompressedStorage {
413            row_offsets,
414            col_indices,
415            edge_data,
416        }
417    }
418    /// Save the graph to a binary file.
419    pub fn save(&self, path: &std::path::Path) -> MenteResult<()> {
420        let data =
421            serde_json::to_vec(self).map_err(|e| MenteError::Serialization(e.to_string()))?;
422        std::fs::write(path, data)?;
423        Ok(())
424    }
425
426    /// Load the graph from a file.
427    pub fn load(path: &std::path::Path) -> MenteResult<Self> {
428        let data = std::fs::read(path)?;
429        let graph: Self =
430            serde_json::from_slice(&data).map_err(|e| MenteError::Serialization(e.to_string()))?;
431        Ok(graph)
432    }
433}
434
435impl Default for CsrGraph {
436    fn default() -> Self {
437        Self::new()
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
446        MemoryEdge {
447            source: src,
448            target: tgt,
449            edge_type: etype,
450            weight: 0.8,
451            created_at: 1000,
452            valid_from: None,
453            valid_until: None,
454            label: None,
455        }
456    }
457
458    #[test]
459    fn test_add_node_idempotent() {
460        let mut g = CsrGraph::new();
461        let id = MemoryId::new();
462        let idx1 = g.add_node(id);
463        let idx2 = g.add_node(id);
464        assert_eq!(idx1, idx2);
465        assert_eq!(g.node_count(), 1);
466    }
467
468    #[test]
469    fn test_add_and_query_edges() {
470        let mut g = CsrGraph::new();
471        let a = MemoryId::new();
472        let b = MemoryId::new();
473        let c = MemoryId::new();
474
475        g.add_edge(&make_edge(a, b, EdgeType::Caused));
476        g.add_edge(&make_edge(a, c, EdgeType::Related));
477
478        let out = g.outgoing(a);
479        assert_eq!(out.len(), 2);
480
481        let inc_b = g.incoming(b);
482        assert_eq!(inc_b.len(), 1);
483        assert_eq!(inc_b[0].0, a);
484    }
485
486    #[test]
487    fn test_remove_edge() {
488        let mut g = CsrGraph::new();
489        let a = MemoryId::new();
490        let b = MemoryId::new();
491
492        g.add_edge(&make_edge(a, b, EdgeType::Caused));
493        assert_eq!(g.outgoing(a).len(), 1);
494
495        g.remove_edge(a, b);
496        assert_eq!(g.outgoing(a).len(), 0);
497    }
498
499    #[test]
500    fn test_compact() {
501        let mut g = CsrGraph::new();
502        let a = MemoryId::new();
503        let b = MemoryId::new();
504        let c = MemoryId::new();
505
506        g.add_edge(&make_edge(a, b, EdgeType::Caused));
507        g.add_edge(&make_edge(b, c, EdgeType::Before));
508        g.compact();
509
510        let out_a = g.outgoing(a);
511        assert_eq!(out_a.len(), 1);
512        assert_eq!(out_a[0].0, b);
513
514        let inc_c = g.incoming(c);
515        assert_eq!(inc_c.len(), 1);
516        assert_eq!(inc_c[0].0, b);
517    }
518
519    #[test]
520    fn test_compact_with_removals() {
521        let mut g = CsrGraph::new();
522        let a = MemoryId::new();
523        let b = MemoryId::new();
524        let c = MemoryId::new();
525
526        g.add_edge(&make_edge(a, b, EdgeType::Caused));
527        g.add_edge(&make_edge(a, c, EdgeType::Related));
528        g.compact();
529
530        g.remove_edge(a, b);
531        g.compact();
532
533        let out = g.outgoing(a);
534        assert_eq!(out.len(), 1);
535        assert_eq!(out[0].0, c);
536    }
537
538    #[test]
539    fn test_remove_node_cleans_id_to_idx() {
540        let mut g = CsrGraph::new();
541        let a = MemoryId::new();
542        let b = MemoryId::new();
543
544        g.add_edge(&make_edge(a, b, EdgeType::Caused));
545        assert!(g.contains_node(a));
546        assert!(g.contains_node(b));
547
548        g.remove_node(a);
549        assert!(
550            !g.contains_node(a),
551            "removed node should not be in id_to_idx"
552        );
553        assert!(g.contains_node(b), "unrelated node should still exist");
554
555        // Edges involving the removed node should be gone
556        assert!(g.outgoing(a).is_empty());
557        assert!(g.incoming(b).is_empty());
558    }
559
560    #[test]
561    fn test_remove_node_then_readd() {
562        let mut g = CsrGraph::new();
563        let a = MemoryId::new();
564        let b = MemoryId::new();
565        let c = MemoryId::new();
566
567        g.add_edge(&make_edge(a, b, EdgeType::Caused));
568        g.remove_node(a);
569
570        // Re-adding the same ID should get a fresh index
571        g.add_edge(&make_edge(a, c, EdgeType::Related));
572        assert!(g.contains_node(a));
573        let out = g.outgoing(a);
574        assert_eq!(out.len(), 1);
575        assert_eq!(out[0].0, c);
576    }
577}