use std::collections::BTreeMap;
use squib_arch::SysReg;
use crate::{
error::{Result, SnapshotError},
state::{FpSimdRegs, GpRegs, MmdsState, PsciVcpuState, VcpuState},
};
pub trait VcpuSnapshotSource {
fn mpidr(&self) -> u64;
fn read_gp_regs(&self) -> Result<GpRegs>;
fn read_fp_simd(&self) -> Result<FpSimdRegs>;
fn read_sys_reg(&self, reg: SysReg) -> Result<Option<u64>>;
fn psci_state(&self) -> PsciVcpuState;
}
pub trait VcpuRestoreTarget {
fn write_gp_regs(&mut self, regs: &GpRegs) -> Result<()>;
fn write_fp_simd(&mut self, regs: &FpSimdRegs) -> Result<()>;
fn write_sys_reg(&mut self, reg: SysReg, value: u64) -> Result<()>;
fn set_psci_state(&mut self, state: PsciVcpuState) -> Result<()>;
}
pub trait GicSnapshotSource {
fn capture(&self) -> Result<Vec<u8>>;
}
pub trait GicRestoreTarget {
fn restore(&mut self, data: &[u8]) -> Result<()>;
}
pub trait MmdsSnapshotSource {
fn capture(&self) -> Result<Option<MmdsState>>;
}
pub trait MmdsRestoreTarget {
fn restore(&mut self, state: Option<&MmdsState>) -> Result<()>;
}
pub fn capture_vcpu_state<S: VcpuSnapshotSource>(source: &S) -> Result<VcpuState> {
let regs = source.read_gp_regs()?;
let fp_regs = source.read_fp_simd()?;
let mut sys_regs = BTreeMap::new();
for reg in SysReg::all() {
if let Some(value) = source.read_sys_reg(*reg)? {
sys_regs.insert(reg.as_encoded(), value);
}
}
Ok(VcpuState {
mpidr: source.mpidr(),
regs,
fp_regs,
sys_regs,
psci_state: source.psci_state(),
})
}
pub fn restore_vcpu_state<T: VcpuRestoreTarget>(
target: &mut T,
state: &VcpuState,
vcpu_index: u32,
) -> Result<()> {
target.write_gp_regs(&state.regs)?;
target.write_fp_simd(&state.fp_regs)?;
for (encoded, value) in &state.sys_regs {
let reg = SysReg::from_encoded(*encoded).ok_or(SnapshotError::Incompatible)?;
target.write_sys_reg(reg, *value)?;
}
let normalized = if vcpu_index == 0 {
PsciVcpuState::On
} else {
PsciVcpuState::Off
};
target.set_psci_state(normalized)?;
Ok(())
}
#[must_use]
pub fn normalized_psci_state(vcpu_index: u32) -> PsciVcpuState {
if vcpu_index == 0 {
PsciVcpuState::On
} else {
PsciVcpuState::Off
}
}
#[cfg(test)]
mod tests {
use std::cell::RefCell;
use super::*;
#[derive(Debug, Default)]
struct MockSource {
mpidr: u64,
sys: std::collections::HashMap<SysReg, u64>,
}
impl VcpuSnapshotSource for MockSource {
fn mpidr(&self) -> u64 {
self.mpidr
}
fn read_gp_regs(&self) -> Result<GpRegs> {
Ok(GpRegs::default())
}
fn read_fp_simd(&self) -> Result<FpSimdRegs> {
Ok(FpSimdRegs::default())
}
fn read_sys_reg(&self, reg: SysReg) -> Result<Option<u64>> {
Ok(self.sys.get(®).copied())
}
fn psci_state(&self) -> PsciVcpuState {
PsciVcpuState::On
}
}
#[derive(Debug, Default)]
struct MockTarget {
gp: RefCell<Option<GpRegs>>,
fp: RefCell<Option<FpSimdRegs>>,
sys: RefCell<std::collections::HashMap<SysReg, u64>>,
psci: RefCell<Option<PsciVcpuState>>,
}
impl VcpuRestoreTarget for MockTarget {
fn write_gp_regs(&mut self, regs: &GpRegs) -> Result<()> {
*self.gp.borrow_mut() = Some(regs.clone());
Ok(())
}
fn write_fp_simd(&mut self, regs: &FpSimdRegs) -> Result<()> {
*self.fp.borrow_mut() = Some(regs.clone());
Ok(())
}
fn write_sys_reg(&mut self, reg: SysReg, value: u64) -> Result<()> {
self.sys.borrow_mut().insert(reg, value);
Ok(())
}
fn set_psci_state(&mut self, state: PsciVcpuState) -> Result<()> {
*self.psci.borrow_mut() = Some(state);
Ok(())
}
}
#[test]
fn test_should_round_trip_curated_sysregs_via_capture_and_restore() {
let mut source = MockSource {
mpidr: 0x42,
..Default::default()
};
source.sys.insert(SysReg::SctlrEl1, 0xCAFE_BEEF);
source.sys.insert(SysReg::Ttbr0El1, 0xDEAD_BEAD);
let state = capture_vcpu_state(&source).unwrap();
assert_eq!(state.mpidr, 0x42);
assert_eq!(state.sys_regs.len(), 2);
let mut target = MockTarget::default();
restore_vcpu_state(&mut target, &state, 0).unwrap();
let sys = target.sys.borrow();
assert_eq!(sys.get(&SysReg::SctlrEl1), Some(&0xCAFE_BEEF));
assert_eq!(sys.get(&SysReg::Ttbr0El1), Some(&0xDEAD_BEAD));
}
#[test]
fn test_should_skip_sys_regs_the_source_returns_none_for() {
let source = MockSource::default(); let state = capture_vcpu_state(&source).unwrap();
assert!(state.sys_regs.is_empty());
}
#[test]
fn test_should_normalize_psci_state_on_restore_for_bsp_only() {
let mut state = VcpuState::new(0);
state.psci_state = PsciVcpuState::Off; let mut target = MockTarget::default();
restore_vcpu_state(&mut target, &state, 0).unwrap();
assert_eq!(*target.psci.borrow(), Some(PsciVcpuState::On));
}
#[test]
fn test_should_normalize_psci_state_to_off_for_secondaries() {
let mut state = VcpuState::new(1);
state.psci_state = PsciVcpuState::On; let mut target = MockTarget::default();
restore_vcpu_state(&mut target, &state, 1).unwrap();
assert_eq!(*target.psci.borrow(), Some(PsciVcpuState::Off));
}
#[test]
fn test_should_reject_unknown_sys_reg_keys_on_restore() {
let mut state = VcpuState::new(0);
state.sys_regs.insert(u64::MAX, 0x1234);
let mut target = MockTarget::default();
let res = restore_vcpu_state(&mut target, &state, 0);
assert!(matches!(res, Err(SnapshotError::Incompatible)));
}
#[test]
fn test_should_provide_normalized_psci_helper() {
assert_eq!(normalized_psci_state(0), PsciVcpuState::On);
assert_eq!(normalized_psci_state(1), PsciVcpuState::Off);
assert_eq!(normalized_psci_state(31), PsciVcpuState::Off);
}
}