1use std::io::{Read, Write};
36
37use serde::{Deserialize, Serialize};
38use thiserror::Error;
39
40use crate::{NodeId, NodeRecord, RelationshipId, RelationshipRecord};
41
42pub const SNAPSHOT_MAGIC: &[u8; 8] = b"LORASNAP";
44
45pub const SNAPSHOT_FORMAT_VERSION: u32 = 1;
47
48pub const SNAPSHOT_MIN_SUPPORTED_FORMAT_VERSION: u32 = 1;
56
57const _: () = assert!(
58 SNAPSHOT_MIN_SUPPORTED_FORMAT_VERSION <= SNAPSHOT_FORMAT_VERSION,
59 "SNAPSHOT_MIN_SUPPORTED_FORMAT_VERSION must not exceed SNAPSHOT_FORMAT_VERSION",
60);
61
62pub(crate) const HEADER_LEN: usize = 40;
65
66pub const HEADER_FLAG_HAS_WAL_LSN: u32 = 1 << 0;
68
69#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
75pub struct SnapshotPayload {
76 pub next_node_id: NodeId,
77 pub next_rel_id: RelationshipId,
78 pub nodes: Vec<NodeRecord>,
79 pub relationships: Vec<RelationshipRecord>,
80}
81
82impl SnapshotPayload {
83 pub fn empty() -> Self {
84 Self {
85 next_node_id: 0,
86 next_rel_id: 0,
87 nodes: Vec::new(),
88 relationships: Vec::new(),
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub struct SnapshotMeta {
97 pub format_version: u32,
99 pub node_count: usize,
101 pub relationship_count: usize,
103 pub wal_lsn: Option<u64>,
106}
107
108#[derive(Debug, Error)]
110pub enum SnapshotError {
111 #[error("snapshot I/O error: {0}")]
112 Io(#[from] std::io::Error),
113
114 #[error("snapshot is not a LORASNAP file (bad magic)")]
115 BadMagic,
116
117 #[error("unsupported snapshot format version: {0}")]
118 UnsupportedVersion(u32),
119
120 #[error("snapshot header too short (expected {expected} bytes, got {actual})")]
121 TruncatedHeader { expected: usize, actual: usize },
122
123 #[error("snapshot CRC mismatch: expected 0x{expected:08x}, got 0x{actual:08x}")]
124 CrcMismatch { expected: u32, actual: u32 },
125
126 #[error("snapshot payload could not be decoded: {0}")]
127 Decode(String),
128
129 #[error("snapshot payload could not be encoded: {0}")]
130 Encode(String),
131}
132
133pub trait Snapshotable {
143 fn save_snapshot<W: Write>(&self, writer: W) -> Result<SnapshotMeta, SnapshotError>;
144
145 fn load_snapshot<R: Read>(&mut self, reader: R) -> Result<SnapshotMeta, SnapshotError>;
146
147 fn save_checkpoint<W: Write>(
159 &self,
160 writer: W,
161 wal_lsn: u64,
162 ) -> Result<SnapshotMeta, SnapshotError>;
163}
164
165#[derive(Debug, Clone, Copy)]
170pub(crate) struct SnapshotHeader {
171 pub format_version: u32,
172 pub header_flags: u32,
173 pub wal_lsn: u64,
174}
175
176impl SnapshotHeader {
177 pub(crate) fn new(format_version: u32, wal_lsn: Option<u64>) -> Self {
178 let (flags, lsn) = match wal_lsn {
179 Some(lsn) => (HEADER_FLAG_HAS_WAL_LSN, lsn),
180 None => (0, 0),
181 };
182 Self {
183 format_version,
184 header_flags: flags,
185 wal_lsn: lsn,
186 }
187 }
188
189 pub(crate) fn wal_lsn_if_set(&self) -> Option<u64> {
190 if self.header_flags & HEADER_FLAG_HAS_WAL_LSN != 0 {
191 Some(self.wal_lsn)
192 } else {
193 None
194 }
195 }
196
197 pub(crate) fn encode(&self) -> [u8; HEADER_LEN] {
198 let mut out = [0u8; HEADER_LEN];
199 out[0..8].copy_from_slice(SNAPSHOT_MAGIC);
200 out[8..12].copy_from_slice(&self.format_version.to_le_bytes());
201 out[12..16].copy_from_slice(&self.header_flags.to_le_bytes());
202 out[16..24].copy_from_slice(&self.wal_lsn.to_le_bytes());
203 out
205 }
206
207 pub(crate) fn decode(bytes: &[u8]) -> Result<Self, SnapshotError> {
208 if bytes.len() < HEADER_LEN {
209 return Err(SnapshotError::TruncatedHeader {
210 expected: HEADER_LEN,
211 actual: bytes.len(),
212 });
213 }
214 if &bytes[0..8] != SNAPSHOT_MAGIC {
215 return Err(SnapshotError::BadMagic);
216 }
217 let format_version = u32::from_le_bytes(bytes[8..12].try_into().unwrap());
218 if format_version < SNAPSHOT_MIN_SUPPORTED_FORMAT_VERSION
219 || format_version > SNAPSHOT_FORMAT_VERSION
220 {
221 return Err(SnapshotError::UnsupportedVersion(format_version));
222 }
223 let header_flags = u32::from_le_bytes(bytes[12..16].try_into().unwrap());
224 let wal_lsn = u64::from_le_bytes(bytes[16..24].try_into().unwrap());
225 Ok(Self {
226 format_version,
227 header_flags,
228 wal_lsn,
229 })
230 }
231}
232
233pub(crate) fn write_snapshot<W: Write>(
240 mut writer: W,
241 payload: &SnapshotPayload,
242 wal_lsn: Option<u64>,
243) -> Result<SnapshotMeta, SnapshotError> {
244 let header = SnapshotHeader::new(SNAPSHOT_FORMAT_VERSION, wal_lsn);
245 let header_bytes = header.encode();
246
247 let payload_bytes =
251 bincode::serialize(payload).map_err(|e| SnapshotError::Encode(e.to_string()))?;
252
253 let mut hasher = crc32fast::Hasher::new();
254 hasher.update(&header_bytes);
255 hasher.update(&payload_bytes);
256 let crc = hasher.finalize();
257
258 writer.write_all(&header_bytes)?;
259 writer.write_all(&payload_bytes)?;
260 writer.write_all(&crc.to_le_bytes())?;
261
262 Ok(SnapshotMeta {
263 format_version: SNAPSHOT_FORMAT_VERSION,
264 node_count: payload.nodes.len(),
265 relationship_count: payload.relationships.len(),
266 wal_lsn: header.wal_lsn_if_set(),
267 })
268}
269
270fn decode_payload_for_version(
286 format_version: u32,
287 bytes: &[u8],
288) -> Result<SnapshotPayload, SnapshotError> {
289 match format_version {
290 1 => bincode::deserialize::<SnapshotPayload>(bytes)
291 .map_err(|e| SnapshotError::Decode(e.to_string())),
292 other => Err(SnapshotError::UnsupportedVersion(other)),
293 }
294}
295
296pub(crate) fn read_snapshot<R: Read>(
298 mut reader: R,
299) -> Result<(SnapshotPayload, SnapshotMeta), SnapshotError> {
300 let mut buf = Vec::new();
304 reader.read_to_end(&mut buf)?;
305
306 if buf.len() < HEADER_LEN + 4 {
307 return Err(SnapshotError::TruncatedHeader {
308 expected: HEADER_LEN + 4,
309 actual: buf.len(),
310 });
311 }
312
313 let header = SnapshotHeader::decode(&buf[..HEADER_LEN])?;
314
315 let crc_offset = buf.len() - 4;
316 let stored_crc = u32::from_le_bytes(buf[crc_offset..].try_into().unwrap());
317
318 let mut hasher = crc32fast::Hasher::new();
319 hasher.update(&buf[..crc_offset]);
320 let actual_crc = hasher.finalize();
321 if stored_crc != actual_crc {
322 return Err(SnapshotError::CrcMismatch {
323 expected: stored_crc,
324 actual: actual_crc,
325 });
326 }
327
328 let payload = decode_payload_for_version(header.format_version, &buf[HEADER_LEN..crc_offset])?;
329
330 let meta = SnapshotMeta {
331 format_version: header.format_version,
332 node_count: payload.nodes.len(),
333 relationship_count: payload.relationships.len(),
334 wal_lsn: header.wal_lsn_if_set(),
335 };
336 Ok((payload, meta))
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use crate::{NodeRecord, Properties, PropertyValue, RelationshipRecord};
343
344 fn sample_payload() -> SnapshotPayload {
345 let mut props = Properties::new();
346 props.insert("name".into(), PropertyValue::String("alice".into()));
347 let nodes = vec![
348 NodeRecord {
349 id: 0,
350 labels: vec!["Person".into()],
351 properties: props.clone(),
352 },
353 NodeRecord {
354 id: 1,
355 labels: vec!["Person".into()],
356 properties: Properties::new(),
357 },
358 ];
359 let relationships = vec![RelationshipRecord {
360 id: 0,
361 src: 0,
362 dst: 1,
363 rel_type: "KNOWS".into(),
364 properties: Properties::new(),
365 }];
366 SnapshotPayload {
367 next_node_id: 2,
368 next_rel_id: 1,
369 nodes,
370 relationships,
371 }
372 }
373
374 #[test]
375 fn roundtrip_without_wal_lsn() {
376 let payload = sample_payload();
377 let mut buf = Vec::new();
378 let meta = write_snapshot(&mut buf, &payload, None).unwrap();
379
380 assert_eq!(meta.format_version, SNAPSHOT_FORMAT_VERSION);
381 assert_eq!(meta.node_count, 2);
382 assert_eq!(meta.relationship_count, 1);
383 assert_eq!(meta.wal_lsn, None);
384
385 let (decoded, decoded_meta) = read_snapshot(&buf[..]).unwrap();
386 assert_eq!(decoded, payload);
387 assert_eq!(decoded_meta, meta);
388 }
389
390 #[test]
391 fn roundtrip_with_wal_lsn() {
392 let payload = sample_payload();
393 let mut buf = Vec::new();
394 let meta = write_snapshot(&mut buf, &payload, Some(42)).unwrap();
395 assert_eq!(meta.wal_lsn, Some(42));
396
397 let (decoded, decoded_meta) = read_snapshot(&buf[..]).unwrap();
398 assert_eq!(decoded, payload);
399 assert_eq!(decoded_meta.wal_lsn, Some(42));
400 }
401
402 #[test]
403 fn bad_magic_rejected() {
404 let payload = sample_payload();
405 let mut buf = Vec::new();
406 write_snapshot(&mut buf, &payload, None).unwrap();
407 buf[0] = b'X';
408 let err = read_snapshot(&buf[..]).unwrap_err();
409 assert!(matches!(err, SnapshotError::BadMagic));
410 }
411
412 #[test]
413 fn future_version_rejected() {
414 let payload = sample_payload();
415 let mut buf = Vec::new();
416 write_snapshot(&mut buf, &payload, None).unwrap();
417 buf[8] = 99;
421 let err = read_snapshot(&buf[..]).unwrap_err();
422 assert!(matches!(err, SnapshotError::UnsupportedVersion(99)));
423 }
424
425 #[test]
426 fn below_min_version_rejected() {
427 let payload = sample_payload();
428 let mut buf = Vec::new();
429 write_snapshot(&mut buf, &payload, None).unwrap();
430 buf[8] = 0;
433 let err = read_snapshot(&buf[..]).unwrap_err();
434 assert!(matches!(err, SnapshotError::UnsupportedVersion(0)));
435 }
436
437 #[test]
438 fn crc_mismatch_rejected() {
439 let payload = sample_payload();
440 let mut buf = Vec::new();
441 write_snapshot(&mut buf, &payload, None).unwrap();
442 let mid = HEADER_LEN + 4;
444 buf[mid] ^= 0xff;
445 let err = read_snapshot(&buf[..]).unwrap_err();
446 assert!(matches!(err, SnapshotError::CrcMismatch { .. }));
447 }
448
449 #[test]
450 fn truncated_file_rejected() {
451 let payload = sample_payload();
452 let mut buf = Vec::new();
453 write_snapshot(&mut buf, &payload, None).unwrap();
454 buf.truncate(10);
455 let err = read_snapshot(&buf[..]).unwrap_err();
456 assert!(matches!(err, SnapshotError::TruncatedHeader { .. }));
457 }
458}