#![allow(dead_code)]
use std::sync::atomic::Ordering;
use std::sync::Arc;
pub const VRING_DESC_F_NEXT: u16 = 1;
pub const VRING_DESC_F_WRITE: u16 = 2;
pub const VRING_DESC_F_INDIRECT: u16 = 4;
#[inline]
fn vq_trace_enabled() -> bool {
static CACHED: std::sync::atomic::AtomicU8 = std::sync::atomic::AtomicU8::new(0);
let v = CACHED.load(std::sync::atomic::Ordering::Relaxed);
if v != 0 {
return v == 2;
}
let on = crate::trace::enabled("vq");
CACHED.store(if on { 2 } else { 1 }, std::sync::atomic::Ordering::Relaxed);
on
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct Desc {
pub addr: u64,
pub len: u32,
pub flags: u16,
pub next: u16,
}
#[derive(Clone)]
pub struct Queue {
pub size: u16,
pub ready: bool,
pub desc_table: u64,
pub avail_ring: u64,
pub used_ring: u64,
pub last_avail_idx: u16,
pub next_used_idx: u16,
pub mem: GuestMem,
}
impl Queue {
pub fn new(mem: GuestMem) -> Self {
Self {
size: 256,
ready: false,
desc_table: 0,
avail_ring: 0,
used_ring: 0,
last_avail_idx: 0,
next_used_idx: 0,
mem,
}
}
pub fn avail_idx(&self) -> u16 {
self.mem.read_u16(self.avail_ring + 2)
}
pub fn pop_chain(&mut self) -> Option<(u16, Vec<Desc>)> {
if self.size == 0 {
return None;
}
let avail = self.avail_idx();
if avail == self.last_avail_idx {
return None;
}
std::sync::atomic::fence(Ordering::Acquire);
let off = self.avail_ring + 4 + ((self.last_avail_idx % self.size) as u64) * 2;
let head = self.mem.read_u16(off);
if vq_trace_enabled() {
eprintln!("[vq desc=0x{:x}] pop_chain: avail.idx={avail} last_avail_idx={} slot={} head={head}",
self.desc_table, self.last_avail_idx, self.last_avail_idx % self.size);
}
self.last_avail_idx = self.last_avail_idx.wrapping_add(1);
let mut chain = Vec::new();
let mut idx = head;
for _ in 0..self.size {
if idx >= self.size {
break;
}
let d_addr = self.desc_table + (idx as u64) * 16;
let desc = Desc {
addr: self.mem.read_u64(d_addr),
len: self.mem.read_u32(d_addr + 8),
flags: self.mem.read_u16(d_addr + 12),
next: self.mem.read_u16(d_addr + 14),
};
chain.push(desc);
if desc.flags & VRING_DESC_F_NEXT == 0 {
break;
}
idx = desc.next;
}
Some((head, chain))
}
pub fn add_used(&mut self, head: u16, used_len: u32) {
if self.size == 0 {
return;
}
let entry_off = self.used_ring + 4 + ((self.next_used_idx % self.size) as u64) * 8;
self.mem.write_u32(entry_off, head as u32);
self.mem.write_u32(entry_off + 4, used_len);
if vq_trace_enabled() {
eprintln!("[vq desc=0x{:x}] add_used: slot={} head={head} len={used_len} next_used_idx_after={}",
self.desc_table, self.next_used_idx % self.size, self.next_used_idx + 1);
}
self.next_used_idx = self.next_used_idx.wrapping_add(1);
std::sync::atomic::fence(Ordering::Release);
self.mem.write_u16(self.used_ring + 2, self.next_used_idx);
}
}
pub fn read_readable_capped(chain: &[Desc], mem: &GuestMem, max_bytes: usize) -> Option<Vec<u8>> {
let mut out: Vec<u8> = Vec::new();
for d in chain.iter().filter(|d| d.flags & VRING_DESC_F_WRITE == 0) {
let want = d.len as usize;
if out.len().saturating_add(want) > max_bytes {
return None;
}
let off = out.len();
out.resize(off + want, 0);
mem.read_slice(d.addr, &mut out[off..]);
}
Some(out)
}
pub fn write_writable(chain: &[Desc], mem: &GuestMem, bytes: &[u8]) -> usize {
let mut written = 0usize;
for d in chain.iter().filter(|d| d.flags & VRING_DESC_F_WRITE != 0) {
if written >= bytes.len() {
break;
}
let take = (d.len as usize).min(bytes.len() - written);
mem.write_slice(d.addr, &bytes[written..written + take]);
written += take;
}
written
}
#[derive(Clone)]
pub struct GuestMem {
inner: Arc<GuestMemInner>,
}
struct GuestMemInner {
host: *mut u8,
base_gpa: u64,
len: usize,
}
unsafe impl Send for GuestMemInner {}
unsafe impl Sync for GuestMemInner {}
impl GuestMem {
pub fn new(host: *mut u8, base_gpa: u64, len: usize) -> Self {
Self {
inner: Arc::new(GuestMemInner {
host,
base_gpa,
len,
}),
}
}
fn translate(&self, gpa: u64, n: usize) -> Option<*mut u8> {
let off = gpa.checked_sub(self.inner.base_gpa)? as usize;
if off > self.inner.len {
return None;
}
if n > self.inner.len - off {
return None;
}
Some(unsafe { self.inner.host.add(off) })
}
fn translate_or_zero(&self, gpa: u64, n: usize) -> *mut u8 {
match self.translate(gpa, n) {
Some(p) => p,
None => {
eprintln!(
"[guest-mem] OOB access: gpa={gpa:#x} len={n} (base={:#x} len={:#x}); zero-filling",
self.inner.base_gpa, self.inner.len
);
std::ptr::null_mut()
}
}
}
pub fn read_u16(&self, gpa: u64) -> u16 {
let p = self.translate_or_zero(gpa, 2);
if p.is_null() {
return 0;
}
unsafe { std::ptr::read_unaligned(p as *const u16) }
}
pub fn read_u32(&self, gpa: u64) -> u32 {
let p = self.translate_or_zero(gpa, 4);
if p.is_null() {
return 0;
}
unsafe { std::ptr::read_unaligned(p as *const u32) }
}
pub fn read_u64(&self, gpa: u64) -> u64 {
let p = self.translate_or_zero(gpa, 8);
if p.is_null() {
return 0;
}
unsafe { std::ptr::read_unaligned(p as *const u64) }
}
pub fn write_u16(&self, gpa: u64, v: u16) {
if let Some(p) = self.translate(gpa, 2) {
unsafe { std::ptr::write_unaligned(p as *mut u16, v) }
}
}
pub fn write_u32(&self, gpa: u64, v: u32) {
if let Some(p) = self.translate(gpa, 4) {
unsafe { std::ptr::write_unaligned(p as *mut u32, v) }
}
}
pub fn write_u64(&self, gpa: u64, v: u64) {
if let Some(p) = self.translate(gpa, 8) {
unsafe { std::ptr::write_unaligned(p as *mut u64, v) }
}
}
pub fn read_slice(&self, gpa: u64, dst: &mut [u8]) {
match self.translate(gpa, dst.len()) {
Some(p) => unsafe {
std::ptr::copy_nonoverlapping(p as *const u8, dst.as_mut_ptr(), dst.len())
},
None => {
eprintln!(
"[guest-mem] OOB read_slice: gpa={gpa:#x} len={} (base={:#x} len={:#x}); zero-filling",
dst.len(), self.inner.base_gpa, self.inner.len
);
dst.fill(0);
}
}
}
pub fn write_slice(&self, gpa: u64, src: &[u8]) {
match self.translate(gpa, src.len()) {
Some(p) => unsafe { std::ptr::copy_nonoverlapping(src.as_ptr(), p, src.len()) },
None => {
eprintln!(
"[guest-mem] OOB write_slice: gpa={gpa:#x} len={} (base={:#x} len={:#x}); dropping",
src.len(), self.inner.base_gpa, self.inner.len
);
}
}
}
pub fn host_ptr(&self, gpa: u64, len: usize) -> *mut u8 {
self.translate(gpa, len).unwrap_or(std::ptr::null_mut())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_mem() -> (GuestMem, Vec<u8>) {
let mut buf = vec![0u8; 4096];
for (i, b) in buf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let mem = GuestMem::new(buf.as_mut_ptr(), 0x10_0000, buf.len());
(mem, buf)
}
fn desc(addr: u64, len: u32, writable: bool) -> Desc {
Desc {
addr,
len,
flags: if writable { VRING_DESC_F_WRITE } else { 0 },
next: 0,
}
}
#[test]
fn read_readable_capped_concatenates_only_readable() {
let (mem, _buf) = make_mem();
let base = 0x10_0000;
mem.write_slice(base + 0x100, b"AAAA");
mem.write_slice(base + 0x200, b"BBBB");
let chain = [
desc(base + 0x100, 4, false),
desc(base + 0x300, 4, true),
desc(base + 0x200, 4, false),
];
let got = read_readable_capped(&chain, &mem, 1024).unwrap();
assert_eq!(got, b"AAAABBBB", "writable descriptor excluded");
}
#[test]
fn read_readable_capped_empty_and_overcap() {
let (mem, _buf) = make_mem();
assert_eq!(read_readable_capped(&[], &mem, 1024), Some(Vec::new()));
let chain = [desc(0x10_0000, 4096, false)];
assert_eq!(read_readable_capped(&chain, &mem, 64), None);
assert!(read_readable_capped(&chain, &mem, 4096).is_some());
}
#[test]
fn write_writable_bounded_and_skips_readable() {
let (mem, _buf) = make_mem();
let base = 0x10_0000;
let chain = [
desc(base + 0x100, 4, true),
desc(base + 0x200, 4, false),
desc(base + 0x300, 4, true),
];
let n = write_writable(&chain, &mem, b"12345678");
assert_eq!(n, 8);
let mut a = [0u8; 4];
let mut b = [0u8; 4];
mem.read_slice(base + 0x100, &mut a);
mem.read_slice(base + 0x300, &mut b);
assert_eq!(&a, b"1234");
assert_eq!(&b, b"5678");
let mut skipped = [0xffu8; 4];
mem.read_slice(base + 0x200, &mut skipped);
assert_ne!(&skipped, b"5678");
}
#[test]
fn write_writable_truncates_to_descriptor_capacity() {
let (mem, _buf) = make_mem();
let base = 0x10_0000;
let chain = [desc(base + 0x100, 4, true)];
assert_eq!(write_writable(&chain, &mem, b"12345678"), 4);
assert_eq!(write_writable(&[], &mem, b"x"), 0);
assert_eq!(write_writable(&chain, &mem, b"hi"), 2);
}
#[test]
fn size_zero_queue_does_not_divide_by_zero() {
let (mem, _buf) = make_mem();
let avail_ring = 0x10_0000 + 0x100;
mem.write_u16(avail_ring + 2, 1);
let mut q = Queue::new(mem);
q.size = 0;
q.ready = true;
q.avail_ring = avail_ring;
q.used_ring = 0x10_0000 + 0x200;
assert!(q.pop_chain().is_none(), "size-0 queue yields no chains");
q.add_used(0, 0); }
#[test]
fn in_bounds_read_returns_real_data() {
let (mem, _buf) = make_mem();
assert_eq!(mem.read_u32(0x10_0000), 0x03020100);
assert_eq!(mem.read_u32(0x10_0000 + 4096 - 4), {
u32::from_le_bytes([0xfc, 0xfd, 0xfe, 0xff])
});
}
#[test]
fn underflow_is_caught_not_uaf() {
let (mem, _buf) = make_mem();
assert_eq!(mem.read_u32(0x10_0000 - 1), 0);
assert_eq!(mem.read_u32(0x0), 0);
assert_eq!(mem.read_u64(u64::MAX - 100), 0);
}
#[test]
fn overflow_past_end_is_caught() {
let (mem, _buf) = make_mem();
let _ = mem.read_u32(0x10_0000 + 4096 - 4);
assert_eq!(mem.read_u32(0x10_0000 + 4096 - 3), 0);
let mut dst = [0u8; 8];
mem.read_slice(0x10_0000 + 4096 - 4, &mut dst[..]);
assert!(dst.iter().all(|&b| b == 0) || dst[..4] == [0xfc, 0xfd, 0xfe, 0xff]);
}
#[test]
fn oob_write_is_silently_dropped() {
let (mem, buf) = make_mem();
let _ = buf;
mem.write_u32(0x0, 0xdeadbeef);
mem.write_slice(u64::MAX - 100, &[0xff; 100]);
}
#[test]
fn descriptor_with_oob_next_terminates_chain() {
let (mem, _buf) = make_mem();
let _ = mem.read_u64(0x10_0000 + 4096 - 8);
assert_eq!(mem.read_u64(0x10_0000 + 4096 - 7), 0);
}
use proptest::prelude::*;
const PT_BASE: u64 = 0x10_0000;
const PT_LEN: usize = 8192;
proptest! {
#![proptest_config(ProptestConfig::with_cases(512))]
#[test]
fn pop_chain_is_panic_free_terminates_and_bounded(
seed in any::<u64>(),
size_log in 0u32..=8, last_avail in any::<u16>(),
used_head in any::<u16>(),
used_len in any::<u32>(),
) {
let mut buf = vec![0u8; PT_LEN];
let mut x = seed | 1;
for b in buf.iter_mut() {
x ^= x << 13; x ^= x >> 7; x ^= x << 17;
*b = (x & 0xff) as u8;
}
let mem = GuestMem::new(buf.as_mut_ptr(), PT_BASE, PT_LEN);
let size: u16 = 1 << size_log;
let desc_table = PT_BASE; let avail_ring = PT_BASE + 4096; let used_ring = PT_BASE + 6000;
mem.write_u16(avail_ring + 2, last_avail.wrapping_add(1));
let mut q = Queue::new(mem);
q.size = size;
q.ready = true;
q.desc_table = desc_table;
q.avail_ring = avail_ring;
q.used_ring = used_ring;
q.last_avail_idx = last_avail;
if let Some((_head, chain)) = q.pop_chain() {
prop_assert!(
chain.len() <= size as usize,
"chain len {} exceeded queue size {}",
chain.len(), size,
);
}
q.add_used(used_head, used_len);
}
}
}