#![deny(missing_docs)]
mod persist;
use std::fmt::Debug;
use std::io::{Read, Write};
use versionize::crc::{CRC64Reader, CRC64Writer};
use versionize::{VersionMap, Versionize, VersionizeResult};
use versionize_derive::Versionize;
pub use crate::persist::Persist;
const BASE_MAGIC_ID_MASK: u64 = !0xFFFFu64;
#[cfg(target_arch = "x86_64")]
const BASE_MAGIC_ID: u64 = 0x0710_1984_8664_0000u64;
#[cfg(target_arch = "aarch64")]
const BASE_MAGIC_ID: u64 = 0x0710_1984_AAAA_0000u64;
#[cfg(target_arch = "powerpc64")]
const BASE_MAGIC_ID: u64 = 0x0710_1984_CC64_0000u64;
#[cfg(target_arch = "riscv64")]
const BASE_MAGIC_ID: u64 = 0x0710_1984_C564_0000u64;
#[derive(Debug, thiserror::Error, displaydoc::Display, PartialEq)]
pub enum Error {
Crc64(u64),
InvalidDataVersion(u16),
InvalidFormatVersion(u16),
InvalidMagic(u64),
InvalidSnapshotSize,
Io(i32),
Versionize(versionize::VersionizeError),
}
#[derive(Default, Debug, Versionize)]
struct SnapshotHdr {
data_version: u16,
}
#[derive(Debug)]
pub struct Snapshot {
hdr: SnapshotHdr,
version_map: VersionMap,
target_version: u16,
}
fn get_format_version(magic_id: u64) -> Result<u16, Error> {
let magic_arch = magic_id & BASE_MAGIC_ID_MASK;
if magic_arch == BASE_MAGIC_ID {
return Ok((magic_id & !BASE_MAGIC_ID_MASK) as u16);
}
Err(Error::InvalidMagic(magic_id))
}
fn build_magic_id(format_version: u16) -> u64 {
BASE_MAGIC_ID | u64::from(format_version)
}
impl Snapshot {
pub fn new(version_map: VersionMap, target_version: u16) -> Snapshot {
Snapshot {
version_map,
hdr: SnapshotHdr::default(),
target_version,
}
}
pub fn get_data_version<T>(mut reader: &mut T, version_map: &VersionMap) -> Result<u16, Error>
where
T: Read + Debug,
{
let format_version_map = Self::format_version_map();
let magic_id =
<u64 as Versionize>::deserialize(&mut reader, &format_version_map, 0 )
.map_err(Error::Versionize)?;
let format_version = get_format_version(magic_id)?;
if format_version > format_version_map.latest_version() || format_version == 0 {
return Err(Error::InvalidFormatVersion(format_version));
}
let hdr: SnapshotHdr =
SnapshotHdr::deserialize(&mut reader, &format_version_map, format_version)
.map_err(Error::Versionize)?;
if hdr.data_version > version_map.latest_version() || hdr.data_version == 0 {
return Err(Error::InvalidDataVersion(hdr.data_version));
}
Ok(hdr.data_version)
}
pub fn unchecked_load<T: Read + Debug, O: Versionize + Debug>(
mut reader: &mut T,
version_map: VersionMap,
) -> Result<(O, u16), Error> {
let data_version = Self::get_data_version(&mut reader, &version_map)?;
let res =
O::deserialize(&mut reader, &version_map, data_version).map_err(Error::Versionize)?;
Ok((res, data_version))
}
pub fn load<T: Read + Debug, O: Versionize + Debug>(
reader: &mut T,
snapshot_len: usize,
version_map: VersionMap,
) -> Result<(O, u16), Error> {
let mut crc_reader = CRC64Reader::new(reader);
let raw_snapshot_len = snapshot_len
.checked_sub(std::mem::size_of::<u64>())
.ok_or(Error::InvalidSnapshotSize)?;
let mut snapshot = vec![0u8; raw_snapshot_len];
crc_reader
.read_exact(&mut snapshot)
.map_err(|ref err| Error::Io(err.raw_os_error().unwrap_or(libc::EINVAL)))?;
let computed_checksum = crc_reader.checksum();
let format_vm = Self::format_version_map();
let stored_checksum: u64 =
Versionize::deserialize(&mut crc_reader, &format_vm, 0).map_err(Error::Versionize)?;
if computed_checksum != stored_checksum {
return Err(Error::Crc64(computed_checksum));
}
let mut snapshot_slice: &[u8] = snapshot.as_mut_slice();
Snapshot::unchecked_load::<_, O>(&mut snapshot_slice, version_map)
}
pub fn save<T, O>(&mut self, writer: &mut T, object: &O) -> Result<(), Error>
where
T: Write + Debug,
O: Versionize + Debug,
{
let mut crc_writer = CRC64Writer::new(writer);
self.save_without_crc(&mut crc_writer, object)?;
let checksum = crc_writer.checksum();
checksum
.serialize(&mut crc_writer, &Self::format_version_map(), 0)
.map_err(Error::Versionize)?;
Ok(())
}
pub fn save_without_crc<T, O>(&mut self, mut writer: &mut T, object: &O) -> Result<(), Error>
where
T: Write,
O: Versionize + Debug,
{
self.hdr = SnapshotHdr {
data_version: self.target_version,
};
let format_version_map = Self::format_version_map();
let magic_id = build_magic_id(format_version_map.latest_version());
magic_id
.serialize(&mut writer, &format_version_map, 0 )
.map_err(Error::Versionize)?;
self.hdr
.serialize(
&mut writer,
&format_version_map,
format_version_map.latest_version(),
)
.map_err(Error::Versionize)?;
object
.serialize(&mut writer, &self.version_map, self.target_version)
.map_err(Error::Versionize)?;
writer
.flush()
.map_err(|ref err| Error::Io(err.raw_os_error().unwrap_or(libc::EINVAL)))
}
fn format_version_map() -> VersionMap {
VersionMap::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Debug, Versionize)]
pub struct Test1 {
field_x: u64,
field0: u64,
field1: u32,
}
#[derive(Clone, Debug, Versionize)]
pub struct Test {
field_x: u64,
field0: u64,
field1: u32,
#[version(start = 2, default_fn = "field2_default")]
field2: u64,
#[version(
start = 3,
default_fn = "field3_default",
ser_fn = "field3_serialize",
de_fn = "field3_deserialize"
)]
field3: String,
#[version(
start = 4,
default_fn = "field4_default",
ser_fn = "field4_serialize",
de_fn = "field4_deserialize"
)]
field4: Vec<u64>,
}
impl Test {
fn field2_default(_: u16) -> u64 {
20
}
fn field3_default(_: u16) -> String {
"default".to_owned()
}
fn field4_default(_: u16) -> Vec<u64> {
vec![1, 2, 3, 4]
}
fn field4_serialize(&mut self, target_version: u16) -> VersionizeResult<()> {
assert_ne!(target_version, Test::version());
self.field0 = self.field4.iter().sum();
if self.field0 == 6666 {
return Err(versionize::VersionizeError::Semantic(
"field4 element sum is 6666".to_owned(),
));
}
Ok(())
}
fn field4_deserialize(&mut self, source_version: u16) -> VersionizeResult<()> {
assert_ne!(source_version, Test::version());
self.field4 = vec![self.field0; 4];
Ok(())
}
fn field3_serialize(&mut self, target_version: u16) -> VersionizeResult<()> {
assert!(target_version < 3);
self.field_x += 1;
Ok(())
}
fn field3_deserialize(&mut self, source_version: u16) -> VersionizeResult<()> {
assert!(source_version < 3);
self.field_x += 1;
if self.field0 == 7777 {
return Err(versionize::VersionizeError::Semantic(
"field0 is 7777".to_owned(),
));
}
Ok(())
}
}
#[test]
fn test_get_format_version() {
#[cfg(target_arch = "x86_64")]
let good_magic_id = 0x0710_1984_8664_0001u64;
#[cfg(target_arch = "aarch64")]
let good_magic_id = 0x0710_1984_AAAA_0001u64;
#[cfg(target_arch = "powerpc64")]
let good_magic_id = 0x0710_1984_CC64_0001u64;
#[cfg(target_arch = "riscv64")]
let good_magic_id = 0x0710_1984_C564_0001u64;
assert_eq!(get_format_version(good_magic_id).unwrap(), 1u16);
let invalid_magic_id = good_magic_id | (1u64 << 63);
assert_eq!(
get_format_version(invalid_magic_id).unwrap_err(),
Error::InvalidMagic(invalid_magic_id)
);
}
#[test]
fn test_struct_semantic_fn() {
let mut vm = VersionMap::new();
vm.new_version()
.set_type_version(Test::type_id(), 2)
.new_version()
.set_type_version(Test::type_id(), 3)
.new_version()
.set_type_version(Test::type_id(), 4);
let state = Test {
field0: 0,
field1: 1,
field2: 2,
field3: "test".to_owned(),
field4: vec![4, 3, 2, 1],
field_x: 0,
};
let mut snapshot_mem = vec![0u8; 1024];
let mut snapshot = Snapshot::new(vm.clone(), 1);
snapshot
.save_without_crc(&mut snapshot_mem.as_mut_slice(), &state)
.unwrap();
let (mut restored_state, _) =
Snapshot::unchecked_load::<_, Test>(&mut snapshot_mem.as_slice(), vm.clone()).unwrap();
assert_eq!(restored_state.field0, state.field4.iter().sum::<u64>());
assert_eq!(restored_state.field4, vec![restored_state.field0; 4]);
assert_eq!(restored_state.field_x, 2);
assert_eq!(restored_state.field1, 1);
assert_eq!(restored_state.field2, 20);
let mut snapshot = Snapshot::new(vm.clone(), 3);
snapshot
.save_without_crc(&mut snapshot_mem.as_mut_slice(), &state)
.unwrap();
(restored_state, _) =
Snapshot::unchecked_load::<_, Test>(&mut snapshot_mem.as_slice(), vm.clone()).unwrap();
assert_eq!(restored_state.field0, state.field4.iter().sum::<u64>());
assert_eq!(restored_state.field4, vec![restored_state.field0; 4]);
assert_eq!(restored_state.field_x, 0);
snapshot = Snapshot::new(vm.clone(), 4);
snapshot
.save_without_crc(&mut snapshot_mem.as_mut_slice(), &state)
.unwrap();
(restored_state, _) =
Snapshot::unchecked_load::<_, Test>(&mut snapshot_mem.as_slice(), vm.clone()).unwrap();
assert_eq!(restored_state.field0, 0);
assert_eq!(restored_state.field4, vec![4, 3, 2, 1]);
snapshot_mem.truncate(10);
let restored_state_result: Result<(Test, _), Error> =
Snapshot::unchecked_load(&mut snapshot_mem.as_slice(), vm);
assert_eq!(
restored_state_result.unwrap_err(),
Error::Versionize(versionize::VersionizeError::Deserialize(String::from(
"Io(Error { kind: UnexpectedEof, message: \"failed to fill whole buffer\" })"
)))
);
}
#[test]
fn test_struct_default_fn() {
let mut vm = VersionMap::new();
vm.new_version()
.set_type_version(Test::type_id(), 2)
.new_version()
.set_type_version(Test::type_id(), 3)
.new_version()
.set_type_version(Test::type_id(), 4);
let state = Test {
field0: 0,
field1: 1,
field2: 2,
field3: "test".to_owned(),
field4: vec![4, 3, 2, 1],
field_x: 0,
};
let state_1 = Test1 {
field_x: 0,
field0: 0,
field1: 1,
};
let mut snapshot_mem = vec![0u8; 1024];
let mut snapshot = Snapshot::new(vm.clone(), 1);
snapshot
.save_without_crc(&mut snapshot_mem.as_mut_slice(), &state_1)
.unwrap();
let (mut restored_state, _) =
Snapshot::unchecked_load::<_, Test>(&mut snapshot_mem.as_slice(), vm.clone()).unwrap();
assert_eq!(restored_state.field1, state_1.field1);
assert_eq!(restored_state.field2, 20);
assert_eq!(restored_state.field3, "default");
snapshot = Snapshot::new(vm.clone(), 2);
snapshot
.save_without_crc(&mut snapshot_mem.as_mut_slice(), &state)
.unwrap();
(restored_state, _) =
Snapshot::unchecked_load::<_, Test>(&mut snapshot_mem.as_slice(), vm.clone()).unwrap();
assert_eq!(restored_state.field1, state.field1);
assert_eq!(restored_state.field2, 2);
assert_eq!(restored_state.field3, "default");
snapshot = Snapshot::new(vm.clone(), 3);
snapshot
.save_without_crc(&mut snapshot_mem.as_mut_slice(), &state)
.unwrap();
(restored_state, _) =
Snapshot::unchecked_load::<_, Test>(&mut snapshot_mem.as_slice(), vm.clone()).unwrap();
assert_eq!(restored_state.field1, state.field1);
assert_eq!(restored_state.field2, 2);
assert_eq!(restored_state.field3, "test");
snapshot = Snapshot::new(vm.clone(), 4);
snapshot
.save_without_crc(&mut snapshot_mem.as_mut_slice(), &state)
.unwrap();
(restored_state, _) =
Snapshot::unchecked_load::<_, Test>(&mut snapshot_mem.as_slice(), vm.clone()).unwrap();
assert_eq!(restored_state.field1, state.field1);
assert_eq!(restored_state.field2, 2);
assert_eq!(restored_state.field3, "test");
}
#[test]
fn test_crc_ok() {
let vm = VersionMap::new();
let state_1 = Test1 {
field_x: 0,
field0: 0,
field1: 1,
};
let mut snapshot_mem = vec![0u8; 1024];
let mut snapshot = Snapshot::new(vm.clone(), 1);
snapshot
.save(&mut snapshot_mem.as_mut_slice(), &state_1)
.unwrap();
let _ = Snapshot::load::<_, Test1>(&mut snapshot_mem.as_slice(), 38, vm).unwrap();
}
#[test]
fn test_invalid_snapshot_size() {
let vm = VersionMap::new();
let snapshot_mem = vec![0u8; 4];
let expected_err = Error::InvalidSnapshotSize;
let load_result: Result<(Test1, _), Error> =
Snapshot::load(&mut snapshot_mem.as_slice(), 4, vm);
assert_eq!(load_result.unwrap_err(), expected_err);
}
#[test]
fn test_corrupted_snapshot() {
let vm = VersionMap::new();
let state_1 = Test1 {
field_x: 0,
field0: 0,
field1: 1,
};
let mut snapshot_mem = vec![0u8; 1024];
let mut snapshot = Snapshot::new(vm.clone(), 1);
snapshot
.save(&mut snapshot_mem.as_mut_slice(), &state_1)
.unwrap();
snapshot_mem[20] = 123;
#[cfg(target_arch = "aarch64")]
let expected_err = Error::Crc64(0x1960_4E6A_A13F_6615);
#[cfg(target_arch = "x86_64")]
let expected_err = Error::Crc64(0x103F_8F52_8F51_20B1);
#[cfg(target_arch = "powerpc64")]
let expected_err = Error::Crc64(0x33D0_CCE5_DA3C_CCEA);
#[cfg(target_arch = "riscv64")]
let expected_err = Error::Crc64(0xFAC5_E225_5586_9011);
let load_result: Result<(Test1, _), Error> =
Snapshot::load(&mut snapshot_mem.as_slice(), 38, vm);
assert_eq!(load_result.unwrap_err(), expected_err);
}
#[allow(non_upper_case_globals)]
#[allow(non_camel_case_types)]
#[allow(non_snake_case)]
#[test]
fn test_kvm_bindings_struct() {
#[repr(C)]
#[derive(Debug, PartialEq, Eq, Versionize)]
pub struct kvm_pit_config {
pub flags: ::std::os::raw::c_uint,
pub pad: [::std::os::raw::c_uint; 15usize],
}
let state = kvm_pit_config {
flags: 123_456,
pad: [0; 15usize],
};
let vm = VersionMap::new();
let mut snapshot_mem = vec![0u8; 1024];
let mut snapshot = Snapshot::new(vm.clone(), 1);
snapshot
.save_without_crc(&mut snapshot_mem.as_mut_slice(), &state)
.unwrap();
let (restored_state, _) =
Snapshot::unchecked_load::<_, kvm_pit_config>(&mut snapshot_mem.as_slice(), vm)
.unwrap();
assert_eq!(restored_state, state);
}
}