1#![allow(clippy::declare_interior_mutable_const)] use core::{
4 cell::UnsafeCell,
5 mem::MaybeUninit,
6 sync::atomic::{AtomicU32, Ordering},
7};
8
9use crate::{buffer::Buffer, endpoint::Endpoint, qh::Qh, td::Td};
10use usb_device::{
11 endpoint::{EndpointAddress, EndpointType},
12 UsbDirection,
13};
14
15#[repr(align(32))]
19struct TdList<const COUNT: usize>([UnsafeCell<Td>; COUNT]);
20
21impl<const COUNT: usize> TdList<COUNT> {
22 const fn new() -> Self {
23 const TD: UnsafeCell<Td> = UnsafeCell::new(Td::new());
24 Self([TD; COUNT])
25 }
26}
27
28#[repr(align(4096))]
32struct QhList<const COUNT: usize>([UnsafeCell<Qh>; COUNT]);
33
34impl<const COUNT: usize> QhList<COUNT> {
35 const fn new() -> Self {
36 const QH: UnsafeCell<Qh> = UnsafeCell::new(Qh::new());
37 Self([QH; COUNT])
38 }
39}
40
41struct EpList<const COUNT: usize>([UnsafeCell<MaybeUninit<Endpoint>>; COUNT]);
45
46impl<const COUNT: usize> EpList<COUNT> {
47 const fn new() -> Self {
48 const EP: UnsafeCell<MaybeUninit<Endpoint>> = UnsafeCell::new(MaybeUninit::uninit());
49 Self([EP; COUNT])
50 }
51}
52
53pub const MAX_ENDPOINTS: usize = 8 * 2;
58
59fn index(ep_addr: EndpointAddress) -> usize {
61 (ep_addr.index() * 2) + (UsbDirection::In == ep_addr.direction()) as usize
62}
63
64pub struct EndpointState<const COUNT: usize = MAX_ENDPOINTS> {
98 qh_list: QhList<COUNT>,
99 td_list: TdList<COUNT>,
100 ep_list: EpList<COUNT>,
101 alloc_mask: AtomicU32,
105}
106
107unsafe impl<const COUNT: usize> Sync for EndpointState<COUNT> {}
108
109impl EndpointState<MAX_ENDPOINTS> {
110 pub const fn max_endpoints() -> Self {
115 Self::new()
116 }
117}
118
119impl<const COUNT: usize> Default for EndpointState<COUNT> {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125impl<const COUNT: usize> EndpointState<COUNT> {
126 pub const fn new() -> Self {
128 Self {
129 qh_list: QhList::new(),
130 td_list: TdList::new(),
131 ep_list: EpList::new(),
132 alloc_mask: AtomicU32::new(0),
133 }
134 }
135
136 pub(crate) fn allocator(&self) -> Option<EndpointAllocator> {
140 const ALLOCATOR_TAKEN: u32 = 1 << 31;
141 let alloc_mask = self.alloc_mask.fetch_or(ALLOCATOR_TAKEN, Ordering::SeqCst);
142 (alloc_mask & ALLOCATOR_TAKEN == 0).then(|| EndpointAllocator {
143 qh_list: &self.qh_list.0[..self.qh_list.0.len().min(MAX_ENDPOINTS)],
144 td_list: &self.td_list.0[..self.td_list.0.len().min(MAX_ENDPOINTS)],
145 ep_list: &self.ep_list.0[..self.ep_list.0.len().min(MAX_ENDPOINTS)],
146 alloc_mask: &self.alloc_mask,
147 })
148 }
149}
150
151pub struct EndpointAllocator<'a> {
152 qh_list: &'a [UnsafeCell<Qh>],
153 td_list: &'a [UnsafeCell<Td>],
154 ep_list: &'a [UnsafeCell<MaybeUninit<Endpoint>>],
155 alloc_mask: &'a AtomicU32,
156}
157
158unsafe impl Send for EndpointAllocator<'_> {}
159
160impl EndpointAllocator<'_> {
161 fn try_mask_update(&mut self, mask: u16) -> Option<()> {
164 let mask = mask.into();
165 (mask & self.alloc_mask.fetch_or(mask, Ordering::SeqCst) == 0).then_some(())
166 }
167
168 fn check_allocated(&self, index: usize) -> Option<()> {
170 (index < self.qh_list.len()).then_some(())?;
171 let mask = 1u16 << index;
172 (mask & self.alloc_mask.load(Ordering::SeqCst) as u16 != 0).then_some(())
173 }
174
175 pub fn qh_list_addr(&self) -> *const () {
179 self.qh_list.as_ptr().cast()
180 }
181
182 pub fn endpoint(&self, addr: EndpointAddress) -> Option<&Endpoint> {
186 let index = index(addr);
187 self.check_allocated(index)?;
188
189 let ep = unsafe { &*self.ep_list[index].get() };
194 Some(unsafe { ep.assume_init_ref() })
196 }
197
198 unsafe fn endpoint_mut_inner(&self, addr: EndpointAddress) -> Option<&mut Endpoint> {
205 let index = index(addr);
206 self.check_allocated(index)?;
207
208 let ep = unsafe { &mut *self.ep_list[index].get() };
212
213 Some(unsafe { ep.assume_init_mut() })
215 }
216
217 pub fn endpoint_mut(&mut self, addr: EndpointAddress) -> Option<&mut Endpoint> {
221 unsafe { self.endpoint_mut_inner(addr) }
223 }
224
225 pub fn endpoints_iter_mut(&mut self) -> impl Iterator<Item = &mut Endpoint> {
227 (0..8)
228 .flat_map(|index| {
229 let ep_out = EndpointAddress::from_parts(index, UsbDirection::Out);
230 let ep_in = EndpointAddress::from_parts(index, UsbDirection::In);
231 [ep_out, ep_in]
232 })
233 .flat_map(|ep| unsafe { self.endpoint_mut_inner(ep) })
235 }
236
237 pub fn nonzero_endpoints_iter_mut(&mut self) -> impl Iterator<Item = &mut Endpoint> {
241 self.endpoints_iter_mut()
242 .filter(|ep| ep.address().index() != 0)
243 }
244
245 pub fn allocate_endpoint(
252 &mut self,
253 addr: EndpointAddress,
254 buffer: Buffer,
255 kind: EndpointType,
256 ) -> Option<&mut Endpoint> {
257 let index = index(addr);
258 (index < self.qh_list.len()).then_some(())?;
259 let mask = 1u16 << index;
260
261 self.try_mask_update(mask)?;
264
265 let qh = unsafe { &mut *self.qh_list[index].get() };
269 let td = unsafe { &mut *self.td_list[index].get() };
270 let ep = unsafe { &mut *self.ep_list[index].get() };
277 ep.write(Endpoint::new(addr, qh, td, buffer, kind));
279 Some(unsafe { ep.assume_init_mut() })
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::{EndpointAddress, EndpointState, EndpointType};
287 use crate::buffer;
288
289 #[test]
290 fn acquire_allocator() {
291 let ep_state = EndpointState::max_endpoints();
292 ep_state.allocator().unwrap();
293 for _ in 0..10 {
294 assert!(ep_state.allocator().is_none());
295 }
296 }
297
298 #[test]
299 fn allocate_endpoint() {
300 let mut buffer = [0; 128];
301 let mut buffer_alloc = unsafe { buffer::Allocator::from_buffer(&mut buffer) };
302 let ep_state = EndpointState::max_endpoints();
303 let mut ep_alloc = ep_state.allocator().unwrap();
304
305 let addr = EndpointAddress::from(0);
307 assert!(ep_alloc.endpoint(addr).is_none());
308 assert!(ep_alloc.endpoint_mut(addr).is_none());
309
310 let ep = ep_alloc
311 .allocate_endpoint(
312 addr,
313 buffer_alloc.allocate(2).unwrap(),
314 EndpointType::Control,
315 )
316 .unwrap();
317 assert_eq!(ep.address(), addr);
318
319 assert!(ep_alloc.endpoint(addr).is_some());
320 assert!(ep_alloc.endpoint_mut(addr).is_some());
321
322 let ep = ep_alloc.allocate_endpoint(
324 addr,
325 buffer_alloc.allocate(2).unwrap(),
326 EndpointType::Control,
327 );
328 assert!(ep.is_none());
329
330 assert!(ep_alloc.endpoint(addr).is_some());
331 assert!(ep_alloc.endpoint_mut(addr).is_some());
332
333 let addr = EndpointAddress::from(1 << 7);
335
336 assert!(ep_alloc.endpoint(addr).is_none());
337 assert!(ep_alloc.endpoint_mut(addr).is_none());
338
339 let ep = ep_alloc
340 .allocate_endpoint(
341 addr,
342 buffer_alloc.allocate(2).unwrap(),
343 EndpointType::Control,
344 )
345 .unwrap();
346 assert_eq!(ep.address(), addr);
347
348 let addr = EndpointAddress::from(3);
351 assert!(ep_alloc.endpoint(addr).is_none());
352 assert!(ep_alloc.endpoint_mut(addr).is_none());
353
354 let ep = ep_alloc
355 .allocate_endpoint(addr, buffer_alloc.allocate(4).unwrap(), EndpointType::Bulk)
356 .unwrap();
357 assert_eq!(ep.address(), addr);
358
359 assert_eq!(ep_alloc.endpoints_iter_mut().count(), 3);
360 assert_eq!(ep_alloc.nonzero_endpoints_iter_mut().count(), 1);
361
362 for (actual, expected) in ep_alloc.endpoints_iter_mut().zip([0usize, 0, 3]) {
363 assert_eq!(actual.address().index(), expected, "{:?}", actual.address());
364 }
365
366 for (actual, expected) in ep_alloc.nonzero_endpoints_iter_mut().zip([3]) {
367 assert_eq!(actual.address().index(), expected, "{:?}", actual.address());
368 }
369
370 let addr = EndpointAddress::from(42);
372 let ep = ep_alloc.allocate_endpoint(
373 addr,
374 buffer_alloc.allocate(4).unwrap(),
375 EndpointType::Interrupt,
376 );
377 assert!(ep.is_none());
378
379 assert_eq!(ep_alloc.endpoints_iter_mut().count(), 3);
380 }
381}