use std::io::{Read, Write};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::{NodeId, NodeRecord, RelationshipId, RelationshipRecord};
pub const SNAPSHOT_MAGIC: &[u8; 8] = b"LORASNAP";
pub const SNAPSHOT_FORMAT_VERSION: u32 = 1;
pub const SNAPSHOT_MIN_SUPPORTED_FORMAT_VERSION: u32 = 1;
const _: () = assert!(
SNAPSHOT_MIN_SUPPORTED_FORMAT_VERSION <= SNAPSHOT_FORMAT_VERSION,
"SNAPSHOT_MIN_SUPPORTED_FORMAT_VERSION must not exceed SNAPSHOT_FORMAT_VERSION",
);
pub(crate) const HEADER_LEN: usize = 40;
pub const HEADER_FLAG_HAS_WAL_LSN: u32 = 1 << 0;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SnapshotPayload {
pub next_node_id: NodeId,
pub next_rel_id: RelationshipId,
pub nodes: Vec<NodeRecord>,
pub relationships: Vec<RelationshipRecord>,
}
impl SnapshotPayload {
pub fn empty() -> Self {
Self {
next_node_id: 0,
next_rel_id: 0,
nodes: Vec::new(),
relationships: Vec::new(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SnapshotMeta {
pub format_version: u32,
pub node_count: usize,
pub relationship_count: usize,
pub wal_lsn: Option<u64>,
}
#[derive(Debug, Error)]
pub enum SnapshotError {
#[error("snapshot I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("snapshot is not a LORASNAP file (bad magic)")]
BadMagic,
#[error("unsupported snapshot format version: {0}")]
UnsupportedVersion(u32),
#[error("snapshot header too short (expected {expected} bytes, got {actual})")]
TruncatedHeader { expected: usize, actual: usize },
#[error("snapshot CRC mismatch: expected 0x{expected:08x}, got 0x{actual:08x}")]
CrcMismatch { expected: u32, actual: u32 },
#[error("snapshot payload could not be decoded: {0}")]
Decode(String),
#[error("snapshot payload could not be encoded: {0}")]
Encode(String),
}
pub trait Snapshotable {
fn save_snapshot<W: Write>(&self, writer: W) -> Result<SnapshotMeta, SnapshotError>;
fn load_snapshot<R: Read>(&mut self, reader: R) -> Result<SnapshotMeta, SnapshotError>;
fn save_checkpoint<W: Write>(
&self,
writer: W,
wal_lsn: u64,
) -> Result<SnapshotMeta, SnapshotError>;
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct SnapshotHeader {
pub format_version: u32,
pub header_flags: u32,
pub wal_lsn: u64,
}
impl SnapshotHeader {
pub(crate) fn new(format_version: u32, wal_lsn: Option<u64>) -> Self {
let (flags, lsn) = match wal_lsn {
Some(lsn) => (HEADER_FLAG_HAS_WAL_LSN, lsn),
None => (0, 0),
};
Self {
format_version,
header_flags: flags,
wal_lsn: lsn,
}
}
pub(crate) fn wal_lsn_if_set(&self) -> Option<u64> {
if self.header_flags & HEADER_FLAG_HAS_WAL_LSN != 0 {
Some(self.wal_lsn)
} else {
None
}
}
pub(crate) fn encode(&self) -> [u8; HEADER_LEN] {
let mut out = [0u8; HEADER_LEN];
out[0..8].copy_from_slice(SNAPSHOT_MAGIC);
out[8..12].copy_from_slice(&self.format_version.to_le_bytes());
out[12..16].copy_from_slice(&self.header_flags.to_le_bytes());
out[16..24].copy_from_slice(&self.wal_lsn.to_le_bytes());
out
}
pub(crate) fn decode(bytes: &[u8]) -> Result<Self, SnapshotError> {
if bytes.len() < HEADER_LEN {
return Err(SnapshotError::TruncatedHeader {
expected: HEADER_LEN,
actual: bytes.len(),
});
}
if &bytes[0..8] != SNAPSHOT_MAGIC {
return Err(SnapshotError::BadMagic);
}
let format_version = u32::from_le_bytes(bytes[8..12].try_into().unwrap());
if format_version < SNAPSHOT_MIN_SUPPORTED_FORMAT_VERSION
|| format_version > SNAPSHOT_FORMAT_VERSION
{
return Err(SnapshotError::UnsupportedVersion(format_version));
}
let header_flags = u32::from_le_bytes(bytes[12..16].try_into().unwrap());
let wal_lsn = u64::from_le_bytes(bytes[16..24].try_into().unwrap());
Ok(Self {
format_version,
header_flags,
wal_lsn,
})
}
}
pub(crate) fn write_snapshot<W: Write>(
mut writer: W,
payload: &SnapshotPayload,
wal_lsn: Option<u64>,
) -> Result<SnapshotMeta, SnapshotError> {
let header = SnapshotHeader::new(SNAPSHOT_FORMAT_VERSION, wal_lsn);
let header_bytes = header.encode();
let payload_bytes =
bincode::serialize(payload).map_err(|e| SnapshotError::Encode(e.to_string()))?;
let mut hasher = crc32fast::Hasher::new();
hasher.update(&header_bytes);
hasher.update(&payload_bytes);
let crc = hasher.finalize();
writer.write_all(&header_bytes)?;
writer.write_all(&payload_bytes)?;
writer.write_all(&crc.to_le_bytes())?;
Ok(SnapshotMeta {
format_version: SNAPSHOT_FORMAT_VERSION,
node_count: payload.nodes.len(),
relationship_count: payload.relationships.len(),
wal_lsn: header.wal_lsn_if_set(),
})
}
fn decode_payload_for_version(
format_version: u32,
bytes: &[u8],
) -> Result<SnapshotPayload, SnapshotError> {
match format_version {
1 => bincode::deserialize::<SnapshotPayload>(bytes)
.map_err(|e| SnapshotError::Decode(e.to_string())),
other => Err(SnapshotError::UnsupportedVersion(other)),
}
}
pub(crate) fn read_snapshot<R: Read>(
mut reader: R,
) -> Result<(SnapshotPayload, SnapshotMeta), SnapshotError> {
let mut buf = Vec::new();
reader.read_to_end(&mut buf)?;
if buf.len() < HEADER_LEN + 4 {
return Err(SnapshotError::TruncatedHeader {
expected: HEADER_LEN + 4,
actual: buf.len(),
});
}
let header = SnapshotHeader::decode(&buf[..HEADER_LEN])?;
let crc_offset = buf.len() - 4;
let stored_crc = u32::from_le_bytes(buf[crc_offset..].try_into().unwrap());
let mut hasher = crc32fast::Hasher::new();
hasher.update(&buf[..crc_offset]);
let actual_crc = hasher.finalize();
if stored_crc != actual_crc {
return Err(SnapshotError::CrcMismatch {
expected: stored_crc,
actual: actual_crc,
});
}
let payload = decode_payload_for_version(header.format_version, &buf[HEADER_LEN..crc_offset])?;
let meta = SnapshotMeta {
format_version: header.format_version,
node_count: payload.nodes.len(),
relationship_count: payload.relationships.len(),
wal_lsn: header.wal_lsn_if_set(),
};
Ok((payload, meta))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{NodeRecord, Properties, PropertyValue, RelationshipRecord};
fn sample_payload() -> SnapshotPayload {
let mut props = Properties::new();
props.insert("name".into(), PropertyValue::String("alice".into()));
let nodes = vec![
NodeRecord {
id: 0,
labels: vec!["Person".into()],
properties: props.clone(),
},
NodeRecord {
id: 1,
labels: vec!["Person".into()],
properties: Properties::new(),
},
];
let relationships = vec![RelationshipRecord {
id: 0,
src: 0,
dst: 1,
rel_type: "KNOWS".into(),
properties: Properties::new(),
}];
SnapshotPayload {
next_node_id: 2,
next_rel_id: 1,
nodes,
relationships,
}
}
#[test]
fn roundtrip_without_wal_lsn() {
let payload = sample_payload();
let mut buf = Vec::new();
let meta = write_snapshot(&mut buf, &payload, None).unwrap();
assert_eq!(meta.format_version, SNAPSHOT_FORMAT_VERSION);
assert_eq!(meta.node_count, 2);
assert_eq!(meta.relationship_count, 1);
assert_eq!(meta.wal_lsn, None);
let (decoded, decoded_meta) = read_snapshot(&buf[..]).unwrap();
assert_eq!(decoded, payload);
assert_eq!(decoded_meta, meta);
}
#[test]
fn roundtrip_with_wal_lsn() {
let payload = sample_payload();
let mut buf = Vec::new();
let meta = write_snapshot(&mut buf, &payload, Some(42)).unwrap();
assert_eq!(meta.wal_lsn, Some(42));
let (decoded, decoded_meta) = read_snapshot(&buf[..]).unwrap();
assert_eq!(decoded, payload);
assert_eq!(decoded_meta.wal_lsn, Some(42));
}
#[test]
fn bad_magic_rejected() {
let payload = sample_payload();
let mut buf = Vec::new();
write_snapshot(&mut buf, &payload, None).unwrap();
buf[0] = b'X';
let err = read_snapshot(&buf[..]).unwrap_err();
assert!(matches!(err, SnapshotError::BadMagic));
}
#[test]
fn future_version_rejected() {
let payload = sample_payload();
let mut buf = Vec::new();
write_snapshot(&mut buf, &payload, None).unwrap();
buf[8] = 99;
let err = read_snapshot(&buf[..]).unwrap_err();
assert!(matches!(err, SnapshotError::UnsupportedVersion(99)));
}
#[test]
fn below_min_version_rejected() {
let payload = sample_payload();
let mut buf = Vec::new();
write_snapshot(&mut buf, &payload, None).unwrap();
buf[8] = 0;
let err = read_snapshot(&buf[..]).unwrap_err();
assert!(matches!(err, SnapshotError::UnsupportedVersion(0)));
}
#[test]
fn crc_mismatch_rejected() {
let payload = sample_payload();
let mut buf = Vec::new();
write_snapshot(&mut buf, &payload, None).unwrap();
let mid = HEADER_LEN + 4;
buf[mid] ^= 0xff;
let err = read_snapshot(&buf[..]).unwrap_err();
assert!(matches!(err, SnapshotError::CrcMismatch { .. }));
}
#[test]
fn truncated_file_rejected() {
let payload = sample_payload();
let mut buf = Vec::new();
write_snapshot(&mut buf, &payload, None).unwrap();
buf.truncate(10);
let err = read_snapshot(&buf[..]).unwrap_err();
assert!(matches!(err, SnapshotError::TruncatedHeader { .. }));
}
}