mod packed;
mod split;
use crate::{Iova, IovaTranslator, Le16, VirtioFeatureFlags};
use bitflags::bitflags;
use libc::iovec;
use packed::VirtqueuePacked;
use split::VirtqueueSplit;
use std::io::{Error, ErrorKind};
use std::mem;
bitflags! {
struct VirtqueueDescriptorFlags: u16 {
const NEXT = 0x1;
const WRITE = 0x2;
const INDIRECT = 0x4;
}
}
pub struct VirtqueueLayout {
pub num_queues: usize,
pub driver_area_offset: usize,
pub device_area_offset: usize,
pub req_offset: usize,
pub end_offset: usize,
}
impl VirtqueueLayout {
pub fn new<R>(
num_queues: usize,
queue_size: usize,
features: VirtioFeatureFlags,
) -> Result<Self, Error> {
if features.contains(VirtioFeatureFlags::RING_PACKED) {
let desc_bytes = mem::size_of::<packed::VirtqueueDescriptor>() * queue_size;
let event_suppress_bytes = mem::size_of::<packed::VirtqueueEventSuppress>();
Self::new_layout::<R>(
num_queues,
queue_size,
desc_bytes,
event_suppress_bytes,
event_suppress_bytes,
)
} else {
let desc_bytes = mem::size_of::<split::VirtqueueDescriptor>() * queue_size;
let avail_bytes = 8 + mem::size_of::<Le16>() * queue_size;
let used_bytes = 8 + mem::size_of::<split::VirtqueueUsedElem>() * queue_size;
if !queue_size.is_power_of_two() || queue_size > 32768 {
return Err(Error::new(ErrorKind::InvalidInput, "Invalid queue size"));
}
let avail_bytes = (avail_bytes + 3) & !0x3;
Self::new_layout::<R>(num_queues, queue_size, desc_bytes, avail_bytes, used_bytes)
}
}
fn new_layout<R>(
num_queues: usize,
queue_size: usize,
desc_bytes: usize,
driver_area_bytes: usize,
device_area_bytes: usize,
) -> Result<Self, Error> {
let req_bytes = mem::size_of::<R>() * queue_size;
let req_align = mem::align_of::<R>();
let req_offset = desc_bytes + driver_area_bytes + device_area_bytes;
let req_offset_aligned = (req_offset + req_align - 1) & !(req_align - 1);
let end_offset = (req_offset_aligned + req_bytes + 15) & !15;
Ok(VirtqueueLayout {
num_queues,
driver_area_offset: desc_bytes,
device_area_offset: desc_bytes + driver_area_bytes,
req_offset: req_offset_aligned,
end_offset,
})
}
}
trait VirtqueueFormat {
fn queue_size(&self) -> u16;
fn desc_table_ptr(&self) -> *const u8;
fn driver_area_ptr(&self) -> *const u8;
fn device_area_ptr(&self) -> *const u8;
fn avail_start_chain(&mut self) -> Option<u16>;
fn avail_rewind_chain(&mut self, chain_id: u16);
fn avail_add_desc_chain(
&mut self,
addr: u64,
len: u32,
flags: VirtqueueDescriptorFlags,
) -> Result<u16, Error>;
fn avail_publish(&mut self, chain_id: u16, last_desc_idx: u16);
fn used_has_next(&self) -> bool;
fn used_next(&mut self) -> Option<u16>;
fn used_size_hint(&self) -> (usize, Option<usize>);
fn avail_notif_needed(&mut self) -> bool;
fn set_used_notif_enabled(&mut self, enabled: bool);
}
pub struct Virtqueue<'a, R: Copy> {
iova_translator: Box<dyn IovaTranslator>,
format: Box<dyn VirtqueueFormat + 'a>,
req: *mut R,
layout: VirtqueueLayout,
}
unsafe impl<R: Copy> Send for Virtqueue<'_, R> {}
unsafe impl<R: Copy> Sync for Virtqueue<'_, R> {}
pub struct VirtqueueCompletion<R> {
pub id: u16,
pub req: R,
}
impl<'a, R: Copy> Virtqueue<'a, R> {
pub fn new(
iova_translator: Box<dyn IovaTranslator>,
buf: &'a mut [u8],
queue_size: u16,
features: VirtioFeatureFlags,
) -> Result<Self, Error> {
let layout = VirtqueueLayout::new::<R>(1, queue_size as usize, features)?;
let event_idx_enabled = features.contains(VirtioFeatureFlags::RING_EVENT_IDX);
let (format, req_mem) = if features.contains(VirtioFeatureFlags::RING_PACKED) {
let mem = buf.get_mut(0..layout.end_offset).ok_or_else(|| {
Error::new(
ErrorKind::InvalidInput,
"Incorrectly sized queue bu
fer",
)
})?;
let (mem, req_mem) = mem.split_at_mut(layout.req_offset);
let (mem, device_es_mem) = mem.split_at_mut(layout.device_area_offset);
let (desc_mem, driver_es_mem) = mem.split_at_mut(layout.driver_area_offset);
let format: Box<dyn VirtqueueFormat + 'a> = Box::new(VirtqueuePacked::new(
desc_mem,
driver_es_mem,
device_es_mem,
queue_size,
event_idx_enabled,
)?);
(format, req_mem)
} else {
let mem = buf.get_mut(0..layout.end_offset).ok_or_else(|| {
Error::new(ErrorKind::InvalidInput, "Incorrectly sized queue buffer")
})?;
let (mem, req_mem) = mem.split_at_mut(layout.req_offset);
let (mem, used_mem) = mem.split_at_mut(layout.device_area_offset);
let (desc_mem, avail_mem) = mem.split_at_mut(layout.driver_area_offset);
let format: Box<dyn VirtqueueFormat + 'a> = Box::new(VirtqueueSplit::new(
avail_mem,
used_mem,
desc_mem,
queue_size,
event_idx_enabled,
)?);
(format, req_mem)
};
let req = req_mem.as_mut_ptr() as *mut R;
if req.align_offset(mem::align_of::<R>()) != 0 {
return Err(Error::new(
ErrorKind::InvalidInput,
"Insufficient memory alignment",
));
}
Ok(Virtqueue {
iova_translator,
format,
req,
layout,
})
}
pub fn queue_size(&self) -> u16 {
self.format.queue_size()
}
pub fn layout(&self) -> &VirtqueueLayout {
&self.layout
}
pub fn desc_table_ptr(&self) -> *const u8 {
self.format.desc_table_ptr()
}
pub fn driver_area_ptr(&self) -> *const u8 {
self.format.driver_area_ptr()
}
pub fn device_area_ptr(&self) -> *const u8 {
self.format.device_area_ptr()
}
pub fn add_request<F>(&mut self, prepare: F) -> Result<u16, Error>
where
F: FnOnce(&mut R, &mut dyn FnMut(iovec, bool) -> Result<(), Error>) -> Result<(), Error>,
{
let chain_id = match self.format.avail_start_chain() {
None => {
return Err(Error::new(ErrorKind::Other, "Not enough free descriptors"));
}
Some(idx) => idx,
};
let req_ptr = unsafe { &mut *self.req.offset(chain_id as isize) };
let mut last_desc_idx: Option<u16> = None;
let res = prepare(req_ptr, &mut |iovec: iovec, from_dev: bool| {
let mut flags = VirtqueueDescriptorFlags::NEXT;
if from_dev {
flags.insert(VirtqueueDescriptorFlags::WRITE);
}
let Iova(iova) = self
.iova_translator
.translate_addr(iovec.iov_base as usize, iovec.iov_len)?;
last_desc_idx = Some(self.format.avail_add_desc_chain(
iova,
iovec.iov_len as u32,
flags,
)?);
Ok(())
});
if let Err(e) = res {
self.format.avail_rewind_chain(chain_id);
return Err(e);
}
self.format.avail_publish(chain_id, last_desc_idx.unwrap());
Ok(chain_id)
}
pub fn completions(&mut self) -> VirtqueueIter<'_, 'a, R> {
VirtqueueIter { virtqueue: self }
}
pub fn avail_notif_needed(&mut self) -> bool {
self.format.avail_notif_needed()
}
pub fn set_used_notif_enabled(&mut self, enabled: bool) {
self.format.set_used_notif_enabled(enabled)
}
}
pub struct VirtqueueIter<'a, 'queue, R: Copy> {
virtqueue: &'a mut Virtqueue<'queue, R>,
}
impl<R: Copy> VirtqueueIter<'_, '_, R> {
pub fn has_next(&self) -> bool {
self.virtqueue.format.used_has_next()
}
}
impl<'a, 'queue, R: Copy> Iterator for VirtqueueIter<'a, 'queue, R> {
type Item = VirtqueueCompletion<R>;
fn next(&mut self) -> Option<Self::Item> {
let id = self.virtqueue.format.used_next()?;
let req = unsafe { *self.virtqueue.req.offset(id as isize) };
Some(VirtqueueCompletion { id, req })
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.virtqueue.format.used_size_hint()
}
}