// Copyright 2020 Shift Cryptosecurity AG
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[macro_use]
extern crate log;
use futures::prelude::*;
use futures::task::SpawnError;
use hidapi::{HidDevice, HidError};
use std::io;
use std::pin::Pin;
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use thiserror::Error;
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
assert_eq!(2 + 2, 4);
}
}
#[derive(Error, Debug)]
pub enum Error {
#[error("libhid failed")]
HidApi(#[from] HidError),
#[error("io failed")]
Io(#[from] io::Error),
#[error("spawn failed")]
Spawn(#[from] SpawnError),
}
enum ReadState {
Idle,
Busy,
}
struct DeviceInner {
device: Arc<Mutex<HidDevice>>,
read_thread: Option<std::thread::JoinHandle<()>>,
rstate: ReadState,
data_rx: mpsc::Receiver<Option<[u8; 64]>>, // One message per read
req_tx: Option<mpsc::Sender<Waker>>, // One message per expected read
buffer: Option<[u8; 64]>,
buffer_pos: usize,
}
pub struct Device {
// store an Option so that `close` works
inner: Option<Arc<Mutex<DeviceInner>>>,
}
impl Clone for Device {
fn clone(&self) -> Self {
Device {
inner: self.inner.as_ref().map(|dev| Arc::clone(&dev)),
}
}
}
impl Drop for Device {
fn drop(&mut self) {
debug!("dropping hid connection");
if let Some(inner) = self.inner.take() {
if let Ok(mut guard) = inner.lock() {
// Take the waker queue and drop it so that the reader thread finihes
let req_tx = guard.req_tx.take();
drop(req_tx);
// Wait for the reader thread to finish
match guard.read_thread.take() {
Some(jh) => match jh.join() {
Ok(_) => info!("device read thread joined"),
Err(_) => error!("failed to join device read thread"),
},
None => error!("already joined"),
}
} else {
error!("Failed to take lock on device");
}
} else {
error!("there was no inner");
}
}
}
impl Device {
pub fn new(device: HidDevice) -> Result<Self, Error> {
let (data_tx, data_rx) = mpsc::channel();
let (req_tx, req_rx) = mpsc::channel::<Waker>();
// set non-blocking so that we can ignore spurious wakeups.
//device.set_blocking_mode(false);
// Must be accessed from both inner thread and asyn_write
let device = Arc::new(Mutex::new(device));
let jh = std::thread::spawn({
let device = Arc::clone(&device);
move || {
loop {
// Wait for read request
debug!("waiting for request");
let waker = match req_rx.recv() {
Ok(waker) => waker,
Err(_e) => {
info!("No more wakers, shutting down");
return;
}
};
debug!("Got notified");
match device.lock() {
Ok(guard) => {
let mut buf = [0u8; 64];
//match guard.read_timeout(&mut buf[..], 1000) {
match guard.read(&mut buf[..]) {
Err(e) => {
error!("hidapi failed: {}", e);
drop(data_tx);
waker.wake_by_ref();
break;
}
Ok(len) => {
if len == 0 {
data_tx.send(None).unwrap();
waker.wake_by_ref();
continue;
}
debug!("Read data");
if let Err(e) = data_tx.send(Some(buf)) {
error!("Sending internally: {}", e);
break;
}
waker.wake_by_ref();
}
}
}
Err(e) => {
error!("Broken lock: {:?}", e);
return;
}
}
}
}
});
Ok(Device {
inner: Some(Arc::new(Mutex::new(DeviceInner {
device,
read_thread: Some(jh),
rstate: ReadState::Idle,
data_rx,
req_tx: Some(req_tx),
buffer: None,
buffer_pos: 0,
}))),
})
}
}
impl AsyncWrite for Device {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context,
mut buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let len = buf.len();
if self.inner.is_none() {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
"Cannot poll a closed device",
)));
}
loop {
let max_len = usize::min(64, buf.len());
// The hidapi API requires that you put the report ID in the first byte.
// If you don't use report IDs you must put a 0 there.
let mut buf_with_report_id = [0u8; 1 + 64];
(&mut buf_with_report_id[1..1 + max_len]).copy_from_slice(&buf[..max_len]);
//let this: &mut Self = &mut self;
debug!("Will write {:?}", &buf_with_report_id[..]);
match self.inner.as_mut().unwrap().lock() {
Ok(guard) => match guard.device.lock() {
Ok(guard) => {
guard
.write(&buf_with_report_id[..])
.map_err(|_| io::Error::new(io::ErrorKind::Other, "hidapi failed"))?;
debug!("Wrote: {:?}", &buf[0..max_len]);
}
Err(e) => error!("{:?}", e),
},
Err(e) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
format!("Mutex broken: {:?}", e),
)))
}
}
buf = &buf[max_len..];
if buf.len() == 0 {
debug!("Wrote total {}: {:?}", buf.len(), buf);
return Poll::Ready(Ok(len));
}
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
// TODO cleanup read thread...
fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
let this: &mut Self = &mut self;
// take the device and drop it
let _device = this.inner.take();
Poll::Ready(Ok(()))
}
}
// Will always read out 64 bytes. Make sure to read out all bytes to avoid trailing bytes in next
// readout.
// Will store all bytes that did not fit in provided buffer and give them next time.
impl AsyncRead for Device {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<Result<usize, io::Error>> {
if self.inner.is_none() {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
"Cannot poll a closed device",
)));
}
let mut this =
self.inner.as_mut().unwrap().lock().map_err(|e| {
io::Error::new(io::ErrorKind::Other, format!("Mutex broken: {:?}", e))
})?;
loop {
let waker = cx.waker().clone();
match this.rstate {
ReadState::Idle => {
debug!("Sending waker");
if let Some(req_tx) = &mut this.req_tx {
if let Err(_e) = req_tx.send(waker) {
error!("failed to send waker");
}
} else {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
"Failed internal send",
)));
}
this.rstate = ReadState::Busy;
}
ReadState::Busy => {
// First send any bytes from the previous readout
if let Some(inner_buf) = this.buffer.take() {
let len = usize::min(buf.len(), inner_buf.len());
let inner_slice = &inner_buf[this.buffer_pos..this.buffer_pos + len];
let buf_slice = &mut buf[..len];
buf_slice.copy_from_slice(inner_slice);
// Check if there is more data left
if this.buffer_pos + inner_slice.len() < inner_buf.len() {
this.buffer = Some(inner_buf);
this.buffer_pos += inner_slice.len();
} else {
this.rstate = ReadState::Idle;
}
return Poll::Ready(Ok(len));
}
// Second try to receive more bytes
let vec = match this.data_rx.try_recv() {
Ok(Some(vec)) => vec,
Ok(None) => {
// end of stream?
return Poll::Pending;
}
Err(e) => match e {
mpsc::TryRecvError::Disconnected => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
format!("Inner channel dead"),
)));
}
mpsc::TryRecvError::Empty => {
return Poll::Pending;
}
},
};
debug!("Read data {:?}", &vec[..]);
let len = usize::min(vec.len(), buf.len());
let buf_slice = &mut buf[..len];
let vec_slice = &vec[..len];
buf_slice.copy_from_slice(vec_slice);
if len < vec.len() {
// If bytes did not fit in buf, store bytes for next readout
this.buffer = Some(vec);
this.buffer_pos = 0;
} else {
this.rstate = ReadState::Idle;
}
debug!("returning {}", len);
return Poll::Ready(Ok(len));
}
};
}
}
}