squib_snapshot/
envelope.rs1use std::io::{Read, Write};
15
16use semver::Version;
17use serde::{Serialize, de::DeserializeOwned};
18
19use crate::error::SnapshotError;
20
21pub const SNAPSHOT_MAGIC_AARCH64: u64 = 0x0710_1984_AAAA_0000;
24
25pub const SNAPSHOT_VERSION: Version = Version::new(5, 0, 0);
30
31pub const SNAPSHOT_DESERIALIZATION_BYTES_LIMIT: usize = 10_000_000;
35
36#[must_use]
38pub const fn arch_magic() -> u64 {
39 SNAPSHOT_MAGIC_AARCH64
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
45pub struct SnapshotHdr {
46 pub magic: u64,
48 pub version: Version,
50}
51
52#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
54pub struct Snapshot<Data> {
55 pub header: SnapshotHdr,
57 pub data: Data,
59}
60
61impl<Data> Snapshot<Data> {
62 #[must_use]
65 pub fn new(data: Data) -> Self {
66 Self {
67 header: SnapshotHdr {
68 magic: arch_magic(),
69 version: SNAPSHOT_VERSION,
70 },
71 data,
72 }
73 }
74
75 #[must_use]
77 pub fn version(&self) -> &Version {
78 &self.header.version
79 }
80}
81
82impl<Data: Serialize> Snapshot<Data> {
83 pub fn save<W: Write>(&self, writer: &mut W) -> Result<(), SnapshotError> {
93 let mut crc_writer = Crc64Writer::new(writer);
94 let encoded =
95 bitcode::serialize(self).map_err(|e| SnapshotError::Bitcode(e.to_string()))?;
96 crc_writer.write_all(&encoded)?;
97 let crc = crc_writer.checksum();
98 crc_writer.into_inner().write_all(&crc.to_le_bytes())?;
99 Ok(())
100 }
101}
102
103impl<Data: DeserializeOwned> Snapshot<Data> {
104 pub fn load<R: Read>(reader: &mut R) -> Result<Self, SnapshotError> {
123 let buf = read_with_limit(reader, SNAPSHOT_DESERIALIZATION_BYTES_LIMIT)?;
124 Self::load_from_slice(&buf)
125 }
126
127 pub fn load_from_slice(buf: &[u8]) -> Result<Self, SnapshotError> {
133 if buf.len() > SNAPSHOT_DESERIALIZATION_BYTES_LIMIT {
134 return Err(SnapshotError::SizeLimitExceeded {
135 limit: SNAPSHOT_DESERIALIZATION_BYTES_LIMIT,
136 });
137 }
138 if buf.len() < 8 {
139 return Err(SnapshotError::TooShort);
140 }
141 if crc64::crc64(0, buf) != 0 {
145 return Err(SnapshotError::CrcMismatch);
146 }
147 let (data_buf, _crc_buf) = buf.split_at(buf.len() - 8);
148 Self::load_without_crc_check(data_buf)
149 }
150
151 pub fn load_without_crc_check(data_buf: &[u8]) -> Result<Self, SnapshotError> {
162 if data_buf.len() > SNAPSHOT_DESERIALIZATION_BYTES_LIMIT {
163 return Err(SnapshotError::SizeLimitExceeded {
164 limit: SNAPSHOT_DESERIALIZATION_BYTES_LIMIT,
165 });
166 }
167 let snapshot: Self =
168 bitcode::deserialize(data_buf).map_err(|e| SnapshotError::Bitcode(e.to_string()))?;
169 if snapshot.header.magic != arch_magic() {
170 return Err(SnapshotError::MagicMismatch {
171 found: snapshot.header.magic,
172 expected: arch_magic(),
173 });
174 }
175 if snapshot.header.version.major != SNAPSHOT_VERSION.major
176 || snapshot.header.version.minor > SNAPSHOT_VERSION.minor
177 {
178 return Err(SnapshotError::VersionMismatch {
179 found: snapshot.header.version.clone(),
180 expected: SNAPSHOT_VERSION,
181 });
182 }
183 Ok(snapshot)
184 }
185}
186
187fn read_with_limit<R: Read>(reader: &mut R, limit: usize) -> Result<Vec<u8>, SnapshotError> {
193 let mut buf = Vec::new();
194 let read_cap = u64::try_from(limit.saturating_add(1)).unwrap_or(u64::MAX);
195 let bytes = reader.take(read_cap).read_to_end(&mut buf)?;
196 if bytes > limit {
197 return Err(SnapshotError::SizeLimitExceeded { limit });
198 }
199 Ok(buf)
200}
201
202#[derive(Debug)]
206pub struct Crc64Writer<W> {
207 writer: W,
208 crc: u64,
209}
210
211impl<W: Write> Crc64Writer<W> {
212 pub fn new(writer: W) -> Self {
214 Self { writer, crc: 0 }
215 }
216
217 #[must_use]
219 pub fn checksum(&self) -> u64 {
220 self.crc
221 }
222
223 pub fn into_inner(self) -> W {
226 self.writer
227 }
228}
229
230impl<W: Write> Write for Crc64Writer<W> {
231 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
232 let written = self.writer.write(buf)?;
233 self.crc = crc64::crc64(self.crc, &buf[..written]);
234 Ok(written)
235 }
236
237 fn flush(&mut self) -> std::io::Result<()> {
238 self.writer.flush()
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use std::io::Cursor;
245
246 use super::*;
247 use crate::state::MicrovmState;
248
249 #[test]
250 fn test_should_round_trip_default_state_through_save_and_load() {
251 let snapshot = Snapshot::new(MicrovmState::default());
252 let mut buf = Vec::new();
253 snapshot.save(&mut buf).unwrap();
254 let back = Snapshot::<MicrovmState>::load(&mut Cursor::new(&buf)).unwrap();
255 assert_eq!(snapshot.header, back.header);
256 }
257
258 #[test]
259 fn test_should_reject_truncated_crc_trailer() {
260 let snapshot = Snapshot::new(MicrovmState::default());
261 let mut buf = Vec::new();
262 snapshot.save(&mut buf).unwrap();
263 let truncated = &buf[..buf.len() - 4];
264 assert!(matches!(
265 Snapshot::<MicrovmState>::load_from_slice(truncated),
266 Err(SnapshotError::CrcMismatch)
267 ));
268 }
269
270 #[test]
271 fn test_should_reject_too_short_buffer() {
272 assert!(matches!(
273 Snapshot::<MicrovmState>::load_from_slice(&[]),
274 Err(SnapshotError::TooShort)
275 ));
276 assert!(matches!(
277 Snapshot::<MicrovmState>::load_from_slice(&[0u8; 4]),
278 Err(SnapshotError::TooShort)
279 ));
280 }
281
282 #[test]
283 fn test_should_reject_bit_flipped_state() {
284 let snapshot = Snapshot::new(MicrovmState::default());
285 let mut buf = Vec::new();
286 snapshot.save(&mut buf).unwrap();
287 buf[2] ^= 0x01;
289 assert!(matches!(
290 Snapshot::<MicrovmState>::load_from_slice(&buf),
291 Err(SnapshotError::CrcMismatch)
292 ));
293 }
294
295 #[test]
296 fn test_should_reject_clobbered_crc_trailer() {
297 let snapshot = Snapshot::new(MicrovmState::default());
298 let mut buf = Vec::new();
299 snapshot.save(&mut buf).unwrap();
300 let len = buf.len();
301 for byte in &mut buf[len - 8..] {
302 *byte ^= 0xFF;
303 }
304 assert!(matches!(
305 Snapshot::<MicrovmState>::load_from_slice(&buf),
306 Err(SnapshotError::CrcMismatch)
307 ));
308 }
309
310 #[test]
311 fn test_should_reject_wrong_magic_via_load_without_crc_check() {
312 let mut snapshot = Snapshot::new(MicrovmState::default());
313 snapshot.header.magic = 0xDEAD_BEEF;
314 let body = bitcode::serialize(&snapshot).unwrap();
315 assert!(matches!(
316 Snapshot::<MicrovmState>::load_without_crc_check(&body),
317 Err(SnapshotError::MagicMismatch { .. })
318 ));
319 }
320
321 #[test]
322 fn test_should_reject_higher_major_version() {
323 let mut snapshot = Snapshot::new(MicrovmState::default());
324 snapshot.header.version =
325 Version::new(SNAPSHOT_VERSION.major + 1, SNAPSHOT_VERSION.minor, 0);
326 let body = bitcode::serialize(&snapshot).unwrap();
327 assert!(matches!(
328 Snapshot::<MicrovmState>::load_without_crc_check(&body),
329 Err(SnapshotError::VersionMismatch { .. })
330 ));
331 }
332
333 #[test]
334 fn test_should_reject_higher_minor_version() {
335 let mut snapshot = Snapshot::new(MicrovmState::default());
336 snapshot.header.version =
337 Version::new(SNAPSHOT_VERSION.major, SNAPSHOT_VERSION.minor + 1, 0);
338 let body = bitcode::serialize(&snapshot).unwrap();
339 assert!(matches!(
340 Snapshot::<MicrovmState>::load_without_crc_check(&body),
341 Err(SnapshotError::VersionMismatch { .. })
342 ));
343 }
344
345 #[test]
346 fn test_should_accept_lower_minor_version() {
347 if SNAPSHOT_VERSION.minor == 0 {
349 return; }
351 let mut snapshot = Snapshot::new(MicrovmState::default());
352 snapshot.header.version =
353 Version::new(SNAPSHOT_VERSION.major, SNAPSHOT_VERSION.minor - 1, 0);
354 let body = bitcode::serialize(&snapshot).unwrap();
355 let _ok = Snapshot::<MicrovmState>::load_without_crc_check(&body).unwrap();
356 }
357
358 #[test]
359 fn test_should_accept_arbitrary_patch_version() {
360 let mut snapshot = Snapshot::new(MicrovmState::default());
361 snapshot.header.version = Version::new(
362 SNAPSHOT_VERSION.major,
363 SNAPSHOT_VERSION.minor,
364 SNAPSHOT_VERSION.patch + 12345,
365 );
366 let body = bitcode::serialize(&snapshot).unwrap();
367 let _ok = Snapshot::<MicrovmState>::load_without_crc_check(&body).unwrap();
368 }
369
370 #[test]
371 fn test_should_enforce_size_limit_on_load() {
372 let huge = vec![0u8; SNAPSHOT_DESERIALIZATION_BYTES_LIMIT + 32];
373 assert!(matches!(
374 Snapshot::<MicrovmState>::load_from_slice(&huge),
375 Err(SnapshotError::SizeLimitExceeded { .. })
376 ));
377 }
378
379 #[test]
380 fn test_should_keep_arch_magic_aarch64_constant() {
381 assert_eq!(arch_magic(), 0x0710_1984_AAAA_0000);
382 }
383}