hidapi_async/
lib.rs

1// Copyright 2020 Shift Cryptosecurity AG
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#[macro_use]
16extern crate log;
17
18use futures::prelude::*;
19use futures::task::SpawnError;
20use hidapi::{HidDevice, HidError};
21use std::io;
22use std::pin::Pin;
23use std::sync::mpsc;
24use std::sync::{Arc, Mutex};
25use std::task::{Context, Poll, Waker};
26use thiserror::Error;
27
28#[cfg(test)]
29mod tests {
30    #[test]
31    fn it_works() {
32        assert_eq!(2 + 2, 4);
33    }
34}
35
36#[derive(Error, Debug)]
37pub enum Error {
38    #[error("libhid failed")]
39    HidApi(#[from] HidError),
40    #[error("io failed")]
41    Io(#[from] io::Error),
42    #[error("spawn failed")]
43    Spawn(#[from] SpawnError),
44}
45
46enum ReadState {
47    Idle,
48    Busy,
49}
50
51struct DeviceInner {
52    device: Arc<Mutex<HidDevice>>,
53    read_thread: Option<std::thread::JoinHandle<()>>,
54    rstate: ReadState,
55    data_rx: mpsc::Receiver<Option<[u8; 64]>>, // One message per read
56    req_tx: Option<mpsc::Sender<Waker>>,       // One message per expected read
57    buffer: Option<[u8; 64]>,
58    buffer_pos: usize,
59}
60
61pub struct Device {
62    // store an Option so that `close` works
63    inner: Option<Arc<Mutex<DeviceInner>>>,
64}
65
66impl Clone for Device {
67    fn clone(&self) -> Self {
68        Device {
69            inner: self.inner.as_ref().map(|dev| Arc::clone(&dev)),
70        }
71    }
72}
73
74impl Drop for Device {
75    fn drop(&mut self) {
76        debug!("dropping hid connection");
77        if let Some(inner) = self.inner.take() {
78            if let Ok(mut guard) = inner.lock() {
79                // Take the waker queue and drop it so that the reader thread finihes
80                let req_tx = guard.req_tx.take();
81                drop(req_tx);
82
83                // Wait for the reader thread to finish
84                match guard.read_thread.take() {
85                    Some(jh) => match jh.join() {
86                        Ok(_) => info!("device read thread joined"),
87                        Err(_) => error!("failed to join device read thread"),
88                    },
89                    None => error!("already joined"),
90                }
91            } else {
92                error!("Failed to take lock on device");
93            }
94        } else {
95            error!("there was no inner");
96        }
97    }
98}
99
100impl Device {
101    pub fn new(device: HidDevice) -> Result<Self, Error> {
102        let (data_tx, data_rx) = mpsc::channel();
103        let (req_tx, req_rx) = mpsc::channel::<Waker>();
104        // set non-blocking so that we can ignore spurious wakeups.
105        //device.set_blocking_mode(false);
106        // Must be accessed from both inner thread and asyn_write
107        let device = Arc::new(Mutex::new(device));
108        let jh = std::thread::spawn({
109            let device = Arc::clone(&device);
110            move || {
111                loop {
112                    // Wait for read request
113                    debug!("waiting for request");
114                    let waker = match req_rx.recv() {
115                        Ok(waker) => waker,
116                        Err(_e) => {
117                            info!("No more wakers, shutting down");
118                            return;
119                        }
120                    };
121                    debug!("Got notified");
122                    match device.lock() {
123                        Ok(guard) => {
124                            let mut buf = [0u8; 64];
125                            //match guard.read_timeout(&mut buf[..], 1000) {
126                            match guard.read(&mut buf[..]) {
127                                Err(e) => {
128                                    error!("hidapi failed: {}", e);
129                                    drop(data_tx);
130                                    waker.wake_by_ref();
131                                    break;
132                                }
133                                Ok(len) => {
134                                    if len == 0 {
135                                        data_tx.send(None).unwrap();
136                                        waker.wake_by_ref();
137                                        continue;
138                                    }
139                                    debug!("Read data");
140                                    if let Err(e) = data_tx.send(Some(buf)) {
141                                        error!("Sending internally: {}", e);
142                                        break;
143                                    }
144                                    waker.wake_by_ref();
145                                }
146                            }
147                        }
148                        Err(e) => {
149                            error!("Broken lock: {:?}", e);
150                            return;
151                        }
152                    }
153                }
154            }
155        });
156        Ok(Device {
157            inner: Some(Arc::new(Mutex::new(DeviceInner {
158                device,
159                read_thread: Some(jh),
160                rstate: ReadState::Idle,
161                data_rx,
162                req_tx: Some(req_tx),
163                buffer: None,
164                buffer_pos: 0,
165            }))),
166        })
167    }
168}
169
170impl AsyncWrite for Device {
171    fn poll_write(
172        mut self: Pin<&mut Self>,
173        _cx: &mut Context,
174        mut buf: &[u8],
175    ) -> Poll<Result<usize, io::Error>> {
176        let len = buf.len();
177        if self.inner.is_none() {
178            return Poll::Ready(Err(io::Error::new(
179                io::ErrorKind::InvalidData,
180                "Cannot poll a closed device",
181            )));
182        }
183        loop {
184            let max_len = usize::min(64, buf.len());
185            // The hidapi API requires that you put the report ID in the first byte.
186            // If you don't use report IDs you must put a 0 there.
187            let mut buf_with_report_id = [0u8; 1 + 64];
188            (&mut buf_with_report_id[1..1 + max_len]).copy_from_slice(&buf[..max_len]);
189
190            //let this: &mut Self = &mut self;
191            debug!("Will write {:?}", &buf_with_report_id[..]);
192            match self.inner.as_mut().unwrap().lock() {
193                Ok(guard) => match guard.device.lock() {
194                    Ok(guard) => {
195                        guard
196                            .write(&buf_with_report_id[..])
197                            .map_err(|_| io::Error::new(io::ErrorKind::Other, "hidapi failed"))?;
198                        debug!("Wrote: {:?}", &buf[0..max_len]);
199                    }
200                    Err(e) => error!("{:?}", e),
201                },
202                Err(e) => {
203                    return Poll::Ready(Err(io::Error::new(
204                        io::ErrorKind::Other,
205                        format!("Mutex broken: {:?}", e),
206                    )))
207                }
208            }
209            buf = &buf[max_len..];
210            if buf.len() == 0 {
211                debug!("Wrote total {}: {:?}", buf.len(), buf);
212                return Poll::Ready(Ok(len));
213            }
214        }
215    }
216    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
217        Poll::Ready(Ok(()))
218    }
219    // TODO cleanup read thread...
220    fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
221        let this: &mut Self = &mut self;
222        // take the device and drop it
223        let _device = this.inner.take();
224        Poll::Ready(Ok(()))
225    }
226}
227
228// Will always read out 64 bytes. Make sure to read out all bytes to avoid trailing bytes in next
229// readout.
230// Will store all bytes that did not fit in provided buffer and give them next time.
231impl AsyncRead for Device {
232    fn poll_read(
233        mut self: Pin<&mut Self>,
234        cx: &mut Context,
235        buf: &mut [u8],
236    ) -> Poll<Result<usize, io::Error>> {
237        if self.inner.is_none() {
238            return Poll::Ready(Err(io::Error::new(
239                io::ErrorKind::InvalidData,
240                "Cannot poll a closed device",
241            )));
242        }
243        let mut this =
244            self.inner.as_mut().unwrap().lock().map_err(|e| {
245                io::Error::new(io::ErrorKind::Other, format!("Mutex broken: {:?}", e))
246            })?;
247        loop {
248            let waker = cx.waker().clone();
249            match this.rstate {
250                ReadState::Idle => {
251                    debug!("Sending waker");
252                    if let Some(req_tx) = &mut this.req_tx {
253                        if let Err(_e) = req_tx.send(waker) {
254                            error!("failed to send waker");
255                        }
256                    } else {
257                        return Poll::Ready(Err(io::Error::new(
258                            io::ErrorKind::InvalidData,
259                            "Failed internal send",
260                        )));
261                    }
262                    this.rstate = ReadState::Busy;
263                }
264                ReadState::Busy => {
265                    // First send any bytes from the previous readout
266                    if let Some(inner_buf) = this.buffer.take() {
267                        let len = usize::min(buf.len(), inner_buf.len());
268                        let inner_slice = &inner_buf[this.buffer_pos..this.buffer_pos + len];
269                        let buf_slice = &mut buf[..len];
270                        buf_slice.copy_from_slice(inner_slice);
271                        // Check if there is more data left
272                        if this.buffer_pos + inner_slice.len() < inner_buf.len() {
273                            this.buffer = Some(inner_buf);
274                            this.buffer_pos += inner_slice.len();
275                        } else {
276                            this.rstate = ReadState::Idle;
277                        }
278                        return Poll::Ready(Ok(len));
279                    }
280
281                    // Second try to receive more bytes
282                    let vec = match this.data_rx.try_recv() {
283                        Ok(Some(vec)) => vec,
284                        Ok(None) => {
285                            // end of stream?
286                            return Poll::Pending;
287                        }
288                        Err(e) => match e {
289                            mpsc::TryRecvError::Disconnected => {
290                                return Poll::Ready(Err(io::Error::new(
291                                    io::ErrorKind::Other,
292                                    format!("Inner channel dead"),
293                                )));
294                            }
295                            mpsc::TryRecvError::Empty => {
296                                return Poll::Pending;
297                            }
298                        },
299                    };
300                    debug!("Read data {:?}", &vec[..]);
301                    let len = usize::min(vec.len(), buf.len());
302                    let buf_slice = &mut buf[..len];
303                    let vec_slice = &vec[..len];
304                    buf_slice.copy_from_slice(vec_slice);
305                    if len < vec.len() {
306                        // If bytes did not fit in buf, store bytes for next readout
307                        this.buffer = Some(vec);
308                        this.buffer_pos = 0;
309                    } else {
310                        this.rstate = ReadState::Idle;
311                    }
312                    debug!("returning {}", len);
313                    return Poll::Ready(Ok(len));
314                }
315            };
316        }
317    }
318}