1use std::collections::HashMap;
11
12use zerompk::{FromMessagePack, ToMessagePack};
13
14use super::index::CsrIndex;
15
16const RKYV_MAGIC: &[u8; 6] = b"RKCS2\0";
18
19#[derive(ToMessagePack, FromMessagePack)]
20struct CsrSnapshotMsgpack {
21 nodes: Vec<String>,
22 labels: Vec<String>,
23 out_offsets: Vec<u32>,
24 out_targets: Vec<u32>,
25 out_labels: Vec<u32>,
26 in_offsets: Vec<u32>,
27 in_targets: Vec<u32>,
28 in_labels: Vec<u32>,
29 buffer_out: Vec<Vec<(u32, u32)>>,
30 buffer_in: Vec<Vec<(u32, u32)>>,
31 deleted: Vec<(u32, u32, u32)>,
32 has_weights: bool,
33 out_weights: Option<Vec<f64>>,
34 in_weights: Option<Vec<f64>>,
35 buffer_out_weights: Vec<Vec<f64>>,
36 buffer_in_weights: Vec<Vec<f64>>,
37}
38
39#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
41struct CsrSnapshotRkyv {
42 nodes: Vec<String>,
43 labels: Vec<String>,
44 out_offsets: Vec<u32>,
45 out_targets: Vec<u32>,
46 out_labels: Vec<u32>,
47 in_offsets: Vec<u32>,
48 in_targets: Vec<u32>,
49 in_labels: Vec<u32>,
50 buffer_out: Vec<Vec<(u32, u32)>>,
51 buffer_in: Vec<Vec<(u32, u32)>>,
52 deleted: Vec<(u32, u32, u32)>,
53 has_weights: bool,
54 out_weights: Option<Vec<f64>>,
55 in_weights: Option<Vec<f64>>,
56 buffer_out_weights: Vec<Vec<f64>>,
57 buffer_in_weights: Vec<Vec<f64>>,
58}
59
60impl CsrIndex {
61 pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
67 let snapshot = CsrSnapshotRkyv {
68 nodes: self.id_to_node.clone(),
69 labels: self.id_to_label.clone(),
70 out_offsets: self.out_offsets.clone(),
71 out_targets: self.out_targets.to_vec(),
72 out_labels: self.out_labels.to_vec(),
73 in_offsets: self.in_offsets.clone(),
74 in_targets: self.in_targets.to_vec(),
75 in_labels: self.in_labels.to_vec(),
76 buffer_out: self.buffer_out.clone(),
77 buffer_in: self.buffer_in.clone(),
78 deleted: self.deleted_edges.iter().copied().collect(),
79 has_weights: self.has_weights,
80 out_weights: self.out_weights.as_ref().map(|w| w.to_vec()),
81 in_weights: self.in_weights.as_ref().map(|w| w.to_vec()),
82 buffer_out_weights: self.buffer_out_weights.clone(),
83 buffer_in_weights: self.buffer_in_weights.clone(),
84 };
85 let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
86 .expect("CSR rkyv serialization should not fail");
87 let mut buf = Vec::with_capacity(RKYV_MAGIC.len() + rkyv_bytes.len());
88 buf.extend_from_slice(RKYV_MAGIC);
89 buf.extend_from_slice(&rkyv_bytes);
90 buf
91 }
92
93 pub fn from_checkpoint(bytes: &[u8]) -> Option<Self> {
98 if bytes.len() > RKYV_MAGIC.len() && &bytes[..RKYV_MAGIC.len()] == RKYV_MAGIC {
99 return Self::from_rkyv_checkpoint(&bytes[RKYV_MAGIC.len()..]);
100 }
101 Self::from_msgpack_checkpoint(bytes)
102 }
103
104 fn from_rkyv_checkpoint(bytes: &[u8]) -> Option<Self> {
111 let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(bytes.len());
112 aligned.extend_from_slice(bytes);
113
114 #[cfg(target_endian = "little")]
115 {
116 Self::from_rkyv_zero_copy(aligned)
117 }
118 #[cfg(not(target_endian = "little"))]
119 {
120 let snap: CsrSnapshotRkyv =
121 rkyv::from_bytes::<CsrSnapshotRkyv, rkyv::rancor::Error>(&aligned).ok()?;
122 Some(Self::from_snapshot_fields(snap))
123 }
124 }
125
126 #[cfg(target_endian = "little")]
133 fn from_rkyv_zero_copy(aligned: rkyv::util::AlignedVec) -> Option<Self> {
134 use super::dense_array::DenseArray;
135
136 let backing = std::sync::Arc::new(aligned);
137
138 let archived =
140 rkyv::access::<rkyv::Archived<CsrSnapshotRkyv>, rkyv::rancor::Error>(&backing).ok()?;
141
142 let out_targets = unsafe {
144 let s = archived.out_targets.as_slice();
145 DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
146 };
147 let out_labels = unsafe {
148 let s = archived.out_labels.as_slice();
149 DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
150 };
151 let in_targets = unsafe {
152 let s = archived.in_targets.as_slice();
153 DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
154 };
155 let in_labels = unsafe {
156 let s = archived.in_labels.as_slice();
157 DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<u32>(), s.len())
158 };
159 let out_weights = archived.out_weights.as_ref().map(|w| unsafe {
160 let s = w.as_slice();
161 DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<f64>(), s.len())
162 });
163 let in_weights = archived.in_weights.as_ref().map(|w| unsafe {
164 let s = w.as_slice();
165 DenseArray::zero_copy(backing.clone(), s.as_ptr().cast::<f64>(), s.len())
166 });
167
168 let snap: CsrSnapshotRkyv =
170 rkyv::from_bytes::<CsrSnapshotRkyv, rkyv::rancor::Error>(&backing).ok()?;
171
172 let node_to_id: HashMap<String, u32> = snap
173 .nodes
174 .iter()
175 .enumerate()
176 .map(|(i, n)| (n.clone(), i as u32))
177 .collect();
178 let label_to_id: HashMap<String, u32> = snap
179 .labels
180 .iter()
181 .enumerate()
182 .map(|(i, l)| (l.clone(), i as u32))
183 .collect();
184 let node_count = snap.nodes.len();
185 let access_counts = (0..node_count).map(|_| std::cell::Cell::new(0)).collect();
186 let buffer_out_weights = if snap.buffer_out_weights.len() == node_count {
187 snap.buffer_out_weights
188 } else {
189 vec![Vec::new(); node_count]
190 };
191 let buffer_in_weights = if snap.buffer_in_weights.len() == node_count {
192 snap.buffer_in_weights
193 } else {
194 vec![Vec::new(); node_count]
195 };
196
197 Some(Self {
198 node_to_id,
199 id_to_node: snap.nodes,
200 label_to_id,
201 id_to_label: snap.labels,
202 out_offsets: snap.out_offsets,
203 out_targets,
204 out_labels,
205 out_weights,
206 in_offsets: snap.in_offsets,
207 in_targets,
208 in_labels,
209 in_weights,
210 buffer_out: snap.buffer_out,
211 buffer_in: snap.buffer_in,
212 buffer_out_weights,
213 buffer_in_weights,
214 deleted_edges: snap.deleted.into_iter().collect(),
215 has_weights: snap.has_weights,
216 node_label_bits: vec![0; node_count],
217 node_label_to_id: HashMap::new(),
218 node_label_names: Vec::new(),
219 access_counts,
220 query_epoch: 0,
221 })
222 }
223
224 fn from_msgpack_checkpoint(bytes: &[u8]) -> Option<Self> {
226 let snap: CsrSnapshotMsgpack = zerompk::from_msgpack(bytes).ok()?;
227 Some(Self::from_snapshot_fields(CsrSnapshotRkyv {
228 nodes: snap.nodes,
229 labels: snap.labels,
230 out_offsets: snap.out_offsets,
231 out_targets: snap.out_targets,
232 out_labels: snap.out_labels,
233 in_offsets: snap.in_offsets,
234 in_targets: snap.in_targets,
235 in_labels: snap.in_labels,
236 buffer_out: snap.buffer_out,
237 buffer_in: snap.buffer_in,
238 deleted: snap.deleted,
239 has_weights: snap.has_weights,
240 out_weights: snap.out_weights,
241 in_weights: snap.in_weights,
242 buffer_out_weights: snap.buffer_out_weights,
243 buffer_in_weights: snap.buffer_in_weights,
244 }))
245 }
246
247 fn from_snapshot_fields(snap: CsrSnapshotRkyv) -> Self {
249 let node_to_id: HashMap<String, u32> = snap
250 .nodes
251 .iter()
252 .enumerate()
253 .map(|(i, n)| (n.clone(), i as u32))
254 .collect();
255 let label_to_id: HashMap<String, u32> = snap
256 .labels
257 .iter()
258 .enumerate()
259 .map(|(i, l)| (l.clone(), i as u32))
260 .collect();
261
262 let node_count = snap.nodes.len();
263 let access_counts = (0..node_count).map(|_| std::cell::Cell::new(0)).collect();
264
265 let buffer_out_weights = if snap.buffer_out_weights.len() == node_count {
266 snap.buffer_out_weights
267 } else {
268 vec![Vec::new(); node_count]
269 };
270 let buffer_in_weights = if snap.buffer_in_weights.len() == node_count {
271 snap.buffer_in_weights
272 } else {
273 vec![Vec::new(); node_count]
274 };
275
276 Self {
277 node_to_id,
278 id_to_node: snap.nodes,
279 label_to_id,
280 id_to_label: snap.labels,
281 out_offsets: snap.out_offsets,
282 out_targets: snap.out_targets.into(),
283 out_labels: snap.out_labels.into(),
284 out_weights: snap.out_weights.map(Into::into),
285 in_offsets: snap.in_offsets,
286 in_targets: snap.in_targets.into(),
287 in_labels: snap.in_labels.into(),
288 in_weights: snap.in_weights.map(Into::into),
289 buffer_out: snap.buffer_out,
290 buffer_in: snap.buffer_in,
291 buffer_out_weights,
292 buffer_in_weights,
293 deleted_edges: snap.deleted.into_iter().collect(),
294 has_weights: snap.has_weights,
295 node_label_bits: vec![0; node_count],
296 node_label_to_id: HashMap::new(),
297 node_label_names: Vec::new(),
298 access_counts,
299 query_epoch: 0,
300 }
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use crate::csr::index::Direction;
308
309 #[test]
310 fn checkpoint_roundtrip_unweighted() {
311 let mut csr = CsrIndex::new();
312 csr.add_edge("a", "KNOWS", "b").unwrap();
313 csr.add_edge("b", "KNOWS", "c").unwrap();
314 csr.compact();
315
316 let bytes = csr.checkpoint_to_bytes();
317 let restored = CsrIndex::from_checkpoint(&bytes).expect("roundtrip");
318 assert_eq!(restored.node_count(), 3);
319 assert_eq!(restored.edge_count(), 2);
320 assert!(!restored.has_weights());
321
322 let n = restored.neighbors("a", Some("KNOWS"), Direction::Out);
323 assert_eq!(n.len(), 1);
324 assert_eq!(n[0].1, "b");
325 }
326
327 #[test]
328 fn checkpoint_roundtrip_weighted() {
329 let mut csr = CsrIndex::new();
330 csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
331 csr.add_edge_weighted("b", "R", "c", 7.0).unwrap();
332 csr.add_edge("c", "R", "d").unwrap();
333 csr.compact();
334
335 let bytes = csr.checkpoint_to_bytes();
336 let restored = CsrIndex::from_checkpoint(&bytes).expect("roundtrip");
337 assert!(restored.has_weights());
338 assert_eq!(restored.edge_weight("a", "R", "b"), Some(2.5));
339 assert_eq!(restored.edge_weight("b", "R", "c"), Some(7.0));
340 assert_eq!(restored.edge_weight("c", "R", "d"), Some(1.0));
341 }
342
343 #[test]
344 fn checkpoint_roundtrip_with_buffer() {
345 let mut csr = CsrIndex::new();
346 csr.add_edge("a", "L", "b").unwrap();
347 let bytes = csr.checkpoint_to_bytes();
349 let restored = CsrIndex::from_checkpoint(&bytes).expect("roundtrip");
350 assert_eq!(restored.edge_count(), 1);
351 }
352}