#[inline]
pub fn le_u16(bytes: &[u8]) -> Option<u16> {
Some(u16::from_le_bytes(bytes.get(..2)?.try_into().ok()?))
}
#[inline]
pub fn le_u32(bytes: &[u8]) -> Option<u32> {
Some(u32::from_le_bytes(bytes.get(..4)?.try_into().ok()?))
}
#[inline]
pub fn le_u64(bytes: &[u8]) -> Option<u64> {
Some(u64::from_le_bytes(bytes.get(..8)?.try_into().ok()?))
}
#[inline]
pub fn le_u128(bytes: &[u8]) -> Option<u128> {
Some(u128::from_le_bytes(bytes.get(..16)?.try_into().ok()?))
}
#[inline]
pub fn cap_count(count: usize, min_elem: usize, region: usize) -> usize {
count.min(region / min_elem.max(1))
}
#[inline]
pub fn ram_region_within(file_len: u64, ram_offset: u64, memory_bytes: u64) -> bool {
ram_offset <= file_len && memory_bytes <= file_len - ram_offset
}
pub fn cow_map_ram(
file: &std::fs::File,
ram_offset: u64,
memory_bytes: usize,
extra_flags: libc::c_int,
) -> std::io::Result<*mut u8> {
use std::os::fd::AsRawFd;
let file_len = file.metadata()?.len();
if !ram_region_within(file_len, ram_offset, memory_bytes as u64) {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"snapshot RAM region [{ram_offset}, {ram_offset}+{memory_bytes}) \
exceeds file length {file_len}"
),
));
}
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
memory_bytes,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_PRIVATE | extra_flags,
file.as_raw_fd(),
ram_offset as libc::off_t,
)
};
if ptr == libc::MAP_FAILED {
return Err(std::io::Error::last_os_error());
}
Ok(ptr as *mut u8)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn le_readers_reject_short_slices() {
assert_eq!(le_u16(&[1, 2]), Some(0x0201));
assert_eq!(le_u16(&[1]), None);
assert_eq!(le_u32(&[1, 0, 0, 0]), Some(1));
assert_eq!(le_u32(&[1, 0, 0]), None);
assert_eq!(le_u64(&[0xff, 0, 0, 0, 0, 0, 0, 0]), Some(0xff));
assert_eq!(le_u64(&[0; 7]), None);
assert!(le_u128(&[7u8; 16]).is_some());
assert_eq!(le_u128(&[7u8; 15]), None);
}
#[test]
fn le_readers_read_only_the_leading_bytes() {
assert_eq!(le_u32(&[1, 0, 0, 0, 0xaa, 0xbb]), Some(1));
}
#[test]
fn cap_count_caps_to_what_the_region_can_hold() {
assert_eq!(cap_count(0xFFFF_FFFF, 12, 1200), 100);
assert_eq!(cap_count(5, 12, 1200), 5);
assert_eq!(cap_count(7, 0, 1200), 7);
assert_eq!(cap_count(7, 12, 0), 0);
}
#[test]
fn ram_region_bounds_are_overflow_safe() {
assert!(ram_region_within(100, 10, 90));
assert!(ram_region_within(100, 100, 0)); assert!(!ram_region_within(100, 10, 91)); assert!(!ram_region_within(100, 101, 0)); assert!(!ram_region_within(100, u64::MAX, 1));
assert!(!ram_region_within(100, 0, u64::MAX));
}
#[test]
fn cow_map_ram_rejects_region_past_eof() {
use std::io::Write;
let mut tmp = std::env::temp_dir();
tmp.push(format!("smframe_cow_{}.bin", std::process::id()));
{
let mut f = std::fs::File::create(&tmp).unwrap();
f.write_all(&[0u8; 4096]).unwrap();
}
let f = std::fs::File::open(&tmp).unwrap();
let err = cow_map_ram(&f, 4096, 4096, 0).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
let ptr = cow_map_ram(&f, 0, 4096, 0).unwrap();
assert!(!ptr.is_null());
unsafe {
assert_eq!(*ptr, 0);
libc::munmap(ptr as *mut libc::c_void, 4096);
}
let _ = std::fs::remove_file(&tmp);
}
}
use crate::devices::virtio::mmio::MmioSnapshot;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum DeviceKind {
Blk = 0,
Vsock = 1,
Volume = 2,
VirtioFs = 3,
}
impl DeviceKind {
pub fn from_virtio_id(id: u32) -> Self {
match id {
19 => Self::Vsock,
26 => Self::VirtioFs,
_ => Self::Blk, }
}
fn from_u8(v: u8) -> Option<Self> {
match v {
0 => Some(Self::Blk),
1 => Some(Self::Vsock),
2 => Some(Self::Volume),
3 => Some(Self::VirtioFs),
_ => None,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Default)]
pub enum DeviceBacking {
#[default]
None,
Disk {
path: String,
size: u64,
},
Volume {
path: String,
size: u64,
mount: String,
},
VirtioFs {
tag: String,
mount: String,
host_path: String,
dax_gpa: u64,
dax_window_len: u64,
backend_state: Vec<u8>,
dax_state: Vec<u8>,
},
}
#[derive(Clone, Debug)]
pub struct DeviceRecord {
pub kind: DeviceKind,
pub mmio: MmioSnapshot,
pub backing: DeviceBacking,
}
impl DeviceRecord {
pub fn write_to<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
w.write_all(&[self.kind as u8])?;
self.mmio.write_to(w)?;
fn s<W: std::io::Write>(w: &mut W, s: &str) -> std::io::Result<()> {
w.write_all(&(s.len() as u32).to_le_bytes())?;
w.write_all(s.as_bytes())
}
fn blob<W: std::io::Write>(w: &mut W, b: &[u8]) -> std::io::Result<()> {
w.write_all(&(b.len() as u32).to_le_bytes())?;
w.write_all(b)
}
match &self.backing {
DeviceBacking::None => w.write_all(&[0u8])?,
DeviceBacking::Disk { path, size } => {
w.write_all(&[1u8])?;
s(w, path)?;
w.write_all(&size.to_le_bytes())?;
}
DeviceBacking::Volume { path, size, mount } => {
w.write_all(&[2u8])?;
s(w, path)?;
w.write_all(&size.to_le_bytes())?;
s(w, mount)?;
}
DeviceBacking::VirtioFs {
tag,
mount,
host_path,
dax_gpa,
dax_window_len,
backend_state,
dax_state,
} => {
w.write_all(&[3u8])?;
s(w, tag)?;
s(w, mount)?;
s(w, host_path)?;
w.write_all(&dax_gpa.to_le_bytes())?;
w.write_all(&dax_window_len.to_le_bytes())?;
blob(w, backend_state)?;
blob(w, dax_state)?;
}
}
Ok(())
}
pub fn read_from<R: std::io::Read>(r: &mut R) -> std::io::Result<DeviceRecord> {
fn b<const N: usize, R: std::io::Read>(r: &mut R) -> std::io::Result<[u8; N]> {
let mut x = [0u8; N];
r.read_exact(&mut x)?;
Ok(x)
}
fn err(m: &str) -> std::io::Error {
std::io::Error::new(std::io::ErrorKind::InvalidData, m.to_string())
}
fn s<R: std::io::Read>(r: &mut R) -> std::io::Result<String> {
let n = u32::from_le_bytes(b::<4, _>(r)?) as usize;
if n > 1 << 16 {
return Err(err("device backing string too long"));
}
let mut buf = vec![0u8; n];
r.read_exact(&mut buf)?;
String::from_utf8(buf).map_err(|_| err("device backing string not utf8"))
}
fn blob<R: std::io::Read>(r: &mut R) -> std::io::Result<Vec<u8>> {
let n = u32::from_le_bytes(b::<4, _>(r)?) as usize;
if n > 16 << 20 {
return Err(err("device backing blob too large"));
}
let mut buf = vec![0u8; n];
r.read_exact(&mut buf)?;
Ok(buf)
}
let kind = DeviceKind::from_u8(b::<1, _>(r)?[0]).ok_or_else(|| err("bad device kind"))?;
let mmio = MmioSnapshot::read_from(r)?;
let backing = match b::<1, _>(r)?[0] {
0 => DeviceBacking::None,
1 => DeviceBacking::Disk {
path: s(r)?,
size: u64::from_le_bytes(b::<8, _>(r)?),
},
2 => DeviceBacking::Volume {
path: s(r)?,
size: u64::from_le_bytes(b::<8, _>(r)?),
mount: s(r)?,
},
3 => DeviceBacking::VirtioFs {
tag: s(r)?,
mount: s(r)?,
host_path: s(r)?,
dax_gpa: u64::from_le_bytes(b::<8, _>(r)?),
dax_window_len: u64::from_le_bytes(b::<8, _>(r)?),
backend_state: blob(r)?,
dax_state: blob(r)?,
},
_ => return Err(err("bad device backing tag")),
};
Ok(DeviceRecord {
kind,
mmio,
backing,
})
}
}
use crate::devices::virtio::vsock::muxer::TsiListenerSnapshot;
#[derive(Clone, Debug)]
pub struct ContainerMeta {
pub num_cpus: u8,
pub mem_size: u64,
pub com1: [u8; 6],
pub clock_host_ticks: u64,
pub clock_ref: u64,
pub intc_blob: Vec<u8>,
pub vcpu_blobs: Vec<Vec<u8>>,
pub devices: Vec<DeviceRecord>,
pub tsi_token: Option<[u8; 32]>,
pub vsock_listeners: Vec<TsiListenerSnapshot>,
}
const CONTAINER_BLOB_MAX: usize = 16 << 20;
const MAX_VCPUS_HINT: usize = 1024;
const MAX_DEVS_HINT: usize = 4096;
const MAX_LISTENERS_HINT: usize = 1 << 16;
impl ContainerMeta {
pub fn write_container<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
fn blob<W: std::io::Write>(w: &mut W, b: &[u8]) -> std::io::Result<()> {
w.write_all(&(b.len() as u32).to_le_bytes())?;
w.write_all(b)
}
w.write_all(&[self.num_cpus])?;
w.write_all(&self.mem_size.to_le_bytes())?;
w.write_all(&self.com1)?;
w.write_all(&self.clock_host_ticks.to_le_bytes())?;
w.write_all(&self.clock_ref.to_le_bytes())?;
blob(w, &self.intc_blob)?;
w.write_all(&(self.vcpu_blobs.len() as u32).to_le_bytes())?;
for v in &self.vcpu_blobs {
blob(w, v)?;
}
w.write_all(&(self.devices.len() as u32).to_le_bytes())?;
for d in &self.devices {
d.write_to(w)?;
}
match &self.tsi_token {
Some(t) => {
w.write_all(&[1u8])?;
w.write_all(t)?;
}
None => w.write_all(&[0u8])?,
}
w.write_all(&(self.vsock_listeners.len() as u32).to_le_bytes())?;
for l in &self.vsock_listeners {
l.write_to(w)?;
}
Ok(())
}
pub fn read_container<R: std::io::Read>(r: &mut R) -> std::io::Result<ContainerMeta> {
fn b<const N: usize, R: std::io::Read>(r: &mut R) -> std::io::Result<[u8; N]> {
let mut x = [0u8; N];
r.read_exact(&mut x)?;
Ok(x)
}
fn err(m: &str) -> std::io::Error {
std::io::Error::new(std::io::ErrorKind::InvalidData, m.to_string())
}
fn blob<R: std::io::Read>(r: &mut R) -> std::io::Result<Vec<u8>> {
let n = u32::from_le_bytes(b::<4, _>(r)?) as usize;
if n > CONTAINER_BLOB_MAX {
return Err(err("container blob too large"));
}
let mut buf = vec![0u8; n];
r.read_exact(&mut buf)?;
Ok(buf)
}
let num_cpus = b::<1, _>(r)?[0];
let mem_size = u64::from_le_bytes(b::<8, _>(r)?);
let com1 = b::<6, _>(r)?;
let clock_host_ticks = u64::from_le_bytes(b::<8, _>(r)?);
let clock_ref = u64::from_le_bytes(b::<8, _>(r)?);
let intc_blob = blob(r)?;
let n_vcpus = u32::from_le_bytes(b::<4, _>(r)?) as usize;
let mut vcpu_blobs = Vec::with_capacity(n_vcpus.min(MAX_VCPUS_HINT));
for _ in 0..n_vcpus {
vcpu_blobs.push(blob(r)?);
}
let n_devices = u32::from_le_bytes(b::<4, _>(r)?) as usize;
let mut devices = Vec::with_capacity(n_devices.min(MAX_DEVS_HINT));
for _ in 0..n_devices {
devices.push(DeviceRecord::read_from(r)?);
}
let tsi_token = match b::<1, _>(r)?[0] {
0 => None,
1 => Some(b::<32, _>(r)?),
_ => return Err(err("bad tsi_token presence byte")),
};
let n_listeners = u32::from_le_bytes(b::<4, _>(r)?) as usize;
let mut vsock_listeners = Vec::with_capacity(n_listeners.min(MAX_LISTENERS_HINT));
for _ in 0..n_listeners {
vsock_listeners.push(TsiListenerSnapshot::read_from(r)?);
}
Ok(ContainerMeta {
num_cpus,
mem_size,
com1,
clock_host_ticks,
clock_ref,
intc_blob,
vcpu_blobs,
devices,
tsi_token,
vsock_listeners,
})
}
}
#[cfg(test)]
mod container_tests {
use super::*;
use crate::devices::virtio::mmio::{MmioSnapshot, QueueSnapshot};
fn mmio(features: u32, nq: usize) -> MmioSnapshot {
MmioSnapshot {
driver_features: [features, features ^ 0xdead_beef],
status: 0xf,
interrupt_status: 1,
queues: (0..nq)
.map(|i| QueueSnapshot {
size: 256,
ready: i % 2 == 0,
desc_table: 0x1000 * (i as u64 + 1),
avail_ring: 0x2000 * (i as u64 + 1),
used_ring: 0x3000 * (i as u64 + 1),
last_avail_idx: i as u16,
next_used_idx: (i as u16) + 7,
})
.collect(),
}
}
fn sample(
devices: Vec<DeviceRecord>,
listeners: usize,
token: bool,
vcpus: usize,
) -> ContainerMeta {
ContainerMeta {
num_cpus: vcpus as u8,
mem_size: 512 << 20,
com1: [1, 2, 3, 4, 5, 6],
clock_host_ticks: 0x1122_3344_5566_7788,
clock_ref: 0x99aa_bbcc_ddee_ff00,
intc_blob: (0..333u32).map(|i| i as u8).collect(),
vcpu_blobs: (0..vcpus).map(|i| vec![i as u8; 64 + i]).collect(),
devices,
tsi_token: token.then_some([0x5a; 32]),
vsock_listeners: (0..listeners)
.map(|i| TsiListenerSnapshot {
cid: 3,
peer_port: 1000 + i as u32,
vm_port: 80,
family: 2,
socktype: 1,
inet_port: if i % 2 == 0 {
Some(8080 + i as u16)
} else {
None
},
})
.collect(),
}
}
fn assert_eq_meta(a: &ContainerMeta, b: &ContainerMeta) {
assert_eq!(a.num_cpus, b.num_cpus);
assert_eq!(a.mem_size, b.mem_size);
assert_eq!(a.com1, b.com1);
assert_eq!(a.clock_host_ticks, b.clock_host_ticks);
assert_eq!(a.clock_ref, b.clock_ref);
assert_eq!(a.intc_blob, b.intc_blob);
assert_eq!(a.vcpu_blobs, b.vcpu_blobs);
assert_eq!(a.tsi_token, b.tsi_token);
assert_eq!(a.devices.len(), b.devices.len());
for (x, y) in a.devices.iter().zip(&b.devices) {
assert_eq!(x.kind, y.kind);
assert_eq!(x.backing, y.backing);
let (mut xb, mut yb) = (Vec::new(), Vec::new());
x.mmio.write_to(&mut xb).unwrap();
y.mmio.write_to(&mut yb).unwrap();
assert_eq!(xb, yb);
}
assert_eq!(a.vsock_listeners.len(), b.vsock_listeners.len());
for (x, y) in a.vsock_listeners.iter().zip(&b.vsock_listeners) {
assert_eq!(x.cid, y.cid);
assert_eq!(x.peer_port, y.peer_port);
assert_eq!(x.vm_port, y.vm_port);
assert_eq!(x.family, y.family);
assert_eq!(x.socktype, y.socktype);
assert_eq!(x.inet_port, y.inet_port);
}
}
fn round_trip(m: &ContainerMeta) {
let mut buf = Vec::new();
m.write_container(&mut buf).unwrap();
let mut cur = std::io::Cursor::new(&buf);
let back = ContainerMeta::read_container(&mut cur).unwrap();
assert_eq!(cur.position() as usize, buf.len());
assert_eq_meta(m, &back);
}
#[test]
fn empty_container_round_trips() {
round_trip(&sample(vec![], 0, false, 1));
}
#[test]
fn full_container_round_trips() {
let devices = vec![
DeviceRecord {
kind: DeviceKind::Blk,
mmio: mmio(0x1, 2),
backing: DeviceBacking::Disk {
path: "/var/lib/sm/root.img".into(),
size: 1 << 30,
},
},
DeviceRecord {
kind: DeviceKind::Vsock,
mmio: mmio(0x2, 3),
backing: DeviceBacking::None,
},
DeviceRecord {
kind: DeviceKind::Volume,
mmio: mmio(0x3, 1),
backing: DeviceBacking::Volume {
path: "/var/lib/sm/vol-b.img".into(),
size: 4 << 20,
mount: "/data".into(),
},
},
DeviceRecord {
kind: DeviceKind::VirtioFs,
mmio: mmio(0x4, 2),
backing: DeviceBacking::VirtioFs {
tag: "shared".into(),
mount: "/mnt/shared".into(),
host_path: "/srv/share".into(),
dax_gpa: 0x4000_0000,
dax_window_len: 0x800_0000,
backend_state: vec![0xab; 200],
dax_state: vec![0xcd; 80],
},
},
];
round_trip(&sample(devices, 3, true, 4));
}
#[test]
fn truncated_stream_errors_not_panics() {
let mut buf = Vec::new();
sample(
vec![DeviceRecord {
kind: DeviceKind::Blk,
mmio: mmio(1, 1),
backing: DeviceBacking::None,
}],
2,
true,
2,
)
.write_container(&mut buf)
.unwrap();
for cut in 0..buf.len() {
let mut cur = std::io::Cursor::new(&buf[..cut]);
assert!(
ContainerMeta::read_container(&mut cur).is_err(),
"prefix {cut}"
);
}
}
#[test]
fn hostile_blob_length_is_rejected() {
let mut buf = Vec::new();
buf.push(1u8); buf.extend_from_slice(&(0u64).to_le_bytes()); buf.extend_from_slice(&[0u8; 6]); buf.extend_from_slice(&(0u64).to_le_bytes()); buf.extend_from_slice(&(0u64).to_le_bytes()); buf.extend_from_slice(&u32::MAX.to_le_bytes()); let mut cur = std::io::Cursor::new(&buf);
let err = ContainerMeta::read_container(&mut cur).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
}