1use std::collections::HashMap;
10use std::mem::size_of;
11
12use nodedb_mem::EngineId;
13
14use super::index::CsrIndex;
15use crate::GraphError;
16
17const RKYV_MAGIC: &[u8; 6] = b"RKCS2\0";
19pub const CSR_FORMAT_VERSION: u8 = 1;
21
22#[derive(Debug, thiserror::Error)]
24#[non_exhaustive]
25pub enum CsrCheckpointError {
26 #[error("unsupported CSR checkpoint version {found}; expected {expected}")]
27 UnsupportedVersion { found: u8, expected: u8 },
28 #[error("CSR checkpoint rkyv deserialization failed")]
29 RkyvDeserialize,
30}
31
32#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
34struct CsrSnapshotRkyv {
35 nodes: Vec<String>,
36 labels: Vec<String>,
37 out_offsets: Vec<u32>,
38 out_targets: Vec<u32>,
39 out_labels: Vec<u32>,
40 in_offsets: Vec<u32>,
41 in_targets: Vec<u32>,
42 in_labels: Vec<u32>,
43 buffer_out: Vec<Vec<(u32, u32)>>,
44 buffer_in: Vec<Vec<(u32, u32)>>,
45 deleted: Vec<(u32, u32, u32)>,
46 has_weights: bool,
47 out_weights: Option<Vec<f64>>,
48 in_weights: Option<Vec<f64>>,
49 buffer_out_weights: Vec<Vec<f64>>,
50 buffer_in_weights: Vec<Vec<f64>>,
51}
52
53impl CsrIndex {
54 pub fn checkpoint_to_bytes(&self) -> Result<Vec<u8>, GraphError> {
61 let snapshot = CsrSnapshotRkyv {
62 nodes: self.id_to_node.clone(),
63 labels: self.id_to_label.clone(),
64 out_offsets: self.out_offsets.clone(),
65 out_targets: self.out_targets.to_vec(),
66 out_labels: self.out_labels.to_vec(),
67 in_offsets: self.in_offsets.clone(),
68 in_targets: self.in_targets.to_vec(),
69 in_labels: self.in_labels.to_vec(),
70 buffer_out: self.buffer_out.clone(),
71 buffer_in: self.buffer_in.clone(),
72 deleted: self.deleted_edges.iter().copied().collect(),
73 has_weights: self.has_weights,
74 out_weights: self.out_weights.as_ref().map(|w| w.to_vec()),
75 in_weights: self.in_weights.as_ref().map(|w| w.to_vec()),
76 buffer_out_weights: self.buffer_out_weights.clone(),
77 buffer_in_weights: self.buffer_in_weights.clone(),
78 };
79 let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
80 .expect("CSR rkyv serialization should not fail");
81 let buf_capacity = RKYV_MAGIC.len() + 1 + rkyv_bytes.len();
82 let _budget_guard = self
83 .governor
84 .as_ref()
85 .map(|g| g.reserve(EngineId::Graph, buf_capacity * size_of::<u8>()))
86 .transpose()?;
87 let mut buf = Vec::with_capacity(buf_capacity);
88 buf.extend_from_slice(RKYV_MAGIC);
89 buf.push(CSR_FORMAT_VERSION);
90 buf.extend_from_slice(&rkyv_bytes);
91 Ok(buf)
92 }
93
94 pub fn from_checkpoint(bytes: &[u8]) -> Result<Option<Self>, CsrCheckpointError> {
103 let header_len = RKYV_MAGIC.len() + 1; if bytes.len() > header_len && &bytes[..RKYV_MAGIC.len()] == RKYV_MAGIC {
105 let version = bytes[RKYV_MAGIC.len()];
106 if version != CSR_FORMAT_VERSION {
107 return Err(CsrCheckpointError::UnsupportedVersion {
108 found: version,
109 expected: CSR_FORMAT_VERSION,
110 });
111 }
112 return Ok(Self::from_rkyv_checkpoint(&bytes[header_len..]));
113 }
114 Ok(None)
115 }
116
117 fn from_rkyv_checkpoint(bytes: &[u8]) -> Option<Self> {
124 let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(bytes.len());
125 aligned.extend_from_slice(bytes);
126
127 #[cfg(target_endian = "little")]
128 {
129 Self::from_rkyv_zero_copy(aligned)
130 }
131 #[cfg(not(target_endian = "little"))]
132 {
133 let snap: CsrSnapshotRkyv =
134 rkyv::from_bytes::<CsrSnapshotRkyv, rkyv::rancor::Error>(&aligned).ok()?;
135 Some(Self::from_snapshot_fields(snap))
136 }
137 }
138
139 #[cfg(target_endian = "little")]
146 fn from_rkyv_zero_copy(aligned: rkyv::util::AlignedVec) -> Option<Self> {
147 use super::dense_array::DenseArray;
148
149 let backing = std::sync::Arc::new(aligned);
150
151 let archived =
153 rkyv::access::<rkyv::Archived<CsrSnapshotRkyv>, rkyv::rancor::Error>(&backing).ok()?;
154
155 let out_targets = unsafe {
157 let s = archived.out_targets.as_slice();
158 DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
159 };
160 let out_labels = unsafe {
161 let s = archived.out_labels.as_slice();
162 DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
163 };
164 let in_targets = unsafe {
165 let s = archived.in_targets.as_slice();
166 DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
167 };
168 let in_labels = unsafe {
169 let s = archived.in_labels.as_slice();
170 DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
171 };
172 let out_weights = archived.out_weights.as_ref().map(|w| unsafe {
173 let s = w.as_slice();
174 DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<f64>(), s.len())
175 });
176 let in_weights = archived.in_weights.as_ref().map(|w| unsafe {
177 let s = w.as_slice();
178 DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<f64>(), s.len())
179 });
180
181 let snap: CsrSnapshotRkyv =
183 rkyv::from_bytes::<CsrSnapshotRkyv, rkyv::rancor::Error>(&backing).ok()?;
184
185 let node_to_id: HashMap<String, u32> = snap
186 .nodes
187 .iter()
188 .enumerate()
189 .map(|(i, n)| (n.clone(), i as u32))
190 .collect();
191 let label_to_id: HashMap<String, u32> = snap
192 .labels
193 .iter()
194 .enumerate()
195 .map(|(i, l)| (l.clone(), i as u32))
196 .collect();
197 let node_count = snap.nodes.len();
198 let access_counts = (0..node_count).map(|_| std::cell::Cell::new(0)).collect();
199 let buffer_out_weights = if snap.buffer_out_weights.len() == node_count {
200 snap.buffer_out_weights
201 } else {
202 vec![Vec::new(); node_count]
203 };
204 let buffer_in_weights = if snap.buffer_in_weights.len() == node_count {
205 snap.buffer_in_weights
206 } else {
207 vec![Vec::new(); node_count]
208 };
209
210 Some(Self {
211 node_to_id,
212 id_to_node: snap.nodes,
213 label_to_id,
214 id_to_label: snap.labels,
215 out_offsets: snap.out_offsets,
216 out_targets,
217 out_labels,
218 out_weights,
219 in_offsets: snap.in_offsets,
220 in_targets,
221 in_labels,
222 in_weights,
223 buffer_out: snap.buffer_out,
224 buffer_in: snap.buffer_in,
225 buffer_out_weights,
226 buffer_in_weights,
227 deleted_edges: snap.deleted.into_iter().collect(),
228 has_weights: snap.has_weights,
229 node_label_bits: vec![0; node_count],
230 node_label_to_id: HashMap::new(),
231 node_label_names: Vec::new(),
232 node_surrogates: vec![0; node_count],
235 surrogate_to_local: HashMap::new(),
236 access_counts,
237 query_epoch: 0,
238 partition_tag: crate::csr::local_node_id::next_partition_tag(),
239 governor: None,
242 })
243 }
244
245 #[cfg(not(target_endian = "little"))]
247 fn from_snapshot_fields(snap: CsrSnapshotRkyv) -> Self {
248 let node_to_id: HashMap<String, u32> = snap
249 .nodes
250 .iter()
251 .enumerate()
252 .map(|(i, n)| (n.clone(), i as u32))
253 .collect();
254 let label_to_id: HashMap<String, u32> = snap
255 .labels
256 .iter()
257 .enumerate()
258 .map(|(i, l)| (l.clone(), i as u32))
259 .collect();
260
261 let node_count = snap.nodes.len();
262 let access_counts = (0..node_count).map(|_| std::cell::Cell::new(0)).collect();
263
264 let buffer_out_weights = if snap.buffer_out_weights.len() == node_count {
265 snap.buffer_out_weights
266 } else {
267 vec![Vec::new(); node_count]
268 };
269 let buffer_in_weights = if snap.buffer_in_weights.len() == node_count {
270 snap.buffer_in_weights
271 } else {
272 vec![Vec::new(); node_count]
273 };
274
275 Self {
276 node_to_id,
277 id_to_node: snap.nodes,
278 label_to_id,
279 id_to_label: snap.labels,
280 out_offsets: snap.out_offsets,
281 out_targets: snap.out_targets.into(),
282 out_labels: snap.out_labels.into(),
283 out_weights: snap.out_weights.map(Into::into),
284 in_offsets: snap.in_offsets,
285 in_targets: snap.in_targets.into(),
286 in_labels: snap.in_labels.into(),
287 in_weights: snap.in_weights.map(Into::into),
288 buffer_out: snap.buffer_out,
289 buffer_in: snap.buffer_in,
290 buffer_out_weights,
291 buffer_in_weights,
292 deleted_edges: snap.deleted.into_iter().collect(),
293 has_weights: snap.has_weights,
294 node_label_bits: vec![0; node_count],
295 node_label_to_id: HashMap::new(),
296 node_label_names: Vec::new(),
297 node_surrogates: vec![0; node_count],
300 surrogate_to_local: HashMap::new(),
301 access_counts,
302 query_epoch: 0,
303 partition_tag: crate::csr::local_node_id::next_partition_tag(),
304 governor: None,
306 }
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use crate::csr::index::Direction;
314
315 #[test]
316 fn checkpoint_roundtrip_unweighted() {
317 let mut csr = CsrIndex::new();
318 csr.add_edge("a", "KNOWS", "b").unwrap();
319 csr.add_edge("b", "KNOWS", "c").unwrap();
320 csr.compact().expect("no governor, cannot fail");
321
322 let bytes = csr.checkpoint_to_bytes().expect("no governor, cannot fail");
323 let restored = CsrIndex::from_checkpoint(&bytes)
324 .expect("roundtrip")
325 .unwrap();
326 assert_eq!(restored.node_count(), 3);
327 assert_eq!(restored.edge_count(), 2);
328 assert!(!restored.has_weights());
329
330 let n = restored.neighbors("a", Some("KNOWS"), Direction::Out);
331 assert_eq!(n.len(), 1);
332 assert_eq!(n[0].1, "b");
333 }
334
335 #[test]
336 fn checkpoint_roundtrip_weighted() {
337 let mut csr = CsrIndex::new();
338 csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
339 csr.add_edge_weighted("b", "R", "c", 7.0).unwrap();
340 csr.add_edge("c", "R", "d").unwrap();
341 csr.compact().expect("no governor, cannot fail");
342
343 let bytes = csr.checkpoint_to_bytes().expect("no governor, cannot fail");
344 let restored = CsrIndex::from_checkpoint(&bytes)
345 .expect("roundtrip")
346 .unwrap();
347 assert!(restored.has_weights());
348 assert_eq!(restored.edge_weight("a", "R", "b"), Some(2.5));
349 assert_eq!(restored.edge_weight("b", "R", "c"), Some(7.0));
350 assert_eq!(restored.edge_weight("c", "R", "d"), Some(1.0));
351 }
352
353 #[test]
354 fn checkpoint_roundtrip_with_buffer() {
355 let mut csr = CsrIndex::new();
356 csr.add_edge("a", "L", "b").unwrap();
357 let bytes = csr.checkpoint_to_bytes().expect("no governor, cannot fail");
359 let restored = CsrIndex::from_checkpoint(&bytes)
360 .expect("roundtrip")
361 .unwrap();
362 assert_eq!(restored.edge_count(), 1);
363 }
364
365 #[test]
366 fn golden_header_layout() {
367 let mut csr = CsrIndex::new();
368 csr.add_edge("a", "KNOWS", "b").unwrap();
369 let bytes = csr.checkpoint_to_bytes().expect("no governor, cannot fail");
370 assert_eq!(&bytes[0..6], b"RKCS2\0");
372 assert_eq!(bytes[6], super::CSR_FORMAT_VERSION);
374 assert!(bytes.len() > 7);
376 }
377
378 #[test]
379 fn version_mismatch_returns_error() {
380 let mut csr = CsrIndex::new();
381 csr.add_edge("a", "KNOWS", "b").unwrap();
382 let mut bytes = csr.checkpoint_to_bytes().expect("no governor, cannot fail");
383 bytes[6] = 0;
385 match CsrIndex::from_checkpoint(&bytes) {
386 Err(CsrCheckpointError::UnsupportedVersion { found, expected }) => {
387 assert_eq!(found, 0);
388 assert_eq!(expected, super::CSR_FORMAT_VERSION);
389 }
390 Err(other) => panic!("unexpected error: {other}"),
391 Ok(_) => panic!("expected UnsupportedVersion error, got Ok"),
392 }
393 }
394}