use super::VirtQueue;
use crate::{Error, Hal, Result, transport::Transport};
use alloc::boxed::Box;
use core::convert::TryInto;
use core::ptr::{NonNull, null_mut};
use zerocopy::FromZeros;
#[derive(Debug)]
pub struct OwningQueue<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> {
queue: VirtQueue<H, SIZE>,
buffers: [NonNull<[u8; BUFFER_SIZE]>; SIZE],
}
impl<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> OwningQueue<H, SIZE, BUFFER_SIZE> {
pub fn new(mut queue: VirtQueue<H, SIZE>) -> Result<Self> {
let mut buffers = [null_mut(); SIZE];
for (i, queue_buffer) in buffers.iter_mut().enumerate() {
let mut buffer: Box<[u8; BUFFER_SIZE]> = FromZeros::new_box_zeroed().unwrap();
let token = unsafe { queue.add(&[], &mut [buffer.as_mut_slice()]) }?;
assert_eq!(i, token.into());
*queue_buffer = Box::into_raw(buffer);
}
let buffers = buffers.map(|ptr| NonNull::new(ptr).unwrap());
Ok(Self { queue, buffers })
}
pub fn should_notify(&self) -> bool {
self.queue.should_notify()
}
pub fn set_dev_notify(&mut self, enable: bool) {
self.queue.set_dev_notify(enable);
}
unsafe fn add_buffer_to_queue(&mut self, index: u16, transport: &mut impl Transport) -> Result {
unsafe {
let buffer = self
.buffers
.get_mut(usize::from(index))
.ok_or(Error::WrongToken)?
.as_mut();
let new_token = self.queue.add(&[], &mut [buffer])?;
assert_eq!(new_token, index);
}
if self.queue.should_notify() {
transport.notify(self.queue.queue_idx);
}
Ok(())
}
fn pop(&mut self) -> Result<Option<(&[u8], u16)>> {
let Some(token) = self.queue.peek_used() else {
return Ok(None);
};
let buffer = unsafe { self.buffers[usize::from(token)].as_mut() };
let len = unsafe { self.queue.pop_used(token, &[], &mut [buffer])? }
.try_into()
.unwrap();
Ok(Some((&buffer[0..len], token)))
}
pub fn poll<T>(
&mut self,
transport: &mut impl Transport,
handler: impl FnOnce(&[u8]) -> Result<Option<T>>,
) -> Result<Option<T>> {
let Some((buffer, token)) = self.pop()? else {
return Ok(None);
};
let result = handler(buffer);
unsafe {
self.add_buffer_to_queue(token, transport)?;
}
result
}
}
unsafe impl<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> Send
for OwningQueue<H, SIZE, BUFFER_SIZE>
where
VirtQueue<H, SIZE>: Send,
{
}
unsafe impl<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> Sync
for OwningQueue<H, SIZE, BUFFER_SIZE>
where
VirtQueue<H, SIZE>: Sync,
{
}
impl<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> Drop
for OwningQueue<H, SIZE, BUFFER_SIZE>
{
fn drop(&mut self) {
for buffer in self.buffers {
unsafe { drop(Box::from_raw(buffer.as_ptr())) };
}
}
}