1use std::io::{Read, Write};
36
37use serde::{Deserialize, Serialize};
38use thiserror::Error;
39
40use crate::{NodeId, NodeRecord, RelationshipId, RelationshipRecord};
41
42pub const SNAPSHOT_MAGIC: &[u8; 8] = b"LORASNAP";
44
45pub const SNAPSHOT_FORMAT_VERSION: u32 = 1;
47
48pub const SNAPSHOT_MIN_SUPPORTED_FORMAT_VERSION: u32 = 1;
56
57const _: () = assert!(
58 SNAPSHOT_MIN_SUPPORTED_FORMAT_VERSION <= SNAPSHOT_FORMAT_VERSION,
59 "SNAPSHOT_MIN_SUPPORTED_FORMAT_VERSION must not exceed SNAPSHOT_FORMAT_VERSION",
60);
61
62pub(crate) const HEADER_LEN: usize = 40;
65
66pub const HEADER_FLAG_HAS_WAL_LSN: u32 = 1 << 0;
68
69#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
75pub struct SnapshotPayload {
76 pub next_node_id: NodeId,
77 pub next_rel_id: RelationshipId,
78 pub nodes: Vec<NodeRecord>,
79 pub relationships: Vec<RelationshipRecord>,
80}
81
82impl SnapshotPayload {
83 pub fn empty() -> Self {
84 Self {
85 next_node_id: 0,
86 next_rel_id: 0,
87 nodes: Vec::new(),
88 relationships: Vec::new(),
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub struct SnapshotMeta {
97 pub format_version: u32,
99 pub node_count: usize,
101 pub relationship_count: usize,
103 pub wal_lsn: Option<u64>,
106}
107
108#[derive(Debug, Error)]
110pub enum SnapshotError {
111 #[error("snapshot I/O error: {0}")]
112 Io(#[from] std::io::Error),
113
114 #[error("snapshot is not a LORASNAP file (bad magic)")]
115 BadMagic,
116
117 #[error("unsupported snapshot format version: {0}")]
118 UnsupportedVersion(u32),
119
120 #[error("snapshot header too short (expected {expected} bytes, got {actual})")]
121 TruncatedHeader { expected: usize, actual: usize },
122
123 #[error("snapshot CRC mismatch: expected 0x{expected:08x}, got 0x{actual:08x}")]
124 CrcMismatch { expected: u32, actual: u32 },
125
126 #[error("snapshot payload could not be decoded: {0}")]
127 Decode(String),
128
129 #[error("snapshot payload could not be encoded: {0}")]
130 Encode(String),
131}
132
133pub trait Snapshotable {
143 fn save_snapshot<W: Write>(&self, writer: W) -> Result<SnapshotMeta, SnapshotError>;
144
145 fn load_snapshot<R: Read>(&mut self, reader: R) -> Result<SnapshotMeta, SnapshotError>;
146}
147
148#[derive(Debug, Clone, Copy)]
153pub(crate) struct SnapshotHeader {
154 pub format_version: u32,
155 pub header_flags: u32,
156 pub wal_lsn: u64,
157}
158
159impl SnapshotHeader {
160 pub(crate) fn new(format_version: u32, wal_lsn: Option<u64>) -> Self {
161 let (flags, lsn) = match wal_lsn {
162 Some(lsn) => (HEADER_FLAG_HAS_WAL_LSN, lsn),
163 None => (0, 0),
164 };
165 Self {
166 format_version,
167 header_flags: flags,
168 wal_lsn: lsn,
169 }
170 }
171
172 pub(crate) fn wal_lsn_if_set(&self) -> Option<u64> {
173 if self.header_flags & HEADER_FLAG_HAS_WAL_LSN != 0 {
174 Some(self.wal_lsn)
175 } else {
176 None
177 }
178 }
179
180 pub(crate) fn encode(&self) -> [u8; HEADER_LEN] {
181 let mut out = [0u8; HEADER_LEN];
182 out[0..8].copy_from_slice(SNAPSHOT_MAGIC);
183 out[8..12].copy_from_slice(&self.format_version.to_le_bytes());
184 out[12..16].copy_from_slice(&self.header_flags.to_le_bytes());
185 out[16..24].copy_from_slice(&self.wal_lsn.to_le_bytes());
186 out
188 }
189
190 pub(crate) fn decode(bytes: &[u8]) -> Result<Self, SnapshotError> {
191 if bytes.len() < HEADER_LEN {
192 return Err(SnapshotError::TruncatedHeader {
193 expected: HEADER_LEN,
194 actual: bytes.len(),
195 });
196 }
197 if &bytes[0..8] != SNAPSHOT_MAGIC {
198 return Err(SnapshotError::BadMagic);
199 }
200 let format_version = u32::from_le_bytes(bytes[8..12].try_into().unwrap());
201 if format_version < SNAPSHOT_MIN_SUPPORTED_FORMAT_VERSION
202 || format_version > SNAPSHOT_FORMAT_VERSION
203 {
204 return Err(SnapshotError::UnsupportedVersion(format_version));
205 }
206 let header_flags = u32::from_le_bytes(bytes[12..16].try_into().unwrap());
207 let wal_lsn = u64::from_le_bytes(bytes[16..24].try_into().unwrap());
208 Ok(Self {
209 format_version,
210 header_flags,
211 wal_lsn,
212 })
213 }
214}
215
216pub(crate) fn write_snapshot<W: Write>(
223 mut writer: W,
224 payload: &SnapshotPayload,
225 wal_lsn: Option<u64>,
226) -> Result<SnapshotMeta, SnapshotError> {
227 let header = SnapshotHeader::new(SNAPSHOT_FORMAT_VERSION, wal_lsn);
228 let header_bytes = header.encode();
229
230 let payload_bytes =
234 bincode::serialize(payload).map_err(|e| SnapshotError::Encode(e.to_string()))?;
235
236 let mut hasher = crc32fast::Hasher::new();
237 hasher.update(&header_bytes);
238 hasher.update(&payload_bytes);
239 let crc = hasher.finalize();
240
241 writer.write_all(&header_bytes)?;
242 writer.write_all(&payload_bytes)?;
243 writer.write_all(&crc.to_le_bytes())?;
244
245 Ok(SnapshotMeta {
246 format_version: SNAPSHOT_FORMAT_VERSION,
247 node_count: payload.nodes.len(),
248 relationship_count: payload.relationships.len(),
249 wal_lsn: header.wal_lsn_if_set(),
250 })
251}
252
253fn decode_payload_for_version(
269 format_version: u32,
270 bytes: &[u8],
271) -> Result<SnapshotPayload, SnapshotError> {
272 match format_version {
273 1 => bincode::deserialize::<SnapshotPayload>(bytes)
274 .map_err(|e| SnapshotError::Decode(e.to_string())),
275 other => Err(SnapshotError::UnsupportedVersion(other)),
276 }
277}
278
279pub(crate) fn read_snapshot<R: Read>(
281 mut reader: R,
282) -> Result<(SnapshotPayload, SnapshotMeta), SnapshotError> {
283 let mut buf = Vec::new();
287 reader.read_to_end(&mut buf)?;
288
289 if buf.len() < HEADER_LEN + 4 {
290 return Err(SnapshotError::TruncatedHeader {
291 expected: HEADER_LEN + 4,
292 actual: buf.len(),
293 });
294 }
295
296 let header = SnapshotHeader::decode(&buf[..HEADER_LEN])?;
297
298 let crc_offset = buf.len() - 4;
299 let stored_crc = u32::from_le_bytes(buf[crc_offset..].try_into().unwrap());
300
301 let mut hasher = crc32fast::Hasher::new();
302 hasher.update(&buf[..crc_offset]);
303 let actual_crc = hasher.finalize();
304 if stored_crc != actual_crc {
305 return Err(SnapshotError::CrcMismatch {
306 expected: stored_crc,
307 actual: actual_crc,
308 });
309 }
310
311 let payload = decode_payload_for_version(header.format_version, &buf[HEADER_LEN..crc_offset])?;
312
313 let meta = SnapshotMeta {
314 format_version: header.format_version,
315 node_count: payload.nodes.len(),
316 relationship_count: payload.relationships.len(),
317 wal_lsn: header.wal_lsn_if_set(),
318 };
319 Ok((payload, meta))
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325 use crate::{NodeRecord, Properties, PropertyValue, RelationshipRecord};
326
327 fn sample_payload() -> SnapshotPayload {
328 let mut props = Properties::new();
329 props.insert("name".into(), PropertyValue::String("alice".into()));
330 let nodes = vec![
331 NodeRecord {
332 id: 0,
333 labels: vec!["Person".into()],
334 properties: props.clone(),
335 },
336 NodeRecord {
337 id: 1,
338 labels: vec!["Person".into()],
339 properties: Properties::new(),
340 },
341 ];
342 let relationships = vec![RelationshipRecord {
343 id: 0,
344 src: 0,
345 dst: 1,
346 rel_type: "KNOWS".into(),
347 properties: Properties::new(),
348 }];
349 SnapshotPayload {
350 next_node_id: 2,
351 next_rel_id: 1,
352 nodes,
353 relationships,
354 }
355 }
356
357 #[test]
358 fn roundtrip_without_wal_lsn() {
359 let payload = sample_payload();
360 let mut buf = Vec::new();
361 let meta = write_snapshot(&mut buf, &payload, None).unwrap();
362
363 assert_eq!(meta.format_version, SNAPSHOT_FORMAT_VERSION);
364 assert_eq!(meta.node_count, 2);
365 assert_eq!(meta.relationship_count, 1);
366 assert_eq!(meta.wal_lsn, None);
367
368 let (decoded, decoded_meta) = read_snapshot(&buf[..]).unwrap();
369 assert_eq!(decoded, payload);
370 assert_eq!(decoded_meta, meta);
371 }
372
373 #[test]
374 fn roundtrip_with_wal_lsn() {
375 let payload = sample_payload();
376 let mut buf = Vec::new();
377 let meta = write_snapshot(&mut buf, &payload, Some(42)).unwrap();
378 assert_eq!(meta.wal_lsn, Some(42));
379
380 let (decoded, decoded_meta) = read_snapshot(&buf[..]).unwrap();
381 assert_eq!(decoded, payload);
382 assert_eq!(decoded_meta.wal_lsn, Some(42));
383 }
384
385 #[test]
386 fn bad_magic_rejected() {
387 let payload = sample_payload();
388 let mut buf = Vec::new();
389 write_snapshot(&mut buf, &payload, None).unwrap();
390 buf[0] = b'X';
391 let err = read_snapshot(&buf[..]).unwrap_err();
392 assert!(matches!(err, SnapshotError::BadMagic));
393 }
394
395 #[test]
396 fn future_version_rejected() {
397 let payload = sample_payload();
398 let mut buf = Vec::new();
399 write_snapshot(&mut buf, &payload, None).unwrap();
400 buf[8] = 99;
404 let err = read_snapshot(&buf[..]).unwrap_err();
405 assert!(matches!(err, SnapshotError::UnsupportedVersion(99)));
406 }
407
408 #[test]
409 fn below_min_version_rejected() {
410 let payload = sample_payload();
411 let mut buf = Vec::new();
412 write_snapshot(&mut buf, &payload, None).unwrap();
413 buf[8] = 0;
416 let err = read_snapshot(&buf[..]).unwrap_err();
417 assert!(matches!(err, SnapshotError::UnsupportedVersion(0)));
418 }
419
420 #[test]
421 fn crc_mismatch_rejected() {
422 let payload = sample_payload();
423 let mut buf = Vec::new();
424 write_snapshot(&mut buf, &payload, None).unwrap();
425 let mid = HEADER_LEN + 4;
427 buf[mid] ^= 0xff;
428 let err = read_snapshot(&buf[..]).unwrap_err();
429 assert!(matches!(err, SnapshotError::CrcMismatch { .. }));
430 }
431
432 #[test]
433 fn truncated_file_rejected() {
434 let payload = sample_payload();
435 let mut buf = Vec::new();
436 write_snapshot(&mut buf, &payload, None).unwrap();
437 buf.truncate(10);
438 let err = read_snapshot(&buf[..]).unwrap_err();
439 assert!(matches!(err, SnapshotError::TruncatedHeader { .. }));
440 }
441}