use std::io::{Read, Write};
use semver::Version;
use serde::{Serialize, de::DeserializeOwned};
use crate::error::SnapshotError;
pub const SNAPSHOT_MAGIC_AARCH64: u64 = 0x0710_1984_AAAA_0000;
pub const SNAPSHOT_VERSION: Version = Version::new(5, 0, 0);
pub const SNAPSHOT_DESERIALIZATION_BYTES_LIMIT: usize = 10_000_000;
#[must_use]
pub const fn arch_magic() -> u64 {
SNAPSHOT_MAGIC_AARCH64
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct SnapshotHdr {
pub magic: u64,
pub version: Version,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct Snapshot<Data> {
pub header: SnapshotHdr,
pub data: Data,
}
impl<Data> Snapshot<Data> {
#[must_use]
pub fn new(data: Data) -> Self {
Self {
header: SnapshotHdr {
magic: arch_magic(),
version: SNAPSHOT_VERSION,
},
data,
}
}
#[must_use]
pub fn version(&self) -> &Version {
&self.header.version
}
}
impl<Data: Serialize> Snapshot<Data> {
pub fn save<W: Write>(&self, writer: &mut W) -> Result<(), SnapshotError> {
let mut crc_writer = Crc64Writer::new(writer);
let encoded =
bitcode::serialize(self).map_err(|e| SnapshotError::Bitcode(e.to_string()))?;
crc_writer.write_all(&encoded)?;
let crc = crc_writer.checksum();
crc_writer.into_inner().write_all(&crc.to_le_bytes())?;
Ok(())
}
}
impl<Data: DeserializeOwned> Snapshot<Data> {
pub fn load<R: Read>(reader: &mut R) -> Result<Self, SnapshotError> {
let buf = read_with_limit(reader, SNAPSHOT_DESERIALIZATION_BYTES_LIMIT)?;
Self::load_from_slice(&buf)
}
pub fn load_from_slice(buf: &[u8]) -> Result<Self, SnapshotError> {
if buf.len() > SNAPSHOT_DESERIALIZATION_BYTES_LIMIT {
return Err(SnapshotError::SizeLimitExceeded {
limit: SNAPSHOT_DESERIALIZATION_BYTES_LIMIT,
});
}
if buf.len() < 8 {
return Err(SnapshotError::TooShort);
}
if crc64::crc64(0, buf) != 0 {
return Err(SnapshotError::CrcMismatch);
}
let (data_buf, _crc_buf) = buf.split_at(buf.len() - 8);
Self::load_without_crc_check(data_buf)
}
pub fn load_without_crc_check(data_buf: &[u8]) -> Result<Self, SnapshotError> {
if data_buf.len() > SNAPSHOT_DESERIALIZATION_BYTES_LIMIT {
return Err(SnapshotError::SizeLimitExceeded {
limit: SNAPSHOT_DESERIALIZATION_BYTES_LIMIT,
});
}
let snapshot: Self =
bitcode::deserialize(data_buf).map_err(|e| SnapshotError::Bitcode(e.to_string()))?;
if snapshot.header.magic != arch_magic() {
return Err(SnapshotError::MagicMismatch {
found: snapshot.header.magic,
expected: arch_magic(),
});
}
if snapshot.header.version.major != SNAPSHOT_VERSION.major
|| snapshot.header.version.minor > SNAPSHOT_VERSION.minor
{
return Err(SnapshotError::VersionMismatch {
found: snapshot.header.version.clone(),
expected: SNAPSHOT_VERSION,
});
}
Ok(snapshot)
}
}
fn read_with_limit<R: Read>(reader: &mut R, limit: usize) -> Result<Vec<u8>, SnapshotError> {
let mut buf = Vec::new();
let read_cap = u64::try_from(limit.saturating_add(1)).unwrap_or(u64::MAX);
let bytes = reader.take(read_cap).read_to_end(&mut buf)?;
if bytes > limit {
return Err(SnapshotError::SizeLimitExceeded { limit });
}
Ok(buf)
}
#[derive(Debug)]
pub struct Crc64Writer<W> {
writer: W,
crc: u64,
}
impl<W: Write> Crc64Writer<W> {
pub fn new(writer: W) -> Self {
Self { writer, crc: 0 }
}
#[must_use]
pub fn checksum(&self) -> u64 {
self.crc
}
pub fn into_inner(self) -> W {
self.writer
}
}
impl<W: Write> Write for Crc64Writer<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let written = self.writer.write(buf)?;
self.crc = crc64::crc64(self.crc, &buf[..written]);
Ok(written)
}
fn flush(&mut self) -> std::io::Result<()> {
self.writer.flush()
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use super::*;
use crate::state::MicrovmState;
#[test]
fn test_should_round_trip_default_state_through_save_and_load() {
let snapshot = Snapshot::new(MicrovmState::default());
let mut buf = Vec::new();
snapshot.save(&mut buf).unwrap();
let back = Snapshot::<MicrovmState>::load(&mut Cursor::new(&buf)).unwrap();
assert_eq!(snapshot.header, back.header);
}
#[test]
fn test_should_reject_truncated_crc_trailer() {
let snapshot = Snapshot::new(MicrovmState::default());
let mut buf = Vec::new();
snapshot.save(&mut buf).unwrap();
let truncated = &buf[..buf.len() - 4];
assert!(matches!(
Snapshot::<MicrovmState>::load_from_slice(truncated),
Err(SnapshotError::CrcMismatch)
));
}
#[test]
fn test_should_reject_too_short_buffer() {
assert!(matches!(
Snapshot::<MicrovmState>::load_from_slice(&[]),
Err(SnapshotError::TooShort)
));
assert!(matches!(
Snapshot::<MicrovmState>::load_from_slice(&[0u8; 4]),
Err(SnapshotError::TooShort)
));
}
#[test]
fn test_should_reject_bit_flipped_state() {
let snapshot = Snapshot::new(MicrovmState::default());
let mut buf = Vec::new();
snapshot.save(&mut buf).unwrap();
buf[2] ^= 0x01;
assert!(matches!(
Snapshot::<MicrovmState>::load_from_slice(&buf),
Err(SnapshotError::CrcMismatch)
));
}
#[test]
fn test_should_reject_clobbered_crc_trailer() {
let snapshot = Snapshot::new(MicrovmState::default());
let mut buf = Vec::new();
snapshot.save(&mut buf).unwrap();
let len = buf.len();
for byte in &mut buf[len - 8..] {
*byte ^= 0xFF;
}
assert!(matches!(
Snapshot::<MicrovmState>::load_from_slice(&buf),
Err(SnapshotError::CrcMismatch)
));
}
#[test]
fn test_should_reject_wrong_magic_via_load_without_crc_check() {
let mut snapshot = Snapshot::new(MicrovmState::default());
snapshot.header.magic = 0xDEAD_BEEF;
let body = bitcode::serialize(&snapshot).unwrap();
assert!(matches!(
Snapshot::<MicrovmState>::load_without_crc_check(&body),
Err(SnapshotError::MagicMismatch { .. })
));
}
#[test]
fn test_should_reject_higher_major_version() {
let mut snapshot = Snapshot::new(MicrovmState::default());
snapshot.header.version =
Version::new(SNAPSHOT_VERSION.major + 1, SNAPSHOT_VERSION.minor, 0);
let body = bitcode::serialize(&snapshot).unwrap();
assert!(matches!(
Snapshot::<MicrovmState>::load_without_crc_check(&body),
Err(SnapshotError::VersionMismatch { .. })
));
}
#[test]
fn test_should_reject_higher_minor_version() {
let mut snapshot = Snapshot::new(MicrovmState::default());
snapshot.header.version =
Version::new(SNAPSHOT_VERSION.major, SNAPSHOT_VERSION.minor + 1, 0);
let body = bitcode::serialize(&snapshot).unwrap();
assert!(matches!(
Snapshot::<MicrovmState>::load_without_crc_check(&body),
Err(SnapshotError::VersionMismatch { .. })
));
}
#[test]
fn test_should_accept_lower_minor_version() {
if SNAPSHOT_VERSION.minor == 0 {
return; }
let mut snapshot = Snapshot::new(MicrovmState::default());
snapshot.header.version =
Version::new(SNAPSHOT_VERSION.major, SNAPSHOT_VERSION.minor - 1, 0);
let body = bitcode::serialize(&snapshot).unwrap();
let _ok = Snapshot::<MicrovmState>::load_without_crc_check(&body).unwrap();
}
#[test]
fn test_should_accept_arbitrary_patch_version() {
let mut snapshot = Snapshot::new(MicrovmState::default());
snapshot.header.version = Version::new(
SNAPSHOT_VERSION.major,
SNAPSHOT_VERSION.minor,
SNAPSHOT_VERSION.patch + 12345,
);
let body = bitcode::serialize(&snapshot).unwrap();
let _ok = Snapshot::<MicrovmState>::load_without_crc_check(&body).unwrap();
}
#[test]
fn test_should_enforce_size_limit_on_load() {
let huge = vec![0u8; SNAPSHOT_DESERIALIZATION_BYTES_LIMIT + 32];
assert!(matches!(
Snapshot::<MicrovmState>::load_from_slice(&huge),
Err(SnapshotError::SizeLimitExceeded { .. })
));
}
#[test]
fn test_should_keep_arch_magic_aarch64_constant() {
assert_eq!(arch_magic(), 0x0710_1984_AAAA_0000);
}
}