Skip to main content

vlfd_rs/
usb.rs

1use crate::error::{Error, Result};
2use nusb::{
3    self, Device, DeviceId, DeviceInfo, Interface, MaybeFuture,
4    transfer::{Bulk, In, Out},
5};
6use std::{
7    io::{Read, Write},
8    sync::{
9        Arc,
10        atomic::{AtomicBool, Ordering},
11    },
12    thread,
13    time::Duration,
14};
15
16#[cfg(target_endian = "big")]
17compile_error!("vlfd-rs currently supports little-endian hosts only");
18
19const INTERFACE: u8 = 0;
20const HOTPLUG_POLL_INTERVAL: Duration = Duration::from_millis(100);
21const IO_BUFFER_SIZE: usize = 16 * 1024;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct TransportConfig {
25    pub usb_timeout: Duration,
26    pub sync_timeout: Duration,
27    pub reset_on_open: bool,
28    pub clear_halt_on_open: bool,
29}
30
31impl Default for TransportConfig {
32    fn default() -> Self {
33        Self {
34            usb_timeout: Duration::from_millis(1_000),
35            sync_timeout: Duration::from_secs(1),
36            reset_on_open: false,
37            clear_halt_on_open: true,
38        }
39    }
40}
41
42#[derive(Debug, Clone, Copy)]
43pub enum Endpoint {
44    FifoWrite = 0x02,
45    Command = 0x04,
46    FifoRead = 0x86,
47    Sync = 0x88,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum HotplugEventKind {
52    Arrived,
53    Left,
54}
55
56#[derive(Debug, Clone)]
57pub struct HotplugDeviceInfo {
58    pub bus_number: u8,
59    pub address: u8,
60    pub port_numbers: Vec<u8>,
61    pub vendor_id: Option<u16>,
62    pub product_id: Option<u16>,
63    pub class_code: Option<u8>,
64    pub sub_class_code: Option<u8>,
65    pub protocol_code: Option<u8>,
66}
67
68impl HotplugDeviceInfo {
69    fn from_device_info(device: &DeviceInfo) -> Self {
70        Self {
71            #[cfg(target_os = "linux")]
72            bus_number: device.busnum(),
73            #[cfg(not(target_os = "linux"))]
74            bus_number: 0,
75            address: device.device_address(),
76            #[cfg(any(target_os = "linux", target_os = "macos", target_os = "windows"))]
77            port_numbers: device.port_chain().to_vec(),
78            #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
79            port_numbers: Vec::new(),
80            vendor_id: Some(device.vendor_id()),
81            product_id: Some(device.product_id()),
82            class_code: Some(device.class()),
83            sub_class_code: Some(device.subclass()),
84            protocol_code: Some(device.protocol()),
85        }
86    }
87}
88
89#[derive(Debug, Clone)]
90pub struct HotplugEvent {
91    pub kind: HotplugEventKind,
92    pub device: HotplugDeviceInfo,
93}
94
95#[derive(Debug, Clone, Copy, Default)]
96pub struct HotplugOptions {
97    pub vendor_id: Option<u16>,
98    pub product_id: Option<u16>,
99    pub class_code: Option<u8>,
100    pub enumerate: bool,
101}
102
103#[derive(Debug, Clone, Default)]
104pub struct Probe {
105    transport: TransportConfig,
106}
107
108impl Probe {
109    pub fn new() -> Self {
110        Self::default()
111    }
112
113    pub fn with_transport_config(transport: TransportConfig) -> Self {
114        Self { transport }
115    }
116
117    pub fn transport_config(&self) -> &TransportConfig {
118        &self.transport
119    }
120
121    pub fn watch<F>(&self, options: HotplugOptions, callback: F) -> Result<HotplugRegistration>
122    where
123        F: FnMut(HotplugEvent) + Send + 'static,
124    {
125        UsbDevice::with_transport_config(self.transport)?
126            .register_hotplug_callback(options, callback)
127    }
128}
129
130pub struct UsbDevice {
131    handle: Option<Device>,
132    interface: Option<Interface>,
133    transport: TransportConfig,
134}
135
136impl UsbDevice {
137    pub fn with_transport_config(transport: TransportConfig) -> Result<Self> {
138        Ok(Self {
139            handle: None,
140            interface: None,
141            transport,
142        })
143    }
144
145    pub fn is_open(&self) -> bool {
146        self.interface.is_some()
147    }
148
149    pub fn transport_config(&self) -> &TransportConfig {
150        &self.transport
151    }
152
153    pub fn open(&mut self, vid: u16, pid: u16) -> Result<()> {
154        if self.is_open() {
155            return Ok(());
156        }
157
158        let device_info = nusb::list_devices()
159            .wait()
160            .map_err(|err| usb_error(err, "nusb_list_devices"))?
161            .find(|device| device.vendor_id() == vid && device.product_id() == pid)
162            .ok_or(Error::DeviceNotFound { vid, pid })?;
163
164        let device = device_info
165            .open()
166            .wait()
167            .map_err(|err| usb_error(err, "nusb_open_device"))?;
168
169        if self.transport.reset_on_open {
170            device
171                .reset()
172                .wait()
173                .map_err(|err| usb_error(err, "nusb_reset_device"))?;
174        }
175
176        let interface = device
177            .detach_and_claim_interface(INTERFACE)
178            .wait()
179            .map_err(|err| usb_error(err, "nusb_claim_interface"))?;
180
181        let mut usb_device = Self {
182            handle: Some(device),
183            interface: Some(interface),
184            transport: self.transport,
185        };
186
187        if usb_device.transport.clear_halt_on_open {
188            usb_device.clear_halt_all()?;
189        }
190
191        *self = usb_device;
192        Ok(())
193    }
194
195    pub fn close(&mut self) -> Result<()> {
196        self.interface.take();
197        self.handle.take();
198        Ok(())
199    }
200
201    pub fn read_bytes(&self, endpoint: Endpoint, buffer: &mut [u8]) -> Result<()> {
202        let interface = self.interface.as_ref().ok_or(Error::DeviceNotOpen)?;
203        bulk_read(interface, endpoint, buffer, self.transport.usb_timeout)
204    }
205
206    pub fn read_words(&self, endpoint: Endpoint, buffer: &mut [u16]) -> Result<()> {
207        let raw = words_as_bytes_mut(buffer);
208        self.read_bytes(endpoint, raw)
209    }
210
211    pub fn write_bytes(&self, endpoint: Endpoint, buffer: &[u8]) -> Result<()> {
212        let interface = self.interface.as_ref().ok_or(Error::DeviceNotOpen)?;
213        bulk_write(interface, endpoint, buffer, self.transport.usb_timeout)
214    }
215
216    pub fn write_words(&self, endpoint: Endpoint, buffer: &[u16]) -> Result<()> {
217        let raw = words_as_bytes(buffer);
218        self.write_bytes(endpoint, raw)
219    }
220
221    pub fn open_in_endpoint(&self, endpoint: Endpoint) -> Result<nusb::Endpoint<Bulk, In>> {
222        let interface = self.interface.as_ref().ok_or(Error::DeviceNotOpen)?;
223        interface
224            .endpoint::<Bulk, In>(endpoint as u8)
225            .map_err(|err| usb_error(err, "nusb_open_in_endpoint"))
226    }
227
228    pub fn open_out_endpoint(&self, endpoint: Endpoint) -> Result<nusb::Endpoint<Bulk, Out>> {
229        let interface = self.interface.as_ref().ok_or(Error::DeviceNotOpen)?;
230        interface
231            .endpoint::<Bulk, Out>(endpoint as u8)
232            .map_err(|err| usb_error(err, "nusb_open_out_endpoint"))
233    }
234
235    pub fn register_hotplug_callback<F>(
236        &self,
237        options: HotplugOptions,
238        mut callback: F,
239    ) -> Result<HotplugRegistration>
240    where
241        F: FnMut(HotplugEvent) + Send + 'static,
242    {
243        let mut seen_devices = Vec::<(DeviceId, HotplugDeviceInfo)>::new();
244        let initial_devices = matching_devices(options)?;
245        if options.enumerate {
246            for device in &initial_devices {
247                callback(HotplugEvent {
248                    kind: HotplugEventKind::Arrived,
249                    device: HotplugDeviceInfo::from_device_info(device),
250                });
251            }
252        }
253        seen_devices.extend(
254            initial_devices
255                .iter()
256                .map(|device| (device.id(), HotplugDeviceInfo::from_device_info(device))),
257        );
258
259        let running = Arc::new(AtomicBool::new(true));
260        let thread_running = Arc::clone(&running);
261        let thread = thread::Builder::new()
262            .name("vlfd-usb-hotplug".into())
263            .spawn(move || {
264                let mut known = seen_devices;
265                while thread_running.load(Ordering::Relaxed) {
266                    if let Ok(devices) = matching_devices(options) {
267                        let mut current = devices
268                            .iter()
269                            .map(|device| {
270                                (device.id(), HotplugDeviceInfo::from_device_info(device))
271                            })
272                            .collect::<Vec<_>>();
273
274                        for (id, info) in &current {
275                            if !known.iter().any(|(known_id, _)| known_id == id) {
276                                callback(HotplugEvent {
277                                    kind: HotplugEventKind::Arrived,
278                                    device: info.clone(),
279                                });
280                            }
281                        }
282
283                        for (id, info) in &known {
284                            if !current.iter().any(|(current_id, _)| current_id == id) {
285                                callback(HotplugEvent {
286                                    kind: HotplugEventKind::Left,
287                                    device: info.clone(),
288                                });
289                            }
290                        }
291
292                        known.clear();
293                        known.append(&mut current);
294                    }
295
296                    thread::sleep(HOTPLUG_POLL_INTERVAL);
297                }
298            })
299            .map_err(Error::Io)?;
300
301        Ok(HotplugRegistration {
302            running,
303            thread: Some(thread),
304        })
305    }
306
307    pub(crate) fn clear_halt_all(&mut self) -> Result<()> {
308        for endpoint in [
309            Endpoint::FifoWrite,
310            Endpoint::Command,
311            Endpoint::FifoRead,
312            Endpoint::Sync,
313        ] {
314            self.clear_halt(endpoint)?;
315        }
316        Ok(())
317    }
318
319    fn clear_halt(&mut self, endpoint: Endpoint) -> Result<()> {
320        let interface = self.interface.as_ref().ok_or(Error::DeviceNotOpen)?;
321        match endpoint {
322            Endpoint::FifoWrite | Endpoint::Command => {
323                let mut ep = interface
324                    .endpoint::<Bulk, Out>(endpoint as u8)
325                    .map_err(|err| usb_error(err, "nusb_open_out_endpoint"))?;
326                ep.clear_halt()
327                    .wait()
328                    .map_err(|err| usb_error(err, "nusb_clear_halt"))?;
329            }
330            Endpoint::FifoRead | Endpoint::Sync => {
331                let mut ep = interface
332                    .endpoint::<Bulk, In>(endpoint as u8)
333                    .map_err(|err| usb_error(err, "nusb_open_in_endpoint"))?;
334                ep.clear_halt()
335                    .wait()
336                    .map_err(|err| usb_error(err, "nusb_clear_halt"))?;
337            }
338        }
339        Ok(())
340    }
341}
342
343impl Drop for UsbDevice {
344    fn drop(&mut self) {
345        let _ = self.close();
346    }
347}
348
349#[derive(Debug)]
350pub struct HotplugRegistration {
351    running: Arc<AtomicBool>,
352    thread: Option<thread::JoinHandle<()>>,
353}
354
355impl Drop for HotplugRegistration {
356    fn drop(&mut self) {
357        self.running.store(false, Ordering::SeqCst);
358        if let Some(handle) = self.thread.take() {
359            let _ = handle.join();
360        }
361    }
362}
363
364fn bulk_read(
365    interface: &Interface,
366    endpoint: Endpoint,
367    buffer: &mut [u8],
368    timeout: Duration,
369) -> Result<()> {
370    let mut reader = interface
371        .endpoint::<Bulk, In>(endpoint as u8)
372        .map_err(|err| usb_error(err, "nusb_open_in_endpoint"))?
373        .reader(IO_BUFFER_SIZE)
374        .with_read_timeout(timeout);
375
376    reader
377        .read_exact(buffer)
378        .map_err(|err| io_error(err, "nusb_bulk_read"))?;
379    Ok(())
380}
381
382fn bulk_write(
383    interface: &Interface,
384    endpoint: Endpoint,
385    buffer: &[u8],
386    timeout: Duration,
387) -> Result<()> {
388    let mut writer = interface
389        .endpoint::<Bulk, Out>(endpoint as u8)
390        .map_err(|err| usb_error(err, "nusb_open_out_endpoint"))?
391        .writer(IO_BUFFER_SIZE)
392        .with_write_timeout(timeout);
393
394    writer
395        .write_all(buffer)
396        .map_err(|err| io_error(err, "nusb_bulk_write"))?;
397    writer
398        .flush()
399        .map_err(|err| io_error(err, "nusb_bulk_flush"))?;
400    Ok(())
401}
402
403fn matching_devices(options: HotplugOptions) -> Result<Vec<DeviceInfo>> {
404    let devices = nusb::list_devices()
405        .wait()
406        .map_err(|err| usb_error(err, "nusb_list_devices"))?;
407    Ok(devices
408        .filter(|device| {
409            options
410                .vendor_id
411                .is_none_or(|vendor_id| device.vendor_id() == vendor_id)
412                && options
413                    .product_id
414                    .is_none_or(|product_id| device.product_id() == product_id)
415                && options
416                    .class_code
417                    .is_none_or(|class_code| device.class() == class_code)
418        })
419        .collect())
420}
421
422fn words_as_bytes(words: &[u16]) -> &[u8] {
423    unsafe { std::slice::from_raw_parts(words.as_ptr() as *const u8, std::mem::size_of_val(words)) }
424}
425
426fn words_as_bytes_mut(words: &mut [u16]) -> &mut [u8] {
427    unsafe {
428        std::slice::from_raw_parts_mut(words.as_mut_ptr() as *mut u8, std::mem::size_of_val(words))
429    }
430}
431
432fn usb_error(err: nusb::Error, context: &'static str) -> Error {
433    Error::Usb {
434        source: Box::new(err),
435        context,
436    }
437}
438
439fn io_error(err: std::io::Error, context: &'static str) -> Error {
440    if err.kind() == std::io::ErrorKind::TimedOut {
441        Error::Timeout(context)
442    } else {
443        Error::Usb {
444            source: Box::new(err),
445            context,
446        }
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::TransportConfig;
453    use std::time::Duration;
454
455    #[test]
456    fn default_transport_config_prefers_stable_open_behavior() {
457        let config = TransportConfig::default();
458        assert_eq!(config.usb_timeout, Duration::from_millis(1_000));
459        assert_eq!(config.sync_timeout, Duration::from_secs(1));
460        assert!(!config.reset_on_open);
461        assert!(config.clear_halt_on_open);
462    }
463}