use std::{
collections::{HashMap, VecDeque},
error::Error,
sync::{
Arc,
Mutex,
atomic::{AtomicBool, AtomicU8, Ordering},
},
thread::{self, JoinHandle},
};
use async_trait::async_trait;
use futures::{FutureExt, channel::oneshot, select};
use hidreport::{Field, Report, ReportDescriptor, Usage, UsageId, UsagePage};
use rand::Rng;
use thiserror::Error;
use crate::nibble::U4;
const MAX_REPORT_DESCRIPTOR_LENGTH: usize = 4096;
const MAX_REPORT_LENGTH: usize = LONG_REPORT_LENGTH;
pub const SHORT_REPORT_ID: u8 = 0x10;
pub const SHORT_REPORT_USAGE_PAGE: u16 = 0xff00;
pub const SHORT_REPORT_USAGE: u16 = 0x0001;
pub const SHORT_REPORT_LENGTH: usize = 7;
pub const LONG_REPORT_ID: u8 = 0x11;
pub const LONG_REPORT_USAGE_PAGE: u16 = 0xff00;
pub const LONG_REPORT_USAGE: u16 = 0x0002;
pub const LONG_REPORT_LENGTH: usize = 20;
#[async_trait]
pub trait RawHidChannel: Sync + Send + 'static {
fn vendor_id(&self) -> u16;
fn product_id(&self) -> u16;
async fn write_report(&self, src: &[u8]) -> Result<usize, Box<dyn Error>>;
async fn read_report(&self, buf: &mut [u8]) -> Result<usize, Box<dyn Error>>;
fn supports_short_long_hidpp(&self) -> Option<(bool, bool)>;
async fn get_report_descriptor(&self, buf: &mut [u8]) -> Result<usize, Box<dyn Error>>;
}
async fn supports_short_long_hidpp(
chan: &impl RawHidChannel,
) -> Result<(bool, bool), ChannelError> {
if let Some((supports_short, supports_long)) = chan.supports_short_long_hidpp() {
return Ok((supports_short, supports_long));
}
let mut raw_descriptor = vec![0u8; MAX_REPORT_DESCRIPTOR_LENGTH];
let descriptor_size = chan.get_report_descriptor(&mut raw_descriptor).await?;
let descriptor = match ReportDescriptor::try_from(&raw_descriptor[..descriptor_size]) {
Ok(val) => val,
Err(err) => return Err(ChannelError::ReportDescriptor(err)),
};
let supports_short = descriptor
.find_input_report(&[SHORT_REPORT_ID])
.and_then(|report| report.fields().first())
.and_then(|field| match field {
Field::Array(arr) => Some(arr.usage_range()),
_ => None,
})
.is_some_and(|range| {
range
.lookup_usage(&Usage::from_page_and_id(
UsagePage::from(SHORT_REPORT_USAGE_PAGE),
UsageId::from(SHORT_REPORT_USAGE),
))
.is_some()
});
let supports_long = descriptor
.find_input_report(&[LONG_REPORT_ID])
.and_then(|report| report.fields().first())
.and_then(|field| match field {
Field::Array(arr) => Some(arr.usage_range()),
_ => None,
})
.is_some_and(|range| {
range
.lookup_usage(&Usage::from_page_and_id(
UsagePage::from(LONG_REPORT_USAGE_PAGE),
UsageId::from(LONG_REPORT_USAGE),
))
.is_some()
});
Ok((supports_short, supports_long))
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub enum HidppMessage {
Short([u8; SHORT_REPORT_LENGTH - 1]),
Long([u8; LONG_REPORT_LENGTH - 1]),
}
impl HidppMessage {
pub fn read_raw(data: &[u8]) -> Option<Self> {
if data.is_empty() {
return None;
}
if data[0] == SHORT_REPORT_ID {
if data.len() != SHORT_REPORT_LENGTH {
return None;
}
return Some(HidppMessage::Short(data[1..].try_into().unwrap()));
} else if data[0] == LONG_REPORT_ID {
if data.len() != LONG_REPORT_LENGTH {
return None;
}
return Some(HidppMessage::Long(data[1..].try_into().unwrap()));
}
None
}
pub fn write_raw(&self, buf: &mut [u8]) -> usize {
match self {
Self::Short(payload) => {
buf[0] = SHORT_REPORT_ID;
buf[1..SHORT_REPORT_LENGTH].copy_from_slice(payload);
SHORT_REPORT_LENGTH
},
Self::Long(payload) => {
buf[0] = LONG_REPORT_ID;
buf[1..LONG_REPORT_LENGTH].copy_from_slice(payload);
LONG_REPORT_LENGTH
},
}
}
}
type MessageListener = Box<dyn Fn(HidppMessage, bool) + Send>;
pub struct HidppChannel {
pub supports_short: bool,
pub supports_long: bool,
pub vendor_id: u16,
pub product_id: u16,
raw_channel: Arc<dyn RawHidChannel>,
rotate_software_id: AtomicBool,
software_id: AtomicU8,
pending_messages: Arc<Mutex<VecDeque<PendingMessage>>>,
message_listeners: Arc<Mutex<HashMap<u32, MessageListener>>>,
read_thread_close: Option<oneshot::Sender<()>>,
read_thread_hdl: Option<JoinHandle<()>>,
}
impl Drop for HidppChannel {
fn drop(&mut self) {
if let Some(read_thread_close) = self.read_thread_close.take() {
let _ = read_thread_close.send(());
}
if let Some(read_thread_hdl) = self.read_thread_hdl.take() {
read_thread_hdl.join().unwrap();
}
}
}
struct PendingMessage {
response_predicate: Box<dyn Fn(&HidppMessage) -> bool + Send>,
sender: oneshot::Sender<HidppMessage>,
}
impl HidppChannel {
pub async fn from_raw_channel(raw: impl RawHidChannel) -> Result<Self, ChannelError> {
let (supports_short, supports_long) = supports_short_long_hidpp(&raw).await?;
if !supports_short && !supports_long {
return Err(ChannelError::HidppNotSupported);
}
let raw_channel_rc = Arc::new(raw);
let pending_messages_rc = Arc::new(Mutex::new(VecDeque::<PendingMessage>::new()));
let message_listeners_rc = Arc::new(Mutex::new(HashMap::<u32, MessageListener>::new()));
let (close_sender, mut close_receiver) = oneshot::channel::<()>();
let read_thread_hdl = thread::spawn({
let raw_channel = Arc::clone(&raw_channel_rc);
let pending_messages = Arc::clone(&pending_messages_rc);
let message_listeners = Arc::clone(&message_listeners_rc);
move || {
futures::executor::block_on(async {
let mut buf = [0u8; MAX_REPORT_LENGTH];
loop {
let res = select! {
_ = close_receiver => {
break;
},
res = raw_channel.read_report(&mut buf).fuse() => res
};
let Ok(len) = res else {
continue;
};
let Some(msg) = HidppMessage::read_raw(&buf[..len]) else {
continue;
};
let mut msgs = pending_messages.lock().unwrap();
let mut matched = false;
if let Some(pos) =
msgs.iter().position(|elem| (elem.response_predicate)(&msg))
{
let waiting = msgs.remove(pos).unwrap();
let _ = waiting.sender.send(msg);
matched = true;
}
for listener in message_listeners.lock().unwrap().values() {
listener(msg, matched);
}
}
});
}
});
Ok(Self {
supports_short,
supports_long,
vendor_id: raw_channel_rc.vendor_id(),
product_id: raw_channel_rc.product_id(),
raw_channel: raw_channel_rc,
rotate_software_id: AtomicBool::new(false),
software_id: AtomicU8::new(0x01),
pending_messages: pending_messages_rc,
message_listeners: message_listeners_rc,
read_thread_close: Some(close_sender),
read_thread_hdl: Some(read_thread_hdl),
})
}
pub fn set_sw_id(&self, sw_id: U4) {
self.software_id.store(sw_id.to_lo(), Ordering::SeqCst);
}
pub fn set_rotating_sw_id(&self, enable: bool) {
self.rotate_software_id.store(enable, Ordering::SeqCst);
}
pub fn get_sw_id(&self) -> U4 {
if self.rotate_software_id.load(Ordering::SeqCst) {
U4::from_lo(
self.software_id
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |old| {
Some(if old & 0x0f == 0x0f {
0x01
} else {
old.wrapping_add(1)
})
})
.unwrap(),
)
} else {
U4::from_lo(self.software_id.load(Ordering::SeqCst))
}
}
pub fn supports_msg(&self, msg: &HidppMessage) -> bool {
match msg {
HidppMessage::Short(_) => self.supports_short,
HidppMessage::Long(_) => self.supports_long,
}
}
pub async fn send(
&self,
msg: HidppMessage,
response_predicate: impl Fn(&HidppMessage) -> bool + Send + 'static,
) -> Result<HidppMessage, ChannelError> {
if !self.supports_msg(&msg) {
return Err(ChannelError::MessageTypeNotSupported);
}
let (sender, receiver) = oneshot::channel::<HidppMessage>();
self.pending_messages
.lock()
.unwrap()
.push_back(PendingMessage {
response_predicate: Box::new(response_predicate),
sender,
});
self.send_and_forget(msg).await?;
receiver.await.map_err(|_| ChannelError::NoResponse)
}
pub async fn send_and_forget(&self, msg: HidppMessage) -> Result<(), ChannelError> {
if !self.supports_msg(&msg) {
return Err(ChannelError::MessageTypeNotSupported);
}
let mut buf = [0u8; LONG_REPORT_LENGTH];
let len = msg.write_raw(&mut buf);
self.raw_channel
.write_report(&buf[..len])
.await
.map(|_| ())
.map_err(ChannelError::Implementation)
}
pub fn add_msg_listener(&self, listener: impl Fn(HidppMessage, bool) + Send + 'static) -> u32 {
let mut listeners = self.message_listeners.lock().unwrap();
let mut rng = rand::rng();
let mut hdl = rng.random::<u32>();
while listeners.contains_key(&hdl) {
hdl = rng.random::<u32>();
}
listeners.insert(hdl, Box::new(listener));
hdl
}
pub fn remove_msg_listener(&self, hdl: u32) -> bool {
self.message_listeners
.lock()
.unwrap()
.remove(&hdl)
.is_some()
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum ChannelError {
#[error("the HID channel implementation returned an error")]
Implementation(#[from] Box<dyn Error>),
#[error("the report descriptor could not be parsed")]
ReportDescriptor(hidreport::ParserError),
#[error("the HID channel does not support HID++")]
HidppNotSupported,
#[error("the channel does not support the given HID++ message type")]
MessageTypeNotSupported,
#[error("the device did not respond to the request")]
NoResponse,
}