1use core::mem::MaybeUninit;
4use core::ops::Range;
5use core::sync::atomic::{AtomicUsize, Ordering};
6
7#[cfg(feature = "usbd-hid")]
8use ssmarshal::serialize;
9#[cfg(feature = "usbd-hid")]
10use usbd_hid::descriptor::AsInputReport;
11
12use crate::control::{InResponse, OutResponse, Recipient, Request, RequestType};
13use crate::driver::{Driver, Endpoint, EndpointError, EndpointIn, EndpointOut};
14use crate::types::InterfaceNumber;
15use crate::{Builder, Handler};
16
17const USB_CLASS_HID: u8 = 0x03;
18const USB_SUBCLASS_NONE: u8 = 0x00;
19const USB_PROTOCOL_NONE: u8 = 0x00;
20
21const HID_DESC_DESCTYPE_HID: u8 = 0x21;
23const HID_DESC_DESCTYPE_HID_REPORT: u8 = 0x22;
24const HID_DESC_SPEC_1_10: [u8; 2] = [0x10, 0x01];
25const HID_DESC_COUNTRY_UNSPEC: u8 = 0x00;
26
27const HID_REQ_SET_IDLE: u8 = 0x0a;
28const HID_REQ_GET_IDLE: u8 = 0x02;
29const HID_REQ_GET_REPORT: u8 = 0x01;
30const HID_REQ_SET_REPORT: u8 = 0x09;
31const HID_REQ_GET_PROTOCOL: u8 = 0x03;
32const HID_REQ_SET_PROTOCOL: u8 = 0x0b;
33
34pub struct Config<'d> {
36 pub report_descriptor: &'d [u8],
38
39 pub request_handler: Option<&'d mut dyn RequestHandler>,
41
42 pub poll_ms: u8,
48
49 pub max_packet_size: u16,
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55#[cfg_attr(feature = "defmt", derive(defmt::Format))]
56pub enum ReportId {
57 In(u8),
59 Out(u8),
61 Feature(u8),
63}
64
65impl ReportId {
66 const fn try_from(value: u16) -> Result<Self, ()> {
67 match value >> 8 {
68 1 => Ok(ReportId::In(value as u8)),
69 2 => Ok(ReportId::Out(value as u8)),
70 3 => Ok(ReportId::Feature(value as u8)),
71 _ => Err(()),
72 }
73 }
74}
75
76pub struct State<'d> {
78 control: MaybeUninit<Control<'d>>,
79 out_report_offset: AtomicUsize,
80}
81
82impl<'d> Default for State<'d> {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88impl<'d> State<'d> {
89 pub const fn new() -> Self {
91 State {
92 control: MaybeUninit::uninit(),
93 out_report_offset: AtomicUsize::new(0),
94 }
95 }
96}
97
98pub struct HidReaderWriter<'d, D: Driver<'d>, const READ_N: usize, const WRITE_N: usize> {
100 reader: HidReader<'d, D, READ_N>,
101 writer: HidWriter<'d, D, WRITE_N>,
102}
103
104fn build<'d, D: Driver<'d>>(
105 builder: &mut Builder<'d, D>,
106 state: &'d mut State<'d>,
107 config: Config<'d>,
108 with_out_endpoint: bool,
109) -> (Option<D::EndpointOut>, D::EndpointIn, &'d AtomicUsize) {
110 let len = config.report_descriptor.len();
111
112 let mut func = builder.function(USB_CLASS_HID, USB_SUBCLASS_NONE, USB_PROTOCOL_NONE);
113 let mut iface = func.interface();
114 let if_num = iface.interface_number();
115 let mut alt = iface.alt_setting(USB_CLASS_HID, USB_SUBCLASS_NONE, USB_PROTOCOL_NONE, None);
116
117 alt.descriptor(
119 HID_DESC_DESCTYPE_HID,
120 &[
121 HID_DESC_SPEC_1_10[0],
123 HID_DESC_SPEC_1_10[1],
124 HID_DESC_COUNTRY_UNSPEC,
126 1,
128 HID_DESC_DESCTYPE_HID_REPORT,
130 (len & 0xFF) as u8,
132 (len >> 8 & 0xFF) as u8,
133 ],
134 );
135
136 let ep_in = alt.endpoint_interrupt_in(None, config.max_packet_size, config.poll_ms);
137 let ep_out = if with_out_endpoint {
138 Some(alt.endpoint_interrupt_out(None, config.max_packet_size, config.poll_ms))
139 } else {
140 None
141 };
142
143 drop(func);
144
145 let control = state.control.write(Control::new(
146 if_num,
147 config.report_descriptor,
148 config.request_handler,
149 &state.out_report_offset,
150 ));
151 builder.handler(control);
152
153 (ep_out, ep_in, &state.out_report_offset)
154}
155
156impl<'d, D: Driver<'d>, const READ_N: usize, const WRITE_N: usize> HidReaderWriter<'d, D, READ_N, WRITE_N> {
157 pub fn new(builder: &mut Builder<'d, D>, state: &'d mut State<'d>, config: Config<'d>) -> Self {
163 let (ep_out, ep_in, offset) = build(builder, state, config, true);
164
165 Self {
166 reader: HidReader {
167 ep_out: ep_out.unwrap(),
168 offset,
169 },
170 writer: HidWriter { ep_in },
171 }
172 }
173
174 pub fn split(self) -> (HidReader<'d, D, READ_N>, HidWriter<'d, D, WRITE_N>) {
176 (self.reader, self.writer)
177 }
178
179 pub async fn ready(&mut self) {
181 self.reader.ready().await;
182 self.writer.ready().await;
183 }
184
185 #[cfg(feature = "usbd-hid")]
187 pub async fn write_serialize<IR: AsInputReport>(&mut self, r: &IR) -> Result<(), EndpointError> {
188 self.writer.write_serialize(r).await
189 }
190
191 pub async fn write(&mut self, report: &[u8]) -> Result<(), EndpointError> {
193 self.writer.write(report).await
194 }
195
196 pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> {
200 self.reader.read(buf).await
201 }
202}
203
204pub struct HidWriter<'d, D: Driver<'d>, const N: usize> {
208 ep_in: D::EndpointIn,
209}
210
211pub struct HidReader<'d, D: Driver<'d>, const N: usize> {
215 ep_out: D::EndpointOut,
216 offset: &'d AtomicUsize,
217}
218
219#[derive(Debug, Clone, PartialEq, Eq)]
221#[cfg_attr(feature = "defmt", derive(defmt::Format))]
222pub enum ReadError {
223 BufferOverflow,
225 Disabled,
227 Sync(Range<usize>),
229}
230
231impl From<EndpointError> for ReadError {
232 fn from(val: EndpointError) -> Self {
233 use EndpointError::{BufferOverflow, Disabled};
234 match val {
235 BufferOverflow => ReadError::BufferOverflow,
236 Disabled => ReadError::Disabled,
237 }
238 }
239}
240
241impl<'d, D: Driver<'d>, const N: usize> HidWriter<'d, D, N> {
242 pub fn new(builder: &mut Builder<'d, D>, state: &'d mut State<'d>, config: Config<'d>) -> Self {
252 let (ep_out, ep_in, _offset) = build(builder, state, config, false);
253
254 assert!(ep_out.is_none());
255
256 Self { ep_in }
257 }
258
259 pub async fn ready(&mut self) {
261 self.ep_in.wait_enabled().await;
262 }
263
264 #[cfg(feature = "usbd-hid")]
266 pub async fn write_serialize<IR: AsInputReport>(&mut self, r: &IR) -> Result<(), EndpointError> {
267 let mut buf: [u8; N] = [0; N];
268 let Ok(size) = serialize(&mut buf, r) else {
269 return Err(EndpointError::BufferOverflow);
270 };
271 self.write(&buf[0..size]).await
272 }
273
274 pub async fn write(&mut self, report: &[u8]) -> Result<(), EndpointError> {
276 assert!(report.len() <= N);
277
278 let max_packet_size = usize::from(self.ep_in.info().max_packet_size);
279 let zlp_needed = report.len() < N && (report.len() % max_packet_size == 0);
280 for chunk in report.chunks(max_packet_size) {
281 self.ep_in.write(chunk).await?;
282 }
283
284 if zlp_needed {
285 self.ep_in.write(&[]).await?;
286 }
287
288 Ok(())
289 }
290}
291
292impl<'d, D: Driver<'d>, const N: usize> HidReader<'d, D, N> {
293 pub async fn ready(&mut self) {
295 self.ep_out.wait_enabled().await;
296 }
297
298 pub async fn run<T: RequestHandler>(mut self, use_report_ids: bool, handler: &mut T) -> ! {
303 let offset = self.offset.load(Ordering::Acquire);
304 assert!(offset == 0);
305 let mut buf = [0; N];
306 loop {
307 match self.read(&mut buf).await {
308 Ok(len) => {
309 let id = if use_report_ids { buf[0] } else { 0 };
310 handler.set_report(ReportId::Out(id), &buf[..len]);
311 }
312 Err(ReadError::BufferOverflow) => warn!(
313 "Host sent output report larger than the configured maximum output report length ({})",
314 N
315 ),
316 Err(ReadError::Disabled) => self.ep_out.wait_enabled().await,
317 Err(ReadError::Sync(_)) => unreachable!(),
318 }
319 }
320 }
321
322 pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> {
337 assert!(N != 0);
338 assert!(buf.len() >= N);
339
340 let max_packet_size = usize::from(self.ep_out.info().max_packet_size);
342 let starting_offset = self.offset.load(Ordering::Acquire);
343 let mut total = starting_offset;
344 loop {
345 for chunk in buf[starting_offset..N].chunks_mut(max_packet_size) {
346 match self.ep_out.read(chunk).await {
347 Ok(size) => {
348 total += size;
349 if size < max_packet_size || total == N {
350 self.offset.store(0, Ordering::Release);
351 break;
352 }
353 self.offset.store(total, Ordering::Release);
354 }
355 Err(err) => {
356 self.offset.store(0, Ordering::Release);
357 return Err(err.into());
358 }
359 }
360 }
361
362 if total > 0 {
364 break;
365 }
366 }
367
368 if starting_offset > 0 {
369 Err(ReadError::Sync(starting_offset..total))
370 } else {
371 Ok(total)
372 }
373 }
374}
375
376pub trait RequestHandler {
378 fn get_report(&mut self, id: ReportId, buf: &mut [u8]) -> Option<usize> {
382 let _ = (id, buf);
383 None
384 }
385
386 fn set_report(&mut self, id: ReportId, data: &[u8]) -> OutResponse {
388 let _ = (id, data);
389 OutResponse::Rejected
390 }
391
392 fn get_idle_ms(&mut self, id: Option<ReportId>) -> Option<u32> {
398 let _ = id;
399 None
400 }
401
402 fn set_idle_ms(&mut self, id: Option<ReportId>, duration_ms: u32) {
407 let _ = (id, duration_ms);
408 }
409}
410
411struct Control<'d> {
412 if_num: InterfaceNumber,
413 report_descriptor: &'d [u8],
414 request_handler: Option<&'d mut dyn RequestHandler>,
415 out_report_offset: &'d AtomicUsize,
416 hid_descriptor: [u8; 9],
417}
418
419impl<'d> Control<'d> {
420 fn new(
421 if_num: InterfaceNumber,
422 report_descriptor: &'d [u8],
423 request_handler: Option<&'d mut dyn RequestHandler>,
424 out_report_offset: &'d AtomicUsize,
425 ) -> Self {
426 Control {
427 if_num,
428 report_descriptor,
429 request_handler,
430 out_report_offset,
431 hid_descriptor: [
432 9,
434 HID_DESC_DESCTYPE_HID,
436 HID_DESC_SPEC_1_10[0],
438 HID_DESC_SPEC_1_10[1],
439 HID_DESC_COUNTRY_UNSPEC,
441 1,
443 HID_DESC_DESCTYPE_HID_REPORT,
445 (report_descriptor.len() & 0xFF) as u8,
447 (report_descriptor.len() >> 8 & 0xFF) as u8,
448 ],
449 }
450 }
451}
452
453impl<'d> Handler for Control<'d> {
454 fn reset(&mut self) {
455 self.out_report_offset.store(0, Ordering::Release);
456 }
457
458 fn control_out(&mut self, req: Request, data: &[u8]) -> Option<OutResponse> {
459 if (req.request_type, req.recipient, req.index)
460 != (RequestType::Class, Recipient::Interface, self.if_num.0 as u16)
461 {
462 return None;
463 }
464
465 #[cfg(feature = "defmt")]
468 trace!("HID control_out {:?} {=[u8]:x}", req, data);
469 match req.request {
470 HID_REQ_SET_IDLE => {
471 if let Some(handler) = self.request_handler.as_mut() {
472 let id = req.value as u8;
473 let id = (id != 0).then_some(ReportId::In(id));
474 let dur = u32::from(req.value >> 8);
475 let dur = if dur == 0 { u32::MAX } else { 4 * dur };
476 handler.set_idle_ms(id, dur);
477 }
478 Some(OutResponse::Accepted)
479 }
480 HID_REQ_SET_REPORT => match (ReportId::try_from(req.value), self.request_handler.as_mut()) {
481 (Ok(id), Some(handler)) => Some(handler.set_report(id, data)),
482 _ => Some(OutResponse::Rejected),
483 },
484 HID_REQ_SET_PROTOCOL => {
485 if req.value == 1 {
486 Some(OutResponse::Accepted)
487 } else {
488 warn!("HID Boot Protocol is unsupported.");
489 Some(OutResponse::Rejected) }
491 }
492 _ => Some(OutResponse::Rejected),
493 }
494 }
495
496 fn control_in<'a>(&'a mut self, req: Request, buf: &'a mut [u8]) -> Option<InResponse<'a>> {
497 if req.index != self.if_num.0 as u16 {
498 return None;
499 }
500
501 match (req.request_type, req.recipient) {
502 (RequestType::Standard, Recipient::Interface) => match req.request {
503 Request::GET_DESCRIPTOR => match (req.value >> 8) as u8 {
504 HID_DESC_DESCTYPE_HID_REPORT => Some(InResponse::Accepted(self.report_descriptor)),
505 HID_DESC_DESCTYPE_HID => Some(InResponse::Accepted(&self.hid_descriptor)),
506 _ => Some(InResponse::Rejected),
507 },
508
509 _ => Some(InResponse::Rejected),
510 },
511 (RequestType::Class, Recipient::Interface) => {
512 trace!("HID control_in {:?}", req);
513 match req.request {
514 HID_REQ_GET_REPORT => {
515 let size = match ReportId::try_from(req.value) {
516 Ok(id) => self.request_handler.as_mut().and_then(|x| x.get_report(id, buf)),
517 Err(_) => None,
518 };
519
520 if let Some(size) = size {
521 Some(InResponse::Accepted(&buf[0..size]))
522 } else {
523 Some(InResponse::Rejected)
524 }
525 }
526 HID_REQ_GET_IDLE => {
527 if let Some(handler) = self.request_handler.as_mut() {
528 let id = req.value as u8;
529 let id = (id != 0).then_some(ReportId::In(id));
530 if let Some(dur) = handler.get_idle_ms(id) {
531 let dur = u8::try_from(dur / 4).unwrap_or(0);
532 buf[0] = dur;
533 Some(InResponse::Accepted(&buf[0..1]))
534 } else {
535 Some(InResponse::Rejected)
536 }
537 } else {
538 Some(InResponse::Rejected)
539 }
540 }
541 HID_REQ_GET_PROTOCOL => {
542 buf[0] = 1;
544 Some(InResponse::Accepted(&buf[0..1]))
545 }
546 _ => Some(InResponse::Rejected),
547 }
548 }
549 _ => None,
550 }
551 }
552}