1use thiserror::Error;
36
37pub const SNAPSHOT_MAGIC: [u8; 4] = *b"NDSN";
39
40pub const SNAPSHOT_FORMAT_VERSION: u16 = 1;
45
46const HEADER_LEN: usize = 4 + 2 + 2 + 4;
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55#[repr(u16)]
56pub enum SnapshotEngineId {
57 Vector = 1,
58 Graph = 2,
59 DocumentSchemaless = 3,
60 DocumentStrict = 4,
61 Columnar = 5,
62 KeyValue = 6,
63 Fts = 7,
64 Spatial = 8,
65 Crdt = 9,
66}
67
68impl SnapshotEngineId {
69 pub fn from_u16(v: u16) -> Result<Self, SnapshotFramingError> {
71 match v {
72 1 => Ok(Self::Vector),
73 2 => Ok(Self::Graph),
74 3 => Ok(Self::DocumentSchemaless),
75 4 => Ok(Self::DocumentStrict),
76 5 => Ok(Self::Columnar),
77 6 => Ok(Self::KeyValue),
78 7 => Ok(Self::Fts),
79 8 => Ok(Self::Spatial),
80 9 => Ok(Self::Crdt),
81 other => Err(SnapshotFramingError::UnknownEngineId(other)),
82 }
83 }
84}
85
86#[derive(Debug, Error, Clone)]
88pub enum SnapshotFramingError {
89 #[error("snapshot frame magic mismatch: expected {SNAPSHOT_MAGIC:?}, got {0:?}")]
90 MagicMismatch([u8; 4]),
91
92 #[error("snapshot frame version mismatch: expected {SNAPSHOT_FORMAT_VERSION}, got {0}")]
93 VersionMismatch(u16),
94
95 #[error("snapshot frame CRC mismatch: stored {stored:#010x}, computed {computed:#010x}")]
96 CrcMismatch { stored: u32, computed: u32 },
97
98 #[error("unknown snapshot engine id: {0}")]
99 UnknownEngineId(u16),
100
101 #[error("snapshot frame truncated: need at least {HEADER_LEN} bytes, got {0}")]
102 Truncated(usize),
103}
104
105impl From<SnapshotFramingError> for crate::error::RaftError {
106 fn from(e: SnapshotFramingError) -> Self {
107 crate::error::RaftError::SnapshotFormat {
108 detail: e.to_string(),
109 }
110 }
111}
112
113pub fn encode_snapshot_chunk(engine_id: SnapshotEngineId, payload: &[u8]) -> Vec<u8> {
118 let engine_bytes = (engine_id as u16).to_be_bytes();
119 let crc = {
120 let mut h = crc32c::crc32c(&engine_bytes);
121 h = crc32c::crc32c_append(h, payload);
122 h
123 };
124
125 let mut out = Vec::with_capacity(HEADER_LEN + payload.len());
126 out.extend_from_slice(&SNAPSHOT_MAGIC);
127 out.extend_from_slice(&SNAPSHOT_FORMAT_VERSION.to_be_bytes());
128 out.extend_from_slice(&engine_bytes);
129 out.extend_from_slice(&crc.to_be_bytes());
130 out.extend_from_slice(payload);
131 out
132}
133
134pub fn decode_snapshot_chunk(
140 data: &[u8],
141) -> Result<(SnapshotEngineId, &[u8]), SnapshotFramingError> {
142 if data.len() < HEADER_LEN {
143 return Err(SnapshotFramingError::Truncated(data.len()));
144 }
145
146 let magic: [u8; 4] = [data[0], data[1], data[2], data[3]];
148 if magic != SNAPSHOT_MAGIC {
149 return Err(SnapshotFramingError::MagicMismatch(magic));
150 }
151
152 let version = u16::from_be_bytes([data[4], data[5]]);
154 if version != SNAPSHOT_FORMAT_VERSION {
155 return Err(SnapshotFramingError::VersionMismatch(version));
156 }
157
158 let engine_id_raw = u16::from_be_bytes([data[6], data[7]]);
160 let engine_id = SnapshotEngineId::from_u16(engine_id_raw)?;
161
162 let stored_crc = u32::from_be_bytes([data[8], data[9], data[10], data[11]]);
164 let payload = &data[HEADER_LEN..];
165 let computed_crc = {
166 let mut h = crc32c::crc32c(&data[6..8]); h = crc32c::crc32c_append(h, payload);
168 h
169 };
170 if stored_crc != computed_crc {
171 return Err(SnapshotFramingError::CrcMismatch {
172 stored: stored_crc,
173 computed: computed_crc,
174 });
175 }
176
177 Ok((engine_id, payload))
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 const ALL_ENGINES: &[SnapshotEngineId] = &[
185 SnapshotEngineId::Vector,
186 SnapshotEngineId::Graph,
187 SnapshotEngineId::DocumentSchemaless,
188 SnapshotEngineId::DocumentStrict,
189 SnapshotEngineId::Columnar,
190 SnapshotEngineId::KeyValue,
191 SnapshotEngineId::Fts,
192 SnapshotEngineId::Spatial,
193 SnapshotEngineId::Crdt,
194 ];
195
196 #[test]
197 fn roundtrip_all_engine_ids() {
198 for &engine_id in ALL_ENGINES {
199 let payload = b"test snapshot payload";
200 let framed = encode_snapshot_chunk(engine_id, payload);
201 let (decoded_id, decoded_payload) = decode_snapshot_chunk(&framed).unwrap();
202 assert_eq!(decoded_id, engine_id);
203 assert_eq!(decoded_payload, payload);
204 }
205 }
206
207 #[test]
208 fn roundtrip_empty_payload() {
209 let framed = encode_snapshot_chunk(SnapshotEngineId::KeyValue, &[]);
210 let (id, payload) = decode_snapshot_chunk(&framed).unwrap();
211 assert_eq!(id, SnapshotEngineId::KeyValue);
212 assert!(payload.is_empty());
213 }
214
215 #[test]
216 fn tamper_magic_returns_magic_mismatch() {
217 let mut framed = encode_snapshot_chunk(SnapshotEngineId::Vector, b"data");
218 framed[0] ^= 0xFF;
219 let err = decode_snapshot_chunk(&framed).unwrap_err();
220 assert!(
221 matches!(err, SnapshotFramingError::MagicMismatch(_)),
222 "{err}"
223 );
224 }
225
226 #[test]
227 fn tamper_version_returns_version_mismatch() {
228 let mut framed = encode_snapshot_chunk(SnapshotEngineId::Graph, b"data");
229 let bad_version = SNAPSHOT_FORMAT_VERSION.wrapping_add(1).to_be_bytes();
231 framed[4] = bad_version[0];
232 framed[5] = bad_version[1];
233 let err = decode_snapshot_chunk(&framed).unwrap_err();
234 assert!(
235 matches!(err, SnapshotFramingError::VersionMismatch(_)),
236 "{err}"
237 );
238 }
239
240 #[test]
241 fn tamper_crc_returns_crc_mismatch() {
242 let mut framed = encode_snapshot_chunk(SnapshotEngineId::Fts, b"important data");
243 framed[9] ^= 0x01;
245 let err = decode_snapshot_chunk(&framed).unwrap_err();
246 assert!(
247 matches!(err, SnapshotFramingError::CrcMismatch { .. }),
248 "{err}"
249 );
250 }
251
252 #[test]
253 fn reject_unknown_engine_id() {
254 let engine_id_raw: u16 = 99;
256 let engine_bytes = engine_id_raw.to_be_bytes();
257 let payload = b"payload";
258 let crc = {
259 let mut h = crc32c::crc32c(&engine_bytes);
260 h = crc32c::crc32c_append(h, payload);
261 h
262 };
263 let mut frame = Vec::new();
264 frame.extend_from_slice(&SNAPSHOT_MAGIC);
265 frame.extend_from_slice(&SNAPSHOT_FORMAT_VERSION.to_be_bytes());
266 frame.extend_from_slice(&engine_bytes);
267 frame.extend_from_slice(&crc.to_be_bytes());
268 frame.extend_from_slice(payload);
269
270 let err = decode_snapshot_chunk(&frame).unwrap_err();
271 assert!(
272 matches!(err, SnapshotFramingError::UnknownEngineId(99)),
273 "{err}"
274 );
275 }
276
277 #[test]
278 fn truncated_frame_returns_truncated_error() {
279 let framed = encode_snapshot_chunk(SnapshotEngineId::Crdt, b"data");
280 let err = decode_snapshot_chunk(&framed[..5]).unwrap_err();
282 assert!(matches!(err, SnapshotFramingError::Truncated(5)), "{err}");
283 }
284
285 #[test]
286 fn from_u16_roundtrip_all_discriminants() {
287 for &engine_id in ALL_ENGINES {
288 let raw = engine_id as u16;
289 let decoded = SnapshotEngineId::from_u16(raw).unwrap();
290 assert_eq!(decoded, engine_id);
291 }
292 }
293
294 #[test]
295 fn from_u16_unknown_returns_error() {
296 let err = SnapshotEngineId::from_u16(0).unwrap_err();
297 assert!(matches!(err, SnapshotFramingError::UnknownEngineId(0)));
298
299 let err = SnapshotEngineId::from_u16(255).unwrap_err();
300 assert!(matches!(err, SnapshotFramingError::UnknownEngineId(255)));
301 }
302
303 #[test]
306 fn golden_raft_snapshot_frame_format() {
307 let payload = b"golden-payload";
308 let framed = encode_snapshot_chunk(SnapshotEngineId::KeyValue, payload);
309
310 assert_eq!(&framed[0..4], b"NDSN", "magic mismatch");
312
313 let version = u16::from_be_bytes([framed[4], framed[5]]);
315 assert_eq!(version, SNAPSHOT_FORMAT_VERSION, "version mismatch");
316 assert_eq!(version, 1u16, "expected SNAPSHOT_FORMAT_VERSION == 1");
317
318 let engine_id_raw = u16::from_be_bytes([framed[6], framed[7]]);
320 assert_eq!(engine_id_raw, SnapshotEngineId::KeyValue as u16);
321
322 let stored_crc = u32::from_be_bytes([framed[8], framed[9], framed[10], framed[11]]);
324 let engine_bytes = (SnapshotEngineId::KeyValue as u16).to_be_bytes();
325 let mut h = crc32c::crc32c(&engine_bytes);
326 h = crc32c::crc32c_append(h, payload);
327 assert_eq!(stored_crc, h, "CRC mismatch");
328
329 assert_eq!(&framed[12..], payload);
331 }
332}