Skip to main content

embassy_usb_host/class/
hid.rs

1//! HID (Human Interface Device) host class driver.
2//!
3//! This driver can communicate with USB HID devices (keyboards, mice, gamepads, etc.).
4
5use embassy_usb_driver::host::{PipeError, UsbHostAllocator, UsbPipe, pipe};
6use embassy_usb_driver::{Direction as UsbDirection, EndpointAddress, EndpointInfo, EndpointType};
7
8pub use super::hid_report::{ReportDescriptor, ReportField};
9use crate::control::SetupPacket;
10use crate::descriptor::ConfigurationDescriptor;
11use crate::handler::EnumerationInfo;
12
13/// HID class code.
14const USB_CLASS_HID: u8 = 0x03;
15/// Interrupt transfer type.
16const TRANSFER_INTERRUPT: u8 = 0x03;
17
18/// HID class request: GET_REPORT.
19const GET_REPORT: u8 = 0x01;
20/// HID class request: SET_IDLE.
21const SET_IDLE: u8 = 0x0A;
22/// HID class request: SET_PROTOCOL.
23const SET_PROTOCOL: u8 = 0x0B;
24
25/// Boot protocol.
26pub const PROTOCOL_BOOT: u8 = 0;
27/// Report protocol.
28pub const PROTOCOL_REPORT: u8 = 1;
29
30// ── Boot-protocol report structs ─────────────────────────────────────────────
31
32/// Decoded keyboard report (USB HID boot protocol, 8 bytes).
33///
34/// All standard USB keyboards support this layout when placed in boot protocol
35/// mode via [`HidHost::set_protocol`] with [`PROTOCOL_BOOT`].
36#[derive(Clone, Debug, Default, PartialEq, Eq)]
37#[cfg_attr(feature = "defmt", derive(defmt::Format))]
38pub struct KeyboardReport {
39    /// Modifier keys bitmask.
40    ///
41    /// Bit 0: Left Ctrl  | Bit 1: Left Shift  | Bit 2: Left Alt  | Bit 3: Left GUI
42    /// Bit 4: Right Ctrl | Bit 5: Right Shift | Bit 6: Right Alt | Bit 7: Right GUI
43    pub modifiers: u8,
44    /// Up to 6 simultaneously pressed key codes (HID usage page 0x07).
45    /// A value of 0x00 means "no key"; 0x01 means "rollover error".
46    pub keycodes: [u8; 6],
47}
48
49impl KeyboardReport {
50    /// Parse a boot-protocol keyboard report from an 8-byte buffer.
51    /// Returns `None` if the buffer is shorter than 8 bytes.
52    pub fn parse(buf: &[u8]) -> Option<Self> {
53        if buf.len() < 8 {
54            return None;
55        }
56        Some(Self {
57            modifiers: buf[0],
58            // buf[1] is reserved
59            keycodes: [buf[2], buf[3], buf[4], buf[5], buf[6], buf[7]],
60        })
61    }
62
63    /// Returns `true` if the given HID key code is currently pressed.
64    pub fn is_pressed(&self, keycode: u8) -> bool {
65        keycode != 0 && self.keycodes.contains(&keycode)
66    }
67
68    /// Returns `true` if Left Ctrl or Right Ctrl is held.
69    pub fn ctrl(&self) -> bool {
70        self.modifiers & 0x11 != 0
71    }
72    /// Returns `true` if Left Shift or Right Shift is held.
73    pub fn shift(&self) -> bool {
74        self.modifiers & 0x22 != 0
75    }
76    /// Returns `true` if Left Alt or Right Alt is held.
77    pub fn alt(&self) -> bool {
78        self.modifiers & 0x44 != 0
79    }
80    /// Returns `true` if Left GUI (Win/Cmd) or Right GUI is held.
81    pub fn gui(&self) -> bool {
82        self.modifiers & 0x88 != 0
83    }
84}
85
86/// Mouse button bitmask used in [`MouseReport`].
87///
88/// Bit 0: left button | Bit 1: right button | Bit 2: middle button
89pub type MouseButtons = u8;
90
91/// Decoded mouse report (USB HID boot protocol, 4 bytes).
92///
93/// All standard USB mice support this layout in boot protocol mode.
94#[derive(Clone, Debug, Default, PartialEq, Eq)]
95#[cfg_attr(feature = "defmt", derive(defmt::Format))]
96pub struct MouseReport {
97    /// Button state. Use the [`MouseButtons`] constants or check bits directly.
98    pub buttons: MouseButtons,
99    /// Horizontal movement since last report (signed, positive = right).
100    pub x: i8,
101    /// Vertical movement since last report (signed, positive = down).
102    pub y: i8,
103    /// Scroll wheel movement (signed, positive = scroll up / away from user).
104    pub wheel: i8,
105}
106
107impl MouseReport {
108    /// Left mouse button.
109    pub const BUTTON_LEFT: MouseButtons = 1 << 0;
110    /// Right mouse button.
111    pub const BUTTON_RIGHT: MouseButtons = 1 << 1;
112    /// Middle mouse button (scroll wheel click).
113    pub const BUTTON_MIDDLE: MouseButtons = 1 << 2;
114
115    /// Parse a boot-protocol mouse report from a buffer (minimum 3 bytes; 4 for wheel).
116    /// Returns `None` if the buffer is shorter than 3 bytes.
117    pub fn parse(buf: &[u8]) -> Option<Self> {
118        if buf.len() < 3 {
119            return None;
120        }
121        Some(Self {
122            buttons: buf[0],
123            x: buf[1] as i8,
124            y: buf[2] as i8,
125            wheel: if buf.len() >= 4 { buf[3] as i8 } else { 0 },
126        })
127    }
128
129    /// Returns `true` if the left button is pressed.
130    pub fn left(&self) -> bool {
131        self.buttons & Self::BUTTON_LEFT != 0
132    }
133    /// Returns `true` if the right button is pressed.
134    pub fn right(&self) -> bool {
135        self.buttons & Self::BUTTON_RIGHT != 0
136    }
137    /// Returns `true` if the middle button is pressed.
138    pub fn middle(&self) -> bool {
139        self.buttons & Self::BUTTON_MIDDLE != 0
140    }
141}
142
143/// HID class descriptor type (appears inside the configuration descriptor).
144const DESC_HID: u8 = 0x21;
145
146/// Information about a HID interface found in a configuration descriptor.
147#[derive(Clone, Debug)]
148#[cfg_attr(feature = "defmt", derive(defmt::Format))]
149pub struct HidInfo {
150    /// HID interface number.
151    pub interface_number: u8,
152    /// Interrupt IN endpoint address (raw, with direction bit).
153    pub interrupt_in_ep: u8,
154    /// Interrupt IN max packet size.
155    pub interrupt_in_mps: u16,
156    /// Length of the HID Report Descriptor in bytes (from the HID class descriptor).
157    /// Pass this to [`HidHost::fetch_report_descriptor`] as the buffer size.
158    pub report_descriptor_len: u16,
159}
160
161/// Find the first HID interface in a configuration descriptor.
162pub fn find_hid(config_desc: &[u8]) -> Option<HidInfo> {
163    let cfg = ConfigurationDescriptor::try_from_slice(config_desc).ok()?;
164
165    for iface in cfg.iter_interface() {
166        if iface.interface_class != USB_CLASS_HID {
167            continue;
168        }
169
170        // Extract report descriptor length from the HID class descriptor (type 0x21).
171        // Layout: bLength, bDescriptorType(0x21), bcdHID(2), bCountryCode,
172        //         bNumDescriptors, bDescriptorType(0x22), wDescriptorLength(2)
173        let report_desc_len = iface
174            .iter_descriptors()
175            .find_map(|(_, data)| {
176                if data.len() >= 7 && data[1] == DESC_HID {
177                    Some(u16::from_le_bytes([data[5], data[6]]))
178                } else {
179                    None
180                }
181            })
182            .unwrap_or(0);
183
184        let ep = iface
185            .iter_endpoints()
186            .find(|ep| ep.transfer_type() == TRANSFER_INTERRUPT && ep.is_in())?;
187
188        return Some(HidInfo {
189            interface_number: iface.interface_number,
190            interrupt_in_ep: ep.endpoint_address,
191            interrupt_in_mps: ep.max_packet_size,
192            report_descriptor_len: report_desc_len,
193        });
194    }
195
196    None
197}
198
199/// HID host class driver error.
200#[derive(Debug)]
201#[cfg_attr(feature = "defmt", derive(defmt::Format))]
202pub enum HidError {
203    /// Transfer error.
204    Transfer(PipeError),
205    /// No matching HID interface found in the device.
206    NoInterface,
207    /// Failed to allocate a pipe.
208    NoPipe,
209}
210
211impl From<PipeError> for HidError {
212    fn from(e: PipeError) -> Self {
213        Self::Transfer(e)
214    }
215}
216
217impl core::fmt::Display for HidError {
218    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
219        match self {
220            Self::Transfer(_e) => write!(f, "Transfer error"),
221            Self::NoInterface => write!(f, "No HID interface found"),
222            Self::NoPipe => write!(f, "No free pipe"),
223        }
224    }
225}
226
227impl core::error::Error for HidError {}
228
229/// HID host driver.
230///
231/// Provides report reading and optional class request access to a USB HID device.
232pub struct HidHost<'d, A: UsbHostAllocator<'d>> {
233    ctrl_ch: A::Pipe<pipe::Control, pipe::InOut>,
234    in_ch: A::Pipe<pipe::Interrupt, pipe::In>,
235    interface: u8,
236    report_descriptor_len: u16,
237    _phantom: core::marker::PhantomData<&'d ()>,
238}
239
240impl<'d, A: UsbHostAllocator<'d>> HidHost<'d, A> {
241    /// Create a new HID host driver.
242    ///
243    /// Parses the config descriptor to find the HID interface and its interrupt IN endpoint,
244    /// then allocates the necessary channels.
245    pub fn new(alloc: &A, config_desc: &[u8], enum_info: &EnumerationInfo) -> Result<Self, HidError> {
246        let info = find_hid(config_desc).ok_or(HidError::NoInterface)?;
247
248        let ctrl_ep_info = EndpointInfo {
249            addr: EndpointAddress::from_parts(0, UsbDirection::In),
250            ep_type: EndpointType::Control,
251            max_packet_size: enum_info.device_desc.max_packet_size0 as u16,
252            interval_ms: 0,
253        };
254
255        let in_ep_info = EndpointInfo {
256            addr: EndpointAddress::from_parts((info.interrupt_in_ep & 0x0F) as usize, UsbDirection::In),
257            ep_type: EndpointType::Interrupt,
258            max_packet_size: info.interrupt_in_mps,
259            interval_ms: 0,
260        };
261
262        let device_address = enum_info.device_address;
263        let split = enum_info.split();
264
265        let ctrl_ch = alloc
266            .alloc_pipe::<pipe::Control, pipe::InOut>(device_address, &ctrl_ep_info, split)
267            .map_err(|_| HidError::NoPipe)?;
268        let in_ch = alloc
269            .alloc_pipe::<pipe::Interrupt, pipe::In>(device_address, &in_ep_info, split)
270            .map_err(|_| HidError::NoPipe)?;
271
272        Ok(Self {
273            ctrl_ch,
274            in_ch,
275            interface: info.interface_number,
276            report_descriptor_len: info.report_descriptor_len,
277            _phantom: core::marker::PhantomData,
278        })
279    }
280
281    /// Fetch the HID Report Descriptor from the device into `buf`.
282    ///
283    /// Returns the descriptor bytes as a slice. Pass the result to
284    /// [`ReportDescriptor::parse`] to decode it:
285    ///
286    /// ```ignore
287    /// let mut buf = [0u8; 256];
288    /// let desc = hid.fetch_report_descriptor(&mut buf).await?;
289    /// let report: ReportDescriptor<32> = ReportDescriptor::parse(desc);
290    /// ```
291    ///
292    /// `buf` should be at least `HidInfo::report_descriptor_len` bytes; any
293    /// excess is unused.
294    pub async fn fetch_report_descriptor<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8], HidError> {
295        let len = (self.report_descriptor_len as usize).min(buf.len()) as u16;
296        let setup = SetupPacket::get_hid_report_descriptor(self.interface, len);
297        let n = self
298            .ctrl_ch
299            .control_in(&setup.to_bytes(), &mut buf[..len as usize])
300            .await?;
301        Ok(&buf[..n])
302    }
303
304    /// Set the idle rate for a report.
305    ///
306    /// `report_id = 0` applies to all reports. `idle_duration = 0` disables idle repeat.
307    ///
308    /// Note: SET_IDLE is optional; some devices STALL this request.
309    /// A STALL is treated as success per the HID specification.
310    pub async fn set_idle(&mut self, report_id: u8, idle_duration: u8) -> Result<(), HidError> {
311        let value = (idle_duration as u16) << 8 | report_id as u16;
312        let setup = SetupPacket::class_interface_out(SET_IDLE, value, self.interface as u16, 0);
313        match self.ctrl_ch.control_out(&setup.to_bytes(), &[]).await {
314            Ok(_) => Ok(()),
315            Err(PipeError::Stall) => Ok(()),
316            Err(e) => Err(HidError::Transfer(e)),
317        }
318    }
319
320    /// Set the protocol (boot or report).
321    pub async fn set_protocol(&mut self, protocol: u8) -> Result<(), HidError> {
322        let setup = SetupPacket::class_interface_out(SET_PROTOCOL, protocol as u16, self.interface as u16, 0);
323        self.ctrl_ch.control_out(&setup.to_bytes(), &[]).await?;
324        Ok(())
325    }
326
327    /// Read a raw input report from the interrupt IN endpoint.
328    ///
329    /// Returns the number of bytes received.
330    pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, HidError> {
331        let n = self.in_ch.request_in(buf).await?;
332        Ok(n)
333    }
334
335    /// Read and parse a boot-protocol keyboard report.
336    ///
337    /// Call [`HidHost::set_protocol`] with [`PROTOCOL_BOOT`] first.
338    /// Returns `None` if the report is malformed (shorter than 8 bytes).
339    pub async fn read_keyboard(&mut self) -> Result<Option<KeyboardReport>, HidError> {
340        let mut buf = [0u8; 8];
341        self.in_ch.request_in(&mut buf).await?;
342        Ok(KeyboardReport::parse(&buf))
343    }
344
345    /// Read and parse a boot-protocol mouse report.
346    ///
347    /// Call [`HidHost::set_protocol`] with [`PROTOCOL_BOOT`] first.
348    /// Returns `None` if the report is malformed (shorter than 3 bytes).
349    pub async fn read_mouse(&mut self) -> Result<Option<MouseReport>, HidError> {
350        let mut buf = [0u8; 4];
351        // Some mice send only 3 bytes; read up to 4.
352        let n = self.in_ch.request_in(&mut buf).await?;
353        Ok(MouseReport::parse(&buf[..n]))
354    }
355
356    /// Issue a GET_REPORT control request.
357    ///
358    /// `report_type`: 1=Input, 2=Output, 3=Feature.
359    /// `report_id`: 0 if the device uses a single report.
360    ///
361    /// Returns the number of bytes received.
362    pub async fn get_report(&mut self, report_type: u8, report_id: u8, buf: &mut [u8]) -> Result<usize, HidError> {
363        let value = (report_type as u16) << 8 | report_id as u16;
364        let setup = SetupPacket::class_interface_in(GET_REPORT, value, self.interface as u16, buf.len() as u16);
365        let n = self.ctrl_ch.control_in(&setup.to_bytes(), buf).await?;
366        Ok(n)
367    }
368}