use std::path::Path;
use crate::{
atomic::AtomicWriter,
dirty::DirtyBitmap,
envelope::Snapshot,
error::{Result, SnapshotError},
memory::{MemoryWriter, PageReader},
state::MicrovmState,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SnapshotKind {
Full,
Diff,
}
#[derive(Debug)]
pub struct SaveRequest<'a, R: PageReader> {
pub state_path: &'a Path,
pub memory_path: &'a Path,
pub kind: SnapshotKind,
pub state: MicrovmState,
pub memory: &'a R,
pub ram_size: u64,
pub memory_page_size: u64,
pub dirty: Option<&'a DirtyBitmap>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SaveReport {
pub kind: SnapshotKind,
pub pages_written: u64,
}
pub fn save<R: PageReader>(req: SaveRequest<'_, R>) -> Result<SaveReport> {
let SaveRequest {
state_path,
memory_path,
kind,
state,
memory,
ram_size,
memory_page_size,
dirty,
} = req;
if matches!(kind, SnapshotKind::Diff) && !state.vm_info.track_dirty_pages {
return Err(SnapshotError::InvalidPath(
"Diff snapshot requested but vm_info.track_dirty_pages is false".into(),
));
}
state.verify_compatible()?;
let mut state_writer = AtomicWriter::open(state_path)?;
let envelope = Snapshot::new(state);
envelope.save(state_writer.file_mut())?;
let mut mem_writer = MemoryWriter::open(memory_path, ram_size, memory_page_size)?;
let pages_written = match kind {
SnapshotKind::Full => {
mem_writer.write_full(memory)?;
0
}
SnapshotKind::Diff => {
let Some(bitmap) = dirty else {
return Err(SnapshotError::InvalidPath(
"Diff snapshot requires track_dirty_pages=true and a dirty bitmap".into(),
));
};
mem_writer.write_diff(memory, bitmap)?
}
};
state_writer.commit()?;
mem_writer.commit()?;
Ok(SaveReport {
kind,
pages_written,
})
}
pub use crate::envelope::SnapshotHdr as Header;
#[cfg(test)]
mod tests {
use std::path::Path;
use tempfile::TempDir;
use super::*;
use crate::{
memory::VecPageReader,
state::{GicState, MicrovmState, VcpuState, VmInfo},
};
fn build_state() -> MicrovmState {
MicrovmState {
vm_info: VmInfo {
mem_size_mib: 256,
smt: false,
cpu_template: "V1N1".into(),
kernel_image_path: "/tmp/vmlinux".into(),
initrd_path: None,
boot_args: "console=ttyAMA0 panic=1".into(),
track_dirty_pages: false,
},
vcpu_states: vec![VcpuState::new(0)],
device_states: crate::state::DeviceStates::default(),
gic_state: GicState::from_bytes(vec![1, 2, 3, 4, 5, 6, 7, 8]),
mmds_state: None,
}
}
fn dest_in(dir: &Path, name: &str) -> std::path::PathBuf {
dir.join(name)
}
#[test]
fn test_should_save_full_snapshot_pair_atomically() {
let dir = TempDir::new().unwrap();
let snap = dest_in(dir.path(), "x.snap");
let mem = dest_in(dir.path(), "x.mem");
let ram_size: u64 = 32 * 1024;
let reader = VecPageReader::new(vec![7u8; ram_size as usize]);
let report = save(SaveRequest {
state_path: &snap,
memory_path: &mem,
kind: SnapshotKind::Full,
state: build_state(),
memory: &reader,
ram_size,
memory_page_size: 16 * 1024,
dirty: None,
})
.unwrap();
assert_eq!(report.kind, SnapshotKind::Full);
assert!(snap.exists());
assert!(mem.exists());
assert_eq!(std::fs::metadata(&mem).unwrap().len(), ram_size);
}
#[test]
fn test_should_reject_diff_without_dirty_bitmap() {
let dir = TempDir::new().unwrap();
let snap = dest_in(dir.path(), "x.snap");
let mem = dest_in(dir.path(), "x.mem");
let mut state = build_state();
state.vm_info.track_dirty_pages = true;
let reader = VecPageReader::new(vec![0u8; 32 * 1024]);
let res = save(SaveRequest {
state_path: &snap,
memory_path: &mem,
kind: SnapshotKind::Diff,
state,
memory: &reader,
ram_size: 32 * 1024,
memory_page_size: 16 * 1024,
dirty: None,
});
assert!(matches!(res, Err(SnapshotError::InvalidPath(_))));
}
#[test]
fn test_should_reject_diff_when_track_dirty_is_false() {
let dir = TempDir::new().unwrap();
let snap = dest_in(dir.path(), "x.snap");
let mem = dest_in(dir.path(), "x.mem");
let bm = DirtyBitmap::new(0, 32 * 1024, 16 * 1024).unwrap();
let reader = VecPageReader::new(vec![0u8; 32 * 1024]);
let res = save(SaveRequest {
state_path: &snap,
memory_path: &mem,
kind: SnapshotKind::Diff,
state: build_state(),
memory: &reader,
ram_size: 32 * 1024,
memory_page_size: 16 * 1024,
dirty: Some(&bm),
});
assert!(matches!(res, Err(SnapshotError::InvalidPath(_))));
}
#[test]
fn test_should_save_diff_only_dirty_pages() {
let dir = TempDir::new().unwrap();
let snap = dest_in(dir.path(), "x.snap");
let mem = dest_in(dir.path(), "x.mem");
let mut state = build_state();
state.vm_info.track_dirty_pages = true;
let bm = DirtyBitmap::new(0, 32 * 1024, 16 * 1024).unwrap();
bm.set_dirty_by_index(1);
let reader = VecPageReader::new(vec![9u8; 32 * 1024]);
let report = save(SaveRequest {
state_path: &snap,
memory_path: &mem,
kind: SnapshotKind::Diff,
state,
memory: &reader,
ram_size: 32 * 1024,
memory_page_size: 16 * 1024,
dirty: Some(&bm),
})
.unwrap();
assert_eq!(report.pages_written, 1);
let buf = std::fs::read(&mem).unwrap();
assert!(buf[..16 * 1024].iter().all(|&b| b == 0));
assert!(buf[16 * 1024..32 * 1024].iter().all(|&b| b == 9));
}
#[test]
fn test_should_reject_save_when_state_is_incompatible() {
let dir = TempDir::new().unwrap();
let snap = dest_in(dir.path(), "x.snap");
let mem = dest_in(dir.path(), "x.mem");
let mut state = build_state();
state.vcpu_states.clear(); let reader = VecPageReader::new(vec![0u8; 32 * 1024]);
let res = save(SaveRequest {
state_path: &snap,
memory_path: &mem,
kind: SnapshotKind::Full,
state,
memory: &reader,
ram_size: 32 * 1024,
memory_page_size: 16 * 1024,
dirty: None,
});
assert!(matches!(res, Err(SnapshotError::Incompatible)));
assert!(
!snap.exists(),
"incompatible state must not stage temp file"
);
}
#[test]
fn test_should_keep_existing_pair_when_save_fails_during_state_validation() {
let dir = TempDir::new().unwrap();
let snap = dest_in(dir.path(), "x.snap");
let mem = dest_in(dir.path(), "x.mem");
std::fs::write(&snap, b"prior good state").unwrap();
std::fs::write(&mem, b"prior good mem").unwrap();
let mut state = build_state();
state.vm_info.smt = true; let reader = VecPageReader::new(vec![0u8; 32 * 1024]);
let _ = save(SaveRequest {
state_path: &snap,
memory_path: &mem,
kind: SnapshotKind::Full,
state,
memory: &reader,
ram_size: 32 * 1024,
memory_page_size: 16 * 1024,
dirty: None,
});
assert_eq!(std::fs::read_to_string(&snap).unwrap(), "prior good state");
assert_eq!(std::fs::read_to_string(&mem).unwrap(), "prior good mem");
}
}