Skip to main content

mentedb_graph/
csr.rs

1//! Compressed Sparse Row/Column graph storage with delta log for incremental updates.
2
3use ahash::HashMap;
4use mentedb_core::edge::{EdgeType, MemoryEdge};
5use mentedb_core::error::{MenteError, MenteResult};
6use mentedb_core::types::{MemoryId, Timestamp};
7use serde::{Deserialize, Serialize};
8
9/// Compact edge data stored in CSR/CSC arrays.
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11pub struct StoredEdge {
12    /// The relationship type.
13    pub edge_type: EdgeType,
14    /// Edge weight (0.0 to 1.0).
15    pub weight: f32,
16    /// When this edge was created.
17    pub created_at: Timestamp,
18    /// When this relationship became valid. None = since creation.
19    #[serde(default, skip_serializing_if = "Option::is_none")]
20    pub valid_from: Option<Timestamp>,
21    /// When this relationship stopped being valid. None = still valid.
22    #[serde(default, skip_serializing_if = "Option::is_none")]
23    pub valid_until: Option<Timestamp>,
24    /// Semantic label for the relationship (e.g. "owns", "attends").
25    #[serde(default, skip_serializing_if = "Option::is_none")]
26    pub label: Option<String>,
27}
28
29impl StoredEdge {
30    /// Converts a [`MemoryEdge`] into a compact stored representation.
31    pub fn from_memory_edge(edge: &MemoryEdge) -> Self {
32        Self {
33            edge_type: edge.edge_type,
34            weight: edge.weight,
35            created_at: edge.created_at,
36            valid_from: edge.valid_from,
37            valid_until: edge.valid_until,
38            label: edge.label.clone(),
39        }
40    }
41
42    /// Returns true if this edge is temporally valid at the given timestamp.
43    pub fn is_valid_at(&self, at: Timestamp) -> bool {
44        let from = self.valid_from.unwrap_or(0);
45        match self.valid_until {
46            Some(until) => at >= from && at < until,
47            None => at >= from,
48        }
49    }
50
51    /// Returns true if this edge has been invalidated.
52    pub fn is_invalidated(&self) -> bool {
53        self.valid_until.is_some()
54    }
55}
56
57/// A pending edge in the delta log before compaction into CSR.
58#[derive(Debug, Clone, Serialize, Deserialize)]
59struct DeltaEdge {
60    source_idx: u32,
61    target_idx: u32,
62    data: StoredEdge,
63}
64
65/// Compressed Sparse Row storage for one direction (outgoing or incoming).
66#[derive(Debug, Clone, Default, Serialize, Deserialize)]
67struct CompressedStorage {
68    /// Length = num_nodes + 1. row_offsets[i]..row_offsets[i+1] gives the range in col_indices/edge_data.
69    row_offsets: Vec<u32>,
70    /// Column indices (target node for CSR, source node for CSC).
71    col_indices: Vec<u32>,
72    /// Edge metadata parallel to col_indices.
73    edge_data: Vec<StoredEdge>,
74}
75
76impl CompressedStorage {
77    #[allow(dead_code)]
78    fn new(num_nodes: usize) -> Self {
79        Self {
80            row_offsets: vec![0; num_nodes + 1],
81            col_indices: Vec::new(),
82            edge_data: Vec::new(),
83        }
84    }
85
86    /// Get neighbors and edge data for a given row index.
87    fn neighbors(&self, row: u32) -> &[u32] {
88        let row = row as usize;
89        if row + 1 >= self.row_offsets.len() {
90            return &[];
91        }
92        let start = self.row_offsets[row] as usize;
93        let end = self.row_offsets[row + 1] as usize;
94        &self.col_indices[start..end]
95    }
96
97    fn edge_data_for(&self, row: u32) -> &[StoredEdge] {
98        let row = row as usize;
99        if row + 1 >= self.row_offsets.len() {
100            return &[];
101        }
102        let start = self.row_offsets[row] as usize;
103        let end = self.row_offsets[row + 1] as usize;
104        &self.edge_data[start..end]
105    }
106}
107
108/// Bidirectional graph with CSR (outgoing) and CSC (incoming) plus a delta log.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct CsrGraph {
111    /// Maps MemoryId -> internal u32 index.
112    id_to_idx: HashMap<MemoryId, u32>,
113    /// Maps internal u32 index -> MemoryId.
114    idx_to_id: Vec<MemoryId>,
115
116    /// CSR for outgoing edges.
117    csr: CompressedStorage,
118    /// CSC for incoming edges.
119    csc: CompressedStorage,
120
121    /// Recent edges not yet merged into the compressed storage.
122    delta_edges: Vec<DeltaEdge>,
123    /// Edges marked for removal (source_idx, target_idx).
124    removed_edges: Vec<(u32, u32)>,
125}
126
127impl CsrGraph {
128    /// Creates a new empty CSR graph.
129    pub fn new() -> Self {
130        Self {
131            id_to_idx: HashMap::default(),
132            idx_to_id: Vec::new(),
133            csr: CompressedStorage::default(),
134            csc: CompressedStorage::default(),
135            delta_edges: Vec::new(),
136            removed_edges: Vec::new(),
137        }
138    }
139
140    /// Register a node. Returns its internal index.
141    pub fn add_node(&mut self, id: MemoryId) -> u32 {
142        if let Some(&idx) = self.id_to_idx.get(&id) {
143            return idx;
144        }
145        let idx = self.idx_to_id.len() as u32;
146        self.id_to_idx.insert(id, idx);
147        self.idx_to_id.push(id);
148        idx
149    }
150
151    /// Remove a node and all its edges.
152    pub fn remove_node(&mut self, id: MemoryId) {
153        let Some(&idx) = self.id_to_idx.get(&id) else {
154            return;
155        };
156        // Mark all outgoing and incoming edges for removal
157        for &neighbor in self.csr.neighbors(idx) {
158            self.removed_edges.push((idx, neighbor));
159        }
160        for &neighbor in self.csc.neighbors(idx) {
161            self.removed_edges.push((neighbor, idx));
162        }
163        // Also remove from delta
164        self.delta_edges
165            .retain(|e| e.source_idx != idx && e.target_idx != idx);
166    }
167
168    /// Add an edge to the delta log.
169    pub fn add_edge(&mut self, edge: &MemoryEdge) {
170        let source_idx = self.add_node(edge.source);
171        let target_idx = self.add_node(edge.target);
172        self.delta_edges.push(DeltaEdge {
173            source_idx,
174            target_idx,
175            data: StoredEdge::from_memory_edge(edge),
176        });
177    }
178
179    /// Strengthen an edge by incrementing its weight (Hebbian learning).
180    /// Adds a delta edge with the new weight; compaction will merge it.
181    pub fn strengthen_edge(&mut self, source: MemoryId, target: MemoryId, delta: f32) {
182        // Find the existing edge to get its current data
183        if let Some(existing) = self
184            .outgoing(source)
185            .into_iter()
186            .find(|(id, _)| *id == target)
187        {
188            let (_, stored) = existing;
189            let new_weight = (stored.weight + delta).min(1.0);
190            let source_idx = self.add_node(source);
191            let target_idx = self.add_node(target);
192            self.delta_edges.push(DeltaEdge {
193                source_idx,
194                target_idx,
195                data: StoredEdge {
196                    weight: new_weight,
197                    ..stored
198                },
199            });
200        }
201    }
202
203    /// Mark an edge for removal.
204    pub fn remove_edge(&mut self, source: MemoryId, target: MemoryId) {
205        let (Some(&src_idx), Some(&tgt_idx)) =
206            (self.id_to_idx.get(&source), self.id_to_idx.get(&target))
207        else {
208            return;
209        };
210        self.removed_edges.push((src_idx, tgt_idx));
211        self.delta_edges
212            .retain(|e| !(e.source_idx == src_idx && e.target_idx == tgt_idx));
213    }
214
215    /// Get all outgoing edges from a node (CSR + delta, minus removed).
216    pub fn outgoing(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
217        let Some(&idx) = self.id_to_idx.get(&id) else {
218            return Vec::new();
219        };
220        self.outgoing_by_idx(idx)
221    }
222
223    /// Get outgoing edges that are temporally valid at the given timestamp.
224    pub fn outgoing_valid_at(&self, id: MemoryId, at: Timestamp) -> Vec<(MemoryId, StoredEdge)> {
225        self.outgoing(id)
226            .into_iter()
227            .filter(|(_, e)| e.is_valid_at(at))
228            .collect()
229    }
230
231    pub(crate) fn outgoing_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
232        let mut results = Vec::new();
233
234        // From compressed storage
235        let neighbors = self.csr.neighbors(idx);
236        let edges = self.csr.edge_data_for(idx);
237        for (i, &neighbor) in neighbors.iter().enumerate() {
238            if !self.is_removed(idx, neighbor)
239                && let Some(&id) = self.idx_to_id.get(neighbor as usize)
240            {
241                results.push((id, edges[i].clone()));
242            }
243        }
244
245        // From delta
246        for delta in &self.delta_edges {
247            if delta.source_idx == idx
248                && let Some(&id) = self.idx_to_id.get(delta.target_idx as usize)
249            {
250                results.push((id, delta.data.clone()));
251            }
252        }
253
254        results
255    }
256
257    /// Get all incoming edges to a node (CSC + delta, minus removed).
258    pub fn incoming(&self, id: MemoryId) -> Vec<(MemoryId, StoredEdge)> {
259        let Some(&idx) = self.id_to_idx.get(&id) else {
260            return Vec::new();
261        };
262        self.incoming_by_idx(idx)
263    }
264
265    /// Get incoming edges that are temporally valid at the given timestamp.
266    pub fn incoming_valid_at(&self, id: MemoryId, at: Timestamp) -> Vec<(MemoryId, StoredEdge)> {
267        self.incoming(id)
268            .into_iter()
269            .filter(|(_, e)| e.is_valid_at(at))
270            .collect()
271    }
272
273    pub(crate) fn incoming_by_idx(&self, idx: u32) -> Vec<(MemoryId, StoredEdge)> {
274        let mut results = Vec::new();
275
276        // From compressed storage (CSC)
277        let neighbors = self.csc.neighbors(idx);
278        let edges = self.csc.edge_data_for(idx);
279        for (i, &neighbor) in neighbors.iter().enumerate() {
280            if !self.is_removed(neighbor, idx)
281                && let Some(&id) = self.idx_to_id.get(neighbor as usize)
282            {
283                results.push((id, edges[i].clone()));
284            }
285        }
286
287        // From delta
288        for delta in &self.delta_edges {
289            if delta.target_idx == idx
290                && let Some(&id) = self.idx_to_id.get(delta.source_idx as usize)
291            {
292                results.push((id, delta.data.clone()));
293            }
294        }
295
296        results
297    }
298
299    /// Check if a node exists in the graph.
300    pub fn contains_node(&self, id: MemoryId) -> bool {
301        self.id_to_idx.contains_key(&id)
302    }
303
304    /// Number of registered nodes.
305    pub fn node_count(&self) -> usize {
306        self.idx_to_id.len()
307    }
308
309    /// Resolve a MemoryId to its internal index.
310    pub(crate) fn get_idx(&self, id: MemoryId) -> Option<u32> {
311        self.id_to_idx.get(&id).copied()
312    }
313
314    /// Resolve an internal index to its MemoryId.
315    #[allow(dead_code)]
316    pub(crate) fn get_id(&self, idx: u32) -> Option<MemoryId> {
317        self.idx_to_id.get(idx as usize).copied()
318    }
319
320    /// All registered node IDs.
321    pub fn node_ids(&self) -> &[MemoryId] {
322        &self.idx_to_id
323    }
324
325    fn is_removed(&self, source: u32, target: u32) -> bool {
326        self.removed_edges
327            .iter()
328            .any(|&(s, t)| s == source && t == target)
329    }
330
331    /// Merge all delta edges and removals into the compressed CSR/CSC storage.
332    pub fn compact(&mut self) {
333        let num_nodes = self.idx_to_id.len();
334
335        // Collect all edges: existing (minus removed) + delta
336        let mut all_edges: Vec<(u32, u32, StoredEdge)> = Vec::new();
337
338        // Existing CSR edges
339        for row in 0..num_nodes {
340            let row = row as u32;
341            let neighbors = self.csr.neighbors(row);
342            let edges = self.csr.edge_data_for(row);
343            for (i, &col) in neighbors.iter().enumerate() {
344                if !self.is_removed(row, col) {
345                    all_edges.push((row, col, edges[i].clone()));
346                }
347            }
348        }
349
350        // Delta edges
351        for delta in &self.delta_edges {
352            all_edges.push((delta.source_idx, delta.target_idx, delta.data.clone()));
353        }
354
355        // Build CSR (sorted by source)
356        self.csr = Self::build_compressed(&all_edges, num_nodes, false);
357
358        // Build CSC (sorted by target)
359        self.csc = Self::build_compressed(&all_edges, num_nodes, true);
360
361        self.delta_edges.clear();
362        self.removed_edges.clear();
363    }
364
365    fn build_compressed(
366        edges: &[(u32, u32, StoredEdge)],
367        num_nodes: usize,
368        transpose: bool,
369    ) -> CompressedStorage {
370        // Count edges per row
371        let mut counts = vec![0u32; num_nodes];
372        for &(src, tgt, ref _data) in edges {
373            let row = if transpose { tgt } else { src };
374            if (row as usize) < num_nodes {
375                counts[row as usize] += 1;
376            }
377        }
378
379        // Build offsets via prefix sum
380        let mut row_offsets = vec![0u32; num_nodes + 1];
381        for i in 0..num_nodes {
382            row_offsets[i + 1] = row_offsets[i] + counts[i];
383        }
384
385        let total = row_offsets[num_nodes] as usize;
386        let mut col_indices = vec![0u32; total];
387        let mut edge_data = vec![
388            StoredEdge {
389                edge_type: EdgeType::Related,
390                weight: 0.0,
391                created_at: 0,
392                valid_from: None,
393                valid_until: None,
394                label: None,
395            };
396            total
397        ];
398
399        // Fill using write cursors
400        let mut cursors = row_offsets[..num_nodes].to_vec();
401        for &(src, tgt, ref data) in edges {
402            let (row, col) = if transpose { (tgt, src) } else { (src, tgt) };
403            if (row as usize) < num_nodes {
404                let pos = cursors[row as usize] as usize;
405                col_indices[pos] = col;
406                edge_data[pos] = data.clone();
407                cursors[row as usize] += 1;
408            }
409        }
410
411        CompressedStorage {
412            row_offsets,
413            col_indices,
414            edge_data,
415        }
416    }
417    /// Save the graph to a JSON file.
418    pub fn save(&self, path: &std::path::Path) -> MenteResult<()> {
419        let data =
420            serde_json::to_vec(self).map_err(|e| MenteError::Serialization(e.to_string()))?;
421        std::fs::write(path, data)?;
422        Ok(())
423    }
424
425    /// Load the graph from a JSON file.
426    pub fn load(path: &std::path::Path) -> MenteResult<Self> {
427        let data = std::fs::read(path)?;
428        let graph: Self =
429            serde_json::from_slice(&data).map_err(|e| MenteError::Serialization(e.to_string()))?;
430        Ok(graph)
431    }
432}
433
434impl Default for CsrGraph {
435    fn default() -> Self {
436        Self::new()
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType) -> MemoryEdge {
445        MemoryEdge {
446            source: src,
447            target: tgt,
448            edge_type: etype,
449            weight: 0.8,
450            created_at: 1000,
451            valid_from: None,
452            valid_until: None,
453            label: None,
454        }
455    }
456
457    #[test]
458    fn test_add_node_idempotent() {
459        let mut g = CsrGraph::new();
460        let id = MemoryId::new();
461        let idx1 = g.add_node(id);
462        let idx2 = g.add_node(id);
463        assert_eq!(idx1, idx2);
464        assert_eq!(g.node_count(), 1);
465    }
466
467    #[test]
468    fn test_add_and_query_edges() {
469        let mut g = CsrGraph::new();
470        let a = MemoryId::new();
471        let b = MemoryId::new();
472        let c = MemoryId::new();
473
474        g.add_edge(&make_edge(a, b, EdgeType::Caused));
475        g.add_edge(&make_edge(a, c, EdgeType::Related));
476
477        let out = g.outgoing(a);
478        assert_eq!(out.len(), 2);
479
480        let inc_b = g.incoming(b);
481        assert_eq!(inc_b.len(), 1);
482        assert_eq!(inc_b[0].0, a);
483    }
484
485    #[test]
486    fn test_remove_edge() {
487        let mut g = CsrGraph::new();
488        let a = MemoryId::new();
489        let b = MemoryId::new();
490
491        g.add_edge(&make_edge(a, b, EdgeType::Caused));
492        assert_eq!(g.outgoing(a).len(), 1);
493
494        g.remove_edge(a, b);
495        assert_eq!(g.outgoing(a).len(), 0);
496    }
497
498    #[test]
499    fn test_compact() {
500        let mut g = CsrGraph::new();
501        let a = MemoryId::new();
502        let b = MemoryId::new();
503        let c = MemoryId::new();
504
505        g.add_edge(&make_edge(a, b, EdgeType::Caused));
506        g.add_edge(&make_edge(b, c, EdgeType::Before));
507        g.compact();
508
509        let out_a = g.outgoing(a);
510        assert_eq!(out_a.len(), 1);
511        assert_eq!(out_a[0].0, b);
512
513        let inc_c = g.incoming(c);
514        assert_eq!(inc_c.len(), 1);
515        assert_eq!(inc_c[0].0, b);
516    }
517
518    #[test]
519    fn test_compact_with_removals() {
520        let mut g = CsrGraph::new();
521        let a = MemoryId::new();
522        let b = MemoryId::new();
523        let c = MemoryId::new();
524
525        g.add_edge(&make_edge(a, b, EdgeType::Caused));
526        g.add_edge(&make_edge(a, c, EdgeType::Related));
527        g.compact();
528
529        g.remove_edge(a, b);
530        g.compact();
531
532        let out = g.outgoing(a);
533        assert_eq!(out.len(), 1);
534        assert_eq!(out[0].0, c);
535    }
536}