use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use crate::error::SnapshotError;
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct GpRegs {
pub x: [u64; 31],
pub sp: u64,
pub pc: u64,
pub pstate: u64,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct FpSimdRegs {
pub v: [[u64; 2]; 32],
pub fpsr: u64,
pub fpcr: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum PsciVcpuState {
On,
#[default]
Off,
OnPending,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct VcpuState {
pub mpidr: u64,
pub regs: GpRegs,
pub fp_regs: FpSimdRegs,
pub sys_regs: BTreeMap<u64, u64>,
pub psci_state: PsciVcpuState,
}
impl VcpuState {
#[must_use]
pub fn new(mpidr: u64) -> Self {
Self {
mpidr,
regs: GpRegs::default(),
fp_regs: FpSimdRegs::default(),
sys_regs: BTreeMap::new(),
psci_state: PsciVcpuState::Off,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct GicState {
pub len: u64,
pub bytes: Vec<u8>,
}
impl GicState {
#[must_use]
pub fn from_bytes(bytes: Vec<u8>) -> Self {
let len = u64::try_from(bytes.len()).unwrap_or(u64::MAX);
Self { len, bytes }
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct MmdsState {
pub data_json: String,
pub token_ttl_seconds: Option<u32>,
}
impl MmdsState {
pub fn with_data(
value: &serde_json::Value,
token_ttl_seconds: Option<u32>,
) -> Result<Self, serde_json::Error> {
Ok(Self {
data_json: serde_json::to_string(value)?,
token_ttl_seconds,
})
}
pub fn data_value(&self) -> Result<serde_json::Value, serde_json::Error> {
if self.data_json.is_empty() {
return Ok(serde_json::Value::Null);
}
serde_json::from_str(&self.data_json)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct VmInfo {
pub mem_size_mib: u64,
pub smt: bool,
pub cpu_template: String,
pub kernel_image_path: String,
pub initrd_path: Option<String>,
pub boot_args: String,
pub track_dirty_pages: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DeviceState {
pub kind: String,
pub id: String,
pub mmio_slot: u32,
pub blob: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct DeviceStates {
pub devices: Vec<DeviceState>,
}
impl DeviceStates {
pub fn from_devices<I: IntoIterator<Item = DeviceState>>(iter: I) -> Self {
let mut devices: Vec<_> = iter.into_iter().collect();
devices.sort_by(|a, b| (a.mmio_slot, a.id.as_str()).cmp(&(b.mmio_slot, b.id.as_str())));
Self { devices }
}
}
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct MicrovmState {
pub vm_info: VmInfo,
pub vcpu_states: Vec<VcpuState>,
pub device_states: DeviceStates,
pub gic_state: GicState,
pub mmds_state: Option<MmdsState>,
}
impl MicrovmState {
pub fn verify_compatible(&self) -> Result<(), SnapshotError> {
if self.vcpu_states.is_empty() || self.vcpu_states.len() > 32 {
return Err(SnapshotError::Incompatible);
}
if self.vm_info.smt {
return Err(SnapshotError::Incompatible);
}
if self.gic_state.bytes.is_empty() {
return Err(SnapshotError::Incompatible);
}
if usize::try_from(self.gic_state.len).unwrap_or(usize::MAX) != self.gic_state.bytes.len() {
return Err(SnapshotError::Incompatible);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn round_trip<T: Serialize + serde::de::DeserializeOwned + PartialEq + std::fmt::Debug>(
value: &T,
) -> T {
let bytes = bitcode::serialize(value).expect("encode");
bitcode::deserialize(&bytes).expect("decode")
}
#[test]
fn test_should_round_trip_default_microvm_state() {
let state = MicrovmState {
vm_info: VmInfo {
mem_size_mib: 256,
smt: false,
cpu_template: String::new(),
kernel_image_path: "/tmp/vmlinux".into(),
initrd_path: None,
boot_args: "console=ttyAMA0 panic=1".into(),
track_dirty_pages: false,
},
vcpu_states: vec![VcpuState::new(0)],
device_states: DeviceStates::default(),
gic_state: GicState::from_bytes(vec![1, 2, 3, 4]),
mmds_state: None,
};
let back = round_trip(&state);
assert_eq!(state, back);
}
#[test]
fn test_should_sort_device_states_by_slot_then_id() {
let states = DeviceStates::from_devices(vec![
DeviceState {
kind: "virtio-net".into(),
id: "eth0".into(),
mmio_slot: 2,
blob: vec![],
},
DeviceState {
kind: "virtio-block".into(),
id: "rootfs".into(),
mmio_slot: 0,
blob: vec![],
},
DeviceState {
kind: "virtio-block".into(),
id: "data".into(),
mmio_slot: 1,
blob: vec![],
},
]);
assert_eq!(states.devices[0].id, "rootfs");
assert_eq!(states.devices[1].id, "data");
assert_eq!(states.devices[2].id, "eth0");
}
#[test]
fn test_should_reject_zero_vcpu_state() {
let state = MicrovmState {
vm_info: VmInfo {
mem_size_mib: 1,
smt: false,
cpu_template: String::new(),
kernel_image_path: String::new(),
initrd_path: None,
boot_args: String::new(),
track_dirty_pages: false,
},
vcpu_states: vec![],
device_states: DeviceStates::default(),
gic_state: GicState::from_bytes(vec![1]),
mmds_state: None,
};
assert!(matches!(
state.verify_compatible(),
Err(SnapshotError::Incompatible)
));
}
#[test]
fn test_should_reject_smt_enabled() {
let state = MicrovmState {
vm_info: VmInfo {
mem_size_mib: 1,
smt: true,
cpu_template: String::new(),
kernel_image_path: String::new(),
initrd_path: None,
boot_args: String::new(),
track_dirty_pages: false,
},
vcpu_states: vec![VcpuState::new(0)],
device_states: DeviceStates::default(),
gic_state: GicState::from_bytes(vec![1]),
mmds_state: None,
};
assert!(matches!(
state.verify_compatible(),
Err(SnapshotError::Incompatible)
));
}
#[test]
fn test_should_reject_empty_gic_blob() {
let state = MicrovmState {
vm_info: VmInfo::default(),
vcpu_states: vec![VcpuState::new(0)],
device_states: DeviceStates::default(),
gic_state: GicState::default(),
mmds_state: None,
};
assert!(matches!(
state.verify_compatible(),
Err(SnapshotError::Incompatible)
));
}
#[test]
fn test_should_reject_gic_length_mismatch() {
let mut state = MicrovmState {
vm_info: VmInfo::default(),
vcpu_states: vec![VcpuState::new(0)],
device_states: DeviceStates::default(),
gic_state: GicState::from_bytes(vec![1, 2, 3]),
mmds_state: None,
};
state.gic_state.len = 99;
assert!(matches!(
state.verify_compatible(),
Err(SnapshotError::Incompatible)
));
}
#[test]
fn test_should_round_trip_populated_mmds_state() {
let mmds = MmdsState::with_data(
&serde_json::json!({"latest": {"meta-data": {"instance-id": "demo"}}}),
Some(3600),
)
.unwrap();
let state = MicrovmState {
vm_info: VmInfo {
mem_size_mib: 64,
smt: false,
cpu_template: String::new(),
kernel_image_path: "/k".into(),
initrd_path: None,
boot_args: String::new(),
track_dirty_pages: false,
},
vcpu_states: vec![VcpuState::new(0)],
device_states: DeviceStates::default(),
gic_state: GicState::from_bytes(vec![0xAA; 16]),
mmds_state: Some(mmds),
};
let back = round_trip(&state);
let restored = back.mmds_state.expect("MMDS round-trip dropped");
assert_eq!(restored.token_ttl_seconds, Some(3600));
let value = restored.data_value().unwrap();
assert_eq!(value["latest"]["meta-data"]["instance-id"], "demo");
}
#[test]
fn test_should_round_trip_psci_state() {
for s in [
PsciVcpuState::On,
PsciVcpuState::Off,
PsciVcpuState::OnPending,
] {
let back = round_trip(&s);
assert_eq!(back, s);
}
}
}