aggligator_transport_usb/
lib.rs

1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![doc(
4    html_logo_url = "https://raw.githubusercontent.com/surban/aggligator/master/.misc/aggligator.png",
5    html_favicon_url = "https://raw.githubusercontent.com/surban/aggligator/master/.misc/aggligator.png",
6    issue_tracker_base_url = "https://github.com/surban/aggligator/issues/"
7)]
8
9//! [Aggligator](aggligator) transport: USB
10//!
11//! This uses a [USB packet channel (UPC)](upc) to encapsulate data over a USB connection.
12
13use std::time::Duration;
14
15static NAME: &str = "usb";
16const TIMEOUT: Duration = Duration::from_secs(1);
17
18#[cfg(feature = "host")]
19#[cfg_attr(docsrs, doc(cfg(feature = "host")))]
20mod host {
21    use aggligator::io::{StreamBox, TxRxBox};
22    use async_trait::async_trait;
23    use futures::{FutureExt, StreamExt};
24    use std::{
25        any::Any,
26        cmp::Ordering,
27        collections::HashSet,
28        fmt,
29        hash::{Hash, Hasher},
30        io::{Error, ErrorKind, Result},
31        time::Duration,
32    };
33    use tokio::{
34        sync::{watch, Mutex},
35        time::sleep,
36    };
37
38    use aggligator::{
39        control::Direction,
40        transport::{ConnectingTransport, LinkTag, LinkTagBox},
41    };
42
43    use super::{NAME, TIMEOUT};
44
45    const PROBE_INTERVAL: Duration = Duration::from_secs(3);
46
47    /// Link tag for outgoing USB link.
48    #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
49    pub struct OutgoingUsbLinkTag {
50        /// Bus id.
51        pub bus_id: String,
52        /// Device address.
53        pub address: u8,
54        /// Interface number.
55        pub interface: u8,
56    }
57
58    impl fmt::Display for OutgoingUsbLinkTag {
59        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60            write!(f, "USB {} -> {}:{}", self.bus_id, self.address, self.interface)
61        }
62    }
63
64    impl LinkTag for OutgoingUsbLinkTag {
65        fn transport_name(&self) -> &str {
66            NAME
67        }
68
69        fn direction(&self) -> Direction {
70            Direction::Outgoing
71        }
72
73        fn user_data(&self) -> Vec<u8> {
74            Vec::new()
75        }
76
77        fn as_any(&self) -> &dyn Any {
78            self
79        }
80
81        fn box_clone(&self) -> LinkTagBox {
82            Box::new(self.clone())
83        }
84
85        fn dyn_cmp(&self, other: &dyn LinkTag) -> Ordering {
86            let other = other.as_any().downcast_ref::<Self>().unwrap();
87            Ord::cmp(self, other)
88        }
89
90        fn dyn_hash(&self, mut state: &mut dyn Hasher) {
91            Hash::hash(self, &mut state)
92        }
93    }
94
95    /// USB device information.
96    #[derive(Debug, Clone, PartialEq, Eq)]
97    #[non_exhaustive]
98    pub struct DeviceInfo {
99        /// Bus id.
100        pub bus_id: String,
101        /// Bus number. (Linux only)
102        #[cfg(any(target_os = "linux", target_os = "android"))]
103        pub busnum: u8,
104        /// Address.
105        pub address: u8,
106        /// USB port number chain.
107        pub port_numbers: Vec<u8>,
108        /// Vendor id.
109        pub vendor_id: u16,
110        /// Product id.
111        pub product_id: u16,
112        /// Class code.
113        pub class_code: u8,
114        /// Sub class code.
115        pub sub_class_code: u8,
116        /// Protocol code.
117        pub protocol_code: u8,
118        /// Device version.
119        pub version: u16,
120        /// USB version.
121        pub usb_version: u16,
122        /// Manufacturer.
123        pub manufacturer: Option<String>,
124        /// Product.
125        pub product: Option<String>,
126        /// Serial number.
127        pub serial_number: Option<String>,
128    }
129
130    /// USB interface information.
131    #[derive(Debug, Clone, PartialEq, Eq)]
132    #[non_exhaustive]
133    pub struct InterfaceInfo {
134        /// Interface number.
135        pub number: u8,
136        /// Class code.
137        pub class_code: u8,
138        /// Sub class code.
139        pub sub_class_code: u8,
140        /// Protocol code.
141        pub protocol_code: u8,
142        /// Description.
143        pub description: Option<String>,
144    }
145
146    type FilterFn = Box<dyn Fn(&DeviceInfo, &InterfaceInfo) -> bool + Send + Sync>;
147
148    /// USB transport for outgoing connections.
149    ///
150    /// This transport is packet-based.
151    pub struct UsbConnector {
152        filter: FilterFn,
153        hotplug: Option<Mutex<nusb::hotplug::HotplugWatch>>,
154    }
155
156    impl fmt::Debug for UsbConnector {
157        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
158            f.debug_struct("UsbConnector").finish_non_exhaustive()
159        }
160    }
161
162    impl UsbConnector {
163        /// Creates a new USB transport for outgoing connections.
164        ///
165        /// The `filter` function is called for each discovered USB device and should return `true` if the target
166        /// USB device and interface is matched.
167        ///
168        /// USB devices are re-enumerated when a hotplug event occurs, or, if hotplug events are unsupported
169        /// by the operating system, periodically.
170        pub fn new(filter: impl Fn(&DeviceInfo, &InterfaceInfo) -> bool + Send + Sync + 'static) -> Result<Self> {
171            let hotplug = match nusb::watch_devices() {
172                Ok(hotplug) => Some(Mutex::new(hotplug)),
173                Err(err) => {
174                    tracing::warn!(%err, "USB hotplug detection not available");
175                    None
176                }
177            };
178
179            Ok(Self { filter: Box::new(filter), hotplug })
180        }
181
182        async fn probe_device(&self, dev_info: &nusb::DeviceInfo) -> Result<Vec<OutgoingUsbLinkTag>> {
183            let dev = dev_info.open().await?;
184            let cfg = dev.active_configuration()?;
185            let desc = dev.device_descriptor();
186
187            let lang = match dev.get_string_descriptor_supported_languages(TIMEOUT).await {
188                Ok(mut langs) => langs.next(),
189                Err(err) => {
190                    tracing::warn!(%err, "cannot get string descriptor languages");
191                    None
192                }
193            };
194
195            let read_desc = async |desc_index| {
196                let desc_index = desc_index?;
197                match lang {
198                    Some(lang) => match dev.get_string_descriptor(desc_index, lang, TIMEOUT).await {
199                        Ok(s) => Some(s),
200                        Err(err) => {
201                            tracing::warn!(%err, "cannot read string descriptor {desc_index}");
202                            None
203                        }
204                    },
205                    None => None,
206                }
207            };
208
209            let device_info = DeviceInfo {
210                bus_id: dev_info.bus_id().to_string(),
211                #[cfg(any(target_os = "linux", target_os = "android"))]
212                busnum: dev_info.busnum(),
213                address: dev_info.device_address(),
214                port_numbers: dev_info.port_chain().to_vec(),
215                vendor_id: dev_info.vendor_id(),
216                product_id: dev_info.product_id(),
217                class_code: dev_info.class(),
218                sub_class_code: dev_info.subclass(),
219                protocol_code: dev_info.protocol(),
220                version: dev_info.device_version(),
221                usb_version: dev_info.usb_version(),
222                manufacturer: read_desc(desc.manufacturer_string_index()).await,
223                product: read_desc(desc.product_string_index()).await,
224                serial_number: read_desc(desc.serial_number_string_index()).await,
225            };
226
227            let mut tags = Vec::new();
228
229            for iface in cfg.interfaces() {
230                let desc = iface.first_alt_setting();
231
232                let interface_info = InterfaceInfo {
233                    number: desc.interface_number(),
234                    class_code: desc.class(),
235                    sub_class_code: desc.subclass(),
236                    protocol_code: desc.protocol(),
237                    description: read_desc(desc.string_index()).await,
238                };
239
240                if (self.filter)(&device_info, &interface_info) {
241                    tags.push(OutgoingUsbLinkTag {
242                        bus_id: dev_info.bus_id().to_string(),
243                        address: dev_info.device_address(),
244                        interface: desc.interface_number(),
245                    });
246                }
247            }
248
249            Ok(tags)
250        }
251    }
252
253    #[async_trait]
254    impl ConnectingTransport for UsbConnector {
255        fn name(&self) -> &str {
256            NAME
257        }
258
259        async fn link_tags(&self, tx: watch::Sender<HashSet<LinkTagBox>>) -> Result<()> {
260            loop {
261                let mut tags = HashSet::new();
262                for dev_info in nusb::list_devices().await? {
263                    match self.probe_device(&dev_info).await {
264                        Ok(dev_tags) => {
265                            tags.extend(dev_tags.into_iter().map(|tag| Box::new(tag) as Box<dyn LinkTag>))
266                        }
267                        Err(err) => {
268                            tracing::trace!(
269                                bus_id =% dev_info.bus_id(), address =% dev_info.device_address(), %err,
270                                "cannot probe device"
271                            )
272                        }
273                    }
274                }
275
276                tx.send_if_modified(|v| {
277                    if *v != tags {
278                        *v = tags;
279                        true
280                    } else {
281                        false
282                    }
283                });
284
285                match &self.hotplug {
286                    Some(hotplug) => {
287                        let mut hotplug = hotplug.lock().await;
288                        tokio::select! {
289                            Some(_) = hotplug.next() => {
290                                tracing::debug!("USB devices changed");
291                                sleep(Duration::from_millis(100)).await;
292                                while let Some(Some(_)) = hotplug.next().now_or_never() {}
293
294                            }
295                            () = sleep(PROBE_INTERVAL) => (),
296                        }
297                    }
298                    None => sleep(PROBE_INTERVAL).await,
299                }
300            }
301        }
302
303        async fn connect(&self, tag: &dyn LinkTag) -> Result<StreamBox> {
304            let tag: &OutgoingUsbLinkTag = tag.as_any().downcast_ref().unwrap();
305
306            let Some(dev) = nusb::list_devices()
307                .await?
308                .find(|cand| cand.bus_id() == tag.bus_id && cand.device_address() == tag.address)
309            else {
310                return Err(Error::new(ErrorKind::NotFound, "USB device gone"));
311            };
312
313            let dev = dev.open().await?;
314            let (tx, rx) = upc::host::connect(dev, tag.interface, &[]).await?;
315
316            Ok(TxRxBox::new(tx.into_sink(), rx.into_stream()).into())
317        }
318    }
319}
320
321#[cfg(feature = "host")]
322#[cfg_attr(docsrs, doc(cfg(feature = "host")))]
323pub use host::*;
324
325#[cfg(feature = "device")]
326#[cfg_attr(docsrs, doc(cfg(feature = "device")))]
327mod device {
328    use aggligator::{control::Direction, io::TxRxBox};
329    use async_trait::async_trait;
330    use core::fmt;
331    use futures::TryStreamExt;
332    use std::{
333        any::Any,
334        cmp::Ordering,
335        ffi::{OsStr, OsString},
336        hash::{Hash, Hasher},
337        io::Result,
338    };
339    use tokio::sync::{mpsc, Mutex};
340    use upc::device::UpcFunction;
341
342    use aggligator::transport::{AcceptedStreamBox, AcceptingTransport, LinkTag, LinkTagBox};
343
344    use super::NAME;
345
346    pub use upc;
347    pub use usb_gadget;
348
349    /// Link tag for incoming USB link.
350    #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
351    pub struct IncomingUsbLinkTag {
352        /// USB device controller name.
353        pub udc: OsString,
354    }
355
356    impl fmt::Display for IncomingUsbLinkTag {
357        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
358            write!(f, "UDC <- {}", self.udc.to_string_lossy())
359        }
360    }
361
362    impl LinkTag for IncomingUsbLinkTag {
363        fn transport_name(&self) -> &str {
364            NAME
365        }
366
367        fn direction(&self) -> Direction {
368            Direction::Incoming
369        }
370
371        fn user_data(&self) -> Vec<u8> {
372            Vec::new()
373        }
374
375        fn as_any(&self) -> &dyn Any {
376            self
377        }
378
379        fn box_clone(&self) -> LinkTagBox {
380            Box::new(self.clone())
381        }
382
383        fn dyn_cmp(&self, other: &dyn LinkTag) -> Ordering {
384            let other = other.as_any().downcast_ref::<Self>().unwrap();
385            Ord::cmp(self, other)
386        }
387
388        fn dyn_hash(&self, mut state: &mut dyn Hasher) {
389            Hash::hash(self, &mut state)
390        }
391    }
392
393    /// USB transport for incoming connections.
394    ///
395    /// This transport is packet-based.
396    #[derive(Debug)]
397    pub struct UsbAcceptor {
398        upc_function: Mutex<UpcFunction>,
399        udc_name: OsString,
400    }
401
402    impl UsbAcceptor {
403        /// Creates a new USB transport accepting incoming connections from `upc_function`.
404        ///
405        /// `udc_name` specifies the name of the USB device controller (UDC).
406        pub fn new(upc_function: UpcFunction, udc_name: impl AsRef<OsStr>) -> Self {
407            Self { upc_function: Mutex::new(upc_function), udc_name: udc_name.as_ref().to_os_string() }
408        }
409    }
410
411    #[async_trait]
412    impl AcceptingTransport for UsbAcceptor {
413        fn name(&self) -> &str {
414            NAME
415        }
416
417        async fn listen(&self, conn_tx: mpsc::Sender<AcceptedStreamBox>) -> Result<()> {
418            let mut upc_function = self.upc_function.lock().await;
419
420            loop {
421                let (tx, rx) = upc_function.accept().await?;
422                let tx_rx = TxRxBox::new(tx.into_sink(), rx.into_stream().map_ok(|p| p.freeze()));
423
424                let tag = IncomingUsbLinkTag { udc: self.udc_name.clone() };
425
426                if conn_tx.send(AcceptedStreamBox::new(tx_rx.into(), tag)).await.is_err() {
427                    break;
428                }
429            }
430
431            Ok(())
432        }
433    }
434}
435
436#[cfg(feature = "device")]
437#[cfg_attr(docsrs, doc(cfg(feature = "device")))]
438pub use device::*;