Skip to main content

nodedb_graph/csr/
persist.rs

1//! CSR checkpoint serialization/deserialization.
2//!
3//! Supports two serialization formats:
4//! - **MessagePack** (zerompk): Legacy format, backwards-compatible.
5//! - **rkyv**: ~3x faster serialization/deserialization. Detected on load
6//!   by magic bytes (`RKCSR\0` header). Future: mmap zero-copy access.
7//!
8//! Used by both Origin (via redb storage) and Lite (via embedded checkpoint).
9
10use std::collections::HashMap;
11
12use zerompk::{FromMessagePack, ToMessagePack};
13
14use super::index::CsrIndex;
15
16/// Magic header for rkyv-serialized CSR snapshots (6 bytes).
17const RKYV_MAGIC: &[u8; 6] = b"RKCS2\0";
18
19#[derive(ToMessagePack, FromMessagePack)]
20struct CsrSnapshotMsgpack {
21    nodes: Vec<String>,
22    labels: Vec<String>,
23    out_offsets: Vec<u32>,
24    out_targets: Vec<u32>,
25    out_labels: Vec<u32>,
26    in_offsets: Vec<u32>,
27    in_targets: Vec<u32>,
28    in_labels: Vec<u32>,
29    buffer_out: Vec<Vec<(u32, u32)>>,
30    buffer_in: Vec<Vec<(u32, u32)>>,
31    deleted: Vec<(u32, u32, u32)>,
32    has_weights: bool,
33    out_weights: Option<Vec<f64>>,
34    in_weights: Option<Vec<f64>>,
35    buffer_out_weights: Vec<Vec<f64>>,
36    buffer_in_weights: Vec<Vec<f64>>,
37}
38
39/// rkyv-serialized CSR snapshot for fast save/load.
40#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
41struct CsrSnapshotRkyv {
42    nodes: Vec<String>,
43    labels: Vec<String>,
44    out_offsets: Vec<u32>,
45    out_targets: Vec<u32>,
46    out_labels: Vec<u32>,
47    in_offsets: Vec<u32>,
48    in_targets: Vec<u32>,
49    in_labels: Vec<u32>,
50    buffer_out: Vec<Vec<(u32, u32)>>,
51    buffer_in: Vec<Vec<(u32, u32)>>,
52    deleted: Vec<(u32, u32, u32)>,
53    has_weights: bool,
54    out_weights: Option<Vec<f64>>,
55    in_weights: Option<Vec<f64>>,
56    buffer_out_weights: Vec<Vec<f64>>,
57    buffer_in_weights: Vec<Vec<f64>>,
58}
59
60impl CsrIndex {
61    /// Serialize the index to rkyv bytes (with magic header) for storage.
62    ///
63    /// rkyv is ~3x faster than MessagePack for both serialization and
64    /// deserialization. The magic header allows `from_checkpoint` to
65    /// auto-detect the format for backward compatibility.
66    pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
67        let snapshot = CsrSnapshotRkyv {
68            nodes: self.id_to_node.clone(),
69            labels: self.id_to_label.clone(),
70            out_offsets: self.out_offsets.clone(),
71            out_targets: self.out_targets.to_vec(),
72            out_labels: self.out_labels.to_vec(),
73            in_offsets: self.in_offsets.clone(),
74            in_targets: self.in_targets.to_vec(),
75            in_labels: self.in_labels.to_vec(),
76            buffer_out: self.buffer_out.clone(),
77            buffer_in: self.buffer_in.clone(),
78            deleted: self.deleted_edges.iter().copied().collect(),
79            has_weights: self.has_weights,
80            out_weights: self.out_weights.as_ref().map(|w| w.to_vec()),
81            in_weights: self.in_weights.as_ref().map(|w| w.to_vec()),
82            buffer_out_weights: self.buffer_out_weights.clone(),
83            buffer_in_weights: self.buffer_in_weights.clone(),
84        };
85        let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
86            .expect("CSR rkyv serialization should not fail");
87        let mut buf = Vec::with_capacity(RKYV_MAGIC.len() + rkyv_bytes.len());
88        buf.extend_from_slice(RKYV_MAGIC);
89        buf.extend_from_slice(&rkyv_bytes);
90        buf
91    }
92
93    /// Restore an index from a checkpoint snapshot.
94    ///
95    /// Auto-detects format: rkyv (magic header `RKCSR\0`) or legacy MessagePack.
96    /// Backwards-compatible with old checkpoints.
97    pub fn from_checkpoint(bytes: &[u8]) -> Option<Self> {
98        if bytes.len() > RKYV_MAGIC.len() && &bytes[..RKYV_MAGIC.len()] == RKYV_MAGIC {
99            return Self::from_rkyv_checkpoint(&bytes[RKYV_MAGIC.len()..]);
100        }
101        Self::from_msgpack_checkpoint(bytes)
102    }
103
104    /// Restore from rkyv-serialized bytes.
105    ///
106    /// On little-endian platforms (x86_64, ARM), dense arrays (targets, labels,
107    /// weights) are zero-copy: DenseArray points directly into the archived
108    /// buffer with no per-element parsing. On big-endian, falls back to full
109    /// deserialization.
110    fn from_rkyv_checkpoint(bytes: &[u8]) -> Option<Self> {
111        let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(bytes.len());
112        aligned.extend_from_slice(bytes);
113
114        #[cfg(target_endian = "little")]
115        {
116            Self::from_rkyv_zero_copy(aligned)
117        }
118        #[cfg(not(target_endian = "little"))]
119        {
120            let snap: CsrSnapshotRkyv =
121                rkyv::from_bytes::<CsrSnapshotRkyv, rkyv::rancor::Error>(&aligned).ok()?;
122            Some(Self::from_snapshot_fields(snap))
123        }
124    }
125
126    /// Zero-copy restore on little-endian platforms.
127    ///
128    /// SAFETY: On little-endian, rkyv's `u32_le`/`u16_le`/`f64_le` have
129    /// identical memory layout to native `u32`/`u16`/`f64`. The pointer
130    /// casts are sound because `ArchivedVec<T>` stores contiguous `T_le`
131    /// values, and the `Arc<AlignedVec>` keeps the buffer alive.
132    #[cfg(target_endian = "little")]
133    fn from_rkyv_zero_copy(aligned: rkyv::util::AlignedVec) -> Option<Self> {
134        use super::dense_array::DenseArray;
135
136        let backing = std::sync::Arc::new(aligned);
137
138        // Access archived data (zero-copy reference into the buffer).
139        let archived =
140            rkyv::access::<rkyv::Archived<CsrSnapshotRkyv>, rkyv::rancor::Error>(&backing).ok()?;
141
142        // Zero-copy DenseArrays for dense CSR arrays.
143        let out_targets = unsafe {
144            let s = archived.out_targets.as_slice();
145            DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
146        };
147        let out_labels = unsafe {
148            let s = archived.out_labels.as_slice();
149            DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
150        };
151        let in_targets = unsafe {
152            let s = archived.in_targets.as_slice();
153            DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
154        };
155        let in_labels = unsafe {
156            let s = archived.in_labels.as_slice();
157            DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
158        };
159        let out_weights = archived.out_weights.as_ref().map(|w| unsafe {
160            let s = w.as_slice();
161            DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<f64>(), s.len())
162        });
163        let in_weights = archived.in_weights.as_ref().map(|w| unsafe {
164            let s = w.as_slice();
165            DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<f64>(), s.len())
166        });
167
168        // Deserialize mutable/small fields (strings, buffers, offsets).
169        let snap: CsrSnapshotRkyv =
170            rkyv::from_bytes::<CsrSnapshotRkyv, rkyv::rancor::Error>(&backing).ok()?;
171
172        let node_to_id: HashMap<String, u32> = snap
173            .nodes
174            .iter()
175            .enumerate()
176            .map(|(i, n)| (n.clone(), i as u32))
177            .collect();
178        let label_to_id: HashMap<String, u32> = snap
179            .labels
180            .iter()
181            .enumerate()
182            .map(|(i, l)| (l.clone(), i as u32))
183            .collect();
184        let node_count = snap.nodes.len();
185        let access_counts = (0..node_count).map(|_| std::cell::Cell::new(0)).collect();
186        let buffer_out_weights = if snap.buffer_out_weights.len() == node_count {
187            snap.buffer_out_weights
188        } else {
189            vec![Vec::new(); node_count]
190        };
191        let buffer_in_weights = if snap.buffer_in_weights.len() == node_count {
192            snap.buffer_in_weights
193        } else {
194            vec![Vec::new(); node_count]
195        };
196
197        Some(Self {
198            node_to_id,
199            id_to_node: snap.nodes,
200            label_to_id,
201            id_to_label: snap.labels,
202            out_offsets: snap.out_offsets,
203            out_targets,
204            out_labels,
205            out_weights,
206            in_offsets: snap.in_offsets,
207            in_targets,
208            in_labels,
209            in_weights,
210            buffer_out: snap.buffer_out,
211            buffer_in: snap.buffer_in,
212            buffer_out_weights,
213            buffer_in_weights,
214            deleted_edges: snap.deleted.into_iter().collect(),
215            has_weights: snap.has_weights,
216            node_label_bits: vec![0; node_count],
217            node_label_to_id: HashMap::new(),
218            node_label_names: Vec::new(),
219            access_counts,
220            query_epoch: 0,
221        })
222    }
223
224    /// Restore from legacy MessagePack bytes.
225    fn from_msgpack_checkpoint(bytes: &[u8]) -> Option<Self> {
226        let snap: CsrSnapshotMsgpack = zerompk::from_msgpack(bytes).ok()?;
227        Some(Self::from_snapshot_fields(CsrSnapshotRkyv {
228            nodes: snap.nodes,
229            labels: snap.labels,
230            out_offsets: snap.out_offsets,
231            out_targets: snap.out_targets,
232            out_labels: snap.out_labels,
233            in_offsets: snap.in_offsets,
234            in_targets: snap.in_targets,
235            in_labels: snap.in_labels,
236            buffer_out: snap.buffer_out,
237            buffer_in: snap.buffer_in,
238            deleted: snap.deleted,
239            has_weights: snap.has_weights,
240            out_weights: snap.out_weights,
241            in_weights: snap.in_weights,
242            buffer_out_weights: snap.buffer_out_weights,
243            buffer_in_weights: snap.buffer_in_weights,
244        }))
245    }
246
247    /// Reconstruct CsrIndex from deserialized snapshot fields.
248    fn from_snapshot_fields(snap: CsrSnapshotRkyv) -> Self {
249        let node_to_id: HashMap<String, u32> = snap
250            .nodes
251            .iter()
252            .enumerate()
253            .map(|(i, n)| (n.clone(), i as u32))
254            .collect();
255        let label_to_id: HashMap<String, u32> = snap
256            .labels
257            .iter()
258            .enumerate()
259            .map(|(i, l)| (l.clone(), i as u32))
260            .collect();
261
262        let node_count = snap.nodes.len();
263        let access_counts = (0..node_count).map(|_| std::cell::Cell::new(0)).collect();
264
265        let buffer_out_weights = if snap.buffer_out_weights.len() == node_count {
266            snap.buffer_out_weights
267        } else {
268            vec![Vec::new(); node_count]
269        };
270        let buffer_in_weights = if snap.buffer_in_weights.len() == node_count {
271            snap.buffer_in_weights
272        } else {
273            vec![Vec::new(); node_count]
274        };
275
276        Self {
277            node_to_id,
278            id_to_node: snap.nodes,
279            label_to_id,
280            id_to_label: snap.labels,
281            out_offsets: snap.out_offsets,
282            out_targets: snap.out_targets.into(),
283            out_labels: snap.out_labels.into(),
284            out_weights: snap.out_weights.map(Into::into),
285            in_offsets: snap.in_offsets,
286            in_targets: snap.in_targets.into(),
287            in_labels: snap.in_labels.into(),
288            in_weights: snap.in_weights.map(Into::into),
289            buffer_out: snap.buffer_out,
290            buffer_in: snap.buffer_in,
291            buffer_out_weights,
292            buffer_in_weights,
293            deleted_edges: snap.deleted.into_iter().collect(),
294            has_weights: snap.has_weights,
295            node_label_bits: vec![0; node_count],
296            node_label_to_id: HashMap::new(),
297            node_label_names: Vec::new(),
298            access_counts,
299            query_epoch: 0,
300        }
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::csr::index::Direction;
308
309    #[test]
310    fn checkpoint_roundtrip_unweighted() {
311        let mut csr = CsrIndex::new();
312        csr.add_edge("a", "KNOWS", "b").unwrap();
313        csr.add_edge("b", "KNOWS", "c").unwrap();
314        csr.compact();
315
316        let bytes = csr.checkpoint_to_bytes();
317        let restored = CsrIndex::from_checkpoint(&bytes).expect("roundtrip");
318        assert_eq!(restored.node_count(), 3);
319        assert_eq!(restored.edge_count(), 2);
320        assert!(!restored.has_weights());
321
322        let n = restored.neighbors("a", Some("KNOWS"), Direction::Out);
323        assert_eq!(n.len(), 1);
324        assert_eq!(n[0].1, "b");
325    }
326
327    #[test]
328    fn checkpoint_roundtrip_weighted() {
329        let mut csr = CsrIndex::new();
330        csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
331        csr.add_edge_weighted("b", "R", "c", 7.0).unwrap();
332        csr.add_edge("c", "R", "d").unwrap();
333        csr.compact();
334
335        let bytes = csr.checkpoint_to_bytes();
336        let restored = CsrIndex::from_checkpoint(&bytes).expect("roundtrip");
337        assert!(restored.has_weights());
338        assert_eq!(restored.edge_weight("a", "R", "b"), Some(2.5));
339        assert_eq!(restored.edge_weight("b", "R", "c"), Some(7.0));
340        assert_eq!(restored.edge_weight("c", "R", "d"), Some(1.0));
341    }
342
343    #[test]
344    fn checkpoint_roundtrip_with_buffer() {
345        let mut csr = CsrIndex::new();
346        csr.add_edge("a", "L", "b").unwrap();
347        // Don't compact — edges in buffer.
348        let bytes = csr.checkpoint_to_bytes();
349        let restored = CsrIndex::from_checkpoint(&bytes).expect("roundtrip");
350        assert_eq!(restored.edge_count(), 1);
351    }
352}