Skip to main content

nodedb_graph/csr/
persist.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! CSR checkpoint serialization via rkyv. On little-endian platforms
4//! dense arrays are restored zero-copy by pointing `DenseArray` at the
5//! archived buffer.
6//!
7//! Used by both Origin (via redb storage) and Lite (via embedded checkpoint).
8
9use std::collections::HashMap;
10use std::mem::size_of;
11
12use nodedb_mem::EngineId;
13
14use super::index::CsrIndex;
15use crate::GraphError;
16
17/// Magic header for rkyv-serialized CSR snapshots (6 bytes).
18const RKYV_MAGIC: &[u8; 6] = b"RKCS2\0";
19/// Current format version for rkyv-serialized CSR snapshots.
20pub const CSR_FORMAT_VERSION: u8 = 1;
21
22/// Errors during CSR checkpoint operations.
23#[derive(Debug, thiserror::Error)]
24#[non_exhaustive]
25pub enum CsrCheckpointError {
26    #[error("unsupported CSR checkpoint version {found}; expected {expected}")]
27    UnsupportedVersion { found: u8, expected: u8 },
28    #[error("CSR checkpoint rkyv deserialization failed")]
29    RkyvDeserialize,
30}
31
32/// rkyv-serialized CSR snapshot for fast save/load.
33#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
34struct CsrSnapshotRkyv {
35    nodes: Vec<String>,
36    labels: Vec<String>,
37    out_offsets: Vec<u32>,
38    out_targets: Vec<u32>,
39    out_labels: Vec<u32>,
40    in_offsets: Vec<u32>,
41    in_targets: Vec<u32>,
42    in_labels: Vec<u32>,
43    buffer_out: Vec<Vec<(u32, u32)>>,
44    buffer_in: Vec<Vec<(u32, u32)>>,
45    deleted: Vec<(u32, u32, u32)>,
46    has_weights: bool,
47    out_weights: Option<Vec<f64>>,
48    in_weights: Option<Vec<f64>>,
49    buffer_out_weights: Vec<Vec<f64>>,
50    buffer_in_weights: Vec<Vec<f64>>,
51}
52
53impl CsrIndex {
54    /// Serialize the index to rkyv bytes (with magic header) for storage.
55    ///
56    /// # Errors
57    ///
58    /// Returns [`GraphError::MemoryBudget`] if a memory governor is installed
59    /// and the serialization buffer would exceed the `Graph` engine budget.
60    pub fn checkpoint_to_bytes(&self) -> Result<Vec<u8>, GraphError> {
61        let snapshot = CsrSnapshotRkyv {
62            nodes: self.id_to_node.clone(),
63            labels: self.id_to_label.clone(),
64            out_offsets: self.out_offsets.clone(),
65            out_targets: self.out_targets.to_vec(),
66            out_labels: self.out_labels.to_vec(),
67            in_offsets: self.in_offsets.clone(),
68            in_targets: self.in_targets.to_vec(),
69            in_labels: self.in_labels.to_vec(),
70            buffer_out: self.buffer_out.clone(),
71            buffer_in: self.buffer_in.clone(),
72            deleted: self.deleted_edges.iter().copied().collect(),
73            has_weights: self.has_weights,
74            out_weights: self.out_weights.as_ref().map(|w| w.to_vec()),
75            in_weights: self.in_weights.as_ref().map(|w| w.to_vec()),
76            buffer_out_weights: self.buffer_out_weights.clone(),
77            buffer_in_weights: self.buffer_in_weights.clone(),
78        };
79        let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
80            .expect("CSR rkyv serialization should not fail");
81        let buf_capacity = RKYV_MAGIC.len() + 1 + rkyv_bytes.len();
82        let _budget_guard = self
83            .governor
84            .as_ref()
85            .map(|g| g.reserve(EngineId::Graph, buf_capacity * size_of::<u8>()))
86            .transpose()?;
87        let mut buf = Vec::with_capacity(buf_capacity);
88        buf.extend_from_slice(RKYV_MAGIC);
89        buf.push(CSR_FORMAT_VERSION);
90        buf.extend_from_slice(&rkyv_bytes);
91        Ok(buf)
92    }
93
94    /// Restore an index from a checkpoint snapshot.
95    ///
96    /// Returns:
97    /// - `Ok(Some(index))` — successfully decoded.
98    /// - `Ok(None)` — buffer does not start with the magic header (no legacy
99    ///   format exists for CSR; callers should treat this as an invalid buffer).
100    /// - `Err(CsrCheckpointError::UnsupportedVersion)` — magic matches but the
101    ///   version byte is not `CSR_FORMAT_VERSION`.
102    pub fn from_checkpoint(bytes: &[u8]) -> Result<Option<Self>, CsrCheckpointError> {
103        let header_len = RKYV_MAGIC.len() + 1; // magic + version byte
104        if bytes.len() > header_len && &bytes[..RKYV_MAGIC.len()] == RKYV_MAGIC {
105            let version = bytes[RKYV_MAGIC.len()];
106            if version != CSR_FORMAT_VERSION {
107                return Err(CsrCheckpointError::UnsupportedVersion {
108                    found: version,
109                    expected: CSR_FORMAT_VERSION,
110                });
111            }
112            return Ok(Self::from_rkyv_checkpoint(&bytes[header_len..]));
113        }
114        Ok(None)
115    }
116
117    /// Restore from rkyv-serialized bytes.
118    ///
119    /// On little-endian platforms (x86_64, ARM), dense arrays (targets, labels,
120    /// weights) are zero-copy: DenseArray points directly into the archived
121    /// buffer with no per-element parsing. On big-endian, falls back to full
122    /// deserialization.
123    fn from_rkyv_checkpoint(bytes: &[u8]) -> Option<Self> {
124        let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(bytes.len());
125        aligned.extend_from_slice(bytes);
126
127        #[cfg(target_endian = "little")]
128        {
129            Self::from_rkyv_zero_copy(aligned)
130        }
131        #[cfg(not(target_endian = "little"))]
132        {
133            let snap: CsrSnapshotRkyv =
134                rkyv::from_bytes::<CsrSnapshotRkyv, rkyv::rancor::Error>(&aligned).ok()?;
135            Some(Self::from_snapshot_fields(snap))
136        }
137    }
138
139    /// Zero-copy restore on little-endian platforms.
140    ///
141    /// SAFETY: On little-endian, rkyv's `u32_le`/`u16_le`/`f64_le` have
142    /// identical memory layout to native `u32`/`u16`/`f64`. The pointer
143    /// casts are sound because `ArchivedVec<T>` stores contiguous `T_le`
144    /// values, and the `Arc<AlignedVec>` keeps the buffer alive.
145    #[cfg(target_endian = "little")]
146    fn from_rkyv_zero_copy(aligned: rkyv::util::AlignedVec) -> Option<Self> {
147        use super::dense_array::DenseArray;
148
149        let backing = std::sync::Arc::new(aligned);
150
151        // Access archived data (zero-copy reference into the buffer).
152        let archived =
153            rkyv::access::<rkyv::Archived<CsrSnapshotRkyv>, rkyv::rancor::Error>(&backing).ok()?;
154
155        // Zero-copy DenseArrays for dense CSR arrays.
156        let out_targets = unsafe {
157            let s = archived.out_targets.as_slice();
158            DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
159        };
160        let out_labels = unsafe {
161            let s = archived.out_labels.as_slice();
162            DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
163        };
164        let in_targets = unsafe {
165            let s = archived.in_targets.as_slice();
166            DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
167        };
168        let in_labels = unsafe {
169            let s = archived.in_labels.as_slice();
170            DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
171        };
172        let out_weights = archived.out_weights.as_ref().map(|w| unsafe {
173            let s = w.as_slice();
174            DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<f64>(), s.len())
175        });
176        let in_weights = archived.in_weights.as_ref().map(|w| unsafe {
177            let s = w.as_slice();
178            DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<f64>(), s.len())
179        });
180
181        // Deserialize mutable/small fields (strings, buffers, offsets).
182        let snap: CsrSnapshotRkyv =
183            rkyv::from_bytes::<CsrSnapshotRkyv, rkyv::rancor::Error>(&backing).ok()?;
184
185        let node_to_id: HashMap<String, u32> = snap
186            .nodes
187            .iter()
188            .enumerate()
189            .map(|(i, n)| (n.clone(), i as u32))
190            .collect();
191        let label_to_id: HashMap<String, u32> = snap
192            .labels
193            .iter()
194            .enumerate()
195            .map(|(i, l)| (l.clone(), i as u32))
196            .collect();
197        let node_count = snap.nodes.len();
198        let access_counts = (0..node_count).map(|_| std::cell::Cell::new(0)).collect();
199        let buffer_out_weights = if snap.buffer_out_weights.len() == node_count {
200            snap.buffer_out_weights
201        } else {
202            vec![Vec::new(); node_count]
203        };
204        let buffer_in_weights = if snap.buffer_in_weights.len() == node_count {
205            snap.buffer_in_weights
206        } else {
207            vec![Vec::new(); node_count]
208        };
209
210        Some(Self {
211            node_to_id,
212            id_to_node: snap.nodes,
213            label_to_id,
214            id_to_label: snap.labels,
215            out_offsets: snap.out_offsets,
216            out_targets,
217            out_labels,
218            out_weights,
219            in_offsets: snap.in_offsets,
220            in_targets,
221            in_labels,
222            in_weights,
223            buffer_out: snap.buffer_out,
224            buffer_in: snap.buffer_in,
225            buffer_out_weights,
226            buffer_in_weights,
227            deleted_edges: snap.deleted.into_iter().collect(),
228            has_weights: snap.has_weights,
229            node_label_bits: vec![0; node_count],
230            node_label_to_id: HashMap::new(),
231            node_label_names: Vec::new(),
232            // Surrogates are runtime-only and not persisted. After checkpoint
233            // restore they start at zero and are repopulated by subsequent EdgePuts.
234            node_surrogates: vec![0; node_count],
235            surrogate_to_local: HashMap::new(),
236            access_counts,
237            query_epoch: 0,
238            partition_tag: crate::csr::local_node_id::next_partition_tag(),
239            // Checkpoint restore creates an ungoverned index; callers that
240            // need budget enforcement should call `set_governor` afterwards.
241            governor: None,
242        })
243    }
244
245    /// Reconstruct CsrIndex from deserialized snapshot fields.
246    #[cfg(not(target_endian = "little"))]
247    fn from_snapshot_fields(snap: CsrSnapshotRkyv) -> Self {
248        let node_to_id: HashMap<String, u32> = snap
249            .nodes
250            .iter()
251            .enumerate()
252            .map(|(i, n)| (n.clone(), i as u32))
253            .collect();
254        let label_to_id: HashMap<String, u32> = snap
255            .labels
256            .iter()
257            .enumerate()
258            .map(|(i, l)| (l.clone(), i as u32))
259            .collect();
260
261        let node_count = snap.nodes.len();
262        let access_counts = (0..node_count).map(|_| std::cell::Cell::new(0)).collect();
263
264        let buffer_out_weights = if snap.buffer_out_weights.len() == node_count {
265            snap.buffer_out_weights
266        } else {
267            vec![Vec::new(); node_count]
268        };
269        let buffer_in_weights = if snap.buffer_in_weights.len() == node_count {
270            snap.buffer_in_weights
271        } else {
272            vec![Vec::new(); node_count]
273        };
274
275        Self {
276            node_to_id,
277            id_to_node: snap.nodes,
278            label_to_id,
279            id_to_label: snap.labels,
280            out_offsets: snap.out_offsets,
281            out_targets: snap.out_targets.into(),
282            out_labels: snap.out_labels.into(),
283            out_weights: snap.out_weights.map(Into::into),
284            in_offsets: snap.in_offsets,
285            in_targets: snap.in_targets.into(),
286            in_labels: snap.in_labels.into(),
287            in_weights: snap.in_weights.map(Into::into),
288            buffer_out: snap.buffer_out,
289            buffer_in: snap.buffer_in,
290            buffer_out_weights,
291            buffer_in_weights,
292            deleted_edges: snap.deleted.into_iter().collect(),
293            has_weights: snap.has_weights,
294            node_label_bits: vec![0; node_count],
295            node_label_to_id: HashMap::new(),
296            node_label_names: Vec::new(),
297            // Surrogates are runtime-only and not persisted. After checkpoint
298            // restore they start at zero and are repopulated by subsequent EdgePuts.
299            node_surrogates: vec![0; node_count],
300            surrogate_to_local: HashMap::new(),
301            access_counts,
302            query_epoch: 0,
303            partition_tag: crate::csr::local_node_id::next_partition_tag(),
304            // Checkpoint restore creates an ungoverned index.
305            governor: None,
306        }
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use crate::csr::index::Direction;
314
315    #[test]
316    fn checkpoint_roundtrip_unweighted() {
317        let mut csr = CsrIndex::new();
318        csr.add_edge("a", "KNOWS", "b").unwrap();
319        csr.add_edge("b", "KNOWS", "c").unwrap();
320        csr.compact().expect("no governor, cannot fail");
321
322        let bytes = csr.checkpoint_to_bytes().expect("no governor, cannot fail");
323        let restored = CsrIndex::from_checkpoint(&bytes)
324            .expect("roundtrip")
325            .unwrap();
326        assert_eq!(restored.node_count(), 3);
327        assert_eq!(restored.edge_count(), 2);
328        assert!(!restored.has_weights());
329
330        let n = restored.neighbors("a", Some("KNOWS"), Direction::Out);
331        assert_eq!(n.len(), 1);
332        assert_eq!(n[0].1, "b");
333    }
334
335    #[test]
336    fn checkpoint_roundtrip_weighted() {
337        let mut csr = CsrIndex::new();
338        csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
339        csr.add_edge_weighted("b", "R", "c", 7.0).unwrap();
340        csr.add_edge("c", "R", "d").unwrap();
341        csr.compact().expect("no governor, cannot fail");
342
343        let bytes = csr.checkpoint_to_bytes().expect("no governor, cannot fail");
344        let restored = CsrIndex::from_checkpoint(&bytes)
345            .expect("roundtrip")
346            .unwrap();
347        assert!(restored.has_weights());
348        assert_eq!(restored.edge_weight("a", "R", "b"), Some(2.5));
349        assert_eq!(restored.edge_weight("b", "R", "c"), Some(7.0));
350        assert_eq!(restored.edge_weight("c", "R", "d"), Some(1.0));
351    }
352
353    #[test]
354    fn checkpoint_roundtrip_with_buffer() {
355        let mut csr = CsrIndex::new();
356        csr.add_edge("a", "L", "b").unwrap();
357        // Don't compact — edges in buffer.
358        let bytes = csr.checkpoint_to_bytes().expect("no governor, cannot fail");
359        let restored = CsrIndex::from_checkpoint(&bytes)
360            .expect("roundtrip")
361            .unwrap();
362        assert_eq!(restored.edge_count(), 1);
363    }
364
365    #[test]
366    fn golden_header_layout() {
367        let mut csr = CsrIndex::new();
368        csr.add_edge("a", "KNOWS", "b").unwrap();
369        let bytes = csr.checkpoint_to_bytes().expect("no governor, cannot fail");
370        // Magic at bytes[0..6].
371        assert_eq!(&bytes[0..6], b"RKCS2\0");
372        // Version byte at bytes[6].
373        assert_eq!(bytes[6], super::CSR_FORMAT_VERSION);
374        // rkyv payload follows immediately.
375        assert!(bytes.len() > 7);
376    }
377
378    #[test]
379    fn version_mismatch_returns_error() {
380        let mut csr = CsrIndex::new();
381        csr.add_edge("a", "KNOWS", "b").unwrap();
382        let mut bytes = csr.checkpoint_to_bytes().expect("no governor, cannot fail");
383        // Corrupt the version byte to an unsupported value.
384        bytes[6] = 0;
385        match CsrIndex::from_checkpoint(&bytes) {
386            Err(CsrCheckpointError::UnsupportedVersion { found, expected }) => {
387                assert_eq!(found, 0);
388                assert_eq!(expected, super::CSR_FORMAT_VERSION);
389            }
390            Err(other) => panic!("unexpected error: {other}"),
391            Ok(_) => panic!("expected UnsupportedVersion error, got Ok"),
392        }
393    }
394}