#![allow(dead_code)]
use std::sync::{Arc, Mutex};
use super::queue::{GuestMem, Queue};
use super::VirtioDevice;
use crate::devices::mmio_bus::MmioDevice;
const MAGIC: u32 = 0x74726976; const VERSION: u32 = 2;
#[derive(Clone, Debug)]
pub struct QueueSnapshot {
pub size: u16,
pub ready: bool,
pub desc_table: u64,
pub avail_ring: u64,
pub used_ring: u64,
pub last_avail_idx: u16,
pub next_used_idx: u16,
}
#[derive(Clone, Debug)]
pub struct MmioSnapshot {
pub driver_features: [u32; 2],
pub status: u32,
pub interrupt_status: u32,
pub queues: Vec<QueueSnapshot>,
}
impl MmioSnapshot {
pub fn write_to<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
w.write_all(&self.driver_features[0].to_le_bytes())?;
w.write_all(&self.driver_features[1].to_le_bytes())?;
w.write_all(&self.status.to_le_bytes())?;
w.write_all(&self.interrupt_status.to_le_bytes())?;
w.write_all(&(self.queues.len() as u32).to_le_bytes())?;
for q in &self.queues {
w.write_all(&q.size.to_le_bytes())?;
w.write_all(&[q.ready as u8])?;
w.write_all(&q.desc_table.to_le_bytes())?;
w.write_all(&q.avail_ring.to_le_bytes())?;
w.write_all(&q.used_ring.to_le_bytes())?;
w.write_all(&q.last_avail_idx.to_le_bytes())?;
w.write_all(&q.next_used_idx.to_le_bytes())?;
}
Ok(())
}
pub fn read_from<R: std::io::Read>(r: &mut R) -> std::io::Result<MmioSnapshot> {
fn rd<const N: usize, R: std::io::Read>(r: &mut R) -> std::io::Result<[u8; N]> {
let mut b = [0u8; N];
r.read_exact(&mut b)?;
Ok(b)
}
let driver_features = [
u32::from_le_bytes(rd::<4, _>(r)?),
u32::from_le_bytes(rd::<4, _>(r)?),
];
let status = u32::from_le_bytes(rd::<4, _>(r)?);
let interrupt_status = u32::from_le_bytes(rd::<4, _>(r)?);
let nq = u32::from_le_bytes(rd::<4, _>(r)?);
let mut queues = Vec::with_capacity(nq.min(64) as usize);
for _ in 0..nq {
let size = u16::from_le_bytes(rd::<2, _>(r)?);
let ready = rd::<1, _>(r)?[0] != 0;
queues.push(QueueSnapshot {
size,
ready,
desc_table: u64::from_le_bytes(rd::<8, _>(r)?),
avail_ring: u64::from_le_bytes(rd::<8, _>(r)?),
used_ring: u64::from_le_bytes(rd::<8, _>(r)?),
last_avail_idx: u16::from_le_bytes(rd::<2, _>(r)?),
next_used_idx: u16::from_le_bytes(rd::<2, _>(r)?),
});
}
Ok(MmioSnapshot {
driver_features,
status,
interrupt_status,
queues,
})
}
}
struct State {
device_features_sel: u32,
driver_features: [u32; 2],
driver_features_sel: u32,
queue_sel: u32,
status: u32,
interrupt_status: u32,
queues: Vec<Queue>,
activated: bool,
shm_sel: u32,
irq_raise: Arc<dyn Fn() + Send + Sync>,
}
pub struct MmioVirtio {
dev: Arc<dyn VirtioDevice>,
state: Mutex<State>,
}
impl MmioVirtio {
pub fn new(
dev: Arc<dyn VirtioDevice>,
mem: GuestMem,
irq_raise: Arc<dyn Fn() + Send + Sync>,
) -> Self {
let queues = (0..dev.num_queues())
.map(|_| Queue::new(mem.clone()))
.collect();
Self {
dev,
state: Mutex::new(State {
device_features_sel: 0,
driver_features: [0; 2],
driver_features_sel: 0,
queue_sel: 0,
status: 0,
interrupt_status: 0,
queues,
activated: false,
shm_sel: 0,
irq_raise,
}),
}
}
pub fn capture_state(&self) -> MmioSnapshot {
let st = self.state.lock().unwrap();
let live = self.dev.snapshot_queues();
let queues: Vec<QueueSnapshot> = (0..st.queues.len())
.map(|i| {
let q = live.get(i).unwrap_or(&st.queues[i]);
QueueSnapshot {
size: q.size,
ready: q.ready,
desc_table: q.desc_table,
avail_ring: q.avail_ring,
used_ring: q.used_ring,
last_avail_idx: q.last_avail_idx,
next_used_idx: q.next_used_idx,
}
})
.collect();
MmioSnapshot {
driver_features: st.driver_features,
status: st.status,
interrupt_status: st.interrupt_status,
queues,
}
}
pub fn restore_state(&self, snap: &MmioSnapshot) {
let mut st = self.state.lock().unwrap();
st.driver_features = snap.driver_features;
st.status = snap.status;
st.interrupt_status = snap.interrupt_status;
for (i, qs) in snap.queues.iter().enumerate() {
if let Some(q) = st.queues.get_mut(i) {
q.size = qs.size;
q.ready = qs.ready;
q.desc_table = qs.desc_table;
q.avail_ring = qs.avail_ring;
q.used_ring = qs.used_ring;
q.last_avail_idx = qs.last_avail_idx;
q.next_used_idx = qs.next_used_idx;
}
}
if snap.status & super::STATUS_DRIVER_OK != 0 {
st.activated = true;
let queues = st.queues.clone();
drop(st);
self.dev.activate(queues);
}
}
pub fn make_config_change_irq(self: &Arc<Self>) -> Arc<dyn Fn() + Send + Sync> {
let me = Arc::downgrade(self);
Arc::new(move || {
let Some(me) = me.upgrade() else {
return;
};
let mut st = me.state.lock().unwrap();
st.interrupt_status |= 0x2;
let f = st.irq_raise.clone();
drop(st);
f();
})
}
pub fn device_id(&self) -> u32 {
self.dev.device_id()
}
pub fn make_used_buffer_irq(self: &Arc<Self>) -> Arc<dyn Fn() + Send + Sync> {
let me = Arc::downgrade(self);
Arc::new(move || {
let Some(me) = me.upgrade() else {
return;
};
let mut st = me.state.lock().unwrap();
st.interrupt_status |= 0x1;
let f = st.irq_raise.clone();
drop(st);
f();
})
}
}
impl MmioDevice for MmioVirtio {
fn read(&self, offset: u64, _size: u8) -> u64 {
let st = self.state.lock().unwrap();
let v: u32 = match offset {
0x000 => MAGIC,
0x004 => VERSION,
0x008 => self.dev.device_id(),
0x00c => self.dev.vendor_id(),
0x010 => {
let f = self.dev.features();
if st.device_features_sel == 0 {
f as u32
} else {
(f >> 32) as u32
}
}
0x034 => self.dev.queue_max_size() as u32,
0x038 => st
.queues
.get(st.queue_sel as usize)
.map(|q| q.size as u32)
.unwrap_or(0),
0x044 => st
.queues
.get(st.queue_sel as usize)
.map(|q| if q.ready { 1 } else { 0 })
.unwrap_or(0),
0x060 => st.interrupt_status,
0x070 => st.status,
0x0b0 | 0x0b4 | 0x0b8 | 0x0bc => {
let sel = st.shm_sel as u8;
let region = self.dev.shm_regions().into_iter().find(|r| r.id == sel);
let (base, len) = match region {
Some(r) => (r.gpa, r.len),
None => (0u64, u64::MAX),
};
match offset {
0x0b0 => len as u32,
0x0b4 => (len >> 32) as u32,
0x0b8 => base as u32,
0x0bc => (base >> 32) as u32,
_ => unreachable!(),
}
}
0x100.. => {
let cfg = self.dev.config();
let off = (offset - 0x100) as usize;
off.checked_add(4)
.and_then(|end| cfg.get(off..end))
.map(|bytes| u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
.unwrap_or(0)
}
_ => 0,
};
v as u64
}
fn write(&self, offset: u64, value: u64, _size: u8) {
let mut st = self.state.lock().unwrap();
let v32 = value as u32;
match offset {
0x014 => st.device_features_sel = v32,
0x020 => {
let i = (st.driver_features_sel & 1) as usize;
st.driver_features[i] = v32;
}
0x024 => st.driver_features_sel = v32,
0x030 => st.queue_sel = v32,
0x038 => {
let sel = st.queue_sel as usize;
if let Some(q) = st.queues.get_mut(sel) {
q.size = v32 as u16;
}
}
0x044 => {
let sel = st.queue_sel as usize;
if let Some(q) = st.queues.get_mut(sel) {
q.ready = v32 != 0;
}
}
0x050 => {
drop(st);
self.dev.notify(v32 as u16);
}
0x064 => st.interrupt_status &= !v32,
0x0ac => st.shm_sel = v32,
0x070 => {
st.status = v32;
if v32 & super::STATUS_DRIVER_OK != 0 && !st.activated {
st.activated = true;
let queues = st.queues.clone();
drop(st);
self.dev.activate(queues);
}
}
0x080 => set_low(&mut st, |q| &mut q.desc_table, v32),
0x084 => set_high(&mut st, |q| &mut q.desc_table, v32),
0x090 => set_low(&mut st, |q| &mut q.avail_ring, v32),
0x094 => set_high(&mut st, |q| &mut q.avail_ring, v32),
0x0a0 => set_low(&mut st, |q| &mut q.used_ring, v32),
0x0a4 => set_high(&mut st, |q| &mut q.used_ring, v32),
0x100.. => {
drop(st);
self.dev.config_write((offset - 0x100) as usize, v32);
}
_ => {}
}
}
fn len(&self) -> u64 {
0x200
}
}
fn set_low(st: &mut State, accessor: impl FnOnce(&mut Queue) -> &mut u64, v: u32) {
let sel = st.queue_sel as usize;
if let Some(q) = st.queues.get_mut(sel) {
let r = accessor(q);
*r = (*r & !0xffff_ffff) | (v as u64);
}
}
fn set_high(st: &mut State, accessor: impl FnOnce(&mut Queue) -> &mut u64, v: u32) {
let sel = st.queue_sel as usize;
if let Some(q) = st.queues.get_mut(sel) {
let r = accessor(q);
*r = (*r & 0xffff_ffff) | ((v as u64) << 32);
}
}
pub fn raise_used_buffer_irq(mmio: &MmioVirtio) {
let mut st = mmio.state.lock().unwrap();
st.interrupt_status |= 0x1; let f = st.irq_raise.clone();
drop(st);
f();
}
#[cfg(test)]
mod tests {
#[test]
fn mmio_snapshot_codec_round_trips_and_is_stable() {
use super::{MmioSnapshot, QueueSnapshot};
let snap = MmioSnapshot {
driver_features: [0x1234_5678, 0x9abc_def0],
status: 0xf,
interrupt_status: 0x1,
queues: vec![
QueueSnapshot {
size: 256,
ready: true,
desc_table: 0x1_0000,
avail_ring: 0x1_1000,
used_ring: 0x1_2000,
last_avail_idx: 7,
next_used_idx: 7,
},
QueueSnapshot {
size: 64,
ready: false,
desc_table: 0x2_0000,
avail_ring: 0,
used_ring: 0,
last_avail_idx: 0,
next_used_idx: 0,
},
],
};
let mut buf = Vec::new();
snap.write_to(&mut buf).unwrap();
assert_eq!(buf.len(), 82, "wire layout drifted (alignment pad?)");
let got = MmioSnapshot::read_from(&mut &buf[..]).unwrap();
assert_eq!(got.driver_features, snap.driver_features);
assert_eq!(got.status, snap.status);
assert_eq!(got.interrupt_status, snap.interrupt_status);
assert_eq!(got.queues.len(), 2);
assert_eq!(got.queues[0].size, 256);
assert!(got.queues[0].ready);
assert_eq!(got.queues[0].desc_table, 0x1_0000);
assert_eq!(got.queues[0].last_avail_idx, 7);
assert!(!got.queues[1].ready);
assert_eq!(got.queues[1].size, 64);
}
use super::*;
use crate::devices::mmio_bus::MmioDevice;
use crate::devices::virtio::fs::{VirtioFs, VirtioFsConfig};
use crate::devices::virtio::queue::GuestMem;
use std::sync::Arc;
fn make_fs_mmio() -> Arc<MmioVirtio> {
let dev = Arc::new(VirtioFs::new(VirtioFsConfig {
tag: "shared".into(),
num_request_queues: 1,
dax_window_gpa: 0x80_0000_0000,
dax_window_len: 0x4_0000_0000, }));
let mem = GuestMem::new(std::ptr::null_mut(), 0, 0);
let irq = Arc::new(|| {}) as Arc<dyn Fn() + Send + Sync>;
Arc::new(MmioVirtio::new(dev as Arc<dyn VirtioDevice>, mem, irq))
}
fn read_u32(d: &MmioVirtio, off: u64) -> u32 {
d.read(off, 4) as u32
}
#[test]
fn shm_region_0_round_trip() {
let m = make_fs_mmio();
m.write(0x0ac, 0, 4);
let len_lo = read_u32(&m, 0x0b0) as u64;
let len_hi = read_u32(&m, 0x0b4) as u64;
let len = (len_hi << 32) | len_lo;
assert_eq!(len, 0x4_0000_0000, "DAX window len");
let base_lo = read_u32(&m, 0x0b8) as u64;
let base_hi = read_u32(&m, 0x0bc) as u64;
let base = (base_hi << 32) | base_lo;
assert_eq!(base, 0x80_0000_0000, "DAX window base");
}
#[test]
fn shm_unknown_region_reports_minus_one_len() {
let m = make_fs_mmio();
m.write(0x0ac, 99, 4);
let len_lo = read_u32(&m, 0x0b0);
let len_hi = read_u32(&m, 0x0b4);
assert_eq!(len_lo, 0xFFFF_FFFF);
assert_eq!(len_hi, 0xFFFF_FFFF);
}
#[test]
fn device_id_and_vendor_id_visible() {
let m = make_fs_mmio();
assert_eq!(read_u32(&m, 0x000), MAGIC);
assert_eq!(read_u32(&m, 0x004), VERSION);
assert_eq!(read_u32(&m, 0x008), 26, "VIRTIO_ID_FS");
}
#[test]
fn queue_num_and_ready_round_trip() {
let m = make_fs_mmio();
m.write(0x030, 0, 4); m.write(0x038, 64, 4); m.write(0x044, 1, 4); let s = m.capture_state();
assert_eq!(s.queues[0].size, 64);
assert!(s.queues[0].ready);
}
#[test]
fn queue_sel_out_of_range_writes_are_ignored_safely() {
let m = make_fs_mmio();
m.write(0x030, 0, 4);
m.write(0x038, 32, 4);
m.write(0x030, 999, 4);
m.write(0x038, 4096, 4);
m.write(0x044, 1, 4);
m.write(0x080, 0xdead_beef, 4);
let s = m.capture_state();
assert_eq!(s.queues[0].size, 32, "real queue untouched by OOR writes");
assert!(!s.queues[0].ready);
}
#[test]
fn queue_address_triples_round_trip() {
let m = make_fs_mmio();
m.write(0x030, 0, 4); m.write(0x080, 0x1111_2222, 4); m.write(0x084, 0xaaaa_bbbb, 4); m.write(0x090, 0x3333_4444, 4); m.write(0x0a0, 0x5555_6666, 4); let s = m.capture_state();
assert_eq!(s.queues[0].desc_table, 0xaaaa_bbbb_1111_2222, "lo/hi merge");
assert_eq!(s.queues[0].avail_ring & 0xffff_ffff, 0x3333_4444);
assert_eq!(s.queues[0].used_ring & 0xffff_ffff, 0x5555_6666);
}
#[test]
fn hostile_zero_size_queue_activate_does_not_panic() {
let m = make_fs_mmio();
m.write(0x030, 0, 4); m.write(0x038, 0, 4); m.write(0x044, 1, 4); m.write(0x070, crate::devices::virtio::STATUS_DRIVER_OK as u64, 4); let s = m.capture_state();
assert_eq!(s.queues[0].size, 0);
assert_ne!(s.status & crate::devices::virtio::STATUS_DRIVER_OK, 0);
}
}