#![allow(clippy::declare_interior_mutable_const)]
use core::{
cell::UnsafeCell,
mem::MaybeUninit,
sync::atomic::{AtomicU32, Ordering},
};
use crate::{buffer::Buffer, endpoint::Endpoint, qh::Qh, td::Td};
use usb_device::{
UsbDirection,
endpoint::{EndpointAddress, EndpointType},
};
#[repr(align(32))]
struct TdList<const COUNT: usize>([UnsafeCell<Td>; COUNT]);
impl<const COUNT: usize> TdList<COUNT> {
const fn new() -> Self {
const TD: UnsafeCell<Td> = UnsafeCell::new(Td::new());
Self([TD; COUNT])
}
}
#[repr(align(4096))]
struct QhList<const COUNT: usize>([UnsafeCell<Qh>; COUNT]);
impl<const COUNT: usize> QhList<COUNT> {
const fn new() -> Self {
const QH: UnsafeCell<Qh> = UnsafeCell::new(Qh::new());
Self([QH; COUNT])
}
}
struct EpList<const COUNT: usize>([UnsafeCell<MaybeUninit<Endpoint>>; COUNT]);
impl<const COUNT: usize> EpList<COUNT> {
const fn new() -> Self {
const EP: UnsafeCell<MaybeUninit<Endpoint>> = UnsafeCell::new(MaybeUninit::uninit());
Self([EP; COUNT])
}
}
pub const MAX_ENDPOINTS: usize = 8 * 2;
fn index(ep_addr: EndpointAddress) -> usize {
(ep_addr.index() * 2) + (UsbDirection::In == ep_addr.direction()) as usize
}
pub struct EndpointState<const COUNT: usize = MAX_ENDPOINTS> {
qh_list: QhList<COUNT>,
td_list: TdList<COUNT>,
ep_list: EpList<COUNT>,
alloc_mask: AtomicU32,
}
unsafe impl<const COUNT: usize> Sync for EndpointState<COUNT> {}
impl EndpointState<MAX_ENDPOINTS> {
pub const fn max_endpoints() -> Self {
Self::new()
}
}
impl<const COUNT: usize> Default for EndpointState<COUNT> {
fn default() -> Self {
Self::new()
}
}
impl<const COUNT: usize> EndpointState<COUNT> {
pub const fn new() -> Self {
Self {
qh_list: QhList::new(),
td_list: TdList::new(),
ep_list: EpList::new(),
alloc_mask: AtomicU32::new(0),
}
}
pub(crate) fn allocator(&self) -> Option<EndpointAllocator<'_>> {
const ALLOCATOR_TAKEN: u32 = 1 << 31;
let alloc_mask = self.alloc_mask.fetch_or(ALLOCATOR_TAKEN, Ordering::SeqCst);
(alloc_mask & ALLOCATOR_TAKEN == 0).then(|| EndpointAllocator {
qh_list: &self.qh_list.0[..self.qh_list.0.len().min(MAX_ENDPOINTS)],
td_list: &self.td_list.0[..self.td_list.0.len().min(MAX_ENDPOINTS)],
ep_list: &self.ep_list.0[..self.ep_list.0.len().min(MAX_ENDPOINTS)],
alloc_mask: &self.alloc_mask,
})
}
}
pub struct EndpointAllocator<'a> {
qh_list: &'a [UnsafeCell<Qh>],
td_list: &'a [UnsafeCell<Td>],
ep_list: &'a [UnsafeCell<MaybeUninit<Endpoint>>],
alloc_mask: &'a AtomicU32,
}
unsafe impl Send for EndpointAllocator<'_> {}
impl EndpointAllocator<'_> {
fn try_mask_update(&mut self, mask: u16) -> Option<()> {
let mask = mask.into();
(mask & self.alloc_mask.fetch_or(mask, Ordering::SeqCst) == 0).then_some(())
}
fn check_allocated(&self, index: usize) -> Option<()> {
(index < self.qh_list.len()).then_some(())?;
let mask = 1u16 << index;
(mask & self.alloc_mask.load(Ordering::SeqCst) as u16 != 0).then_some(())
}
pub fn qh_list_addr(&self) -> *const () {
self.qh_list.as_ptr().cast()
}
pub fn endpoint(&self, addr: EndpointAddress) -> Option<&Endpoint> {
let index = index(addr);
self.check_allocated(index)?;
let ep = unsafe { &*self.ep_list[index].get() };
Some(unsafe { ep.assume_init_ref() })
}
#[expect(clippy::mut_from_ref, reason = "Only called while &mut available")]
unsafe fn endpoint_mut_inner(&self, addr: EndpointAddress) -> Option<&mut Endpoint> {
let index = index(addr);
self.check_allocated(index)?;
let ep = unsafe { &mut *self.ep_list[index].get() };
Some(unsafe { ep.assume_init_mut() })
}
pub fn endpoint_mut(&mut self, addr: EndpointAddress) -> Option<&mut Endpoint> {
unsafe { self.endpoint_mut_inner(addr) }
}
pub fn endpoints_iter_mut(&mut self) -> impl Iterator<Item = &mut Endpoint> {
(0..8)
.flat_map(|index| {
let ep_out = EndpointAddress::from_parts(index, UsbDirection::Out);
let ep_in = EndpointAddress::from_parts(index, UsbDirection::In);
[ep_out, ep_in]
})
.flat_map(|ep| unsafe { self.endpoint_mut_inner(ep) })
}
pub fn nonzero_endpoints_iter_mut(&mut self) -> impl Iterator<Item = &mut Endpoint> {
self.endpoints_iter_mut()
.filter(|ep| ep.address().index() != 0)
}
pub fn allocate_endpoint(
&mut self,
addr: EndpointAddress,
buffer: Buffer,
kind: EndpointType,
) -> Option<&mut Endpoint> {
let index = index(addr);
(index < self.qh_list.len()).then_some(())?;
let mask = 1u16 << index;
self.try_mask_update(mask)?;
let qh = unsafe { &mut *self.qh_list[index].get() };
let td = unsafe { &mut *self.td_list[index].get() };
let ep = unsafe { &mut *self.ep_list[index].get() };
ep.write(Endpoint::new(addr, qh, td, buffer, kind));
Some(unsafe { ep.assume_init_mut() })
}
}
#[cfg(test)]
mod tests {
use super::{EndpointAddress, EndpointState, EndpointType};
use crate::buffer;
#[test]
fn acquire_allocator() {
let ep_state = EndpointState::max_endpoints();
ep_state.allocator().unwrap();
for _ in 0..10 {
assert!(ep_state.allocator().is_none());
}
}
#[test]
fn allocate_endpoint() {
let mut buffer = [0; 128];
let mut buffer_alloc = unsafe { buffer::Allocator::from_buffer(&mut buffer) };
let ep_state = EndpointState::max_endpoints();
let mut ep_alloc = ep_state.allocator().unwrap();
let addr = EndpointAddress::from(0);
assert!(ep_alloc.endpoint(addr).is_none());
assert!(ep_alloc.endpoint_mut(addr).is_none());
let ep = ep_alloc
.allocate_endpoint(
addr,
buffer_alloc.allocate(2).unwrap(),
EndpointType::Control,
)
.unwrap();
assert_eq!(ep.address(), addr);
assert!(ep_alloc.endpoint(addr).is_some());
assert!(ep_alloc.endpoint_mut(addr).is_some());
let ep = ep_alloc.allocate_endpoint(
addr,
buffer_alloc.allocate(2).unwrap(),
EndpointType::Control,
);
assert!(ep.is_none());
assert!(ep_alloc.endpoint(addr).is_some());
assert!(ep_alloc.endpoint_mut(addr).is_some());
let addr = EndpointAddress::from(1 << 7);
assert!(ep_alloc.endpoint(addr).is_none());
assert!(ep_alloc.endpoint_mut(addr).is_none());
let ep = ep_alloc
.allocate_endpoint(
addr,
buffer_alloc.allocate(2).unwrap(),
EndpointType::Control,
)
.unwrap();
assert_eq!(ep.address(), addr);
let addr = EndpointAddress::from(3);
assert!(ep_alloc.endpoint(addr).is_none());
assert!(ep_alloc.endpoint_mut(addr).is_none());
let ep = ep_alloc
.allocate_endpoint(addr, buffer_alloc.allocate(4).unwrap(), EndpointType::Bulk)
.unwrap();
assert_eq!(ep.address(), addr);
assert_eq!(ep_alloc.endpoints_iter_mut().count(), 3);
assert_eq!(ep_alloc.nonzero_endpoints_iter_mut().count(), 1);
for (actual, expected) in ep_alloc.endpoints_iter_mut().zip([0usize, 0, 3]) {
assert_eq!(actual.address().index(), expected, "{:?}", actual.address());
}
for (actual, expected) in ep_alloc.nonzero_endpoints_iter_mut().zip([3]) {
assert_eq!(actual.address().index(), expected, "{:?}", actual.address());
}
let addr = EndpointAddress::from(42);
let ep = ep_alloc.allocate_endpoint(
addr,
buffer_alloc.allocate(4).unwrap(),
EndpointType::Interrupt,
);
assert!(ep.is_none());
assert_eq!(ep_alloc.endpoints_iter_mut().count(), 3);
}
}