use std::{
collections::{HashMap, VecDeque},
error::Error,
sync::{
Arc, Mutex,
atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
},
thread::{self, JoinHandle},
time::Duration,
};
use async_trait::async_trait;
use futures::{FutureExt, channel::oneshot, select};
use hidreport::{Field, Report, ReportDescriptor, Usage, UsageId, UsagePage};
use rand::Rng;
use thiserror::Error;
use crate::nibble::U4;
const MAX_REPORT_DESCRIPTOR_LENGTH: usize = 4096;
const MAX_REPORT_LENGTH: usize = LONG_REPORT_LENGTH;
pub const SEND_RESPONSE_TIMEOUT: Duration = Duration::from_secs(5);
pub const SHORT_REPORT_ID: u8 = 0x10;
pub const SHORT_REPORT_USAGE_PAGE: u16 = 0xff00;
pub const SHORT_REPORT_USAGE: u16 = 0x0001;
pub const SHORT_REPORT_LENGTH: usize = 7;
pub const LONG_REPORT_ID: u8 = 0x11;
pub const LONG_REPORT_USAGE_PAGE: u16 = 0xff00;
pub const LONG_REPORT_USAGE: u16 = 0x0002;
pub const LONG_REPORT_LENGTH: usize = 20;
#[async_trait]
pub trait RawHidChannel: Sync + Send + 'static {
fn vendor_id(&self) -> u16;
fn product_id(&self) -> u16;
async fn write_report(&self, src: &[u8]) -> Result<usize, Box<dyn Error + Sync + Send>>;
async fn read_report(&self, buf: &mut [u8]) -> Result<usize, Box<dyn Error + Sync + Send>>;
fn supports_short_long_hidpp(&self) -> Option<(bool, bool)>;
async fn get_report_descriptor(
&self,
buf: &mut [u8],
) -> Result<usize, Box<dyn Error + Sync + Send>>;
}
async fn supports_short_long_hidpp(
chan: &impl RawHidChannel,
) -> Result<(bool, bool), ChannelError> {
if let Some((supports_short, supports_long)) = chan.supports_short_long_hidpp() {
return Ok((supports_short, supports_long));
}
let mut raw_descriptor = vec![0u8; MAX_REPORT_DESCRIPTOR_LENGTH];
let descriptor_size = chan.get_report_descriptor(&mut raw_descriptor).await?;
let descriptor = match ReportDescriptor::try_from(&raw_descriptor[..descriptor_size]) {
Ok(val) => val,
Err(err) => return Err(ChannelError::ReportDescriptor(err)),
};
let supports_short = descriptor
.find_input_report(&[SHORT_REPORT_ID])
.and_then(|report| report.fields().first())
.and_then(|field| match field {
Field::Array(arr) => Some(arr.usage_range()),
_ => None,
})
.is_some_and(|range| {
range
.lookup_usage(&Usage::from_page_and_id(
UsagePage::from(SHORT_REPORT_USAGE_PAGE),
UsageId::from(SHORT_REPORT_USAGE),
))
.is_some()
});
let supports_long = descriptor
.find_input_report(&[LONG_REPORT_ID])
.and_then(|report| report.fields().first())
.and_then(|field| match field {
Field::Array(arr) => Some(arr.usage_range()),
_ => None,
})
.is_some_and(|range| {
range
.lookup_usage(&Usage::from_page_and_id(
UsagePage::from(LONG_REPORT_USAGE_PAGE),
UsageId::from(LONG_REPORT_USAGE),
))
.is_some()
});
Ok((supports_short, supports_long))
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub enum HidppMessage {
Short([u8; SHORT_REPORT_LENGTH - 1]),
Long([u8; LONG_REPORT_LENGTH - 1]),
}
impl HidppMessage {
pub fn read_raw(data: &[u8]) -> Option<Self> {
if data.is_empty() {
return None;
}
if data[0] == SHORT_REPORT_ID {
if data.len() != SHORT_REPORT_LENGTH {
return None;
}
return Some(HidppMessage::Short(data[1..].try_into().unwrap()));
} else if data[0] == LONG_REPORT_ID {
if data.len() != LONG_REPORT_LENGTH {
return None;
}
return Some(HidppMessage::Long(data[1..].try_into().unwrap()));
}
None
}
pub fn write_raw(&self, buf: &mut [u8]) -> usize {
match self {
Self::Short(payload) => {
buf[0] = SHORT_REPORT_ID;
buf[1..SHORT_REPORT_LENGTH].copy_from_slice(payload);
SHORT_REPORT_LENGTH
}
Self::Long(payload) => {
buf[0] = LONG_REPORT_ID;
buf[1..LONG_REPORT_LENGTH].copy_from_slice(payload);
LONG_REPORT_LENGTH
}
}
}
}
type MessageListener = Box<dyn Fn(HidppMessage, bool) + Send>;
pub struct HidppChannel {
pub supports_short: bool,
pub supports_long: bool,
pub vendor_id: u16,
pub product_id: u16,
raw_channel: Arc<dyn RawHidChannel>,
rotate_software_id: AtomicBool,
software_id: AtomicU8,
pending_messages: Arc<Mutex<VecDeque<PendingMessage>>>,
pending_message_id: AtomicU64,
message_listeners: Arc<Mutex<HashMap<u32, MessageListener>>>,
read_thread_close: Option<oneshot::Sender<()>>,
read_thread_hdl: Option<JoinHandle<()>>,
}
impl Drop for HidppChannel {
fn drop(&mut self) {
if let Some(read_thread_close) = self.read_thread_close.take() {
let _ = read_thread_close.send(());
}
if let Some(read_thread_hdl) = self.read_thread_hdl.take() {
read_thread_hdl.join().unwrap();
}
}
}
struct PendingMessage {
id: u64,
response_predicate: Box<dyn Fn(&HidppMessage) -> bool + Send>,
sender: oneshot::Sender<HidppMessage>,
}
impl HidppChannel {
pub async fn from_raw_channel(raw: impl RawHidChannel) -> Result<Self, ChannelError> {
let (supports_short, supports_long) = supports_short_long_hidpp(&raw).await?;
if !supports_short && !supports_long {
return Err(ChannelError::HidppNotSupported);
}
let raw_channel_rc = Arc::new(raw);
let pending_messages_rc = Arc::new(Mutex::new(VecDeque::<PendingMessage>::new()));
let message_listeners_rc = Arc::new(Mutex::new(HashMap::<u32, MessageListener>::new()));
let (close_sender, mut close_receiver) = oneshot::channel::<()>();
let read_thread_hdl = thread::spawn({
let raw_channel = Arc::clone(&raw_channel_rc);
let pending_messages = Arc::clone(&pending_messages_rc);
let message_listeners = Arc::clone(&message_listeners_rc);
move || {
futures::executor::block_on(async {
let mut buf = [0u8; MAX_REPORT_LENGTH];
loop {
let res = select! {
_ = close_receiver => {
break;
},
res = raw_channel.read_report(&mut buf).fuse() => res
};
let Ok(len) = res else {
continue;
};
let Some(msg) = HidppMessage::read_raw(&buf[..len]) else {
continue;
};
let mut msgs = pending_messages.lock().unwrap();
let mut matched = false;
if let Some(pos) =
msgs.iter().position(|elem| (elem.response_predicate)(&msg))
{
let waiting = msgs.remove(pos).unwrap();
let _ = waiting.sender.send(msg);
matched = true;
}
for listener in message_listeners.lock().unwrap().values() {
listener(msg, matched);
}
}
});
}
});
Ok(Self {
supports_short,
supports_long,
vendor_id: raw_channel_rc.vendor_id(),
product_id: raw_channel_rc.product_id(),
raw_channel: raw_channel_rc,
rotate_software_id: AtomicBool::new(false),
software_id: AtomicU8::new(0x01),
pending_messages: pending_messages_rc,
pending_message_id: AtomicU64::new(1),
message_listeners: message_listeners_rc,
read_thread_close: Some(close_sender),
read_thread_hdl: Some(read_thread_hdl),
})
}
pub fn set_sw_id(&self, sw_id: U4) {
self.software_id.store(sw_id.to_lo(), Ordering::SeqCst);
}
pub fn set_rotating_sw_id(&self, enable: bool) {
self.rotate_software_id.store(enable, Ordering::SeqCst);
}
pub fn get_sw_id(&self) -> U4 {
if self.rotate_software_id.load(Ordering::SeqCst) {
U4::from_lo(
self.software_id
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |old| {
Some(if old & 0x0f == 0x0f {
0x01
} else {
old.wrapping_add(1)
})
})
.unwrap(),
)
} else {
U4::from_lo(self.software_id.load(Ordering::SeqCst))
}
}
pub fn supports_msg(&self, msg: &HidppMessage) -> bool {
match msg {
HidppMessage::Short(_) => self.supports_short,
HidppMessage::Long(_) => self.supports_long,
}
}
fn normalize_outgoing(&self, msg: HidppMessage) -> HidppMessage {
match msg {
HidppMessage::Short(payload) if !self.supports_short && self.supports_long => {
HidppMessage::Long(short_payload_as_long(&payload))
}
other => other,
}
}
pub async fn send(
&self,
msg: HidppMessage,
response_predicate: impl Fn(&HidppMessage) -> bool + Send + 'static,
) -> Result<HidppMessage, ChannelError> {
self.send_with_timeout(msg, response_predicate, SEND_RESPONSE_TIMEOUT)
.await
}
pub async fn send_with_timeout(
&self,
msg: HidppMessage,
response_predicate: impl Fn(&HidppMessage) -> bool + Send + 'static,
timeout: Duration,
) -> Result<HidppMessage, ChannelError> {
let msg = self.normalize_outgoing(msg);
if !self.supports_msg(&msg) {
return Err(ChannelError::MessageTypeNotSupported);
}
let (sender, receiver) = oneshot::channel::<HidppMessage>();
let pending_id = self.pending_message_id.fetch_add(1, Ordering::SeqCst);
{
let mut pending = self.pending_messages.lock().unwrap();
pending.retain(|m| !m.sender.is_canceled());
pending.push_back(PendingMessage {
id: pending_id,
response_predicate: Box::new(response_predicate),
sender,
});
}
let mut request = std::pin::pin!(
async {
self.send_and_forget(msg).await?;
receiver.await.map_err(|_| ChannelError::NoResponse)
}
.fuse()
);
let result = select! {
result = request => result,
_ = futures_timer::Delay::new(timeout).fuse() => Err(ChannelError::Timeout),
};
if result.is_err() {
self.remove_pending_message(pending_id);
}
result
}
fn remove_pending_message(&self, id: u64) {
let mut pending = self.pending_messages.lock().unwrap();
if let Some(pos) = pending.iter().position(|msg| msg.id == id) {
pending.remove(pos);
}
}
pub async fn send_and_forget(&self, msg: HidppMessage) -> Result<(), ChannelError> {
let msg = self.normalize_outgoing(msg);
if !self.supports_msg(&msg) {
return Err(ChannelError::MessageTypeNotSupported);
}
let mut buf = [0u8; LONG_REPORT_LENGTH];
let len = msg.write_raw(&mut buf);
self.raw_channel
.write_report(&buf[..len])
.await
.map(|_| ())
.map_err(ChannelError::Implementation)
}
pub fn add_msg_listener(&self, listener: impl Fn(HidppMessage, bool) + Send + 'static) -> u32 {
let mut listeners = self.message_listeners.lock().unwrap();
let mut rng = rand::rng();
let mut hdl = rng.random::<u32>();
while listeners.contains_key(&hdl) {
hdl = rng.random::<u32>();
}
listeners.insert(hdl, Box::new(listener));
hdl
}
pub fn remove_msg_listener(&self, hdl: u32) -> bool {
self.message_listeners
.lock()
.unwrap()
.remove(&hdl)
.is_some()
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum ChannelError {
#[error("the HID channel implementation returned an error")]
Implementation(#[from] Box<dyn Error + Sync + Send>),
#[error("the report descriptor could not be parsed")]
ReportDescriptor(hidreport::ParserError),
#[error("the HID channel does not support HID++")]
HidppNotSupported,
#[error("the channel does not support the given HID++ message type")]
MessageTypeNotSupported,
#[error("the device did not respond to the request")]
NoResponse,
#[error("the request timed out before the device responded")]
Timeout,
}
fn short_payload_as_long(payload: &[u8; SHORT_REPORT_LENGTH - 1]) -> [u8; LONG_REPORT_LENGTH - 1] {
let mut long = [0u8; LONG_REPORT_LENGTH - 1];
long[..payload.len()].copy_from_slice(payload);
long
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
io,
sync::{Arc, Mutex},
time::{Duration, Instant},
};
#[test]
fn short_payload_widens_preserving_header_and_padding() {
let short = [0xff, 0x05, 0x1e, 0xaa, 0xbb, 0xcc];
let long = short_payload_as_long(&short);
assert_eq!(&long[..short.len()], &short[..]); assert!(long[short.len()..].iter().all(|&b| b == 0)); assert_eq!(long.len(), LONG_REPORT_LENGTH - 1);
}
#[test]
fn send_returns_response_before_timeout() {
futures::executor::block_on(async {
let (raw, handle) = MockRawHidChannel::new();
let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
let request = short_msg(0x10);
let response = short_msg(0x20);
handle.queue_response(response);
let actual = channel
.send_with_timeout(
request,
move |candidate| *candidate == response,
Duration::from_secs(1),
)
.await
.unwrap();
assert_eq!(actual, response);
assert_eq!(handle.written_reports().len(), 1);
assert_pending_empty(&channel);
});
}
#[test]
fn send_times_out_and_removes_pending_message() {
futures::executor::block_on(async {
let (raw, handle) = MockRawHidChannel::new();
let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
let request = short_msg(0x10);
let response = short_msg(0x20);
let started = Instant::now();
let err = channel
.send_with_timeout(
request,
move |candidate| *candidate == response,
Duration::from_millis(25),
)
.await
.unwrap_err();
assert!(matches!(err, ChannelError::Timeout));
assert!(started.elapsed() < Duration::from_secs(1));
assert_eq!(handle.written_reports().len(), 1);
assert_pending_empty(&channel);
});
}
#[test]
fn timeout_removes_only_its_own_pending_message() {
futures::executor::block_on(async {
let (raw, handle) = MockRawHidChannel::new();
let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
let never_answered = short_msg(0x20);
let slow_response = short_msg(0x21);
let timed_out = channel.send_with_timeout(
short_msg(0x10),
move |candidate| *candidate == never_answered,
Duration::from_millis(25),
);
let answered = channel.send_with_timeout(
short_msg(0x11),
move |candidate| *candidate == slow_response,
Duration::from_secs(1),
);
let respond_late = async {
futures_timer::Delay::new(Duration::from_millis(100)).await;
handle.send_incoming(slow_response).await;
};
let (timed_out, answered, ()) = futures::join!(timed_out, answered, respond_late);
assert!(matches!(timed_out.unwrap_err(), ChannelError::Timeout));
assert_eq!(answered.unwrap(), slow_response);
assert_pending_empty(&channel);
});
}
#[test]
fn late_response_after_timeout_is_ignored() {
futures::executor::block_on(async {
let (raw, handle) = MockRawHidChannel::new();
let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
let events = Arc::new(Mutex::new(Vec::new()));
let listener_events = Arc::clone(&events);
channel.add_msg_listener(move |msg, matched| {
listener_events.lock().unwrap().push((msg, matched));
});
let request = short_msg(0x10);
let late_response = short_msg(0x20);
let err = channel
.send_with_timeout(
request,
move |candidate| *candidate == late_response,
Duration::from_millis(25),
)
.await
.unwrap_err();
assert!(matches!(err, ChannelError::Timeout));
assert_pending_empty(&channel);
handle.send_incoming(late_response).await;
wait_for_event_count(&events, 1).await;
assert_eq!(events.lock().unwrap()[0], (late_response, false));
assert_pending_empty(&channel);
let later_request = short_msg(0x30);
let later_response = short_msg(0x40);
handle.queue_response(later_response);
let actual = channel
.send_with_timeout(
later_request,
move |candidate| *candidate == later_response,
Duration::from_secs(1),
)
.await
.unwrap();
assert_eq!(actual, later_response);
wait_for_event_count(&events, 2).await;
assert_eq!(events.lock().unwrap()[1], (later_response, true));
assert_pending_empty(&channel);
});
}
#[test]
fn send_and_forget_writes_without_pending_message() {
futures::executor::block_on(async {
let (raw, handle) = MockRawHidChannel::new();
let channel = HidppChannel::from_raw_channel(raw).await.unwrap();
channel.send_and_forget(short_msg(0x10)).await.unwrap();
assert_eq!(handle.written_reports().len(), 1);
assert_pending_empty(&channel);
});
}
#[derive(Clone)]
struct MockRawHidHandle {
incoming_tx: async_channel::Sender<Vec<u8>>,
written_reports: Arc<Mutex<Vec<Vec<u8>>>>,
responses_on_write: Arc<Mutex<VecDeque<Vec<u8>>>>,
}
impl MockRawHidHandle {
fn queue_response(&self, msg: HidppMessage) {
self.responses_on_write
.lock()
.unwrap()
.push_back(raw_report(msg));
}
async fn send_incoming(&self, msg: HidppMessage) {
self.incoming_tx.send(raw_report(msg)).await.unwrap();
}
fn written_reports(&self) -> Vec<Vec<u8>> {
self.written_reports.lock().unwrap().clone()
}
}
struct MockRawHidChannel {
incoming_tx: async_channel::Sender<Vec<u8>>,
incoming_rx: async_channel::Receiver<Vec<u8>>,
written_reports: Arc<Mutex<Vec<Vec<u8>>>>,
responses_on_write: Arc<Mutex<VecDeque<Vec<u8>>>>,
}
impl MockRawHidChannel {
fn new() -> (Self, MockRawHidHandle) {
let (incoming_tx, incoming_rx) = async_channel::unbounded();
let written_reports = Arc::new(Mutex::new(Vec::new()));
let responses_on_write = Arc::new(Mutex::new(VecDeque::new()));
let handle = MockRawHidHandle {
incoming_tx: incoming_tx.clone(),
written_reports: Arc::clone(&written_reports),
responses_on_write: Arc::clone(&responses_on_write),
};
(
Self {
incoming_tx,
incoming_rx,
written_reports,
responses_on_write,
},
handle,
)
}
}
#[async_trait]
impl RawHidChannel for MockRawHidChannel {
fn vendor_id(&self) -> u16 {
0x046d
}
fn product_id(&self) -> u16 {
0xc539
}
async fn write_report(&self, src: &[u8]) -> Result<usize, Box<dyn Error + Sync + Send>> {
self.written_reports.lock().unwrap().push(src.to_vec());
let response = self.responses_on_write.lock().unwrap().pop_front();
if let Some(response) = response {
self.incoming_tx.send(response).await.unwrap();
}
Ok(src.len())
}
async fn read_report(&self, buf: &mut [u8]) -> Result<usize, Box<dyn Error + Sync + Send>> {
let report = self.incoming_rx.recv().await.map_err(|_| mock_error())?;
let len = report.len().min(buf.len());
buf[..len].copy_from_slice(&report[..len]);
Ok(len)
}
fn supports_short_long_hidpp(&self) -> Option<(bool, bool)> {
Some((true, true))
}
async fn get_report_descriptor(
&self,
_buf: &mut [u8],
) -> Result<usize, Box<dyn Error + Sync + Send>> {
unreachable!("mock declares HID++ support")
}
}
fn short_msg(marker: u8) -> HidppMessage {
HidppMessage::Short([0xff, marker, 0x10, marker, marker, marker])
}
fn raw_report(msg: HidppMessage) -> Vec<u8> {
let mut buf = [0u8; LONG_REPORT_LENGTH];
let len = msg.write_raw(&mut buf);
buf[..len].to_vec()
}
fn assert_pending_empty(channel: &HidppChannel) {
assert!(channel.pending_messages.lock().unwrap().is_empty());
}
async fn wait_for_event_count(events: &Arc<Mutex<Vec<(HidppMessage, bool)>>>, count: usize) {
let started = Instant::now();
while started.elapsed() < Duration::from_secs(1) {
if events.lock().unwrap().len() >= count {
return;
}
futures_timer::Delay::new(Duration::from_millis(10)).await;
}
panic!("timed out waiting for {count} listener events");
}
fn mock_error() -> Box<dyn Error + Sync + Send> {
Box::new(io::Error::new(
io::ErrorKind::BrokenPipe,
"mock channel closed",
))
}
}