imxrt_usbd/
state.rs

1#![allow(clippy::declare_interior_mutable_const)] // Usage is legit in this module.
2
3use core::{
4    cell::UnsafeCell,
5    mem::MaybeUninit,
6    sync::atomic::{AtomicU32, Ordering},
7};
8
9use crate::{buffer::Buffer, endpoint::Endpoint, qh::Qh, td::Td};
10use usb_device::{
11    endpoint::{EndpointAddress, EndpointType},
12    UsbDirection,
13};
14
15/// A list of transfer descriptors
16///
17/// Supports 1 TD per QH (per endpoint direction)
18#[repr(align(32))]
19struct TdList<const COUNT: usize>([UnsafeCell<Td>; COUNT]);
20
21impl<const COUNT: usize> TdList<COUNT> {
22    const fn new() -> Self {
23        const TD: UnsafeCell<Td> = UnsafeCell::new(Td::new());
24        Self([TD; COUNT])
25    }
26}
27
28/// A list of queue heads
29///
30/// One queue head per endpoint, per direction (default).
31#[repr(align(4096))]
32struct QhList<const COUNT: usize>([UnsafeCell<Qh>; COUNT]);
33
34impl<const COUNT: usize> QhList<COUNT> {
35    const fn new() -> Self {
36        const QH: UnsafeCell<Qh> = UnsafeCell::new(Qh::new());
37        Self([QH; COUNT])
38    }
39}
40
41/// The collection of endpoints.
42///
43/// Maintained inside the EndpointState so that it's sized just right.
44struct EpList<const COUNT: usize>([UnsafeCell<MaybeUninit<Endpoint>>; COUNT]);
45
46impl<const COUNT: usize> EpList<COUNT> {
47    const fn new() -> Self {
48        const EP: UnsafeCell<MaybeUninit<Endpoint>> = UnsafeCell::new(MaybeUninit::uninit());
49        Self([EP; COUNT])
50    }
51}
52
53/// The maximum supported number of endpoints.
54///
55/// Eight endpoints, two in each direction. Any endpoints allocated
56/// beyond this are wasted.
57pub const MAX_ENDPOINTS: usize = 8 * 2;
58
59/// Produces an index into the EPs, QHs, and TDs collections
60fn index(ep_addr: EndpointAddress) -> usize {
61    (ep_addr.index() * 2) + (UsbDirection::In == ep_addr.direction()) as usize
62}
63
64/// Driver state associated with endpoints.
65///
66/// Each USB driver needs an `EndpointState`. Allocate a `static` object
67/// and supply it to your USB constructor. Make sure that states are not
68/// shared across USB instances; otherwise, the driver constructor panics.
69///
70/// Use [`max_endpoints()`](EndpointState::max_endpoints) if you're not interested in reducing the
71/// memory used by this allocation. The default object holds enough
72/// state for all supported endpoints.
73///
74/// ```
75/// use imxrt_usbd::EndpointState;
76///
77/// static EP_STATE: EndpointState = EndpointState::max_endpoints();
78/// ```
79///
80/// If you know that you can use fewer endpoints, you can control the
81/// memory utilization with the const generic `COUNT`. You're expected
82/// to provide at least two endpoints -- one in each direction -- for
83/// control endpoints.
84///
85/// Know that endpoints are allocated in pairs; all even endpoints are
86/// OUT, and all odd endpoints are IN. For example, a `COUNT` of 5 will
87/// have 3 out endpoints, and 2 in endpoints. You can never have more
88/// IN that OUT endpoints without overallocating OUT endpoints.
89///
90/// ```
91/// use imxrt_usbd::EndpointState;
92///
93/// static EP_STATE: EndpointState<5> = EndpointState::new();
94/// ```
95///
96/// Any endpoint state allocated beyond [`MAX_ENDPOINTS`] are wasted.
97pub struct EndpointState<const COUNT: usize = MAX_ENDPOINTS> {
98    qh_list: QhList<COUNT>,
99    td_list: TdList<COUNT>,
100    ep_list: EpList<COUNT>,
101    /// Low 16 bits are used for tracking endpoint allocation.
102    /// Bit 31 is set when the allocator is first taken. This
103    /// bit is always dropped during u32 -> u16 conversions.
104    alloc_mask: AtomicU32,
105}
106
107unsafe impl<const COUNT: usize> Sync for EndpointState<COUNT> {}
108
109impl EndpointState<MAX_ENDPOINTS> {
110    /// Allocate space for the maximum number of endpoints.
111    ///
112    /// Use this if you don't want to consider the exact number
113    /// of endpoints that you might need.
114    pub const fn max_endpoints() -> Self {
115        Self::new()
116    }
117}
118
119impl<const COUNT: usize> Default for EndpointState<COUNT> {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125impl<const COUNT: usize> EndpointState<COUNT> {
126    /// Allocate state for `COUNT` endpoints.
127    pub const fn new() -> Self {
128        Self {
129            qh_list: QhList::new(),
130            td_list: TdList::new(),
131            ep_list: EpList::new(),
132            alloc_mask: AtomicU32::new(0),
133        }
134    }
135
136    /// Acquire the allocator.
137    ///
138    /// Returns `None` if the allocator was already taken.
139    pub(crate) fn allocator(&self) -> Option<EndpointAllocator> {
140        const ALLOCATOR_TAKEN: u32 = 1 << 31;
141        let alloc_mask = self.alloc_mask.fetch_or(ALLOCATOR_TAKEN, Ordering::SeqCst);
142        (alloc_mask & ALLOCATOR_TAKEN == 0).then(|| EndpointAllocator {
143            qh_list: &self.qh_list.0[..self.qh_list.0.len().min(MAX_ENDPOINTS)],
144            td_list: &self.td_list.0[..self.td_list.0.len().min(MAX_ENDPOINTS)],
145            ep_list: &self.ep_list.0[..self.ep_list.0.len().min(MAX_ENDPOINTS)],
146            alloc_mask: &self.alloc_mask,
147        })
148    }
149}
150
151pub struct EndpointAllocator<'a> {
152    qh_list: &'a [UnsafeCell<Qh>],
153    td_list: &'a [UnsafeCell<Td>],
154    ep_list: &'a [UnsafeCell<MaybeUninit<Endpoint>>],
155    alloc_mask: &'a AtomicU32,
156}
157
158unsafe impl Send for EndpointAllocator<'_> {}
159
160impl EndpointAllocator<'_> {
161    /// Atomically inserts the endpoint bit into the allocation mask, returning `None` if the
162    /// bit was already set.
163    fn try_mask_update(&mut self, mask: u16) -> Option<()> {
164        let mask = mask.into();
165        (mask & self.alloc_mask.fetch_or(mask, Ordering::SeqCst) == 0).then_some(())
166    }
167
168    /// Returns `Some` if the endpoint is allocated.
169    fn check_allocated(&self, index: usize) -> Option<()> {
170        (index < self.qh_list.len()).then_some(())?;
171        let mask = 1u16 << index;
172        (mask & self.alloc_mask.load(Ordering::SeqCst) as u16 != 0).then_some(())
173    }
174
175    /// Acquire the QH list address.
176    ///
177    /// Used to tell the hardware where the queue heads are located.
178    pub fn qh_list_addr(&self) -> *const () {
179        self.qh_list.as_ptr().cast()
180    }
181
182    /// Acquire the endpoint.
183    ///
184    /// Returns `None` if the endpoint isn't allocated.
185    pub fn endpoint(&self, addr: EndpointAddress) -> Option<&Endpoint> {
186        let index = index(addr);
187        self.check_allocated(index)?;
188
189        // Safety: there's no other mutable access at this call site.
190        // Perceived lifetime is tied to the EndpointAllocator, which has an
191        // immutable receiver.
192
193        let ep = unsafe { &*self.ep_list[index].get() };
194        // Safety: endpoint is allocated. Checked above.
195        Some(unsafe { ep.assume_init_ref() })
196    }
197
198    /// Implementation detail to permit endpoint iteration.
199    ///
200    /// # Safety
201    ///
202    /// This can only be called from a method that takes a mutable receiver.
203    /// Otherwise, you could reach the same mutable endpoint more than once.
204    unsafe fn endpoint_mut_inner(&self, addr: EndpointAddress) -> Option<&mut Endpoint> {
205        let index = index(addr);
206        self.check_allocated(index)?;
207
208        // Safety: the caller ensures that we actually have a mutable reference.
209        // Once we have a mutable reference, this is equivalent to calling the
210        // safe UnsafeCell::get_mut method.
211        let ep = unsafe { &mut *self.ep_list[index].get() };
212
213        // Safety: endpoint is allocated. Checked above.
214        Some(unsafe { ep.assume_init_mut() })
215    }
216
217    /// Aquire the mutable endpoint.
218    ///
219    /// Returns `None` if the endpoint isn't allocated.
220    pub fn endpoint_mut(&mut self, addr: EndpointAddress) -> Option<&mut Endpoint> {
221        // Safety: call from method with mutable receiver.
222        unsafe { self.endpoint_mut_inner(addr) }
223    }
224
225    /// Return an iterator of all allocated endpoints.
226    pub fn endpoints_iter_mut(&mut self) -> impl Iterator<Item = &mut Endpoint> {
227        (0..8)
228            .flat_map(|index| {
229                let ep_out = EndpointAddress::from_parts(index, UsbDirection::Out);
230                let ep_in = EndpointAddress::from_parts(index, UsbDirection::In);
231                [ep_out, ep_in]
232            })
233            // Safety: call from method with mutable receiver.
234            .flat_map(|ep| unsafe { self.endpoint_mut_inner(ep) })
235    }
236
237    /// Returns an iterator for all non-zero, allocated endpoints.
238    ///
239    /// "Non-zero" excludes the first two control endpoints.
240    pub fn nonzero_endpoints_iter_mut(&mut self) -> impl Iterator<Item = &mut Endpoint> {
241        self.endpoints_iter_mut()
242            .filter(|ep| ep.address().index() != 0)
243    }
244
245    /// Allocate the endpoint for the specified address.
246    ///
247    /// Returns `None` if any are true:
248    ///
249    /// - The endpoint is already allocated.
250    /// - We cannot allocate an endpoint for the given address.
251    pub fn allocate_endpoint(
252        &mut self,
253        addr: EndpointAddress,
254        buffer: Buffer,
255        kind: EndpointType,
256    ) -> Option<&mut Endpoint> {
257        let index = index(addr);
258        (index < self.qh_list.len()).then_some(())?;
259        let mask = 1u16 << index;
260
261        // If we pass this call, we're the only caller able to observe mutable
262        // QHs, TDs, and EPs at index.
263        self.try_mask_update(mask)?;
264
265        // Safety: index in range. Atomic update on alloc_mask prevents races for
266        // allocation, and ensures that we only release one &mut reference for each
267        // component.
268        let qh = unsafe { &mut *self.qh_list[index].get() };
269        let td = unsafe { &mut *self.td_list[index].get() };
270        // We cannot access these two components after this call. The endpoint
271        // takes mutable references, so it has exclusive ownership of both.
272        // This module is designed to isolate this access so we can visually
273        // see where we have these &mut accesses.
274
275        // EP is uninitialized.
276        let ep = unsafe { &mut *self.ep_list[index].get() };
277        // Nothing to drop here.
278        ep.write(Endpoint::new(addr, qh, td, buffer, kind));
279        // Safety: EP is initialized.
280        Some(unsafe { ep.assume_init_mut() })
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::{EndpointAddress, EndpointState, EndpointType};
287    use crate::buffer;
288
289    #[test]
290    fn acquire_allocator() {
291        let ep_state = EndpointState::max_endpoints();
292        ep_state.allocator().unwrap();
293        for _ in 0..10 {
294            assert!(ep_state.allocator().is_none());
295        }
296    }
297
298    #[test]
299    fn allocate_endpoint() {
300        let mut buffer = [0; 128];
301        let mut buffer_alloc = unsafe { buffer::Allocator::from_buffer(&mut buffer) };
302        let ep_state = EndpointState::max_endpoints();
303        let mut ep_alloc = ep_state.allocator().unwrap();
304
305        // First endpoint allocation.
306        let addr = EndpointAddress::from(0);
307        assert!(ep_alloc.endpoint(addr).is_none());
308        assert!(ep_alloc.endpoint_mut(addr).is_none());
309
310        let ep = ep_alloc
311            .allocate_endpoint(
312                addr,
313                buffer_alloc.allocate(2).unwrap(),
314                EndpointType::Control,
315            )
316            .unwrap();
317        assert_eq!(ep.address(), addr);
318
319        assert!(ep_alloc.endpoint(addr).is_some());
320        assert!(ep_alloc.endpoint_mut(addr).is_some());
321
322        // Double-allocate existing endpoint.
323        let ep = ep_alloc.allocate_endpoint(
324            addr,
325            buffer_alloc.allocate(2).unwrap(),
326            EndpointType::Control,
327        );
328        assert!(ep.is_none());
329
330        assert!(ep_alloc.endpoint(addr).is_some());
331        assert!(ep_alloc.endpoint_mut(addr).is_some());
332
333        // Allocate a new endpoint.
334        let addr = EndpointAddress::from(1 << 7);
335
336        assert!(ep_alloc.endpoint(addr).is_none());
337        assert!(ep_alloc.endpoint_mut(addr).is_none());
338
339        let ep = ep_alloc
340            .allocate_endpoint(
341                addr,
342                buffer_alloc.allocate(2).unwrap(),
343                EndpointType::Control,
344            )
345            .unwrap();
346        assert_eq!(ep.address(), addr);
347
348        // Allocate a non-zero endpoint
349
350        let addr = EndpointAddress::from(3);
351        assert!(ep_alloc.endpoint(addr).is_none());
352        assert!(ep_alloc.endpoint_mut(addr).is_none());
353
354        let ep = ep_alloc
355            .allocate_endpoint(addr, buffer_alloc.allocate(4).unwrap(), EndpointType::Bulk)
356            .unwrap();
357        assert_eq!(ep.address(), addr);
358
359        assert_eq!(ep_alloc.endpoints_iter_mut().count(), 3);
360        assert_eq!(ep_alloc.nonzero_endpoints_iter_mut().count(), 1);
361
362        for (actual, expected) in ep_alloc.endpoints_iter_mut().zip([0usize, 0, 3]) {
363            assert_eq!(actual.address().index(), expected, "{:?}", actual.address());
364        }
365
366        for (actual, expected) in ep_alloc.nonzero_endpoints_iter_mut().zip([3]) {
367            assert_eq!(actual.address().index(), expected, "{:?}", actual.address());
368        }
369
370        // Try to allocate an invalid endpoint.
371        let addr = EndpointAddress::from(42);
372        let ep = ep_alloc.allocate_endpoint(
373            addr,
374            buffer_alloc.allocate(4).unwrap(),
375            EndpointType::Interrupt,
376        );
377        assert!(ep.is_none());
378
379        assert_eq!(ep_alloc.endpoints_iter_mut().count(), 3);
380    }
381}