#![allow(clippy::declare_interior_mutable_const)]
use core::{
cell::UnsafeCell,
mem::MaybeUninit,
sync::atomic::{AtomicU32, Ordering},
};
use crate::{buffer::Buffer, endpoint::Endpoint, qh::Qh, td::Td};
use usb_device::{
endpoint::{EndpointAddress, EndpointType},
UsbDirection,
};
#[repr(align(32))]
struct TdList<const COUNT: usize>([UnsafeCell<Td>; COUNT]);
impl<const COUNT: usize> TdList<COUNT> {
const fn new() -> Self {
const TD: UnsafeCell<Td> = UnsafeCell::new(Td::new());
Self([TD; COUNT])
}
}
#[repr(align(4096))]
struct QhList<const COUNT: usize>([UnsafeCell<Qh>; COUNT]);
impl<const COUNT: usize> QhList<COUNT> {
const fn new() -> Self {
const QH: UnsafeCell<Qh> = UnsafeCell::new(Qh::new());
Self([QH; COUNT])
}
}
struct EpList<const COUNT: usize>([UnsafeCell<MaybeUninit<Endpoint>>; COUNT]);
impl<const COUNT: usize> EpList<COUNT> {
const fn new() -> Self {
const EP: UnsafeCell<MaybeUninit<Endpoint>> = UnsafeCell::new(MaybeUninit::uninit());
Self([EP; COUNT])
}
}
pub const MAX_ENDPOINTS: usize = 8 * 2;
fn index(ep_addr: EndpointAddress) -> usize {
(ep_addr.index() * 2) + (UsbDirection::In == ep_addr.direction()) as usize
}
pub struct EndpointState<const COUNT: usize = MAX_ENDPOINTS> {
qh_list: QhList<COUNT>,
td_list: TdList<COUNT>,
ep_list: EpList<COUNT>,
alloc_mask: AtomicU32,
}
unsafe impl<const COUNT: usize> Sync for EndpointState<COUNT> {}
impl EndpointState<MAX_ENDPOINTS> {
pub const fn max_endpoints() -> Self {
Self::new()
}
}
impl<const COUNT: usize> EndpointState<COUNT> {
pub const fn new() -> Self {
Self {
qh_list: QhList::new(),
td_list: TdList::new(),
ep_list: EpList::new(),
alloc_mask: AtomicU32::new(0),
}
}
pub(crate) fn allocator(&self) -> Option<EndpointAllocator> {
const ALLOCATOR_TAKEN: u32 = 1 << 31;
let alloc_mask = self.alloc_mask.fetch_or(ALLOCATOR_TAKEN, Ordering::SeqCst);
(alloc_mask & ALLOCATOR_TAKEN == 0).then(|| EndpointAllocator {
qh_list: &self.qh_list.0[..self.qh_list.0.len().min(MAX_ENDPOINTS)],
td_list: &self.td_list.0[..self.td_list.0.len().min(MAX_ENDPOINTS)],
ep_list: &self.ep_list.0[..self.ep_list.0.len().min(MAX_ENDPOINTS)],
alloc_mask: &self.alloc_mask,
})
}
}
pub struct EndpointAllocator<'a> {
qh_list: &'a [UnsafeCell<Qh>],
td_list: &'a [UnsafeCell<Td>],
ep_list: &'a [UnsafeCell<MaybeUninit<Endpoint>>],
alloc_mask: &'a AtomicU32,
}
unsafe impl Send for EndpointAllocator<'_> {}
impl EndpointAllocator<'_> {
fn try_mask_update(&mut self, mask: u16) -> Option<()> {
let mask = mask.into();
(mask & self.alloc_mask.fetch_or(mask, Ordering::SeqCst) == 0).then_some(())
}
fn check_allocated(&self, index: usize) -> Option<()> {
let mask = (index < self.qh_list.len()).then_some(1u16 << index)?;
(mask & self.alloc_mask.load(Ordering::SeqCst) as u16 != 0).then_some(())
}
pub fn qh_list_addr(&self) -> *const () {
self.qh_list.as_ptr().cast()
}
pub fn capacity(&self) -> usize {
self.ep_list.len()
}
pub fn endpoint(&self, addr: EndpointAddress) -> Option<&Endpoint> {
let index = index(addr);
self.check_allocated(index)?;
let ep = unsafe { &*self.ep_list[index].get() };
Some(unsafe { ep.assume_init_ref() })
}
pub fn endpoint_mut(&mut self, addr: EndpointAddress) -> Option<&mut Endpoint> {
let index = index(addr);
self.check_allocated(index)?;
let ep = unsafe { &mut *self.ep_list[index].get() };
Some(unsafe { ep.assume_init_mut() })
}
pub fn allocate_endpoint(
&mut self,
addr: EndpointAddress,
buffer: Buffer,
kind: EndpointType,
) -> Option<&mut Endpoint> {
let index = index(addr);
let mask = (index < self.qh_list.len()).then_some(1u16 << index)?;
self.try_mask_update(mask)?;
let qh = unsafe { &mut *self.qh_list[index].get() };
let td = unsafe { &mut *self.td_list[index].get() };
let ep = unsafe { &mut *self.ep_list[index].get() };
ep.write(Endpoint::new(addr, qh, td, buffer, kind));
Some(unsafe { ep.assume_init_mut() })
}
}
#[cfg(test)]
mod tests {
use super::{EndpointAddress, EndpointState, EndpointType};
use crate::buffer;
#[test]
fn acquire_allocator() {
let ep_state = EndpointState::max_endpoints();
ep_state.allocator().unwrap();
for _ in 0..10 {
assert!(ep_state.allocator().is_none());
}
}
#[test]
fn allocate_endpoint() {
let mut buffer = [0; 128];
let mut buffer_alloc = unsafe { buffer::Allocator::from_buffer(&mut buffer) };
let ep_state = EndpointState::max_endpoints();
let mut ep_alloc = ep_state.allocator().unwrap();
let addr = EndpointAddress::from(0);
assert!(ep_alloc.endpoint(addr).is_none());
assert!(ep_alloc.endpoint_mut(addr).is_none());
let ep = ep_alloc
.allocate_endpoint(addr, buffer_alloc.allocate(2).unwrap(), EndpointType::Bulk)
.unwrap();
assert_eq!(ep.address(), addr);
assert!(ep_alloc.endpoint(addr).is_some());
assert!(ep_alloc.endpoint_mut(addr).is_some());
let ep =
ep_alloc.allocate_endpoint(addr, buffer_alloc.allocate(2).unwrap(), EndpointType::Bulk);
assert!(ep.is_none());
assert!(ep_alloc.endpoint(addr).is_some());
assert!(ep_alloc.endpoint_mut(addr).is_some());
let addr = EndpointAddress::from(1 << 7);
assert!(ep_alloc.endpoint(addr).is_none());
assert!(ep_alloc.endpoint_mut(addr).is_none());
let ep = ep_alloc
.allocate_endpoint(addr, buffer_alloc.allocate(2).unwrap(), EndpointType::Bulk)
.unwrap();
assert_eq!(ep.address(), addr);
}
}