Skip to main content

embassy_usb_host/class/
kbd.rs

1//! Host driver for HID boot-protocol keyboards.
2#![allow(missing_docs)]
3
4use core::num::NonZeroU8;
5
6use bitflags::bitflags;
7use embassy_usb_driver::host::{HostError, UsbHostAllocator, UsbPipe, pipe};
8use embassy_usb_driver::{Direction, EndpointInfo, EndpointType};
9
10use crate::control::ControlPipeExt;
11use crate::descriptor::{DEFAULT_MAX_DESCRIPTOR_SIZE, InterfaceDescriptor, USBDescriptor};
12use crate::handler::{EnumerationInfo, HandlerEvent, RegisterError};
13
14#[repr(C)]
15#[derive(Debug)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub struct KeyStatusUpdate {
18    /// Modifier keys bitmask (LeftCtrl, LeftShift, LeftAlt, LeftGUI, RightCtrl, RightShift, RightAlt, RightGUI).
19    pub modifiers: u8,
20    /// Reserved (OEM).
21    pub reserved: u8,
22    /// Keycodes of currently pressed keys (0 = not pressed, 1 = rollover).
23    pub keypress: [Option<NonZeroU8>; 6],
24}
25
26impl KeyStatusUpdate {
27    fn from_buffer_unchecked(value: [u8; 8]) -> Self {
28        // SAFETY: Option<NonZeroU8> is None when the u8 value is 0.
29        unsafe { core::mem::transmute(value) }
30    }
31}
32
33#[derive(Debug)]
34#[cfg_attr(feature = "defmt", derive(defmt::Format))]
35pub enum KbdEvent {
36    KeyStatusUpdate(KeyStatusUpdate),
37}
38
39/// Host-side HID boot-keyboard driver.
40pub struct KbdHandler<'d, A: UsbHostAllocator<'d>> {
41    interrupt_channel: A::Pipe<pipe::Interrupt, pipe::In>,
42    control_channel: A::Pipe<pipe::Control, pipe::InOut>,
43    _phantom: core::marker::PhantomData<&'d ()>,
44}
45
46impl<'d, A: UsbHostAllocator<'d>> KbdHandler<'d, A> {
47    /// Attempt to register a keyboard handler for the given device.
48    pub async fn try_register(alloc: &A, enum_info: &EnumerationInfo) -> Result<Self, RegisterError> {
49        let mut control_channel = alloc.alloc_pipe::<pipe::Control, pipe::InOut>(
50            enum_info.device_address,
51            &EndpointInfo {
52                addr: 0.into(),
53                ep_type: EndpointType::Control,
54                max_packet_size: (enum_info.device_desc.max_packet_size0 as u16)
55                    .min(enum_info.speed().max_packet_size()),
56                interval_ms: 0,
57            },
58            enum_info.split(),
59        )?;
60
61        let mut cfg_desc_buf = [0u8; DEFAULT_MAX_DESCRIPTOR_SIZE];
62        let configuration = enum_info
63            .active_config_or_set_default(&mut control_channel, &mut cfg_desc_buf)
64            .await?;
65
66        let iface = configuration
67            .iter_interface()
68            .find(|v| {
69                matches!(
70                    v,
71                    InterfaceDescriptor {
72                        interface_class: 0x03,
73                        interface_subclass: 0x1,
74                        interface_protocol: 0x1,
75                        ..
76                    }
77                )
78            })
79            .ok_or(RegisterError::NoSupportedInterface)?;
80
81        let interrupt_ep = iface
82            .iter_endpoints()
83            .find(|v| v.ep_type() == EndpointType::Interrupt && v.ep_dir() == Direction::In)
84            .ok_or(RegisterError::NoSupportedInterface)?;
85
86        control_channel
87            .set_configuration(configuration.configuration_value)
88            .await?;
89
90        let interrupt_channel = alloc.alloc_pipe::<pipe::Interrupt, pipe::In>(
91            enum_info.device_address,
92            &interrupt_ep.into(),
93            enum_info.split(),
94        )?;
95
96        debug!("[kbd]: Setting PROTOCOL & idle");
97        const SET_PROTOCOL: u8 = 0x0B;
98        const BOOT_PROTOCOL: u16 = 0x0000;
99        if let Err(err) = control_channel
100            .class_request_out(SET_PROTOCOL, BOOT_PROTOCOL, iface.interface_number as u16, &[])
101            .await
102        {
103            error!("[kbd]: Failed to set protocol: {:?}", err);
104        }
105
106        const SET_IDLE: u8 = 0x0A;
107        if let Err(err) = control_channel
108            .class_request_out(SET_IDLE, 0, iface.interface_number as u16, &[])
109            .await
110        {
111            error!("[kbd]: Failed to set idle: {:?}", err);
112        }
113
114        Ok(KbdHandler {
115            interrupt_channel,
116            control_channel,
117            _phantom: core::marker::PhantomData,
118        })
119    }
120
121    /// Wait for the next keyboard event.
122    pub async fn wait_for_event(&mut self) -> Result<HandlerEvent<KbdEvent>, HostError> {
123        let mut buffer = [0u8; 8];
124        debug!("[kbd]: Requesting interrupt IN");
125        self.interrupt_channel.request_in(&mut buffer[..]).await?;
126        debug!("[kbd]: Got interrupt {:?}", buffer);
127        Ok(HandlerEvent::HandlerEvent(KbdEvent::KeyStatusUpdate(
128            KeyStatusUpdate::from_buffer_unchecked(buffer),
129        )))
130    }
131
132    /// SET_REPORT — update keyboard LEDs.
133    pub async fn set_state(&mut self, state: &KeyboardState) -> Result<(), HostError> {
134        const SET_REPORT: u8 = 0x09;
135        const OUTPUT_REPORT: u16 = 2 << 8;
136        self.control_channel
137            .class_request_out(SET_REPORT, OUTPUT_REPORT, 0, &[state.bits()])
138            .await
139    }
140}
141
142bitflags! {
143    /// Keyboard LED state.
144    pub struct KeyboardState: u8 {
145        const NUM_LOCK    = 1 << 0;
146        const CAPS_LOCK   = 1 << 1;
147        const SCROLL_LOCK = 1 << 2;
148        const COMPOSE     = 1 << 3;
149        const KANA        = 1 << 4;
150    }
151}
152
153/// HID class descriptor (type 0x21).
154#[cfg_attr(feature = "defmt", derive(defmt::Format))]
155pub struct HIDDescriptor {
156    pub len: u8,
157    pub descriptor_type: u8,
158    pub bcd_hid: u16,
159    pub country_code: u8,
160    pub num_descriptors: u8,
161    pub descriptor_type0: u8,
162    pub descriptor_length0: u16,
163}
164
165impl USBDescriptor for HIDDescriptor {
166    const SIZE: usize = 9;
167    const DESC_TYPE: u8 = 33;
168    type Error = ();
169
170    fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
171        if bytes.len() < Self::SIZE || bytes[1] != Self::DESC_TYPE {
172            return Err(());
173        }
174        Ok(Self {
175            len: bytes[0],
176            descriptor_type: bytes[1],
177            bcd_hid: u16::from_le_bytes([bytes[2], bytes[3]]),
178            country_code: bytes[4],
179            num_descriptors: bytes[5],
180            descriptor_type0: bytes[6],
181            descriptor_length0: u16::from_le_bytes([bytes[7], bytes[8]]),
182        })
183    }
184}