mod utils;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, PoisonError};
use std::task::{Context, Poll};
use async_channel::{Receiver, Sender, WeakSender};
use futures_lite::stream::Boxed;
use futures_lite::{Stream, StreamExt};
use log::{debug, trace};
use once_cell::sync::OnceCell;
use windows::core::{h, Ref, HRESULT, HSTRING};
use windows::Devices::Enumeration::{DeviceInformation, DeviceInformationUpdate, DeviceWatcher};
use windows::Devices::HumanInterfaceDevice::{HidDevice, HidInputReport, HidInputReportReceivedEventArgs};
use windows::Foundation::TypedEventHandler;
use windows::Storage::FileAccessMode;
use windows::Win32::Foundation::ERROR_FILE_NOT_FOUND;
use crate::backend::winrt::utils::{DeviceInformationSteam, IBufferExt, WinResultExt};
use crate::backend::{Backend, DeviceInfoStream};
use crate::device_info::DeviceId;
use crate::error::HidResult;
use crate::{ensure, AsyncHidRead, AsyncHidWrite, AsyncHidFeatureHandle, DeviceEvent, DeviceInfo, HidError};
const DEVICE_SELECTOR: &HSTRING = h!(
r#"System.Devices.InterfaceClassGuid:="{4D1E55B2-F16F-11CF-88CB-001111000030}" AND System.Devices.InterfaceEnabled:=System.StructuredQueryType.Boolean#True"#
);
#[derive(Default)]
struct DeviceWatcherContext {
next_id: AtomicU64,
active_readers: Mutex<Vec<(u64, HSTRING, WeakSender<HidInputReport>)>>,
watchers: Mutex<Vec<Sender<DeviceEvent>>>
}
#[derive(Default, Clone)]
pub struct WinRtBackend {
context: Arc<DeviceWatcherContext>,
watcher: Arc<OnceCell<DeviceWatcher>>
}
impl Backend for WinRtBackend {
type Reader = InputReceiver;
type Writer = HidDevice;
type FeatureHandle = HidDevice;
async fn enumerate(&self) -> HidResult<DeviceInfoStream> {
let devices = DeviceInformation::FindAllAsyncAqsFilter(DEVICE_SELECTOR)?.await?;
let devices = DeviceInformationSteam::from(devices)
.then(|info| Box::pin(get_device_information(info)))
.filter_map(|r| r.transpose());
Ok(devices.boxed())
}
fn watch(&self) -> HidResult<Boxed<DeviceEvent>> {
struct WatchHelper(WinRtBackend);
impl Drop for WatchHelper {
fn drop(&mut self) {
self.0.clear_closed_event_watchers()
}
}
impl Stream for WatchHelper {
type Item = DeviceEvent;
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Pending
}
}
let (sender, receiver) = async_channel::bounded(64);
self.register_event_watcher(sender)?;
Ok(receiver.chain(WatchHelper(self.clone())).boxed())
}
async fn query_info(&self, id: &DeviceId) -> HidResult<Vec<DeviceInfo>> {
let DeviceId::UncPath(id) = id;
let info = DeviceInformation::CreateFromIdAsync(id)?.await?;
Ok(get_device_information(info).await?.into_iter().collect())
}
async fn open(&self, id: &DeviceId, read: bool, write: bool) -> HidResult<(Option<Self::Reader>, Option<Self::Writer>)> {
let mode = match (read, write) {
(true, false) => FileAccessMode::Read,
(_, true) => FileAccessMode::ReadWrite,
(false, false) => panic!("Not supported")
};
let DeviceId::UncPath(id) = id;
let device = HidDevice::FromIdAsync(id, mode)?
.await
.map_err(|err| match err {
e if e.code().is_ok() || e.code() == HRESULT::from_win32(ERROR_FILE_NOT_FOUND.0) => HidError::NotConnected,
e => e.into()
})?;
let input = match read {
true => Some(InputReceiver::new(self, id, device.clone())?),
false => None
};
Ok((input, read.then_some(device)))
}
async fn open_feature_handle(&self, id: &DeviceId) -> HidResult<Self::FeatureHandle> {
let DeviceId::UncPath(id) = id;
let device = HidDevice::FromIdAsync(id, FileAccessMode::ReadWrite)?
.await
.map_err(|err| match err {
e if e.code().is_ok() || e.code() == HRESULT::from_win32(ERROR_FILE_NOT_FOUND.0) => HidError::NotConnected,
e => e.into()
})?;
Ok(device)
}
}
async fn get_device_information(device: DeviceInformation) -> HidResult<Option<DeviceInfo>> {
let id = device.Id()?;
let name = device.Name()?.to_string_lossy();
let device = HidDevice::FromIdAsync(&id, FileAccessMode::Read)?;
let Some(device) = device.await.extract_null()? else {
return Ok(None);
};
Ok(Some(DeviceInfo {
id: DeviceId::UncPath(id),
name,
product_id: device.ProductId()?,
vendor_id: device.VendorId()?,
usage_id: device.UsageId()?,
usage_page: device.UsagePage()?,
manufacturer: None,
serial_number: None
}))
}
pub struct InputReceiver {
backend: WinRtBackend,
device: HidDevice,
buffer: Receiver<HidInputReport>,
token: i64,
watcher_registration: u64
}
impl InputReceiver {
fn new(backend: &WinRtBackend, id: &HSTRING, device: HidDevice) -> HidResult<Self> {
let (sender, receiver) = async_channel::bounded(64);
let watcher_registration = backend.register_active_reader(id.clone(), &sender)?;
let token = device.InputReportReceived(&TypedEventHandler::new(move |_, args: Ref<HidInputReportReceivedEventArgs>| {
if let Some(args) = args.as_ref() {
let msg = args.Report()?;
let _ = sender.force_send(msg);
}
Ok(())
}))?;
Ok(Self {
backend: backend.clone(),
device,
buffer: receiver,
token,
watcher_registration
})
}
}
impl Drop for InputReceiver {
fn drop(&mut self) {
self.device
.RemoveInputReportReceived(self.token)
.unwrap_or_else(|err| log::warn!("Failed to unregister input report callback\n\t{err:?}"));
self.backend
.unregister_active_reader(self.watcher_registration);
}
}
impl AsyncHidRead for InputReceiver {
async fn read_input_report<'a>(&'a mut self, buf: &'a mut [u8]) -> HidResult<usize> {
let report = self
.buffer
.recv()
.await
.map_err(|_| HidError::Disconnected)?;
let buffer = report.Data()?;
let buffer = buffer.as_slice()?;
ensure!(!buffer.is_empty(), HidError::message("Input report is empty"));
let size = buf.len().min(buffer.len());
let start = if buffer[0] == 0x0 { 1 } else { 0 };
buf[..(size - start)].copy_from_slice(&buffer[start..size]);
Ok(size - start)
}
}
impl AsyncHidWrite for HidDevice {
async fn write_output_report<'a>(&'a mut self, buf: &'a [u8]) -> HidResult<()> {
let report = self.CreateOutputReport()?;
{
let mut buffer = report.Data()?;
ensure!(buffer.Length()? as usize >= buf.len(), HidError::message("Output report is too large"));
let (buffer, remainder) = buffer.as_mut_slice()?.split_at_mut(buf.len());
buffer.copy_from_slice(buf);
remainder.fill(0);
}
self.SendOutputReportAsync(&report)?.await?;
Ok(())
}
}
impl AsyncHidFeatureHandle for HidDevice {
async fn read_feature_report<'a>(&'a mut self, buf: &'a mut [u8]) -> HidResult<usize> {
ensure!(!buf.is_empty(), HidError::message("Buffer cannot be empty"));
let report = self.GetFeatureReportByIdAsync(buf[0] as u16)?.await?;
let data = report.Data()?;
let data_slice = data.as_slice()?;
ensure!(!data_slice.is_empty(), HidError::message("Feature report is empty"));
let copy_size = buf.len().min(data_slice.len());
buf[..copy_size].copy_from_slice(&data_slice[..copy_size]);
Ok(copy_size)
}
async fn write_feature_report<'a>(&'a mut self, buf: &'a [u8]) -> HidResult<()> {
let report = self.CreateFeatureReport()?;
{
let mut buffer = report.Data()?;
ensure!(buffer.Length()? as usize >= buf.len(), HidError::message("Feature report is too large"));
let (buffer, remainder) = buffer.as_mut_slice()?.split_at_mut(buf.len());
buffer.copy_from_slice(buf);
remainder.fill(0);
}
self.SendFeatureReportAsync(&report)?.await?;
Ok(())
}
}
impl WinRtBackend {
fn initialize_watcher(&self) -> HidResult<()> {
let _initialized = self.watcher.get_or_try_init(|| {
let watcher = DeviceInformation::CreateWatcherAqsFilter(DEVICE_SELECTOR)?;
watcher.Removed(&TypedEventHandler::new({
let ctx = self.context.clone();
move |_, info: Ref<DeviceInformationUpdate>| {
let info = info.ok()?;
let id = info.Id()?;
ctx.active_readers
.lock()
.unwrap_or_else(PoisonError::into_inner)
.retain(|(reg, rid, channel)| match rid == &id {
true => {
if let Some(channel) = channel.upgrade() {
trace!("Force close channel of reader {}", reg);
channel.close();
}
false
}
false => true
});
ctx.watchers
.lock()
.unwrap_or_else(PoisonError::into_inner)
.retain(|channel| {
channel
.force_send(DeviceEvent::Disconnected(DeviceId::UncPath(id.clone())))
.is_ok()
});
Ok(())
}
}))?;
let enumeration_complete = Arc::new(AtomicBool::new(false));
watcher.Added(&TypedEventHandler::new({
let ctx = self.context.clone();
let enumeration_complete = enumeration_complete.clone();
move |_, info: Ref<DeviceInformation>| {
if !enumeration_complete.load(Ordering::Relaxed) {
return Ok(());
}
let info = info.ok()?;
let id = info.Id()?;
ctx.watchers
.lock()
.unwrap_or_else(PoisonError::into_inner)
.retain(|channel| {
channel
.force_send(DeviceEvent::Connected(DeviceId::UncPath(id.clone())))
.is_ok()
});
Ok(())
}
}))?;
watcher.EnumerationCompleted(&TypedEventHandler::new(move |_, _| {
enumeration_complete.store(true, Ordering::Relaxed);
Ok(())
}))?;
debug!("Starting device watcher");
watcher.Start()?;
Ok::<_, HidError>(watcher)
})?;
Ok(())
}
fn register_active_reader(&self, id: HSTRING, sender: &Sender<HidInputReport>) -> HidResult<u64> {
self.initialize_watcher()?;
let registration = self.context.next_id.fetch_add(1, Ordering::Relaxed);
let mut readers = self
.context
.active_readers
.lock()
.unwrap_or_else(PoisonError::into_inner);
readers.push((registration, id, sender.downgrade()));
trace!(
"Registered active reader with device watcher (id: {}, number of registered readers: {})",
registration,
readers.len()
);
Ok(registration)
}
fn unregister_active_reader(&self, registration: u64) {
let mut readers = self
.context
.active_readers
.lock()
.unwrap_or_else(PoisonError::into_inner);
let count = readers.len();
readers.retain(|(id, _, _)| *id != registration);
if readers.len() == count {
trace!("Reader {} was already removed from the device watcher", registration);
} else {
trace!("Unregistered reader {} from the device watcher", registration);
}
}
fn register_event_watcher(&self, sender: Sender<DeviceEvent>) -> HidResult<()> {
self.initialize_watcher()?;
let mut watchers = self
.context
.watchers
.lock()
.unwrap_or_else(PoisonError::into_inner);
watchers.push(sender);
trace!("Registered new event watcher (total: {})", watchers.len());
Ok(())
}
fn clear_closed_event_watchers(&self) {
let mut watchers = self
.context
.watchers
.lock()
.unwrap_or_else(PoisonError::into_inner);
let count = watchers.len();
watchers.retain(|watcher| !watcher.is_closed());
trace!("Cleared {} event watchers ({} remaining)", count - watchers.len(), watchers.len());
}
}