rpk_firmware/
hid.rs

1use core::sync::atomic::{AtomicUsize, Ordering};
2use embassy_usb::{
3    class::hid::{ReadError, ReportId, RequestHandler},
4    driver::{Driver, Endpoint, EndpointError, EndpointIn, EndpointOut},
5};
6
7use crate::warn;
8
9pub struct HidWriter<'d, D: Driver<'d>, const N: usize> {
10    ep_in: D::EndpointIn,
11}
12
13impl<'d, D: Driver<'d>, const N: usize> HidWriter<'d, D, N> {
14    pub fn new(ep_in: <D>::EndpointIn) -> Self {
15        Self { ep_in }
16    }
17
18    /// Writes `report` to its interrupt endpoint.
19    pub async fn write(&mut self, report: &[u8]) -> Result<(), EndpointError> {
20        assert!(report.len() <= N);
21
22        let max_packet_size = usize::from(self.ep_in.info().max_packet_size);
23        let zlp_needed = report.len() < N && (report.len() % max_packet_size == 0);
24        for chunk in report.chunks(max_packet_size) {
25            self.ep_in.write(chunk).await?;
26        }
27
28        if zlp_needed {
29            self.ep_in.write(&[]).await?;
30        }
31
32        Ok(())
33    }
34}
35
36pub struct HidReader<'d, D: Driver<'d>, const N: usize> {
37    ep_out: D::EndpointOut,
38    offset: &'d AtomicUsize,
39}
40
41impl<'d, D: Driver<'d>, const N: usize> HidReader<'d, D, N> {
42    pub fn new(ep_out: <D>::EndpointOut, offset: &'d AtomicUsize) -> Self {
43        Self { ep_out, offset }
44    }
45
46    /// Delivers output reports from the Interrupt Out pipe to `handler`.
47    ///
48    /// If `use_report_ids` is true, the first byte of the report will be used as
49    /// the `ReportId` value. Otherwise the `ReportId` value will be 0.
50    pub async fn run<T: RequestHandler>(mut self, use_report_ids: bool, handler: &mut T) -> ! {
51        let offset = self.offset.load(Ordering::Acquire);
52        assert!(offset == 0);
53        let mut buf = [0; N];
54        loop {
55            match self.read(&mut buf).await {
56                Ok(len) => {
57                    let id = if use_report_ids { buf[0] } else { 0 };
58                    handler.set_report(ReportId::Out(id), &buf[..len]);
59                }
60                Err(ReadError::BufferOverflow) => {
61                    warn!(
62                    "Host sent output report larger than the configured maximum output report length ({})",
63                    N
64                );
65                }
66                Err(ReadError::Disabled) => self.ep_out.wait_enabled().await,
67                Err(ReadError::Sync(_)) => unreachable!(),
68            }
69        }
70    }
71
72    /// Reads an output report from the Interrupt Out pipe.
73    ///
74    /// **Note:** Any reports sent from the host over the control pipe will be
75    /// passed to [`RequestHandler::set_report()`] for handling. The application
76    /// is responsible for ensuring output reports from both pipes are handled
77    /// correctly.
78    ///
79    /// **Note:** If `N` > the maximum packet size of the endpoint (i.e. output
80    /// reports may be split across multiple packets) and this method's future
81    /// is dropped after some packets have been read, the next call to `read()`
82    /// will return a [`ReadError::Sync`]. The range in the sync error
83    /// indicates the portion `buf` that was filled by the current call to
84    /// `read()`. If the dropped future used the same `buf`, then `buf` will
85    /// contain the full report.
86    pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> {
87        assert!(N != 0);
88        assert!(buf.len() >= N);
89
90        // Read packets from the endpoint
91        let max_packet_size = usize::from(self.ep_out.info().max_packet_size);
92        let starting_offset = self.offset.load(Ordering::Acquire);
93        let mut total = starting_offset;
94        loop {
95            for chunk in buf[starting_offset..N].chunks_mut(max_packet_size) {
96                match self.ep_out.read(chunk).await {
97                    Ok(size) => {
98                        total += size;
99                        if size < max_packet_size || total == N {
100                            self.offset.store(0, Ordering::Release);
101                            break;
102                        }
103                        self.offset.store(total, Ordering::Release);
104                    }
105                    Err(err) => {
106                        self.offset.store(0, Ordering::Release);
107                        return Err(err.into());
108                    }
109                }
110            }
111
112            // Some hosts may send ZLPs even when not required by the HID spec, so we'll loop as long as total == 0.
113            if total > 0 {
114                break;
115            }
116        }
117
118        if starting_offset > 0 {
119            Err(ReadError::Sync(starting_offset..total))
120        } else {
121            Ok(total)
122        }
123    }
124}