1use std::{
6 collections::{HashMap, VecDeque},
7 error::Error,
8 sync::{
9 Arc, Mutex,
10 atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
11 },
12 thread::{self, JoinHandle},
13 time::Duration,
14};
15
16use async_trait::async_trait;
17use futures::{FutureExt, channel::oneshot, select};
18use hidreport::{Field, Report, ReportDescriptor, Usage, UsageId, UsagePage};
19use rand::Rng;
20use thiserror::Error;
21
22use crate::nibble::U4;
23
24const MAX_REPORT_DESCRIPTOR_LENGTH: usize = 4096;
28
29const MAX_REPORT_LENGTH: usize = LONG_REPORT_LENGTH;
32
33pub const SEND_RESPONSE_TIMEOUT: Duration = Duration::from_secs(5);
37
38pub const SHORT_REPORT_ID: u8 = 0x10;
40
41pub const SHORT_REPORT_USAGE_PAGE: u16 = 0xff00;
43
44pub const SHORT_REPORT_USAGE: u16 = 0x0001;
46
47pub const SHORT_REPORT_LENGTH: usize = 7;
49
50pub const LONG_REPORT_ID: u8 = 0x11;
52
53pub const LONG_REPORT_USAGE_PAGE: u16 = 0xff00;
55
56pub const LONG_REPORT_USAGE: u16 = 0x0002;
58
59pub const LONG_REPORT_LENGTH: usize = 20;
61
62#[async_trait]
70pub trait RawHidChannel: Sync + Send + 'static {
71 fn vendor_id(&self) -> u16;
73
74 fn product_id(&self) -> u16;
76
77 async fn write_report(&self, src: &[u8]) -> Result<usize, Box<dyn Error + Sync + Send>>;
81
82 async fn read_report(&self, buf: &mut [u8]) -> Result<usize, Box<dyn Error + Sync + Send>>;
97
98 fn supports_short_long_hidpp(&self) -> Option<(bool, bool)>;
104
105 async fn get_report_descriptor(
111 &self,
112 buf: &mut [u8],
113 ) -> Result<usize, Box<dyn Error + Sync + Send>>;
114}
115
116async fn supports_short_long_hidpp(
118 chan: &impl RawHidChannel,
119) -> Result<(bool, bool), ChannelError> {
120 if let Some((supports_short, supports_long)) = chan.supports_short_long_hidpp() {
121 return Ok((supports_short, supports_long));
122 }
123
124 let mut raw_descriptor = vec![0u8; MAX_REPORT_DESCRIPTOR_LENGTH];
125 let descriptor_size = chan.get_report_descriptor(&mut raw_descriptor).await?;
126
127 let descriptor = match ReportDescriptor::try_from(&raw_descriptor[..descriptor_size]) {
128 Ok(val) => val,
129 Err(err) => return Err(ChannelError::ReportDescriptor(err)),
130 };
131
132 let supports_short = descriptor
133 .find_input_report(&[SHORT_REPORT_ID])
134 .and_then(|report| report.fields().first())
135 .and_then(|field| match field {
136 Field::Array(arr) => Some(arr.usage_range()),
137 _ => None,
138 })
139 .is_some_and(|range| {
140 range
141 .lookup_usage(&Usage::from_page_and_id(
142 UsagePage::from(SHORT_REPORT_USAGE_PAGE),
143 UsageId::from(SHORT_REPORT_USAGE),
144 ))
145 .is_some()
146 });
147
148 let supports_long = descriptor
149 .find_input_report(&[LONG_REPORT_ID])
150 .and_then(|report| report.fields().first())
151 .and_then(|field| match field {
152 Field::Array(arr) => Some(arr.usage_range()),
153 _ => None,
154 })
155 .is_some_and(|range| {
156 range
157 .lookup_usage(&Usage::from_page_and_id(
158 UsagePage::from(LONG_REPORT_USAGE_PAGE),
159 UsageId::from(LONG_REPORT_USAGE),
160 ))
161 .is_some()
162 });
163
164 Ok((supports_short, supports_long))
165}
166
167#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
169pub enum HidppMessage {
170 Short([u8; SHORT_REPORT_LENGTH - 1]),
175
176 Long([u8; LONG_REPORT_LENGTH - 1]),
181}
182
183impl HidppMessage {
184 pub fn read_raw(data: &[u8]) -> Option<Self> {
186 if data.is_empty() {
187 return None;
188 }
189
190 if data[0] == SHORT_REPORT_ID {
191 if data.len() != SHORT_REPORT_LENGTH {
192 return None;
193 }
194
195 return Some(HidppMessage::Short(data[1..].try_into().unwrap()));
196 } else if data[0] == LONG_REPORT_ID {
197 if data.len() != LONG_REPORT_LENGTH {
198 return None;
199 }
200
201 return Some(HidppMessage::Long(data[1..].try_into().unwrap()));
202 }
203
204 None
205 }
206
207 pub fn write_raw(&self, buf: &mut [u8]) -> usize {
211 match self {
212 Self::Short(payload) => {
213 buf[0] = SHORT_REPORT_ID;
214 buf[1..SHORT_REPORT_LENGTH].copy_from_slice(payload);
215 SHORT_REPORT_LENGTH
216 }
217 Self::Long(payload) => {
218 buf[0] = LONG_REPORT_ID;
219 buf[1..LONG_REPORT_LENGTH].copy_from_slice(payload);
220 LONG_REPORT_LENGTH
221 }
222 }
223 }
224}
225
226type MessageListener = Box<dyn Fn(HidppMessage, bool) + Send>;
227
228pub struct HidppChannel {
230 pub supports_short: bool,
232
233 pub supports_long: bool,
235
236 pub vendor_id: u16,
238
239 pub product_id: u16,
241
242 raw_channel: Arc<dyn RawHidChannel>,
244
245 rotate_software_id: AtomicBool,
247
248 software_id: AtomicU8,
250
251 pending_messages: Arc<Mutex<VecDeque<PendingMessage>>>,
253
254 pending_message_id: AtomicU64,
256
257 message_listeners: Arc<Mutex<HashMap<u32, MessageListener>>>,
260
261 read_thread_close: Option<oneshot::Sender<()>>,
263
264 read_thread_hdl: Option<JoinHandle<()>>,
267}
268
269impl Drop for HidppChannel {
270 fn drop(&mut self) {
271 if let Some(read_thread_close) = self.read_thread_close.take() {
272 let _ = read_thread_close.send(());
277 }
278
279 if let Some(read_thread_hdl) = self.read_thread_hdl.take() {
280 read_thread_hdl.join().unwrap();
281 }
282 }
283}
284
285struct PendingMessage {
287 id: u64,
289
290 response_predicate: Box<dyn Fn(&HidppMessage) -> bool + Send>,
293
294 sender: oneshot::Sender<HidppMessage>,
297}
298
299impl HidppChannel {
300 pub async fn from_raw_channel(raw: impl RawHidChannel) -> Result<Self, ChannelError> {
305 let (supports_short, supports_long) = supports_short_long_hidpp(&raw).await?;
306
307 if !supports_short && !supports_long {
308 return Err(ChannelError::HidppNotSupported);
309 }
310
311 let raw_channel_rc = Arc::new(raw);
312 let pending_messages_rc = Arc::new(Mutex::new(VecDeque::<PendingMessage>::new()));
313 let message_listeners_rc = Arc::new(Mutex::new(HashMap::<u32, MessageListener>::new()));
314
315 let (close_sender, mut close_receiver) = oneshot::channel::<()>();
316
317 let read_thread_hdl = thread::spawn({
318 let raw_channel = Arc::clone(&raw_channel_rc);
319 let pending_messages = Arc::clone(&pending_messages_rc);
320 let message_listeners = Arc::clone(&message_listeners_rc);
321
322 move || {
323 futures::executor::block_on(async {
324 let mut buf = [0u8; MAX_REPORT_LENGTH];
325
326 loop {
327 let res = select! {
328 _ = close_receiver => {
329 break;
330 },
331 res = raw_channel.read_report(&mut buf).fuse() => res
332 };
333
334 let Ok(len) = res else {
335 continue;
336 };
337
338 let Some(msg) = HidppMessage::read_raw(&buf[..len]) else {
339 continue;
340 };
341
342 let mut msgs = pending_messages.lock().unwrap();
343 let mut matched = false;
344 if let Some(pos) =
345 msgs.iter().position(|elem| (elem.response_predicate)(&msg))
346 {
347 let waiting = msgs.remove(pos).unwrap();
348 let _ = waiting.sender.send(msg);
349 matched = true;
350 }
351
352 for listener in message_listeners.lock().unwrap().values() {
353 listener(msg, matched);
354 }
355 }
356 });
357 }
358 });
359
360 Ok(Self {
361 supports_short,
362 supports_long,
363 vendor_id: raw_channel_rc.vendor_id(),
364 product_id: raw_channel_rc.product_id(),
365 raw_channel: raw_channel_rc,
366 rotate_software_id: AtomicBool::new(false),
367 software_id: AtomicU8::new(0x01),
368 pending_messages: pending_messages_rc,
369 pending_message_id: AtomicU64::new(1),
370 message_listeners: message_listeners_rc,
371 read_thread_close: Some(close_sender),
372 read_thread_hdl: Some(read_thread_hdl),
373 })
374 }
375
376 pub fn set_sw_id(&self, sw_id: U4) {
382 self.software_id.store(sw_id.to_lo(), Ordering::SeqCst);
383 }
384
385 pub fn set_rotating_sw_id(&self, enable: bool) {
394 self.rotate_software_id.store(enable, Ordering::SeqCst);
395 }
396
397 pub fn get_sw_id(&self) -> U4 {
403 if self.rotate_software_id.load(Ordering::SeqCst) {
404 U4::from_lo(
405 self.software_id
406 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |old| {
407 Some(if old & 0x0f == 0x0f {
408 0x01
409 } else {
410 old.wrapping_add(1)
411 })
412 })
413 .unwrap(),
414 )
415 } else {
416 U4::from_lo(self.software_id.load(Ordering::SeqCst))
417 }
418 }
419
420 pub fn supports_msg(&self, msg: &HidppMessage) -> bool {
422 match msg {
423 HidppMessage::Short(_) => self.supports_short,
424 HidppMessage::Long(_) => self.supports_long,
425 }
426 }
427
428 fn normalize_outgoing(&self, msg: HidppMessage) -> HidppMessage {
438 match msg {
439 HidppMessage::Short(payload) if !self.supports_short && self.supports_long => {
440 HidppMessage::Long(short_payload_as_long(&payload))
441 }
442 other => other,
443 }
444 }
445
446 pub async fn send(
455 &self,
456 msg: HidppMessage,
457 response_predicate: impl Fn(&HidppMessage) -> bool + Send + 'static,
458 ) -> Result<HidppMessage, ChannelError> {
459 self.send_with_timeout(msg, response_predicate, SEND_RESPONSE_TIMEOUT)
460 .await
461 }
462
463 pub async fn send_with_timeout(
477 &self,
478 msg: HidppMessage,
479 response_predicate: impl Fn(&HidppMessage) -> bool + Send + 'static,
480 timeout: Duration,
481 ) -> Result<HidppMessage, ChannelError> {
482 let msg = self.normalize_outgoing(msg);
483 if !self.supports_msg(&msg) {
484 return Err(ChannelError::MessageTypeNotSupported);
485 }
486
487 let (sender, receiver) = oneshot::channel::<HidppMessage>();
488 let pending_id = self.pending_message_id.fetch_add(1, Ordering::SeqCst);
489
490 {
491 let mut pending = self.pending_messages.lock().unwrap();
492 pending.retain(|m| !m.sender.is_canceled());
501 pending.push_back(PendingMessage {
502 id: pending_id,
503 response_predicate: Box::new(response_predicate),
504 sender,
505 });
506 }
507
508 let mut request = std::pin::pin!(
512 async {
513 self.send_and_forget(msg).await?;
514 receiver.await.map_err(|_| ChannelError::NoResponse)
515 }
516 .fuse()
517 );
518
519 let result = select! {
520 result = request => result,
521 _ = futures_timer::Delay::new(timeout).fuse() => Err(ChannelError::Timeout),
522 };
523
524 if result.is_err() {
525 self.remove_pending_message(pending_id);
529 }
530
531 result
532 }
533
534 fn remove_pending_message(&self, id: u64) {
535 let mut pending = self.pending_messages.lock().unwrap();
536 if let Some(pos) = pending.iter().position(|msg| msg.id == id) {
537 pending.remove(pos);
538 }
539 }
540
541 pub async fn send_and_forget(&self, msg: HidppMessage) -> Result<(), ChannelError> {
546 let msg = self.normalize_outgoing(msg);
547 if !self.supports_msg(&msg) {
548 return Err(ChannelError::MessageTypeNotSupported);
549 }
550
551 let mut buf = [0u8; LONG_REPORT_LENGTH];
552 let len = msg.write_raw(&mut buf);
553 self.raw_channel
554 .write_report(&buf[..len])
555 .await
556 .map(|_| ())
557 .map_err(ChannelError::Implementation)
558 }
559
560 pub fn add_msg_listener(&self, listener: impl Fn(HidppMessage, bool) + Send + 'static) -> u32 {
565 let mut listeners = self.message_listeners.lock().unwrap();
566
567 let mut rng = rand::rng();
568 let mut hdl = rng.random::<u32>();
569 while listeners.contains_key(&hdl) {
570 hdl = rng.random::<u32>();
571 }
572
573 listeners.insert(hdl, Box::new(listener));
574 hdl
575 }
576
577 pub fn remove_msg_listener(&self, hdl: u32) -> bool {
581 self.message_listeners
582 .lock()
583 .unwrap()
584 .remove(&hdl)
585 .is_some()
586 }
587}
588
589#[derive(Debug, Error)]
592#[non_exhaustive]
593pub enum ChannelError {
594 #[error("the HID channel implementation returned an error")]
597 Implementation(#[from] Box<dyn Error + Sync + Send>),
598
599 #[error("the report descriptor could not be parsed")]
601 ReportDescriptor(hidreport::ParserError),
602
603 #[error("the HID channel does not support HID++")]
605 HidppNotSupported,
606
607 #[error("the channel does not support the given HID++ message type")]
610 MessageTypeNotSupported,
611
612 #[error("the device did not respond to the request")]
614 NoResponse,
615
616 #[error("the request timed out before the device responded")]
620 Timeout,
621}
622
623fn short_payload_as_long(payload: &[u8; SHORT_REPORT_LENGTH - 1]) -> [u8; LONG_REPORT_LENGTH - 1] {
629 let mut long = [0u8; LONG_REPORT_LENGTH - 1];
630 long[..payload.len()].copy_from_slice(payload);
631 long
632}
633
634#[cfg(test)]
635mod tests {
636 use super::*;
637 use std::{
638 io,
639 sync::{Arc, Mutex},
640 time::{Duration, Instant},
641 };
642
643 #[test]
644 fn short_payload_widens_preserving_header_and_padding() {
645 let short = [0xff, 0x05, 0x1e, 0xaa, 0xbb, 0xcc];
647 let long = short_payload_as_long(&short);
648 assert_eq!(&long[..short.len()], &short[..]); assert!(long[short.len()..].iter().all(|&b| b == 0)); assert_eq!(long.len(), LONG_REPORT_LENGTH - 1);
651 }
652
653 #[test]
654 fn send_returns_response_before_timeout() {
655 futures::executor::block_on(async {
656 let (raw, handle) = MockRawHidChannel::new();
657 let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
658
659 let request = short_msg(0x10);
660 let response = short_msg(0x20);
661 handle.queue_response(response);
662
663 let actual = channel
664 .send_with_timeout(
665 request,
666 move |candidate| *candidate == response,
667 Duration::from_secs(1),
668 )
669 .await
670 .unwrap();
671
672 assert_eq!(actual, response);
673 assert_eq!(handle.written_reports().len(), 1);
674 assert_pending_empty(&channel);
675 });
676 }
677
678 #[test]
679 fn send_times_out_and_removes_pending_message() {
680 futures::executor::block_on(async {
681 let (raw, handle) = MockRawHidChannel::new();
682 let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
683 let request = short_msg(0x10);
684 let response = short_msg(0x20);
685
686 let started = Instant::now();
687 let err = channel
688 .send_with_timeout(
689 request,
690 move |candidate| *candidate == response,
691 Duration::from_millis(25),
692 )
693 .await
694 .unwrap_err();
695
696 assert!(matches!(err, ChannelError::Timeout));
697 assert!(started.elapsed() < Duration::from_secs(1));
698 assert_eq!(handle.written_reports().len(), 1);
699 assert_pending_empty(&channel);
700 });
701 }
702
703 #[test]
704 fn timeout_removes_only_its_own_pending_message() {
705 futures::executor::block_on(async {
706 let (raw, handle) = MockRawHidChannel::new();
707 let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
708
709 let never_answered = short_msg(0x20);
710 let slow_response = short_msg(0x21);
711
712 let timed_out = channel.send_with_timeout(
713 short_msg(0x10),
714 move |candidate| *candidate == never_answered,
715 Duration::from_millis(25),
716 );
717 let answered = channel.send_with_timeout(
718 short_msg(0x11),
719 move |candidate| *candidate == slow_response,
720 Duration::from_secs(1),
721 );
722 let respond_late = async {
725 futures_timer::Delay::new(Duration::from_millis(100)).await;
726 handle.send_incoming(slow_response).await;
727 };
728
729 let (timed_out, answered, ()) = futures::join!(timed_out, answered, respond_late);
730
731 assert!(matches!(timed_out.unwrap_err(), ChannelError::Timeout));
732 assert_eq!(answered.unwrap(), slow_response);
733 assert_pending_empty(&channel);
734 });
735 }
736
737 #[test]
738 fn late_response_after_timeout_is_ignored() {
739 futures::executor::block_on(async {
740 let (raw, handle) = MockRawHidChannel::new();
741 let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
742 let events = Arc::new(Mutex::new(Vec::new()));
743 let listener_events = Arc::clone(&events);
744 channel.add_msg_listener(move |msg, matched| {
745 listener_events.lock().unwrap().push((msg, matched));
746 });
747
748 let request = short_msg(0x10);
749 let late_response = short_msg(0x20);
750 let err = channel
751 .send_with_timeout(
752 request,
753 move |candidate| *candidate == late_response,
754 Duration::from_millis(25),
755 )
756 .await
757 .unwrap_err();
758
759 assert!(matches!(err, ChannelError::Timeout));
760 assert_pending_empty(&channel);
761
762 handle.send_incoming(late_response).await;
763 wait_for_event_count(&events, 1).await;
764 assert_eq!(events.lock().unwrap()[0], (late_response, false));
765 assert_pending_empty(&channel);
766
767 let later_request = short_msg(0x30);
768 let later_response = short_msg(0x40);
769 handle.queue_response(later_response);
770 let actual = channel
771 .send_with_timeout(
772 later_request,
773 move |candidate| *candidate == later_response,
774 Duration::from_secs(1),
775 )
776 .await
777 .unwrap();
778
779 assert_eq!(actual, later_response);
780 wait_for_event_count(&events, 2).await;
781 assert_eq!(events.lock().unwrap()[1], (later_response, true));
782 assert_pending_empty(&channel);
783 });
784 }
785
786 #[test]
787 fn send_and_forget_writes_without_pending_message() {
788 futures::executor::block_on(async {
789 let (raw, handle) = MockRawHidChannel::new();
790 let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
791
792 channel.send_and_forget(short_msg(0x10)).await.unwrap();
793
794 assert_eq!(handle.written_reports().len(), 1);
795 assert_pending_empty(&channel);
796 });
797 }
798
799 #[derive(Clone)]
800 struct MockRawHidHandle {
801 incoming_tx: async_channel::Sender<Vec<u8>>,
802 written_reports: Arc<Mutex<Vec<Vec<u8>>>>,
803 responses_on_write: Arc<Mutex<VecDeque<Vec<u8>>>>,
804 }
805
806 impl MockRawHidHandle {
807 fn queue_response(&self, msg: HidppMessage) {
808 self.responses_on_write
809 .lock()
810 .unwrap()
811 .push_back(raw_report(msg));
812 }
813
814 async fn send_incoming(&self, msg: HidppMessage) {
815 self.incoming_tx.send(raw_report(msg)).await.unwrap();
816 }
817
818 fn written_reports(&self) -> Vec<Vec<u8>> {
819 self.written_reports.lock().unwrap().clone()
820 }
821 }
822
823 struct MockRawHidChannel {
824 incoming_tx: async_channel::Sender<Vec<u8>>,
825 incoming_rx: async_channel::Receiver<Vec<u8>>,
826 written_reports: Arc<Mutex<Vec<Vec<u8>>>>,
827 responses_on_write: Arc<Mutex<VecDeque<Vec<u8>>>>,
828 }
829
830 impl MockRawHidChannel {
831 fn new() -> (Self, MockRawHidHandle) {
832 let (incoming_tx, incoming_rx) = async_channel::unbounded();
833 let written_reports = Arc::new(Mutex::new(Vec::new()));
834 let responses_on_write = Arc::new(Mutex::new(VecDeque::new()));
835
836 let handle = MockRawHidHandle {
837 incoming_tx: incoming_tx.clone(),
838 written_reports: Arc::clone(&written_reports),
839 responses_on_write: Arc::clone(&responses_on_write),
840 };
841
842 (
843 Self {
844 incoming_tx,
845 incoming_rx,
846 written_reports,
847 responses_on_write,
848 },
849 handle,
850 )
851 }
852 }
853
854 #[async_trait]
855 impl RawHidChannel for MockRawHidChannel {
856 fn vendor_id(&self) -> u16 {
857 0x046d
858 }
859
860 fn product_id(&self) -> u16 {
861 0xc539
862 }
863
864 async fn write_report(&self, src: &[u8]) -> Result<usize, Box<dyn Error + Sync + Send>> {
865 self.written_reports.lock().unwrap().push(src.to_vec());
866 let response = self.responses_on_write.lock().unwrap().pop_front();
867 if let Some(response) = response {
868 self.incoming_tx.send(response).await.unwrap();
869 }
870
871 Ok(src.len())
872 }
873
874 async fn read_report(&self, buf: &mut [u8]) -> Result<usize, Box<dyn Error + Sync + Send>> {
875 let report = self.incoming_rx.recv().await.map_err(|_| mock_error())?;
876 let len = report.len().min(buf.len());
877 buf[..len].copy_from_slice(&report[..len]);
878 Ok(len)
879 }
880
881 fn supports_short_long_hidpp(&self) -> Option<(bool, bool)> {
882 Some((true, true))
883 }
884
885 async fn get_report_descriptor(
886 &self,
887 _buf: &mut [u8],
888 ) -> Result<usize, Box<dyn Error + Sync + Send>> {
889 unreachable!("mock declares HID++ support")
890 }
891 }
892
893 fn short_msg(marker: u8) -> HidppMessage {
894 HidppMessage::Short([0xff, marker, 0x10, marker, marker, marker])
895 }
896
897 fn raw_report(msg: HidppMessage) -> Vec<u8> {
898 let mut buf = [0u8; LONG_REPORT_LENGTH];
899 let len = msg.write_raw(&mut buf);
900 buf[..len].to_vec()
901 }
902
903 fn assert_pending_empty(channel: &HidppChannel) {
904 assert!(channel.pending_messages.lock().unwrap().is_empty());
905 }
906
907 async fn wait_for_event_count(events: &Arc<Mutex<Vec<(HidppMessage, bool)>>>, count: usize) {
908 let started = Instant::now();
909 while started.elapsed() < Duration::from_secs(1) {
910 if events.lock().unwrap().len() >= count {
911 return;
912 }
913 futures_timer::Delay::new(Duration::from_millis(10)).await;
914 }
915
916 panic!("timed out waiting for {count} listener events");
917 }
918
919 fn mock_error() -> Box<dyn Error + Sync + Send> {
920 Box::new(io::Error::new(
921 io::ErrorKind::BrokenPipe,
922 "mock channel closed",
923 ))
924 }
925}