use thiserror::Error;
pub const SNAPSHOT_MAGIC: [u8; 4] = *b"NDSN";
pub const SNAPSHOT_FORMAT_VERSION: u16 = 1;
const HEADER_LEN: usize = 4 + 2 + 2 + 4;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
pub enum SnapshotEngineId {
Vector = 1,
Graph = 2,
DocumentSchemaless = 3,
DocumentStrict = 4,
Columnar = 5,
KeyValue = 6,
Fts = 7,
Spatial = 8,
Crdt = 9,
}
impl SnapshotEngineId {
pub fn from_u16(v: u16) -> Result<Self, SnapshotFramingError> {
match v {
1 => Ok(Self::Vector),
2 => Ok(Self::Graph),
3 => Ok(Self::DocumentSchemaless),
4 => Ok(Self::DocumentStrict),
5 => Ok(Self::Columnar),
6 => Ok(Self::KeyValue),
7 => Ok(Self::Fts),
8 => Ok(Self::Spatial),
9 => Ok(Self::Crdt),
other => Err(SnapshotFramingError::UnknownEngineId(other)),
}
}
}
#[derive(Debug, Error, Clone)]
pub enum SnapshotFramingError {
#[error("snapshot frame magic mismatch: expected {SNAPSHOT_MAGIC:?}, got {0:?}")]
MagicMismatch([u8; 4]),
#[error("snapshot frame version mismatch: expected {SNAPSHOT_FORMAT_VERSION}, got {0}")]
VersionMismatch(u16),
#[error("snapshot frame CRC mismatch: stored {stored:#010x}, computed {computed:#010x}")]
CrcMismatch { stored: u32, computed: u32 },
#[error("unknown snapshot engine id: {0}")]
UnknownEngineId(u16),
#[error("snapshot frame truncated: need at least {HEADER_LEN} bytes, got {0}")]
Truncated(usize),
}
impl From<SnapshotFramingError> for crate::error::RaftError {
fn from(e: SnapshotFramingError) -> Self {
crate::error::RaftError::SnapshotFormat {
detail: e.to_string(),
}
}
}
pub fn encode_snapshot_chunk(engine_id: SnapshotEngineId, payload: &[u8]) -> Vec<u8> {
let engine_bytes = (engine_id as u16).to_be_bytes();
let crc = {
let mut h = crc32c::crc32c(&engine_bytes);
h = crc32c::crc32c_append(h, payload);
h
};
let mut out = Vec::with_capacity(HEADER_LEN + payload.len());
out.extend_from_slice(&SNAPSHOT_MAGIC);
out.extend_from_slice(&SNAPSHOT_FORMAT_VERSION.to_be_bytes());
out.extend_from_slice(&engine_bytes);
out.extend_from_slice(&crc.to_be_bytes());
out.extend_from_slice(payload);
out
}
pub fn decode_snapshot_chunk(
data: &[u8],
) -> Result<(SnapshotEngineId, &[u8]), SnapshotFramingError> {
if data.len() < HEADER_LEN {
return Err(SnapshotFramingError::Truncated(data.len()));
}
let magic: [u8; 4] = [data[0], data[1], data[2], data[3]];
if magic != SNAPSHOT_MAGIC {
return Err(SnapshotFramingError::MagicMismatch(magic));
}
let version = u16::from_be_bytes([data[4], data[5]]);
if version != SNAPSHOT_FORMAT_VERSION {
return Err(SnapshotFramingError::VersionMismatch(version));
}
let engine_id_raw = u16::from_be_bytes([data[6], data[7]]);
let engine_id = SnapshotEngineId::from_u16(engine_id_raw)?;
let stored_crc = u32::from_be_bytes([data[8], data[9], data[10], data[11]]);
let payload = &data[HEADER_LEN..];
let computed_crc = {
let mut h = crc32c::crc32c(&data[6..8]); h = crc32c::crc32c_append(h, payload);
h
};
if stored_crc != computed_crc {
return Err(SnapshotFramingError::CrcMismatch {
stored: stored_crc,
computed: computed_crc,
});
}
Ok((engine_id, payload))
}
#[cfg(test)]
mod tests {
use super::*;
const ALL_ENGINES: &[SnapshotEngineId] = &[
SnapshotEngineId::Vector,
SnapshotEngineId::Graph,
SnapshotEngineId::DocumentSchemaless,
SnapshotEngineId::DocumentStrict,
SnapshotEngineId::Columnar,
SnapshotEngineId::KeyValue,
SnapshotEngineId::Fts,
SnapshotEngineId::Spatial,
SnapshotEngineId::Crdt,
];
#[test]
fn roundtrip_all_engine_ids() {
for &engine_id in ALL_ENGINES {
let payload = b"test snapshot payload";
let framed = encode_snapshot_chunk(engine_id, payload);
let (decoded_id, decoded_payload) = decode_snapshot_chunk(&framed).unwrap();
assert_eq!(decoded_id, engine_id);
assert_eq!(decoded_payload, payload);
}
}
#[test]
fn roundtrip_empty_payload() {
let framed = encode_snapshot_chunk(SnapshotEngineId::KeyValue, &[]);
let (id, payload) = decode_snapshot_chunk(&framed).unwrap();
assert_eq!(id, SnapshotEngineId::KeyValue);
assert!(payload.is_empty());
}
#[test]
fn tamper_magic_returns_magic_mismatch() {
let mut framed = encode_snapshot_chunk(SnapshotEngineId::Vector, b"data");
framed[0] ^= 0xFF;
let err = decode_snapshot_chunk(&framed).unwrap_err();
assert!(
matches!(err, SnapshotFramingError::MagicMismatch(_)),
"{err}"
);
}
#[test]
fn tamper_version_returns_version_mismatch() {
let mut framed = encode_snapshot_chunk(SnapshotEngineId::Graph, b"data");
let bad_version = SNAPSHOT_FORMAT_VERSION.wrapping_add(1).to_be_bytes();
framed[4] = bad_version[0];
framed[5] = bad_version[1];
let err = decode_snapshot_chunk(&framed).unwrap_err();
assert!(
matches!(err, SnapshotFramingError::VersionMismatch(_)),
"{err}"
);
}
#[test]
fn tamper_crc_returns_crc_mismatch() {
let mut framed = encode_snapshot_chunk(SnapshotEngineId::Fts, b"important data");
framed[9] ^= 0x01;
let err = decode_snapshot_chunk(&framed).unwrap_err();
assert!(
matches!(err, SnapshotFramingError::CrcMismatch { .. }),
"{err}"
);
}
#[test]
fn reject_unknown_engine_id() {
let engine_id_raw: u16 = 99;
let engine_bytes = engine_id_raw.to_be_bytes();
let payload = b"payload";
let crc = {
let mut h = crc32c::crc32c(&engine_bytes);
h = crc32c::crc32c_append(h, payload);
h
};
let mut frame = Vec::new();
frame.extend_from_slice(&SNAPSHOT_MAGIC);
frame.extend_from_slice(&SNAPSHOT_FORMAT_VERSION.to_be_bytes());
frame.extend_from_slice(&engine_bytes);
frame.extend_from_slice(&crc.to_be_bytes());
frame.extend_from_slice(payload);
let err = decode_snapshot_chunk(&frame).unwrap_err();
assert!(
matches!(err, SnapshotFramingError::UnknownEngineId(99)),
"{err}"
);
}
#[test]
fn truncated_frame_returns_truncated_error() {
let framed = encode_snapshot_chunk(SnapshotEngineId::Crdt, b"data");
let err = decode_snapshot_chunk(&framed[..5]).unwrap_err();
assert!(matches!(err, SnapshotFramingError::Truncated(5)), "{err}");
}
#[test]
fn from_u16_roundtrip_all_discriminants() {
for &engine_id in ALL_ENGINES {
let raw = engine_id as u16;
let decoded = SnapshotEngineId::from_u16(raw).unwrap();
assert_eq!(decoded, engine_id);
}
}
#[test]
fn from_u16_unknown_returns_error() {
let err = SnapshotEngineId::from_u16(0).unwrap_err();
assert!(matches!(err, SnapshotFramingError::UnknownEngineId(0)));
let err = SnapshotEngineId::from_u16(255).unwrap_err();
assert!(matches!(err, SnapshotFramingError::UnknownEngineId(255)));
}
#[test]
fn golden_raft_snapshot_frame_format() {
let payload = b"golden-payload";
let framed = encode_snapshot_chunk(SnapshotEngineId::KeyValue, payload);
assert_eq!(&framed[0..4], b"NDSN", "magic mismatch");
let version = u16::from_be_bytes([framed[4], framed[5]]);
assert_eq!(version, SNAPSHOT_FORMAT_VERSION, "version mismatch");
assert_eq!(version, 1u16, "expected SNAPSHOT_FORMAT_VERSION == 1");
let engine_id_raw = u16::from_be_bytes([framed[6], framed[7]]);
assert_eq!(engine_id_raw, SnapshotEngineId::KeyValue as u16);
let stored_crc = u32::from_be_bytes([framed[8], framed[9], framed[10], framed[11]]);
let engine_bytes = (SnapshotEngineId::KeyValue as u16).to_be_bytes();
let mut h = crc32c::crc32c(&engine_bytes);
h = crc32c::crc32c_append(h, payload);
assert_eq!(stored_crc, h, "CRC mismatch");
assert_eq!(&framed[12..], payload);
}
}