use std::mem::ManuallyDrop;
use std::num::Wrapping;
use vm_memory::{
AtomicAccess, GuestMemoryError, GuestMemoryRegion, GuestMemoryResult, MemoryRegionAddress,
VolatileSlice,
};
use std::mem::MaybeUninit;
use vm_memory::ByteValued;
use super::*;
struct StubRegion {
buffer: *mut u8,
region_len: usize,
region_start: GuestAddress,
}
impl StubRegion {
fn new(buf_ptr: *mut u8, buf_len: usize, start_offset: u64) -> Self {
Self {
buffer: buf_ptr,
region_len: buf_len,
region_start: GuestAddress(start_offset),
}
}
fn to_region_addr(&self, addr: GuestAddress) -> Option<MemoryRegionAddress> {
let offset = addr
.raw_value()
.checked_sub(self.region_start.raw_value())?;
if offset < self.region_len as u64 {
Some(MemoryRegionAddress(offset))
} else {
None
}
}
fn checked_offset(
&self,
addr: MemoryRegionAddress,
count: usize,
) -> Option<MemoryRegionAddress> {
let end = addr.0.checked_add(count as u64)?;
if end <= self.region_len as u64 {
Some(MemoryRegionAddress(end))
} else {
None
}
}
}
impl GuestMemoryRegion for StubRegion {
type B = ();
fn len(&self) -> <GuestAddress as vm_memory::AddressValue>::V {
self.region_len.try_into().unwrap()
}
fn start_addr(&self) -> GuestAddress {
self.region_start
}
fn bitmap(&self) -> Self::B {
()
}
fn get_slice(
&self,
offset: MemoryRegionAddress,
count: usize,
) -> GuestMemoryResult<VolatileSlice<()>> {
Ok(unsafe {
VolatileSlice::with_bitmap(
self.buffer.add(offset.raw_value() as usize),
count,
(),
None,
)
})
}
}
impl Bytes<MemoryRegionAddress> for StubRegion {
type E = GuestMemoryError;
fn write(&self, buf: &[u8], addr: MemoryRegionAddress) -> Result<usize, Self::E> {
let offset = addr.0 as usize;
let end = offset
.checked_add(buf.len())
.ok_or(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)))?;
if end > self.region_len as usize {
return Err(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)));
}
unsafe {
std::ptr::copy_nonoverlapping(buf.as_ptr(), self.buffer.add(offset), buf.len());
}
Ok(buf.len())
}
fn read(&self, buf: &mut [u8], addr: MemoryRegionAddress) -> Result<usize, Self::E> {
let offset = addr.0 as usize;
let end = offset
.checked_add(buf.len())
.ok_or(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)))?;
if end > self.region_len as usize {
return Err(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)));
}
unsafe {
std::ptr::copy_nonoverlapping(self.buffer.add(offset), buf.as_mut_ptr(), buf.len());
}
Ok(buf.len())
}
fn write_slice(&self, buf: &[u8], addr: MemoryRegionAddress) -> Result<(), Self::E> {
self.write(buf, addr)?;
Ok(())
}
fn read_slice(&self, buf: &mut [u8], addr: MemoryRegionAddress) -> Result<(), Self::E> {
self.read(buf, addr)?;
Ok(())
}
fn read_obj<T: ByteValued>(&self, addr: MemoryRegionAddress) -> Result<T, Self::E> {
let size = std::mem::size_of::<T>();
let offset = addr.0 as usize;
let end = offset
.checked_add(size)
.ok_or(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)))?;
if end > self.region_len as usize {
return Err(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)));
}
let mut result: T = unsafe { MaybeUninit::<T>::zeroed().assume_init() };
unsafe {
std::ptr::copy_nonoverlapping(
self.buffer.add(offset),
result.as_mut_slice().as_mut_ptr(),
size,
);
}
Ok(result)
}
fn write_obj<T: ByteValued>(&self, val: T, addr: MemoryRegionAddress) -> Result<(), Self::E> {
let size = std::mem::size_of::<T>();
let offset = addr.0 as usize;
let end = offset
.checked_add(size)
.ok_or(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)))?;
if end > self.region_len as usize {
return Err(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)));
}
let bytes = val.as_slice();
unsafe {
std::ptr::copy_nonoverlapping(bytes.as_ptr(), self.buffer.add(offset), size);
}
Ok(())
}
fn read_volatile_from<F>(
&self,
addr: MemoryRegionAddress,
_src: &mut F,
count: usize,
) -> Result<usize, Self::E>
where
F: vm_memory::ReadVolatile,
{
let offset = addr.0 as usize;
let end = offset
.checked_add(count)
.ok_or(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)))?;
if end > self.region_len as usize {
return Err(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)));
}
unsafe {
let slice = std::slice::from_raw_parts_mut(self.buffer.add(offset), count);
let v = vm_memory::volatile_memory::VolatileSlice::from(slice);
let mut s = v.offset(0).map_err(Into::<Self::E>::into)?;
let n = _src.read_volatile(&mut s).map_err(Into::<Self::E>::into)?;
return Ok(n);
}
}
fn read_exact_volatile_from<F>(
&self,
addr: MemoryRegionAddress,
src: &mut F,
count: usize,
) -> Result<(), Self::E>
where
F: vm_memory::ReadVolatile,
{
let _ = self.read_volatile_from(addr, src, count)?;
Ok(())
}
fn write_volatile_to<F>(
&self,
addr: MemoryRegionAddress,
dst: &mut F,
count: usize,
) -> Result<usize, Self::E>
where
F: vm_memory::WriteVolatile,
{
let offset = addr.0 as usize;
let end = offset
.checked_add(count)
.ok_or(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)))?;
if end > self.region_len as usize {
return Err(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)));
}
unsafe {
let slice = std::slice::from_raw_parts_mut(self.buffer.add(offset), count);
let v = vm_memory::volatile_memory::VolatileSlice::from(slice);
return dst.write_volatile(&v).map_err(Into::into);
}
}
fn write_all_volatile_to<F>(
&self,
addr: MemoryRegionAddress,
dst: &mut F,
count: usize,
) -> Result<(), Self::E>
where
F: vm_memory::WriteVolatile,
{
let _ = self.write_volatile_to(addr, dst, count)?;
Ok(())
}
fn store<T: AtomicAccess>(
&self,
val: T,
addr: MemoryRegionAddress,
_order: Ordering,
) -> Result<(), Self::E> {
let size = std::mem::size_of::<T>();
let offset = addr.0 as usize;
let end = offset
.checked_add(size)
.ok_or(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)))?;
if end > self.region_len as usize {
return Err(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)));
}
let bytes = val.as_slice();
unsafe {
std::ptr::copy_nonoverlapping(bytes.as_ptr(), self.buffer.add(offset), size);
}
Ok(())
}
fn load<T: AtomicAccess>(
&self,
addr: MemoryRegionAddress,
_order: Ordering,
) -> Result<T, Self::E> {
let size = std::mem::size_of::<T>();
let offset = addr.0 as usize;
let end = offset
.checked_add(size)
.ok_or(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)))?;
if end > self.region_len as usize {
return Err(GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)));
}
unsafe {
let slice = std::slice::from_raw_parts(self.buffer.add(offset), size);
T::from_slice(slice)
.ok_or_else(|| GuestMemoryError::InvalidGuestAddress(GuestAddress(addr.0)))
.copied()
}
}
}
#[kani::proof]
#[kani::unwind(0)]
fn verify_stubregion_write_read() {
let mut buffer = kani::vec::exact_vec::<u8, 16>();
let region = StubRegion::new(buffer.as_mut_ptr(), buffer.len(), 0);
let bytes: [u8; 16] = kani::any();
let write_offset: usize = kani::any();
kani::assume(write_offset <= buffer.len() - 16);
assert!(region
.write(&bytes, MemoryRegionAddress(write_offset as u64))
.is_ok());
let mut readback = kani::vec::exact_vec::<u8, 16>();
assert!(region
.read(&mut readback, MemoryRegionAddress(write_offset as u64))
.is_ok());
let idx: usize = kani::any();
kani::assume(idx < 16);
assert_eq!(bytes[idx], readback[idx]);
}
struct SingleRegionGuestMemory {
the_region: StubRegion,
}
impl GuestMemory for SingleRegionGuestMemory {
type R = StubRegion;
fn num_regions(&self) -> usize {
1
}
fn find_region(&self, addr: GuestAddress) -> Option<&Self::R> {
self.the_region
.to_region_addr(addr)
.map(|_| &self.the_region)
}
fn iter(&self) -> impl Iterator<Item = &Self::R> {
std::iter::once(&self.the_region)
}
fn try_access<F>(
&self,
count: usize,
addr: GuestAddress,
mut f: F,
) -> vm_memory::guest_memory::Result<usize>
where
F: FnMut(
usize,
usize,
MemoryRegionAddress,
&Self::R,
) -> vm_memory::guest_memory::Result<usize>,
{
let region_addr = self
.the_region
.to_region_addr(addr)
.ok_or(vm_memory::guest_memory::Error::InvalidGuestAddress(addr))?;
self.the_region
.checked_offset(region_addr, count)
.ok_or(vm_memory::guest_memory::Error::InvalidGuestAddress(addr))?;
f(0, count, region_addr, &self.the_region)
}
}
impl kani::Arbitrary for SingleRegionGuestMemory {
fn any() -> Self {
let memory =
ManuallyDrop::new(kani::vec::exact_vec::<u8, GUEST_MEMORY_SIZE>()).as_mut_ptr();
Self {
the_region: StubRegion::new(memory, GUEST_MEMORY_SIZE, GUEST_MEMORY_BASE),
}
}
}
struct ProofContext {
queue: Queue,
memory: SingleRegionGuestMemory,
}
const GUEST_MEMORY_BASE: u64 = 0;
const GUEST_MEMORY_SIZE: usize = QUEUE_END as usize + 30;
const MAX_QUEUE_SIZE: u16 = 4;
const QUEUE_BASE_ADDRESS: u64 = GUEST_MEMORY_BASE;
const MAX_START_AVAIL_RING_BASE_ADDRESS: u64 = QUEUE_BASE_ADDRESS + MAX_QUEUE_SIZE as u64 * 16;
const MAX_START_USED_RING_BASE_ADDRESS: u64 =
MAX_START_AVAIL_RING_BASE_ADDRESS + 6 + 2 * MAX_QUEUE_SIZE as u64 + 2;
const QUEUE_END: u64 = MAX_START_USED_RING_BASE_ADDRESS + 6 + 8 * MAX_QUEUE_SIZE as u64;
impl kani::Arbitrary for ProofContext {
fn any() -> Self {
let desc_tbl_queue_size = kani::any::<u16>();
kani::assume(
desc_tbl_queue_size <= 16 * MAX_QUEUE_SIZE && (desc_tbl_queue_size & 0xF == 0),
);
let avail_ring_base_address: u64 = QUEUE_BASE_ADDRESS + desc_tbl_queue_size as u64;
let avail_ring_queue_size = kani::any::<u16>();
kani::assume(
avail_ring_queue_size <= 6 + 2 * MAX_QUEUE_SIZE + 2
&& (avail_ring_queue_size & 0x1 == 0),
);
let used_ring_base_address: u64 = avail_ring_base_address + avail_ring_queue_size as u64;
let used_ring_queue_size = kani::any::<u16>();
kani::assume(
avail_ring_queue_size <= 6 + 8 * MAX_QUEUE_SIZE && (used_ring_queue_size & 0x3 == 0),
);
kani::assume(QUEUE_END == used_ring_base_address + used_ring_queue_size as u64);
let mem = SingleRegionGuestMemory::any();
let mut queue = Queue::new(MAX_QUEUE_SIZE).unwrap();
queue.ready = true;
queue.set_desc_table_address(
Some(QUEUE_BASE_ADDRESS as u32),
Some((QUEUE_BASE_ADDRESS >> 32) as u32),
);
queue.set_avail_ring_address(
Some(avail_ring_base_address as u32),
Some((avail_ring_base_address >> 32) as u32),
);
queue.set_used_ring_address(
Some(used_ring_base_address as u32),
Some((used_ring_base_address >> 32) as u32),
);
queue.set_next_avail(kani::any());
queue.set_next_used(kani::any());
queue.set_event_idx(kani::any());
queue.num_added = Wrapping(kani::any());
kani::assume(queue.is_valid(&mem));
ProofContext { queue, memory: mem }
}
}
#[kani::proof]
#[kani::unwind(0)]
fn verify_device_notification_suppression() {
let ProofContext {
mut queue,
memory: mem,
} = kani::any();
let num_added_old = queue.num_added.0;
let needs_notification = queue.needs_notification(&mem);
if !queue.event_idx_enabled {
assert!(needs_notification.unwrap());
} else {
if Wrapping(queue.used_event(&mem, Ordering::Relaxed).unwrap())
== std::num::Wrapping(queue.next_used - Wrapping(1))
&& num_added_old > 0
{
assert!(needs_notification.unwrap());
kani::cover!();
}
}
}
fn get_used_idx(
queue: &Queue,
mem: &SingleRegionGuestMemory,
) -> Result<u16, vm_memory::GuestMemoryError> {
let addr =
queue
.used_ring
.checked_add(2)
.ok_or(vm_memory::GuestMemoryError::InvalidGuestAddress(
queue.used_ring,
))?;
let val = mem.load(addr, Ordering::Acquire)?;
Ok(u16::from_le(val))
}
fn verify_add_used() {
let ProofContext { mut queue, memory } = kani::any();
let used_idx = queue.next_used;
let used_desc_table_index = kani::any();
let old_val = get_used_idx(&queue, &memory).unwrap();
let old_num_added = queue.num_added;
if queue
.add_used(&memory, used_desc_table_index, kani::any())
.is_ok()
{
assert_eq!(queue.next_used, used_idx + Wrapping(1));
assert_eq!(queue.next_used.0, get_used_idx(&queue, &memory).unwrap());
assert_eq!(old_num_added + Wrapping(1), queue.num_added);
kani::cover!();
} else {
assert_eq!(queue.next_used, used_idx);
assert!(used_desc_table_index >= queue.size());
assert_eq!(old_val, get_used_idx(&queue, &memory).unwrap());
assert_eq!(old_num_added, queue.num_added);
kani::cover!();
}
}