1use std::{
6 collections::{HashMap, VecDeque},
7 error::Error,
8 sync::{
9 Arc, Mutex,
10 atomic::{AtomicBool, AtomicU8, Ordering},
11 },
12 thread::{self, JoinHandle},
13};
14
15use async_trait::async_trait;
16use futures::{FutureExt, channel::oneshot, select};
17use hidreport::{Field, Report, ReportDescriptor, Usage, UsageId, UsagePage};
18use rand::Rng;
19use thiserror::Error;
20
21use crate::nibble::U4;
22
23const MAX_REPORT_DESCRIPTOR_LENGTH: usize = 4096;
27
28const MAX_REPORT_LENGTH: usize = LONG_REPORT_LENGTH;
31
32pub const SHORT_REPORT_ID: u8 = 0x10;
34
35pub const SHORT_REPORT_USAGE_PAGE: u16 = 0xff00;
37
38pub const SHORT_REPORT_USAGE: u16 = 0x0001;
40
41pub const SHORT_REPORT_LENGTH: usize = 7;
43
44pub const LONG_REPORT_ID: u8 = 0x11;
46
47pub const LONG_REPORT_USAGE_PAGE: u16 = 0xff00;
49
50pub const LONG_REPORT_USAGE: u16 = 0x0002;
52
53pub const LONG_REPORT_LENGTH: usize = 20;
55
56#[async_trait]
64pub trait RawHidChannel: Sync + Send + 'static {
65 fn vendor_id(&self) -> u16;
67
68 fn product_id(&self) -> u16;
70
71 async fn write_report(&self, src: &[u8]) -> Result<usize, Box<dyn Error + Sync + Send>>;
75
76 async fn read_report(&self, buf: &mut [u8]) -> Result<usize, Box<dyn Error + Sync + Send>>;
84
85 fn supports_short_long_hidpp(&self) -> Option<(bool, bool)>;
91
92 async fn get_report_descriptor(
98 &self,
99 buf: &mut [u8],
100 ) -> Result<usize, Box<dyn Error + Sync + Send>>;
101}
102
103async fn supports_short_long_hidpp(
105 chan: &impl RawHidChannel,
106) -> Result<(bool, bool), ChannelError> {
107 if let Some((supports_short, supports_long)) = chan.supports_short_long_hidpp() {
108 return Ok((supports_short, supports_long));
109 }
110
111 let mut raw_descriptor = vec![0u8; MAX_REPORT_DESCRIPTOR_LENGTH];
112 let descriptor_size = chan.get_report_descriptor(&mut raw_descriptor).await?;
113
114 let descriptor = match ReportDescriptor::try_from(&raw_descriptor[..descriptor_size]) {
115 Ok(val) => val,
116 Err(err) => return Err(ChannelError::ReportDescriptor(err)),
117 };
118
119 let supports_short = descriptor
120 .find_input_report(&[SHORT_REPORT_ID])
121 .and_then(|report| report.fields().first())
122 .and_then(|field| match field {
123 Field::Array(arr) => Some(arr.usage_range()),
124 _ => None,
125 })
126 .is_some_and(|range| {
127 range
128 .lookup_usage(&Usage::from_page_and_id(
129 UsagePage::from(SHORT_REPORT_USAGE_PAGE),
130 UsageId::from(SHORT_REPORT_USAGE),
131 ))
132 .is_some()
133 });
134
135 let supports_long = descriptor
136 .find_input_report(&[LONG_REPORT_ID])
137 .and_then(|report| report.fields().first())
138 .and_then(|field| match field {
139 Field::Array(arr) => Some(arr.usage_range()),
140 _ => None,
141 })
142 .is_some_and(|range| {
143 range
144 .lookup_usage(&Usage::from_page_and_id(
145 UsagePage::from(LONG_REPORT_USAGE_PAGE),
146 UsageId::from(LONG_REPORT_USAGE),
147 ))
148 .is_some()
149 });
150
151 Ok((supports_short, supports_long))
152}
153
154#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
156pub enum HidppMessage {
157 Short([u8; SHORT_REPORT_LENGTH - 1]),
162
163 Long([u8; LONG_REPORT_LENGTH - 1]),
168}
169
170impl HidppMessage {
171 pub fn read_raw(data: &[u8]) -> Option<Self> {
173 if data.is_empty() {
174 return None;
175 }
176
177 if data[0] == SHORT_REPORT_ID {
178 if data.len() != SHORT_REPORT_LENGTH {
179 return None;
180 }
181
182 return Some(HidppMessage::Short(data[1..].try_into().unwrap()));
183 } else if data[0] == LONG_REPORT_ID {
184 if data.len() != LONG_REPORT_LENGTH {
185 return None;
186 }
187
188 return Some(HidppMessage::Long(data[1..].try_into().unwrap()));
189 }
190
191 None
192 }
193
194 pub fn write_raw(&self, buf: &mut [u8]) -> usize {
198 match self {
199 Self::Short(payload) => {
200 buf[0] = SHORT_REPORT_ID;
201 buf[1..SHORT_REPORT_LENGTH].copy_from_slice(payload);
202 SHORT_REPORT_LENGTH
203 }
204 Self::Long(payload) => {
205 buf[0] = LONG_REPORT_ID;
206 buf[1..LONG_REPORT_LENGTH].copy_from_slice(payload);
207 LONG_REPORT_LENGTH
208 }
209 }
210 }
211}
212
213type MessageListener = Box<dyn Fn(HidppMessage, bool) + Send>;
214
215pub struct HidppChannel {
217 pub supports_short: bool,
219
220 pub supports_long: bool,
222
223 pub vendor_id: u16,
225
226 pub product_id: u16,
228
229 raw_channel: Arc<dyn RawHidChannel>,
231
232 rotate_software_id: AtomicBool,
234
235 software_id: AtomicU8,
237
238 pending_messages: Arc<Mutex<VecDeque<PendingMessage>>>,
240
241 message_listeners: Arc<Mutex<HashMap<u32, MessageListener>>>,
244
245 read_thread_close: Option<oneshot::Sender<()>>,
247
248 read_thread_hdl: Option<JoinHandle<()>>,
251}
252
253impl Drop for HidppChannel {
254 fn drop(&mut self) {
255 if let Some(read_thread_close) = self.read_thread_close.take() {
256 let _ = read_thread_close.send(());
261 }
262
263 if let Some(read_thread_hdl) = self.read_thread_hdl.take() {
264 read_thread_hdl.join().unwrap();
265 }
266 }
267}
268
269struct PendingMessage {
271 response_predicate: Box<dyn Fn(&HidppMessage) -> bool + Send>,
274
275 sender: oneshot::Sender<HidppMessage>,
278}
279
280impl HidppChannel {
281 pub async fn from_raw_channel(raw: impl RawHidChannel) -> Result<Self, ChannelError> {
286 let (supports_short, supports_long) = supports_short_long_hidpp(&raw).await?;
287
288 if !supports_short && !supports_long {
289 return Err(ChannelError::HidppNotSupported);
290 }
291
292 let raw_channel_rc = Arc::new(raw);
293 let pending_messages_rc = Arc::new(Mutex::new(VecDeque::<PendingMessage>::new()));
294 let message_listeners_rc = Arc::new(Mutex::new(HashMap::<u32, MessageListener>::new()));
295
296 let (close_sender, mut close_receiver) = oneshot::channel::<()>();
297
298 let read_thread_hdl = thread::spawn({
299 let raw_channel = Arc::clone(&raw_channel_rc);
300 let pending_messages = Arc::clone(&pending_messages_rc);
301 let message_listeners = Arc::clone(&message_listeners_rc);
302
303 move || {
304 futures::executor::block_on(async {
305 let mut buf = [0u8; MAX_REPORT_LENGTH];
306
307 loop {
308 let res = select! {
309 _ = close_receiver => {
310 break;
311 },
312 res = raw_channel.read_report(&mut buf).fuse() => res
313 };
314
315 let Ok(len) = res else {
316 continue;
317 };
318
319 let Some(msg) = HidppMessage::read_raw(&buf[..len]) else {
320 continue;
321 };
322
323 let mut msgs = pending_messages.lock().unwrap();
324 let mut matched = false;
325 if let Some(pos) =
326 msgs.iter().position(|elem| (elem.response_predicate)(&msg))
327 {
328 let waiting = msgs.remove(pos).unwrap();
329 let _ = waiting.sender.send(msg);
330 matched = true;
331 }
332
333 for listener in message_listeners.lock().unwrap().values() {
334 listener(msg, matched);
335 }
336 }
337 });
338 }
339 });
340
341 Ok(Self {
342 supports_short,
343 supports_long,
344 vendor_id: raw_channel_rc.vendor_id(),
345 product_id: raw_channel_rc.product_id(),
346 raw_channel: raw_channel_rc,
347 rotate_software_id: AtomicBool::new(false),
348 software_id: AtomicU8::new(0x01),
349 pending_messages: pending_messages_rc,
350 message_listeners: message_listeners_rc,
351 read_thread_close: Some(close_sender),
352 read_thread_hdl: Some(read_thread_hdl),
353 })
354 }
355
356 pub fn set_sw_id(&self, sw_id: U4) {
362 self.software_id.store(sw_id.to_lo(), Ordering::SeqCst);
363 }
364
365 pub fn set_rotating_sw_id(&self, enable: bool) {
374 self.rotate_software_id.store(enable, Ordering::SeqCst);
375 }
376
377 pub fn get_sw_id(&self) -> U4 {
383 if self.rotate_software_id.load(Ordering::SeqCst) {
384 U4::from_lo(
385 self.software_id
386 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |old| {
387 Some(if old & 0x0f == 0x0f {
388 0x01
389 } else {
390 old.wrapping_add(1)
391 })
392 })
393 .unwrap(),
394 )
395 } else {
396 U4::from_lo(self.software_id.load(Ordering::SeqCst))
397 }
398 }
399
400 pub fn supports_msg(&self, msg: &HidppMessage) -> bool {
402 match msg {
403 HidppMessage::Short(_) => self.supports_short,
404 HidppMessage::Long(_) => self.supports_long,
405 }
406 }
407
408 fn normalize_outgoing(&self, msg: HidppMessage) -> HidppMessage {
418 match msg {
419 HidppMessage::Short(payload) if !self.supports_short && self.supports_long => {
420 HidppMessage::Long(short_payload_as_long(&payload))
421 }
422 other => other,
423 }
424 }
425
426 pub async fn send(
432 &self,
433 msg: HidppMessage,
434 response_predicate: impl Fn(&HidppMessage) -> bool + Send + 'static,
435 ) -> Result<HidppMessage, ChannelError> {
436 let msg = self.normalize_outgoing(msg);
437 if !self.supports_msg(&msg) {
438 return Err(ChannelError::MessageTypeNotSupported);
439 }
440
441 let (sender, receiver) = oneshot::channel::<HidppMessage>();
442
443 self.pending_messages
444 .lock()
445 .unwrap()
446 .push_back(PendingMessage {
447 response_predicate: Box::new(response_predicate),
448 sender,
449 });
450
451 self.send_and_forget(msg).await?;
452
453 receiver.await.map_err(|_| ChannelError::NoResponse)
454 }
455
456 pub async fn send_and_forget(&self, msg: HidppMessage) -> Result<(), ChannelError> {
461 let msg = self.normalize_outgoing(msg);
462 if !self.supports_msg(&msg) {
463 return Err(ChannelError::MessageTypeNotSupported);
464 }
465
466 let mut buf = [0u8; LONG_REPORT_LENGTH];
467 let len = msg.write_raw(&mut buf);
468 self.raw_channel
469 .write_report(&buf[..len])
470 .await
471 .map(|_| ())
472 .map_err(ChannelError::Implementation)
473 }
474
475 pub fn add_msg_listener(&self, listener: impl Fn(HidppMessage, bool) + Send + 'static) -> u32 {
480 let mut listeners = self.message_listeners.lock().unwrap();
481
482 let mut rng = rand::rng();
483 let mut hdl = rng.random::<u32>();
484 while listeners.contains_key(&hdl) {
485 hdl = rng.random::<u32>();
486 }
487
488 listeners.insert(hdl, Box::new(listener));
489 hdl
490 }
491
492 pub fn remove_msg_listener(&self, hdl: u32) -> bool {
496 self.message_listeners
497 .lock()
498 .unwrap()
499 .remove(&hdl)
500 .is_some()
501 }
502}
503
504#[derive(Debug, Error)]
507#[non_exhaustive]
508pub enum ChannelError {
509 #[error("the HID channel implementation returned an error")]
512 Implementation(#[from] Box<dyn Error + Sync + Send>),
513
514 #[error("the report descriptor could not be parsed")]
516 ReportDescriptor(hidreport::ParserError),
517
518 #[error("the HID channel does not support HID++")]
520 HidppNotSupported,
521
522 #[error("the channel does not support the given HID++ message type")]
525 MessageTypeNotSupported,
526
527 #[error("the device did not respond to the request")]
529 NoResponse,
530}
531
532fn short_payload_as_long(payload: &[u8; SHORT_REPORT_LENGTH - 1]) -> [u8; LONG_REPORT_LENGTH - 1] {
538 let mut long = [0u8; LONG_REPORT_LENGTH - 1];
539 long[..payload.len()].copy_from_slice(payload);
540 long
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546
547 #[test]
548 fn short_payload_widens_preserving_header_and_padding() {
549 let short = [0xff, 0x05, 0x1e, 0xaa, 0xbb, 0xcc];
551 let long = short_payload_as_long(&short);
552 assert_eq!(&long[..short.len()], &short[..]); assert!(long[short.len()..].iter().all(|&b| b == 0)); assert_eq!(long.len(), LONG_REPORT_LENGTH - 1);
555 }
556}