Skip to main content

hidpp/
channel.rs

1//! Implements basic messaging across HID and HID++ channels.
2//!
3//! This includes mapping incoming messages to previously sent requests.
4
5use std::{
6    collections::{HashMap, VecDeque},
7    error::Error,
8    sync::{
9        Arc, Mutex,
10        atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
11    },
12    thread::{self, JoinHandle},
13    time::Duration,
14};
15
16use async_trait::async_trait;
17use futures::{FutureExt, channel::oneshot, select};
18use hidreport::{Field, Report, ReportDescriptor, Usage, UsageId, UsagePage};
19use rand::Rng;
20use thiserror::Error;
21
22use crate::nibble::U4;
23
24/// hidapi defines this as the maximum EXPECTED size of report descriptors.
25/// We will trust this for now, but a workaround may be required if devices do
26/// in fact return longer descriptors.
27const MAX_REPORT_DESCRIPTOR_LENGTH: usize = 4096;
28
29/// This is the size of the buffer incoming reports are read into.
30/// As we only care about HID++ reports, this equals to [`LONG_REPORT_LENGTH`].
31const MAX_REPORT_LENGTH: usize = LONG_REPORT_LENGTH;
32
33/// The default time budget for a [`HidppChannel::send`] request: the report
34/// write plus the wait for a matching response. Callers that need a different
35/// budget can use [`HidppChannel::send_with_timeout`].
36pub const SEND_RESPONSE_TIMEOUT: Duration = Duration::from_secs(5);
37
38/// The ID of the HID report that is used to transmit short HID++ messages.
39pub const SHORT_REPORT_ID: u8 = 0x10;
40
41/// The HID usage page ID of short HID++ message reports.
42pub const SHORT_REPORT_USAGE_PAGE: u16 = 0xff00;
43
44/// The HID usage ID of short HID++ message reports.
45pub const SHORT_REPORT_USAGE: u16 = 0x0001;
46
47/// The length of short HID++ message reports (including report ID).
48pub const SHORT_REPORT_LENGTH: usize = 7;
49
50/// The ID of the HID report that is used to transmit long HID++ messages.
51pub const LONG_REPORT_ID: u8 = 0x11;
52
53/// The HID usage page ID of long HID++ message reports.
54pub const LONG_REPORT_USAGE_PAGE: u16 = 0xff00;
55
56/// The HID usage ID of long HID++ message reports.
57pub const LONG_REPORT_USAGE: u16 = 0x0002;
58
59/// The length of long HID++ message reports (including report ID).
60pub const LONG_REPORT_LENGTH: usize = 20;
61
62/// Represents an arbitrary HID communication channel that is both readable and
63/// writable. It has to support async I/O.
64///
65/// Any type this trait is implemented for can be used for HID(++)
66/// communication. If a specific channel supports HID++ is determined at a later
67/// stage and is not directly related to potential implementations of this
68/// trait.
69#[async_trait]
70pub trait RawHidChannel: Sync + Send + 'static {
71    /// Provides the vendor ID of the connected HID device.
72    fn vendor_id(&self) -> u16;
73
74    /// Provides the product ID of the connected HID device.
75    fn product_id(&self) -> u16;
76
77    /// Writes a raw report to the channel.
78    ///
79    /// Returns the exact amount of written bytes on success.
80    async fn write_report(&self, src: &[u8]) -> Result<usize, Box<dyn Error + Sync + Send>>;
81
82    /// Reads a raw report from the channel.
83    ///
84    /// If the buffer is not large enough to fit the whole report, its remainder
85    /// should be discarded and must not be returned by any succeeding call to
86    /// [`Self::read_report`].
87    ///
88    /// Returns the exact amount or read bytes on success. An `Err` is treated
89    /// as transient: the [`HidppChannel`] read loop logs it and retries, so an
90    /// implementation must not surface a condition that will never clear (it
91    /// would busy-spin the loop). For a *permanent* failure — the device is
92    /// gone and no report will ever arrive — the future may instead park
93    /// forever. That is sound because the read loop always races this future
94    /// against the channel's close signal in a `select!`; any other caller
95    /// must do the same and must not await `read_report` bare.
96    async fn read_report(&self, buf: &mut [u8]) -> Result<usize, Box<dyn Error + Sync + Send>>;
97
98    /// If the implementation already knows whether the underlying HID channel
99    /// supports HID++ messages, it should return `Some((supports_short,
100    /// supports_long))` from this method.
101    ///
102    /// In this case, the report descriptor will not be read and parsed.
103    fn supports_short_long_hidpp(&self) -> Option<(bool, bool)>;
104
105    /// Retrieves the raw HID report descriptor from the channel.
106    ///
107    /// This is used to determine whether the channel supports HID++.
108    ///
109    /// Returns the exact size of the report descriptor on success.
110    async fn get_report_descriptor(
111        &self,
112        buf: &mut [u8],
113    ) -> Result<usize, Box<dyn Error + Sync + Send>>;
114}
115
116/// Checks whether a raw channel supports short or long HID++ messages.
117async fn supports_short_long_hidpp(
118    chan: &impl RawHidChannel,
119) -> Result<(bool, bool), ChannelError> {
120    if let Some((supports_short, supports_long)) = chan.supports_short_long_hidpp() {
121        return Ok((supports_short, supports_long));
122    }
123
124    let mut raw_descriptor = vec![0u8; MAX_REPORT_DESCRIPTOR_LENGTH];
125    let descriptor_size = chan.get_report_descriptor(&mut raw_descriptor).await?;
126
127    let descriptor = match ReportDescriptor::try_from(&raw_descriptor[..descriptor_size]) {
128        Ok(val) => val,
129        Err(err) => return Err(ChannelError::ReportDescriptor(err)),
130    };
131
132    let supports_short = descriptor
133        .find_input_report(&[SHORT_REPORT_ID])
134        .and_then(|report| report.fields().first())
135        .and_then(|field| match field {
136            Field::Array(arr) => Some(arr.usage_range()),
137            _ => None,
138        })
139        .is_some_and(|range| {
140            range
141                .lookup_usage(&Usage::from_page_and_id(
142                    UsagePage::from(SHORT_REPORT_USAGE_PAGE),
143                    UsageId::from(SHORT_REPORT_USAGE),
144                ))
145                .is_some()
146        });
147
148    let supports_long = descriptor
149        .find_input_report(&[LONG_REPORT_ID])
150        .and_then(|report| report.fields().first())
151        .and_then(|field| match field {
152            Field::Array(arr) => Some(arr.usage_range()),
153            _ => None,
154        })
155        .is_some_and(|range| {
156            range
157                .lookup_usage(&Usage::from_page_and_id(
158                    UsagePage::from(LONG_REPORT_USAGE_PAGE),
159                    UsageId::from(LONG_REPORT_USAGE),
160                ))
161                .is_some()
162        });
163
164    Ok((supports_short, supports_long))
165}
166
167/// Represents an unversioned HID++ message.
168#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
169pub enum HidppMessage {
170    /// Represents a short HID++ message.
171    ///
172    /// Please check [`HidppChannel::supports_short`] before sending this kind
173    /// of message.
174    Short([u8; SHORT_REPORT_LENGTH - 1]),
175
176    /// Represents a long HID++ message.
177    ///
178    /// Please check [`HidppChannel::supports_long`] before sending this kind of
179    /// message.
180    Long([u8; LONG_REPORT_LENGTH - 1]),
181}
182
183impl HidppMessage {
184    /// Tries to read a HID++ message from raw data.
185    pub fn read_raw(data: &[u8]) -> Option<Self> {
186        if data.is_empty() {
187            return None;
188        }
189
190        if data[0] == SHORT_REPORT_ID {
191            if data.len() != SHORT_REPORT_LENGTH {
192                return None;
193            }
194
195            return Some(HidppMessage::Short(data[1..].try_into().unwrap()));
196        } else if data[0] == LONG_REPORT_ID {
197            if data.len() != LONG_REPORT_LENGTH {
198                return None;
199            }
200
201            return Some(HidppMessage::Long(data[1..].try_into().unwrap()));
202        }
203
204        None
205    }
206
207    /// Writes a HID++ message in its raw byte form into a buffer.
208    ///
209    /// Returns the amount of written bytes.
210    pub fn write_raw(&self, buf: &mut [u8]) -> usize {
211        match self {
212            Self::Short(payload) => {
213                buf[0] = SHORT_REPORT_ID;
214                buf[1..SHORT_REPORT_LENGTH].copy_from_slice(payload);
215                SHORT_REPORT_LENGTH
216            }
217            Self::Long(payload) => {
218                buf[0] = LONG_REPORT_ID;
219                buf[1..LONG_REPORT_LENGTH].copy_from_slice(payload);
220                LONG_REPORT_LENGTH
221            }
222        }
223    }
224}
225
226type MessageListener = Box<dyn Fn(HidppMessage, bool) + Send>;
227
228/// Represents a HID communication channel supporting HID++.
229pub struct HidppChannel {
230    /// Whether the channel supports short (7 bytes) HID++ messages.
231    pub supports_short: bool,
232
233    /// Whether the channel supports long (20 bytes) HID++ messages.
234    pub supports_long: bool,
235
236    /// The vendor ID of the connected HID device.
237    pub vendor_id: u16,
238
239    // The product ID of the connected HID device.
240    pub product_id: u16,
241
242    /// The underlying raw HID channel.
243    raw_channel: Arc<dyn RawHidChannel>,
244
245    /// Whether to rotate the [`Self::software_id`].
246    rotate_software_id: AtomicBool,
247
248    /// The software ID to provide at the next call to [`Self::get_sw_id`].
249    software_id: AtomicU8,
250
251    /// All sent messages that are waiting for a response.
252    pending_messages: Arc<Mutex<VecDeque<PendingMessage>>>,
253
254    /// The request ID assigned to the next pending message.
255    pending_message_id: AtomicU64,
256
257    /// Registered listeners that will receive notifications about incoming
258    /// messages.
259    message_listeners: Arc<Mutex<HashMap<u32, MessageListener>>>,
260
261    /// The sender signaling the read thread to stop.
262    read_thread_close: Option<oneshot::Sender<()>>,
263
264    /// The handle to the read thread. Should be joined after signaling
265    /// [`Self::read_thread_close`].
266    read_thread_hdl: Option<JoinHandle<()>>,
267}
268
269impl Drop for HidppChannel {
270    fn drop(&mut self) {
271        if let Some(read_thread_close) = self.read_thread_close.take() {
272            // This only fails if the receiving end, which is owned by the read thread in
273            // this case, is dropped.
274            // This just means that the read thread is already stopped, so we can ignore the
275            // error here.
276            let _ = read_thread_close.send(());
277        }
278
279        if let Some(read_thread_hdl) = self.read_thread_hdl.take() {
280            read_thread_hdl.join().unwrap();
281        }
282    }
283}
284
285/// Represents a message that was sent and is waiting for a response.
286struct PendingMessage {
287    /// Unique ID used to remove this request if it times out.
288    id: u64,
289
290    /// The predicate that has to match for an incoming message to be classified
291    /// as the response.
292    response_predicate: Box<dyn Fn(&HidppMessage) -> bool + Send>,
293
294    /// The oneshot sender used to provide the response message to the receiving
295    /// end.
296    sender: oneshot::Sender<HidppMessage>,
297}
298
299impl HidppChannel {
300    /// Tries to construct a HID++ channel from a raw HID channel.
301    ///
302    /// If the given HID channel does not support HID++,
303    /// [`ChannelError::HidppNotSupported`] will be returned.
304    pub async fn from_raw_channel(raw: impl RawHidChannel) -> Result<Self, ChannelError> {
305        let (supports_short, supports_long) = supports_short_long_hidpp(&raw).await?;
306
307        if !supports_short && !supports_long {
308            return Err(ChannelError::HidppNotSupported);
309        }
310
311        let raw_channel_rc = Arc::new(raw);
312        let pending_messages_rc = Arc::new(Mutex::new(VecDeque::<PendingMessage>::new()));
313        let message_listeners_rc = Arc::new(Mutex::new(HashMap::<u32, MessageListener>::new()));
314
315        let (close_sender, mut close_receiver) = oneshot::channel::<()>();
316
317        let read_thread_hdl = thread::spawn({
318            let raw_channel = Arc::clone(&raw_channel_rc);
319            let pending_messages = Arc::clone(&pending_messages_rc);
320            let message_listeners = Arc::clone(&message_listeners_rc);
321
322            move || {
323                futures::executor::block_on(async {
324                    let mut buf = [0u8; MAX_REPORT_LENGTH];
325
326                    loop {
327                        let res = select! {
328                            _ = close_receiver => {
329                                break;
330                            },
331                            res = raw_channel.read_report(&mut buf).fuse() => res
332                        };
333
334                        let Ok(len) = res else {
335                            continue;
336                        };
337
338                        let Some(msg) = HidppMessage::read_raw(&buf[..len]) else {
339                            continue;
340                        };
341
342                        let mut msgs = pending_messages.lock().unwrap();
343                        let mut matched = false;
344                        if let Some(pos) =
345                            msgs.iter().position(|elem| (elem.response_predicate)(&msg))
346                        {
347                            let waiting = msgs.remove(pos).unwrap();
348                            let _ = waiting.sender.send(msg);
349                            matched = true;
350                        }
351
352                        for listener in message_listeners.lock().unwrap().values() {
353                            listener(msg, matched);
354                        }
355                    }
356                });
357            }
358        });
359
360        Ok(Self {
361            supports_short,
362            supports_long,
363            vendor_id: raw_channel_rc.vendor_id(),
364            product_id: raw_channel_rc.product_id(),
365            raw_channel: raw_channel_rc,
366            rotate_software_id: AtomicBool::new(false),
367            software_id: AtomicU8::new(0x01),
368            pending_messages: pending_messages_rc,
369            pending_message_id: AtomicU64::new(1),
370            message_listeners: message_listeners_rc,
371            read_thread_close: Some(close_sender),
372            read_thread_hdl: Some(read_thread_hdl),
373        })
374    }
375
376    /// Sets the software ID that should be returned by the next call to
377    /// [`Self::get_sw_id`].
378    ///
379    /// Using software ID `0` is highly discouraged as it is used for device
380    /// notifications.
381    pub fn set_sw_id(&self, sw_id: U4) {
382        self.software_id.store(sw_id.to_lo(), Ordering::SeqCst);
383    }
384
385    /// Sets whether the software ID returned by a call to [`Self::get_sw_id`]
386    /// should increment (and potentially wrap around) after each call.
387    ///
388    /// This comes in handy when trying to map responses to requests
389    /// consistently.
390    ///
391    /// Software ID `0` will be skipped in the rotation process as it is
392    /// reserved for device notifications.
393    pub fn set_rotating_sw_id(&self, enable: bool) {
394        self.rotate_software_id.store(enable, Ordering::SeqCst);
395    }
396
397    /// Provides a software ID that can be used to send a HID++ message across
398    /// the channel.
399    ///
400    /// This method should be called separately for every message to send as it
401    /// may rotate (as indicated by [`Self::set_rotating_sw_id`]).
402    pub fn get_sw_id(&self) -> U4 {
403        if self.rotate_software_id.load(Ordering::SeqCst) {
404            U4::from_lo(
405                self.software_id
406                    .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |old| {
407                        Some(if old & 0x0f == 0x0f {
408                            0x01
409                        } else {
410                            old.wrapping_add(1)
411                        })
412                    })
413                    .unwrap(),
414            )
415        } else {
416            U4::from_lo(self.software_id.load(Ordering::SeqCst))
417        }
418    }
419
420    /// Checks whether the channel supports the given HID++ message.
421    pub fn supports_msg(&self, msg: &HidppMessage) -> bool {
422        match msg {
423            HidppMessage::Short(_) => self.supports_short,
424            HidppMessage::Long(_) => self.supports_long,
425        }
426    }
427
428    /// Re-frames a short message as long on a long-only channel — a device that
429    /// exposes only the long HID++ report (e.g. a Bluetooth-LE-direct mouse on
430    /// macOS, where `IOHIDDeviceSetReport` rejects the short report). The HID++
431    /// header bytes sit at the same offsets in both widths, so the only change
432    /// is the report id plus zero-padding the extra payload; the device answers
433    /// with a long report, which still matches the request by header. A no-op on
434    /// channels that advertise short support.
435    ///
436    /// (OpenLogi local addition — candidate for upstreaming.)
437    fn normalize_outgoing(&self, msg: HidppMessage) -> HidppMessage {
438        match msg {
439            HidppMessage::Short(payload) if !self.supports_short && self.supports_long => {
440                HidppMessage::Long(short_payload_as_long(&payload))
441            }
442            other => other,
443        }
444    }
445
446    /// Sends a HID++ message across the channel and waits for a response.
447    ///
448    /// If no response is expected/required, use [`Self::send_and_forget`].
449    ///
450    /// The whole request — the report write plus the wait for a matching
451    /// response — is bounded by [`SEND_RESPONSE_TIMEOUT`]; the future resolves
452    /// to [`ChannelError::Timeout`] on elapse. Use [`Self::send_with_timeout`]
453    /// to choose a different budget.
454    pub async fn send(
455        &self,
456        msg: HidppMessage,
457        response_predicate: impl Fn(&HidppMessage) -> bool + Send + 'static,
458    ) -> Result<HidppMessage, ChannelError> {
459        self.send_with_timeout(msg, response_predicate, SEND_RESPONSE_TIMEOUT)
460            .await
461    }
462
463    /// Sends a HID++ message across the channel and waits for a response,
464    /// bounding the whole request — the report write plus the wait for a
465    /// matching response — by `timeout`.
466    ///
467    /// On elapse the request's pending entry is removed (concurrent in-flight
468    /// requests are unaffected) and [`ChannelError::Timeout`] is returned; a
469    /// response that still arrives later reaches message listeners as an
470    /// unmatched message.
471    ///
472    /// [`Self::send`] uses this with [`SEND_RESPONSE_TIMEOUT`], which suits
473    /// requests to a device that may be asleep. Requests that should fail
474    /// faster — e.g. probing a receiver that answers immediately or not at
475    /// all — can pass a tighter budget.
476    pub async fn send_with_timeout(
477        &self,
478        msg: HidppMessage,
479        response_predicate: impl Fn(&HidppMessage) -> bool + Send + 'static,
480        timeout: Duration,
481    ) -> Result<HidppMessage, ChannelError> {
482        let msg = self.normalize_outgoing(msg);
483        if !self.supports_msg(&msg) {
484            return Err(ChannelError::MessageTypeNotSupported);
485        }
486
487        let (sender, receiver) = oneshot::channel::<HidppMessage>();
488        let pending_id = self.pending_message_id.fetch_add(1, Ordering::SeqCst);
489
490        {
491            let mut pending = self.pending_messages.lock().unwrap();
492            // Drop abandoned requests before queuing this one. Timeouts and
493            // write failures remove their entry eagerly below, but a caller
494            // cancelled mid-flight (an outer `timeout(..)` dropping the whole
495            // future) still leaves its `PendingMessage` behind. On a channel
496            // reused across inventory ticks those would accumulate unboundedly
497            // — and a late response could be mis-delivered to a recycled
498            // software id. `is_canceled()` is true once the receiver is gone,
499            // so this prunes exactly the give-ups.
500            pending.retain(|m| !m.sender.is_canceled());
501            pending.push_back(PendingMessage {
502                id: pending_id,
503                response_predicate: Box::new(response_predicate),
504                sender,
505            });
506        }
507
508        // The deadline covers the write as well: `write_report` has no
509        // bounded-time contract of its own, so a wedged device could otherwise
510        // park `send` forever before the response wait even starts.
511        let mut request = std::pin::pin!(
512            async {
513                self.send_and_forget(msg).await?;
514                receiver.await.map_err(|_| ChannelError::NoResponse)
515            }
516            .fuse()
517        );
518
519        let result = select! {
520            result = request => result,
521            _ = futures_timer::Delay::new(timeout).fuse() => Err(ChannelError::Timeout),
522        };
523
524        if result.is_err() {
525            // A timeout or write failure leaves the entry queued — remove it
526            // eagerly. After a matched response the read thread has already
527            // taken it, so this is a no-op then.
528            self.remove_pending_message(pending_id);
529        }
530
531        result
532    }
533
534    fn remove_pending_message(&self, id: u64) {
535        let mut pending = self.pending_messages.lock().unwrap();
536        if let Some(pos) = pending.iter().position(|msg| msg.id == id) {
537            pending.remove(pos);
538        }
539    }
540
541    /// Sends a HID++ message across the channel and does not wait for a
542    /// response.
543    ///
544    /// If a response is expected, use [`Self::send`],
545    pub async fn send_and_forget(&self, msg: HidppMessage) -> Result<(), ChannelError> {
546        let msg = self.normalize_outgoing(msg);
547        if !self.supports_msg(&msg) {
548            return Err(ChannelError::MessageTypeNotSupported);
549        }
550
551        let mut buf = [0u8; LONG_REPORT_LENGTH];
552        let len = msg.write_raw(&mut buf);
553        self.raw_channel
554            .write_report(&buf[..len])
555            .await
556            .map(|_| ())
557            .map_err(ChannelError::Implementation)
558    }
559
560    /// Registers a listener that will be called for every incoming message.
561    ///
562    /// Returns a handle that can be used to remove the listener using a call to
563    /// [`Self::remove_msg_listener`].
564    pub fn add_msg_listener(&self, listener: impl Fn(HidppMessage, bool) + Send + 'static) -> u32 {
565        let mut listeners = self.message_listeners.lock().unwrap();
566
567        let mut rng = rand::rng();
568        let mut hdl = rng.random::<u32>();
569        while listeners.contains_key(&hdl) {
570            hdl = rng.random::<u32>();
571        }
572
573        listeners.insert(hdl, Box::new(listener));
574        hdl
575    }
576
577    /// Removes a previously registered message listener.
578    ///
579    /// Returns whether a listener was found using the given handle.
580    pub fn remove_msg_listener(&self, hdl: u32) -> bool {
581        self.message_listeners
582            .lock()
583            .unwrap()
584            .remove(&hdl)
585            .is_some()
586    }
587}
588
589/// Represents an error that occurred when creating or interacting with a HID or
590/// HID++ communication channel.
591#[derive(Debug, Error)]
592#[non_exhaustive]
593pub enum ChannelError {
594    /// Indicates that the concrete implementation of [`RawHidChannel`] returned
595    /// an error.
596    #[error("the HID channel implementation returned an error")]
597    Implementation(#[from] Box<dyn Error + Sync + Send>),
598
599    /// Indicates that the HID report descriptor could not be parsed.
600    #[error("the report descriptor could not be parsed")]
601    ReportDescriptor(hidreport::ParserError),
602
603    /// Indicates that the channel in question does not support HID++.
604    #[error("the HID channel does not support HID++")]
605    HidppNotSupported,
606
607    /// Indicates that the HID++ channel does not support messages of the given
608    /// type (short/long).
609    #[error("the channel does not support the given HID++ message type")]
610    MessageTypeNotSupported,
611
612    /// Indicates that no response was received following a request.
613    #[error("the device did not respond to the request")]
614    NoResponse,
615
616    /// Indicates that a request did not complete within its time budget —
617    /// typically the device is asleep, out of range or connected to another
618    /// host. See [`HidppChannel::send_with_timeout`].
619    #[error("the request timed out before the device responded")]
620    Timeout,
621}
622
623/// Widen a short HID++ payload (6 bytes) to a long one (19 bytes): the HID++
624/// header bytes (device / feature / function|sw) sit at the same offsets in
625/// both widths, so the only change is zero-padding the trailing payload. Used
626/// to re-frame short messages as long on a long-only channel — see
627/// [`HidppChannel::normalize_outgoing`]. (OpenLogi local addition.)
628fn short_payload_as_long(payload: &[u8; SHORT_REPORT_LENGTH - 1]) -> [u8; LONG_REPORT_LENGTH - 1] {
629    let mut long = [0u8; LONG_REPORT_LENGTH - 1];
630    long[..payload.len()].copy_from_slice(payload);
631    long
632}
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637    use std::{
638        io,
639        sync::{Arc, Mutex},
640        time::{Duration, Instant},
641    };
642
643    #[test]
644    fn short_payload_widens_preserving_header_and_padding() {
645        // [device, feature, function|sw, p0, p1, p2]
646        let short = [0xff, 0x05, 0x1e, 0xaa, 0xbb, 0xcc];
647        let long = short_payload_as_long(&short);
648        assert_eq!(&long[..short.len()], &short[..]); // header + payload copied verbatim
649        assert!(long[short.len()..].iter().all(|&b| b == 0)); // remainder zero-padded
650        assert_eq!(long.len(), LONG_REPORT_LENGTH - 1);
651    }
652
653    #[test]
654    fn send_returns_response_before_timeout() {
655        futures::executor::block_on(async {
656            let (raw, handle) = MockRawHidChannel::new();
657            let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
658
659            let request = short_msg(0x10);
660            let response = short_msg(0x20);
661            handle.queue_response(response);
662
663            let actual = channel
664                .send_with_timeout(
665                    request,
666                    move |candidate| *candidate == response,
667                    Duration::from_secs(1),
668                )
669                .await
670                .unwrap();
671
672            assert_eq!(actual, response);
673            assert_eq!(handle.written_reports().len(), 1);
674            assert_pending_empty(&channel);
675        });
676    }
677
678    #[test]
679    fn send_times_out_and_removes_pending_message() {
680        futures::executor::block_on(async {
681            let (raw, handle) = MockRawHidChannel::new();
682            let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
683            let request = short_msg(0x10);
684            let response = short_msg(0x20);
685
686            let started = Instant::now();
687            let err = channel
688                .send_with_timeout(
689                    request,
690                    move |candidate| *candidate == response,
691                    Duration::from_millis(25),
692                )
693                .await
694                .unwrap_err();
695
696            assert!(matches!(err, ChannelError::Timeout));
697            assert!(started.elapsed() < Duration::from_secs(1));
698            assert_eq!(handle.written_reports().len(), 1);
699            assert_pending_empty(&channel);
700        });
701    }
702
703    #[test]
704    fn timeout_removes_only_its_own_pending_message() {
705        futures::executor::block_on(async {
706            let (raw, handle) = MockRawHidChannel::new();
707            let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
708
709            let never_answered = short_msg(0x20);
710            let slow_response = short_msg(0x21);
711
712            let timed_out = channel.send_with_timeout(
713                short_msg(0x10),
714                move |candidate| *candidate == never_answered,
715                Duration::from_millis(25),
716            );
717            let answered = channel.send_with_timeout(
718                short_msg(0x11),
719                move |candidate| *candidate == slow_response,
720                Duration::from_secs(1),
721            );
722            // Answer the second request only after the first has timed out, so
723            // a removal that took the wrong entry would fail this test.
724            let respond_late = async {
725                futures_timer::Delay::new(Duration::from_millis(100)).await;
726                handle.send_incoming(slow_response).await;
727            };
728
729            let (timed_out, answered, ()) = futures::join!(timed_out, answered, respond_late);
730
731            assert!(matches!(timed_out.unwrap_err(), ChannelError::Timeout));
732            assert_eq!(answered.unwrap(), slow_response);
733            assert_pending_empty(&channel);
734        });
735    }
736
737    #[test]
738    fn late_response_after_timeout_is_ignored() {
739        futures::executor::block_on(async {
740            let (raw, handle) = MockRawHidChannel::new();
741            let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
742            let events = Arc::new(Mutex::new(Vec::new()));
743            let listener_events = Arc::clone(&events);
744            channel.add_msg_listener(move |msg, matched| {
745                listener_events.lock().unwrap().push((msg, matched));
746            });
747
748            let request = short_msg(0x10);
749            let late_response = short_msg(0x20);
750            let err = channel
751                .send_with_timeout(
752                    request,
753                    move |candidate| *candidate == late_response,
754                    Duration::from_millis(25),
755                )
756                .await
757                .unwrap_err();
758
759            assert!(matches!(err, ChannelError::Timeout));
760            assert_pending_empty(&channel);
761
762            handle.send_incoming(late_response).await;
763            wait_for_event_count(&events, 1).await;
764            assert_eq!(events.lock().unwrap()[0], (late_response, false));
765            assert_pending_empty(&channel);
766
767            let later_request = short_msg(0x30);
768            let later_response = short_msg(0x40);
769            handle.queue_response(later_response);
770            let actual = channel
771                .send_with_timeout(
772                    later_request,
773                    move |candidate| *candidate == later_response,
774                    Duration::from_secs(1),
775                )
776                .await
777                .unwrap();
778
779            assert_eq!(actual, later_response);
780            wait_for_event_count(&events, 2).await;
781            assert_eq!(events.lock().unwrap()[1], (later_response, true));
782            assert_pending_empty(&channel);
783        });
784    }
785
786    #[test]
787    fn send_and_forget_writes_without_pending_message() {
788        futures::executor::block_on(async {
789            let (raw, handle) = MockRawHidChannel::new();
790            let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
791
792            channel.send_and_forget(short_msg(0x10)).await.unwrap();
793
794            assert_eq!(handle.written_reports().len(), 1);
795            assert_pending_empty(&channel);
796        });
797    }
798
799    #[derive(Clone)]
800    struct MockRawHidHandle {
801        incoming_tx: async_channel::Sender<Vec<u8>>,
802        written_reports: Arc<Mutex<Vec<Vec<u8>>>>,
803        responses_on_write: Arc<Mutex<VecDeque<Vec<u8>>>>,
804    }
805
806    impl MockRawHidHandle {
807        fn queue_response(&self, msg: HidppMessage) {
808            self.responses_on_write
809                .lock()
810                .unwrap()
811                .push_back(raw_report(msg));
812        }
813
814        async fn send_incoming(&self, msg: HidppMessage) {
815            self.incoming_tx.send(raw_report(msg)).await.unwrap();
816        }
817
818        fn written_reports(&self) -> Vec<Vec<u8>> {
819            self.written_reports.lock().unwrap().clone()
820        }
821    }
822
823    struct MockRawHidChannel {
824        incoming_tx: async_channel::Sender<Vec<u8>>,
825        incoming_rx: async_channel::Receiver<Vec<u8>>,
826        written_reports: Arc<Mutex<Vec<Vec<u8>>>>,
827        responses_on_write: Arc<Mutex<VecDeque<Vec<u8>>>>,
828    }
829
830    impl MockRawHidChannel {
831        fn new() -> (Self, MockRawHidHandle) {
832            let (incoming_tx, incoming_rx) = async_channel::unbounded();
833            let written_reports = Arc::new(Mutex::new(Vec::new()));
834            let responses_on_write = Arc::new(Mutex::new(VecDeque::new()));
835
836            let handle = MockRawHidHandle {
837                incoming_tx: incoming_tx.clone(),
838                written_reports: Arc::clone(&written_reports),
839                responses_on_write: Arc::clone(&responses_on_write),
840            };
841
842            (
843                Self {
844                    incoming_tx,
845                    incoming_rx,
846                    written_reports,
847                    responses_on_write,
848                },
849                handle,
850            )
851        }
852    }
853
854    #[async_trait]
855    impl RawHidChannel for MockRawHidChannel {
856        fn vendor_id(&self) -> u16 {
857            0x046d
858        }
859
860        fn product_id(&self) -> u16 {
861            0xc539
862        }
863
864        async fn write_report(&self, src: &[u8]) -> Result<usize, Box<dyn Error + Sync + Send>> {
865            self.written_reports.lock().unwrap().push(src.to_vec());
866            let response = self.responses_on_write.lock().unwrap().pop_front();
867            if let Some(response) = response {
868                self.incoming_tx.send(response).await.unwrap();
869            }
870
871            Ok(src.len())
872        }
873
874        async fn read_report(&self, buf: &mut [u8]) -> Result<usize, Box<dyn Error + Sync + Send>> {
875            let report = self.incoming_rx.recv().await.map_err(|_| mock_error())?;
876            let len = report.len().min(buf.len());
877            buf[..len].copy_from_slice(&report[..len]);
878            Ok(len)
879        }
880
881        fn supports_short_long_hidpp(&self) -> Option<(bool, bool)> {
882            Some((true, true))
883        }
884
885        async fn get_report_descriptor(
886            &self,
887            _buf: &mut [u8],
888        ) -> Result<usize, Box<dyn Error + Sync + Send>> {
889            unreachable!("mock declares HID++ support")
890        }
891    }
892
893    fn short_msg(marker: u8) -> HidppMessage {
894        HidppMessage::Short([0xff, marker, 0x10, marker, marker, marker])
895    }
896
897    fn raw_report(msg: HidppMessage) -> Vec<u8> {
898        let mut buf = [0u8; LONG_REPORT_LENGTH];
899        let len = msg.write_raw(&mut buf);
900        buf[..len].to_vec()
901    }
902
903    fn assert_pending_empty(channel: &HidppChannel) {
904        assert!(channel.pending_messages.lock().unwrap().is_empty());
905    }
906
907    async fn wait_for_event_count(events: &Arc<Mutex<Vec<(HidppMessage, bool)>>>, count: usize) {
908        let started = Instant::now();
909        while started.elapsed() < Duration::from_secs(1) {
910            if events.lock().unwrap().len() >= count {
911                return;
912            }
913            futures_timer::Delay::new(Duration::from_millis(10)).await;
914        }
915
916        panic!("timed out waiting for {count} listener events");
917    }
918
919    fn mock_error() -> Box<dyn Error + Sync + Send> {
920        Box::new(io::Error::new(
921            io::ErrorKind::BrokenPipe,
922            "mock channel closed",
923        ))
924    }
925}