#![allow(dead_code)]
use std::sync::atomic::Ordering;
use std::sync::Arc;
pub const VRING_DESC_F_NEXT: u16 = 1;
pub const VRING_DESC_F_WRITE: u16 = 2;
pub const VRING_DESC_F_INDIRECT: u16 = 4;
#[inline]
fn vq_trace_enabled() -> bool {
static CACHED: std::sync::atomic::AtomicU8 = std::sync::atomic::AtomicU8::new(0);
let v = CACHED.load(std::sync::atomic::Ordering::Relaxed);
if v != 0 {
return v == 2;
}
let on = std::env::var_os("SUPERMACHINE_VQ_TRACE").is_some();
CACHED.store(
if on { 2 } else { 1 },
std::sync::atomic::Ordering::Relaxed,
);
on
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct Desc {
pub addr: u64,
pub len: u32,
pub flags: u16,
pub next: u16,
}
#[derive(Clone)]
pub struct Queue {
pub size: u16,
pub ready: bool,
pub desc_table: u64,
pub avail_ring: u64,
pub used_ring: u64,
pub last_avail_idx: u16,
pub next_used_idx: u16,
pub mem: GuestMem,
}
impl Queue {
pub fn new(mem: GuestMem) -> Self {
Self {
size: 256,
ready: false,
desc_table: 0,
avail_ring: 0,
used_ring: 0,
last_avail_idx: 0,
next_used_idx: 0,
mem,
}
}
pub fn avail_idx(&self) -> u16 {
self.mem.read_u16(self.avail_ring + 2)
}
pub fn pop_chain(&mut self) -> Option<(u16, Vec<Desc>)> {
let avail = self.avail_idx();
if avail == self.last_avail_idx {
return None;
}
let off = self.avail_ring + 4 + ((self.last_avail_idx % self.size) as u64) * 2;
let head = self.mem.read_u16(off);
if vq_trace_enabled() {
eprintln!("[vq desc=0x{:x}] pop_chain: avail.idx={avail} last_avail_idx={} slot={} head={head}",
self.desc_table, self.last_avail_idx, self.last_avail_idx % self.size);
}
self.last_avail_idx = self.last_avail_idx.wrapping_add(1);
let mut chain = Vec::new();
let mut idx = head;
for _ in 0..self.size {
if idx >= self.size {
break;
}
let d_addr = self.desc_table + (idx as u64) * 16;
let desc = Desc {
addr: self.mem.read_u64(d_addr),
len: self.mem.read_u32(d_addr + 8),
flags: self.mem.read_u16(d_addr + 12),
next: self.mem.read_u16(d_addr + 14),
};
chain.push(desc);
if desc.flags & VRING_DESC_F_NEXT == 0 {
break;
}
idx = desc.next;
}
Some((head, chain))
}
pub fn add_used(&mut self, head: u16, used_len: u32) {
let entry_off = self.used_ring + 4 + ((self.next_used_idx % self.size) as u64) * 8;
self.mem.write_u32(entry_off, head as u32);
self.mem.write_u32(entry_off + 4, used_len);
if vq_trace_enabled() {
eprintln!("[vq desc=0x{:x}] add_used: slot={} head={head} len={used_len} next_used_idx_after={}",
self.desc_table, self.next_used_idx % self.size, self.next_used_idx + 1);
}
self.next_used_idx = self.next_used_idx.wrapping_add(1);
std::sync::atomic::fence(Ordering::Release);
self.mem.write_u16(self.used_ring + 2, self.next_used_idx);
}
}
#[derive(Clone)]
pub struct GuestMem {
inner: Arc<GuestMemInner>,
}
struct GuestMemInner {
host: *mut u8,
base_gpa: u64,
len: usize,
}
unsafe impl Send for GuestMemInner {}
unsafe impl Sync for GuestMemInner {}
impl GuestMem {
pub fn new(host: *mut u8, base_gpa: u64, len: usize) -> Self {
Self {
inner: Arc::new(GuestMemInner {
host,
base_gpa,
len,
}),
}
}
fn translate(&self, gpa: u64, n: usize) -> Option<*mut u8> {
let off = gpa.checked_sub(self.inner.base_gpa)? as usize;
if off > self.inner.len {
return None;
}
if n > self.inner.len - off {
return None;
}
Some(unsafe { self.inner.host.add(off) })
}
fn translate_or_zero(&self, gpa: u64, n: usize) -> *mut u8 {
match self.translate(gpa, n) {
Some(p) => p,
None => {
eprintln!(
"[guest-mem] OOB access: gpa={gpa:#x} len={n} (base={:#x} len={:#x}); zero-filling",
self.inner.base_gpa, self.inner.len
);
std::ptr::null_mut()
}
}
}
pub fn read_u16(&self, gpa: u64) -> u16 {
let p = self.translate_or_zero(gpa, 2);
if p.is_null() {
return 0;
}
unsafe { std::ptr::read_unaligned(p as *const u16) }
}
pub fn read_u32(&self, gpa: u64) -> u32 {
let p = self.translate_or_zero(gpa, 4);
if p.is_null() {
return 0;
}
unsafe { std::ptr::read_unaligned(p as *const u32) }
}
pub fn read_u64(&self, gpa: u64) -> u64 {
let p = self.translate_or_zero(gpa, 8);
if p.is_null() {
return 0;
}
unsafe { std::ptr::read_unaligned(p as *const u64) }
}
pub fn write_u16(&self, gpa: u64, v: u16) {
if let Some(p) = self.translate(gpa, 2) {
unsafe { std::ptr::write_unaligned(p as *mut u16, v) }
}
}
pub fn write_u32(&self, gpa: u64, v: u32) {
if let Some(p) = self.translate(gpa, 4) {
unsafe { std::ptr::write_unaligned(p as *mut u32, v) }
}
}
pub fn write_u64(&self, gpa: u64, v: u64) {
if let Some(p) = self.translate(gpa, 8) {
unsafe { std::ptr::write_unaligned(p as *mut u64, v) }
}
}
pub fn read_slice(&self, gpa: u64, dst: &mut [u8]) {
match self.translate(gpa, dst.len()) {
Some(p) => unsafe {
std::ptr::copy_nonoverlapping(p as *const u8, dst.as_mut_ptr(), dst.len())
},
None => {
eprintln!(
"[guest-mem] OOB read_slice: gpa={gpa:#x} len={} (base={:#x} len={:#x}); zero-filling",
dst.len(), self.inner.base_gpa, self.inner.len
);
dst.fill(0);
}
}
}
pub fn write_slice(&self, gpa: u64, src: &[u8]) {
match self.translate(gpa, src.len()) {
Some(p) => unsafe {
std::ptr::copy_nonoverlapping(src.as_ptr(), p, src.len())
},
None => {
eprintln!(
"[guest-mem] OOB write_slice: gpa={gpa:#x} len={} (base={:#x} len={:#x}); dropping",
src.len(), self.inner.base_gpa, self.inner.len
);
}
}
}
pub fn host_ptr(&self, gpa: u64, len: usize) -> *mut u8 {
self.translate(gpa, len).unwrap_or(std::ptr::null_mut())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_mem() -> (GuestMem, Vec<u8>) {
let mut buf = vec![0u8; 4096];
for (i, b) in buf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let mem = GuestMem::new(buf.as_mut_ptr(), 0x10_0000, buf.len());
(mem, buf)
}
#[test]
fn in_bounds_read_returns_real_data() {
let (mem, _buf) = make_mem();
assert_eq!(mem.read_u32(0x10_0000), 0x03020100);
assert_eq!(mem.read_u32(0x10_0000 + 4096 - 4), {
u32::from_le_bytes([0xfc, 0xfd, 0xfe, 0xff])
});
}
#[test]
fn underflow_is_caught_not_uaf() {
let (mem, _buf) = make_mem();
assert_eq!(mem.read_u32(0x10_0000 - 1), 0);
assert_eq!(mem.read_u32(0x0), 0);
assert_eq!(mem.read_u64(u64::MAX - 100), 0);
}
#[test]
fn overflow_past_end_is_caught() {
let (mem, _buf) = make_mem();
let _ = mem.read_u32(0x10_0000 + 4096 - 4);
assert_eq!(mem.read_u32(0x10_0000 + 4096 - 3), 0);
let mut dst = vec![0u8; 8];
mem.read_slice(0x10_0000 + 4096 - 4, &mut dst[..]);
assert!(dst.iter().all(|&b| b == 0) || dst[..4] == [0xfc, 0xfd, 0xfe, 0xff]);
}
#[test]
fn oob_write_is_silently_dropped() {
let (mem, buf) = make_mem();
let _ = buf;
mem.write_u32(0x0, 0xdeadbeef);
mem.write_slice(u64::MAX - 100, &[0xff; 100]);
}
#[test]
fn descriptor_with_oob_next_terminates_chain() {
let (mem, _buf) = make_mem();
let _ = mem.read_u64(0x10_0000 + 4096 - 8);
assert_eq!(mem.read_u64(0x10_0000 + 4096 - 7), 0);
}
}