embassy_usb/class/
hid.rs

1//! USB HID (Human Interface Device) class implementation.
2
3use core::mem::MaybeUninit;
4use core::ops::Range;
5use core::sync::atomic::{AtomicUsize, Ordering};
6
7#[cfg(feature = "usbd-hid")]
8use ssmarshal::serialize;
9#[cfg(feature = "usbd-hid")]
10use usbd_hid::descriptor::AsInputReport;
11
12use crate::control::{InResponse, OutResponse, Recipient, Request, RequestType};
13use crate::driver::{Driver, Endpoint, EndpointError, EndpointIn, EndpointOut};
14use crate::types::InterfaceNumber;
15use crate::{Builder, Handler};
16
17const USB_CLASS_HID: u8 = 0x03;
18const USB_SUBCLASS_NONE: u8 = 0x00;
19const USB_PROTOCOL_NONE: u8 = 0x00;
20
21// HID
22const HID_DESC_DESCTYPE_HID: u8 = 0x21;
23const HID_DESC_DESCTYPE_HID_REPORT: u8 = 0x22;
24const HID_DESC_SPEC_1_10: [u8; 2] = [0x10, 0x01];
25const HID_DESC_COUNTRY_UNSPEC: u8 = 0x00;
26
27const HID_REQ_SET_IDLE: u8 = 0x0a;
28const HID_REQ_GET_IDLE: u8 = 0x02;
29const HID_REQ_GET_REPORT: u8 = 0x01;
30const HID_REQ_SET_REPORT: u8 = 0x09;
31const HID_REQ_GET_PROTOCOL: u8 = 0x03;
32const HID_REQ_SET_PROTOCOL: u8 = 0x0b;
33
34/// Configuration for the HID class.
35pub struct Config<'d> {
36    /// HID report descriptor.
37    pub report_descriptor: &'d [u8],
38
39    /// Handler for control requests.
40    pub request_handler: Option<&'d mut dyn RequestHandler>,
41
42    /// Configures how frequently the host should poll for reading/writing HID reports.
43    ///
44    /// A lower value means better throughput & latency, at the expense
45    /// of CPU on the device & bandwidth on the bus. A value of 10 is reasonable for
46    /// high performance uses, and a value of 255 is good for best-effort usecases.
47    pub poll_ms: u8,
48
49    /// Max packet size for both the IN and OUT endpoints.
50    pub max_packet_size: u16,
51}
52
53/// Report ID
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55#[cfg_attr(feature = "defmt", derive(defmt::Format))]
56pub enum ReportId {
57    /// IN report
58    In(u8),
59    /// OUT report
60    Out(u8),
61    /// Feature report
62    Feature(u8),
63}
64
65impl ReportId {
66    const fn try_from(value: u16) -> Result<Self, ()> {
67        match value >> 8 {
68            1 => Ok(ReportId::In(value as u8)),
69            2 => Ok(ReportId::Out(value as u8)),
70            3 => Ok(ReportId::Feature(value as u8)),
71            _ => Err(()),
72        }
73    }
74}
75
76/// Internal state for USB HID.
77pub struct State<'d> {
78    control: MaybeUninit<Control<'d>>,
79    out_report_offset: AtomicUsize,
80}
81
82impl<'d> Default for State<'d> {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88impl<'d> State<'d> {
89    /// Create a new `State`.
90    pub const fn new() -> Self {
91        State {
92            control: MaybeUninit::uninit(),
93            out_report_offset: AtomicUsize::new(0),
94        }
95    }
96}
97
98/// USB HID reader/writer.
99pub struct HidReaderWriter<'d, D: Driver<'d>, const READ_N: usize, const WRITE_N: usize> {
100    reader: HidReader<'d, D, READ_N>,
101    writer: HidWriter<'d, D, WRITE_N>,
102}
103
104fn build<'d, D: Driver<'d>>(
105    builder: &mut Builder<'d, D>,
106    state: &'d mut State<'d>,
107    config: Config<'d>,
108    with_out_endpoint: bool,
109) -> (Option<D::EndpointOut>, D::EndpointIn, &'d AtomicUsize) {
110    let len = config.report_descriptor.len();
111
112    let mut func = builder.function(USB_CLASS_HID, USB_SUBCLASS_NONE, USB_PROTOCOL_NONE);
113    let mut iface = func.interface();
114    let if_num = iface.interface_number();
115    let mut alt = iface.alt_setting(USB_CLASS_HID, USB_SUBCLASS_NONE, USB_PROTOCOL_NONE, None);
116
117    // HID descriptor
118    alt.descriptor(
119        HID_DESC_DESCTYPE_HID,
120        &[
121            // HID Class spec version
122            HID_DESC_SPEC_1_10[0],
123            HID_DESC_SPEC_1_10[1],
124            // Country code not supported
125            HID_DESC_COUNTRY_UNSPEC,
126            // Number of following descriptors
127            1,
128            // We have a HID report descriptor the host should read
129            HID_DESC_DESCTYPE_HID_REPORT,
130            // HID report descriptor size,
131            (len & 0xFF) as u8,
132            (len >> 8 & 0xFF) as u8,
133        ],
134    );
135
136    let ep_in = alt.endpoint_interrupt_in(None, config.max_packet_size, config.poll_ms);
137    let ep_out = if with_out_endpoint {
138        Some(alt.endpoint_interrupt_out(None, config.max_packet_size, config.poll_ms))
139    } else {
140        None
141    };
142
143    drop(func);
144
145    let control = state.control.write(Control::new(
146        if_num,
147        config.report_descriptor,
148        config.request_handler,
149        &state.out_report_offset,
150    ));
151    builder.handler(control);
152
153    (ep_out, ep_in, &state.out_report_offset)
154}
155
156impl<'d, D: Driver<'d>, const READ_N: usize, const WRITE_N: usize> HidReaderWriter<'d, D, READ_N, WRITE_N> {
157    /// Creates a new `HidReaderWriter`.
158    ///
159    /// This will allocate one IN and one OUT endpoints. If you only need writing (sending)
160    /// HID reports, consider using [`HidWriter::new`] instead, which allocates an IN endpoint only.
161    ///
162    pub fn new(builder: &mut Builder<'d, D>, state: &'d mut State<'d>, config: Config<'d>) -> Self {
163        let (ep_out, ep_in, offset) = build(builder, state, config, true);
164
165        Self {
166            reader: HidReader {
167                ep_out: ep_out.unwrap(),
168                offset,
169            },
170            writer: HidWriter { ep_in },
171        }
172    }
173
174    /// Splits into separate readers/writers for input and output reports.
175    pub fn split(self) -> (HidReader<'d, D, READ_N>, HidWriter<'d, D, WRITE_N>) {
176        (self.reader, self.writer)
177    }
178
179    /// Waits for both IN and OUT endpoints to be enabled.
180    pub async fn ready(&mut self) {
181        self.reader.ready().await;
182        self.writer.ready().await;
183    }
184
185    /// Writes an input report by serializing the given report structure.
186    #[cfg(feature = "usbd-hid")]
187    pub async fn write_serialize<IR: AsInputReport>(&mut self, r: &IR) -> Result<(), EndpointError> {
188        self.writer.write_serialize(r).await
189    }
190
191    /// Writes `report` to its interrupt endpoint.
192    pub async fn write(&mut self, report: &[u8]) -> Result<(), EndpointError> {
193        self.writer.write(report).await
194    }
195
196    /// Reads an output report from the Interrupt Out pipe.
197    ///
198    /// See [`HidReader::read`].
199    pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> {
200        self.reader.read(buf).await
201    }
202}
203
204/// USB HID writer.
205///
206/// You can obtain a `HidWriter` using [`HidReaderWriter::split`].
207pub struct HidWriter<'d, D: Driver<'d>, const N: usize> {
208    ep_in: D::EndpointIn,
209}
210
211/// USB HID reader.
212///
213/// You can obtain a `HidReader` using [`HidReaderWriter::split`].
214pub struct HidReader<'d, D: Driver<'d>, const N: usize> {
215    ep_out: D::EndpointOut,
216    offset: &'d AtomicUsize,
217}
218
219/// Error when reading a HID report.
220#[derive(Debug, Clone, PartialEq, Eq)]
221#[cfg_attr(feature = "defmt", derive(defmt::Format))]
222pub enum ReadError {
223    /// The given buffer was too small to read the received report.
224    BufferOverflow,
225    /// The endpoint is disabled.
226    Disabled,
227    /// The report was only partially read. See [`HidReader::read`] for details.
228    Sync(Range<usize>),
229}
230
231impl From<EndpointError> for ReadError {
232    fn from(val: EndpointError) -> Self {
233        use EndpointError::{BufferOverflow, Disabled};
234        match val {
235            BufferOverflow => ReadError::BufferOverflow,
236            Disabled => ReadError::Disabled,
237        }
238    }
239}
240
241impl<'d, D: Driver<'d>, const N: usize> HidWriter<'d, D, N> {
242    /// Creates a new HidWriter.
243    ///
244    /// This will allocate one IN endpoint only, so the host won't be able to send
245    /// reports to us. If you need that, consider using [`HidReaderWriter::new`] instead.
246    ///
247    /// poll_ms configures how frequently the host should poll for reading/writing
248    /// HID reports. A lower value means better throughput & latency, at the expense
249    /// of CPU on the device & bandwidth on the bus. A value of 10 is reasonable for
250    /// high performance uses, and a value of 255 is good for best-effort usecases.
251    pub fn new(builder: &mut Builder<'d, D>, state: &'d mut State<'d>, config: Config<'d>) -> Self {
252        let (ep_out, ep_in, _offset) = build(builder, state, config, false);
253
254        assert!(ep_out.is_none());
255
256        Self { ep_in }
257    }
258
259    /// Waits for the interrupt in endpoint to be enabled.
260    pub async fn ready(&mut self) {
261        self.ep_in.wait_enabled().await;
262    }
263
264    /// Writes an input report by serializing the given report structure.
265    #[cfg(feature = "usbd-hid")]
266    pub async fn write_serialize<IR: AsInputReport>(&mut self, r: &IR) -> Result<(), EndpointError> {
267        let mut buf: [u8; N] = [0; N];
268        let Ok(size) = serialize(&mut buf, r) else {
269            return Err(EndpointError::BufferOverflow);
270        };
271        self.write(&buf[0..size]).await
272    }
273
274    /// Writes `report` to its interrupt endpoint.
275    pub async fn write(&mut self, report: &[u8]) -> Result<(), EndpointError> {
276        assert!(report.len() <= N);
277
278        let max_packet_size = usize::from(self.ep_in.info().max_packet_size);
279        let zlp_needed = report.len() < N && (report.len() % max_packet_size == 0);
280        for chunk in report.chunks(max_packet_size) {
281            self.ep_in.write(chunk).await?;
282        }
283
284        if zlp_needed {
285            self.ep_in.write(&[]).await?;
286        }
287
288        Ok(())
289    }
290}
291
292impl<'d, D: Driver<'d>, const N: usize> HidReader<'d, D, N> {
293    /// Waits for the interrupt out endpoint to be enabled.
294    pub async fn ready(&mut self) {
295        self.ep_out.wait_enabled().await;
296    }
297
298    /// Delivers output reports from the Interrupt Out pipe to `handler`.
299    ///
300    /// If `use_report_ids` is true, the first byte of the report will be used as
301    /// the `ReportId` value. Otherwise the `ReportId` value will be 0.
302    pub async fn run<T: RequestHandler>(mut self, use_report_ids: bool, handler: &mut T) -> ! {
303        let offset = self.offset.load(Ordering::Acquire);
304        assert!(offset == 0);
305        let mut buf = [0; N];
306        loop {
307            match self.read(&mut buf).await {
308                Ok(len) => {
309                    let id = if use_report_ids { buf[0] } else { 0 };
310                    handler.set_report(ReportId::Out(id), &buf[..len]);
311                }
312                Err(ReadError::BufferOverflow) => warn!(
313                    "Host sent output report larger than the configured maximum output report length ({})",
314                    N
315                ),
316                Err(ReadError::Disabled) => self.ep_out.wait_enabled().await,
317                Err(ReadError::Sync(_)) => unreachable!(),
318            }
319        }
320    }
321
322    /// Reads an output report from the Interrupt Out pipe.
323    ///
324    /// **Note:** Any reports sent from the host over the control pipe will be
325    /// passed to [`RequestHandler::set_report()`] for handling. The application
326    /// is responsible for ensuring output reports from both pipes are handled
327    /// correctly.
328    ///
329    /// **Note:** If `N` > the maximum packet size of the endpoint (i.e. output
330    /// reports may be split across multiple packets) and this method's future
331    /// is dropped after some packets have been read, the next call to `read()`
332    /// will return a [`ReadError::Sync`]. The range in the sync error
333    /// indicates the portion `buf` that was filled by the current call to
334    /// `read()`. If the dropped future used the same `buf`, then `buf` will
335    /// contain the full report.
336    pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> {
337        assert!(N != 0);
338        assert!(buf.len() >= N);
339
340        // Read packets from the endpoint
341        let max_packet_size = usize::from(self.ep_out.info().max_packet_size);
342        let starting_offset = self.offset.load(Ordering::Acquire);
343        let mut total = starting_offset;
344        loop {
345            for chunk in buf[starting_offset..N].chunks_mut(max_packet_size) {
346                match self.ep_out.read(chunk).await {
347                    Ok(size) => {
348                        total += size;
349                        if size < max_packet_size || total == N {
350                            self.offset.store(0, Ordering::Release);
351                            break;
352                        }
353                        self.offset.store(total, Ordering::Release);
354                    }
355                    Err(err) => {
356                        self.offset.store(0, Ordering::Release);
357                        return Err(err.into());
358                    }
359                }
360            }
361
362            // Some hosts may send ZLPs even when not required by the HID spec, so we'll loop as long as total == 0.
363            if total > 0 {
364                break;
365            }
366        }
367
368        if starting_offset > 0 {
369            Err(ReadError::Sync(starting_offset..total))
370        } else {
371            Ok(total)
372        }
373    }
374}
375
376/// Handler for HID-related control requests.
377pub trait RequestHandler {
378    /// Reads the value of report `id` into `buf` returning the size.
379    ///
380    /// Returns `None` if `id` is invalid or no data is available.
381    fn get_report(&mut self, id: ReportId, buf: &mut [u8]) -> Option<usize> {
382        let _ = (id, buf);
383        None
384    }
385
386    /// Sets the value of report `id` to `data`.
387    fn set_report(&mut self, id: ReportId, data: &[u8]) -> OutResponse {
388        let _ = (id, data);
389        OutResponse::Rejected
390    }
391
392    /// Get the idle rate for `id`.
393    ///
394    /// If `id` is `None`, get the idle rate for all reports. Returning `None`
395    /// will reject the control request. Any duration at or above 1.024 seconds
396    /// or below 4ms will be returned as an indefinite idle rate.
397    fn get_idle_ms(&mut self, id: Option<ReportId>) -> Option<u32> {
398        let _ = id;
399        None
400    }
401
402    /// Set the idle rate for `id` to `dur`.
403    ///
404    /// If `id` is `None`, set the idle rate of all input reports to `dur`. If
405    /// an indefinite duration is requested, `dur` will be set to `u32::MAX`.
406    fn set_idle_ms(&mut self, id: Option<ReportId>, duration_ms: u32) {
407        let _ = (id, duration_ms);
408    }
409}
410
411struct Control<'d> {
412    if_num: InterfaceNumber,
413    report_descriptor: &'d [u8],
414    request_handler: Option<&'d mut dyn RequestHandler>,
415    out_report_offset: &'d AtomicUsize,
416    hid_descriptor: [u8; 9],
417}
418
419impl<'d> Control<'d> {
420    fn new(
421        if_num: InterfaceNumber,
422        report_descriptor: &'d [u8],
423        request_handler: Option<&'d mut dyn RequestHandler>,
424        out_report_offset: &'d AtomicUsize,
425    ) -> Self {
426        Control {
427            if_num,
428            report_descriptor,
429            request_handler,
430            out_report_offset,
431            hid_descriptor: [
432                // Length of buf inclusive of size prefix
433                9,
434                // Descriptor type
435                HID_DESC_DESCTYPE_HID,
436                // HID Class spec version
437                HID_DESC_SPEC_1_10[0],
438                HID_DESC_SPEC_1_10[1],
439                // Country code not supported
440                HID_DESC_COUNTRY_UNSPEC,
441                // Number of following descriptors
442                1,
443                // We have a HID report descriptor the host should read
444                HID_DESC_DESCTYPE_HID_REPORT,
445                // HID report descriptor size,
446                (report_descriptor.len() & 0xFF) as u8,
447                (report_descriptor.len() >> 8 & 0xFF) as u8,
448            ],
449        }
450    }
451}
452
453impl<'d> Handler for Control<'d> {
454    fn reset(&mut self) {
455        self.out_report_offset.store(0, Ordering::Release);
456    }
457
458    fn control_out(&mut self, req: Request, data: &[u8]) -> Option<OutResponse> {
459        if (req.request_type, req.recipient, req.index)
460            != (RequestType::Class, Recipient::Interface, self.if_num.0 as u16)
461        {
462            return None;
463        }
464
465        // This uses a defmt-specific formatter that causes use of the `log`
466        // feature to fail to build, so leave it defmt-specific for now.
467        #[cfg(feature = "defmt")]
468        trace!("HID control_out {:?} {=[u8]:x}", req, data);
469        match req.request {
470            HID_REQ_SET_IDLE => {
471                if let Some(handler) = self.request_handler.as_mut() {
472                    let id = req.value as u8;
473                    let id = (id != 0).then_some(ReportId::In(id));
474                    let dur = u32::from(req.value >> 8);
475                    let dur = if dur == 0 { u32::MAX } else { 4 * dur };
476                    handler.set_idle_ms(id, dur);
477                }
478                Some(OutResponse::Accepted)
479            }
480            HID_REQ_SET_REPORT => match (ReportId::try_from(req.value), self.request_handler.as_mut()) {
481                (Ok(id), Some(handler)) => Some(handler.set_report(id, data)),
482                _ => Some(OutResponse::Rejected),
483            },
484            HID_REQ_SET_PROTOCOL => {
485                if req.value == 1 {
486                    Some(OutResponse::Accepted)
487                } else {
488                    warn!("HID Boot Protocol is unsupported.");
489                    Some(OutResponse::Rejected) // UNSUPPORTED: Boot Protocol
490                }
491            }
492            _ => Some(OutResponse::Rejected),
493        }
494    }
495
496    fn control_in<'a>(&'a mut self, req: Request, buf: &'a mut [u8]) -> Option<InResponse<'a>> {
497        if req.index != self.if_num.0 as u16 {
498            return None;
499        }
500
501        match (req.request_type, req.recipient) {
502            (RequestType::Standard, Recipient::Interface) => match req.request {
503                Request::GET_DESCRIPTOR => match (req.value >> 8) as u8 {
504                    HID_DESC_DESCTYPE_HID_REPORT => Some(InResponse::Accepted(self.report_descriptor)),
505                    HID_DESC_DESCTYPE_HID => Some(InResponse::Accepted(&self.hid_descriptor)),
506                    _ => Some(InResponse::Rejected),
507                },
508
509                _ => Some(InResponse::Rejected),
510            },
511            (RequestType::Class, Recipient::Interface) => {
512                trace!("HID control_in {:?}", req);
513                match req.request {
514                    HID_REQ_GET_REPORT => {
515                        let size = match ReportId::try_from(req.value) {
516                            Ok(id) => self.request_handler.as_mut().and_then(|x| x.get_report(id, buf)),
517                            Err(_) => None,
518                        };
519
520                        if let Some(size) = size {
521                            Some(InResponse::Accepted(&buf[0..size]))
522                        } else {
523                            Some(InResponse::Rejected)
524                        }
525                    }
526                    HID_REQ_GET_IDLE => {
527                        if let Some(handler) = self.request_handler.as_mut() {
528                            let id = req.value as u8;
529                            let id = (id != 0).then_some(ReportId::In(id));
530                            if let Some(dur) = handler.get_idle_ms(id) {
531                                let dur = u8::try_from(dur / 4).unwrap_or(0);
532                                buf[0] = dur;
533                                Some(InResponse::Accepted(&buf[0..1]))
534                            } else {
535                                Some(InResponse::Rejected)
536                            }
537                        } else {
538                            Some(InResponse::Rejected)
539                        }
540                    }
541                    HID_REQ_GET_PROTOCOL => {
542                        // UNSUPPORTED: Boot Protocol
543                        buf[0] = 1;
544                        Some(InResponse::Accepted(&buf[0..1]))
545                    }
546                    _ => Some(InResponse::Rejected),
547                }
548            }
549            _ => None,
550        }
551    }
552}