postcard_rpc/host_client/
raw_nusb.rs

1//! Implementation of transport using nusb
2
3use std::future::Future;
4
5use nusb::{
6    transfer::{Direction, EndpointType, Queue, RequestBuffer, TransferError},
7    DeviceInfo,
8};
9use postcard_schema::Schema;
10use serde::de::DeserializeOwned;
11
12use crate::{
13    header::VarSeqKind,
14    host_client::{HostClient, WireRx, WireSpawn, WireTx},
15};
16
17// TODO: These should all be configurable, PRs welcome
18
19/// The size in bytes of the largest possible IN transfer
20pub(crate) const MAX_TRANSFER_SIZE: usize = 1024;
21/// How many in-flight requests at once - allows nusb to keep pulling frames
22/// even if we haven't processed them host-side yet.
23pub(crate) const IN_FLIGHT_REQS: usize = 4;
24/// How many consecutive IN errors will we try to recover from before giving up?
25pub(crate) const MAX_STALL_RETRIES: usize = 10;
26
27/// # `nusb` Constructor Methods
28///
29/// These methods are used to create a new [HostClient] instance for use with `nusb` and
30/// USB bulk transfer encoding.
31///
32/// **Requires feature**: `raw-nusb`
33impl<WireErr> HostClient<WireErr>
34where
35    WireErr: DeserializeOwned + Schema,
36{
37    /// Try to create a new link using [`nusb`] for connectivity
38    ///
39    /// The provided function will be used to find a matching device. The first
40    /// matching device will be connected to. `err_uri_path` is
41    /// the path associated with the `WireErr` message type.
42    ///
43    /// Returns an error if no device could be found, or if there was an error
44    /// connecting to the device.
45    ///
46    /// This constructor is available when the `raw-nusb` feature is enabled.
47    ///
48    /// ## Platform specific support
49    ///
50    /// When using Windows, the WinUSB driver does not allow enumerating interfaces.
51    /// When on windows, this method will ALWAYS try to connect to interface zero.
52    /// This limitation may be removed in the future, and if so, will be changed to
53    /// look for the first interface with the class of 0xFF.
54    ///
55    /// ## Example
56    ///
57    /// ```rust,no_run
58    /// use postcard_rpc::host_client::HostClient;
59    /// use postcard_rpc::header::VarSeqKind;
60    /// use serde::{Serialize, Deserialize};
61    /// use postcard_schema::Schema;
62    ///
63    /// /// A "wire error" type your server can use to respond to any
64    /// /// kind of request, for example if deserializing a request fails
65    /// #[derive(Debug, PartialEq, Schema, Serialize, Deserialize)]
66    /// pub enum Error {
67    ///    SomethingBad
68    /// }
69    ///
70    /// let client = HostClient::<Error>::try_new_raw_nusb(
71    ///     // Find the first device with the serial 12345678
72    ///     |d| d.serial_number() == Some("12345678"),
73    ///     // the URI/path for `Error` messages
74    ///     "error",
75    ///     // Outgoing queue depth in messages
76    ///     8,
77    ///     // Use one-byte sequence numbers
78    ///     VarSeqKind::Seq1,
79    /// ).unwrap();
80    /// ```
81    pub fn try_new_raw_nusb<F: FnMut(&DeviceInfo) -> bool>(
82        func: F,
83        err_uri_path: &str,
84        outgoing_depth: usize,
85        seq_no_kind: VarSeqKind,
86    ) -> Result<Self, String> {
87        let x = nusb::list_devices()
88            .map_err(|e| format!("Error listing devices: {e:?}"))?
89            .find(func)
90            .ok_or_else(|| String::from("Failed to find matching nusb device!"))?;
91
92        // NOTE: We can't enumerate interfaces on Windows. For now, just use
93        // a hardcoded interface of zero instead of trying to find the right one
94        #[cfg(not(target_os = "windows"))]
95        let interface_id = x
96            .interfaces()
97            .position(|i| i.class() == 0xFF)
98            .ok_or_else(|| String::from("Failed to find matching interface!!"))?;
99
100        #[cfg(target_os = "windows")]
101        let interface_id = 0;
102
103        Self::try_from_nusb_and_interface(
104            &x,
105            interface_id,
106            err_uri_path,
107            outgoing_depth,
108            seq_no_kind,
109        )
110    }
111
112    /// Try to create a new link using [`nusb`] for connectivity
113    ///
114    /// The provided function will be used to find a matching device and interface. The first
115    /// matching device will be connected to. `err_uri_path` is
116    /// the path associated with the `WireErr` message type.
117    ///
118    /// Returns an error if no device or interface could be found, or if there was an error
119    /// connecting to the device or interface.
120    ///
121    /// This constructor is available when the `raw-nusb` feature is enabled.
122    ///
123    /// ## Platform specific support
124    ///
125    /// When using Windows, the WinUSB driver does not allow enumerating interfaces.
126    /// Therefore, this constructor is not available on windows. This limitation may
127    /// be removed in the future.
128    ///
129    /// ## Example
130    ///
131    /// ```rust,no_run
132    /// use postcard_rpc::host_client::HostClient;
133    /// use postcard_rpc::header::VarSeqKind;
134    /// use serde::{Serialize, Deserialize};
135    /// use postcard_schema::Schema;
136    ///
137    /// /// A "wire error" type your server can use to respond to any
138    /// /// kind of request, for example if deserializing a request fails
139    /// #[derive(Debug, PartialEq, Schema, Serialize, Deserialize)]
140    /// pub enum Error {
141    ///    SomethingBad
142    /// }
143    ///
144    /// let client = HostClient::<Error>::try_new_raw_nusb_with_interface(
145    ///     // Find the first device with the serial 12345678
146    ///     |d| d.serial_number() == Some("12345678"),
147    ///     // Find the "Vendor Specific" interface
148    ///     |i| i.class() == 0xFF,
149    ///     // the URI/path for `Error` messages
150    ///     "error",
151    ///     // Outgoing queue depth in messages
152    ///     8,
153    ///     // Use one-byte sequence numbers
154    ///     VarSeqKind::Seq1,
155    /// ).unwrap();
156    /// ```
157    #[cfg(not(target_os = "windows"))]
158    pub fn try_new_raw_nusb_with_interface<
159        F1: FnMut(&DeviceInfo) -> bool,
160        F2: FnMut(&nusb::InterfaceInfo) -> bool,
161    >(
162        device_func: F1,
163        interface_func: F2,
164        err_uri_path: &str,
165        outgoing_depth: usize,
166        seq_no_kind: VarSeqKind,
167    ) -> Result<Self, String> {
168        let x = nusb::list_devices()
169            .map_err(|e| format!("Error listing devices: {e:?}"))?
170            .find(device_func)
171            .ok_or_else(|| String::from("Failed to find matching nusb device!"))?;
172        let interface_id = x
173            .interfaces()
174            .position(interface_func)
175            .ok_or_else(|| String::from("Failed to find matching interface!!"))?;
176
177        Self::try_from_nusb_and_interface(
178            &x,
179            interface_id,
180            err_uri_path,
181            outgoing_depth,
182            seq_no_kind,
183        )
184    }
185
186    /// Try to create a new link using [`nusb`] for connectivity
187    ///
188    /// This will connect to the given device and interface. `err_uri_path` is
189    /// the path associated with the `WireErr` message type.
190    ///
191    /// Returns an error if there was an error connecting to the device or interface.
192    ///
193    /// This constructor is available when the `raw-nusb` feature is enabled.
194    ///
195    /// ## Example
196    ///
197    /// ```rust,no_run
198    /// use postcard_rpc::host_client::HostClient;
199    /// use postcard_rpc::header::VarSeqKind;
200    /// use serde::{Serialize, Deserialize};
201    /// use postcard_schema::Schema;
202    ///
203    /// /// A "wire error" type your server can use to respond to any
204    /// /// kind of request, for example if deserializing a request fails
205    /// #[derive(Debug, PartialEq, Schema, Serialize, Deserialize)]
206    /// pub enum Error {
207    ///    SomethingBad
208    /// }
209    ///
210    /// // Assume the first usb device is the one we're interested
211    /// let dev = nusb::list_devices().unwrap().next().unwrap();
212    /// let client = HostClient::<Error>::try_from_nusb_and_interface(
213    ///     // Device to open
214    ///     &dev,
215    ///     // Use the first interface (0)
216    ///     0,
217    ///     // the URI/path for `Error` messages
218    ///     "error",
219    ///     // Outgoing queue depth in messages
220    ///     8,
221    ///     // Use one-byte sequence numbers
222    ///     VarSeqKind::Seq1,
223    /// ).unwrap();
224    /// ```
225    pub fn try_from_nusb_and_interface(
226        dev: &DeviceInfo,
227        interface_id: usize,
228        err_uri_path: &str,
229        outgoing_depth: usize,
230        seq_no_kind: VarSeqKind,
231    ) -> Result<Self, String> {
232        let dev = dev
233            .open()
234            .map_err(|e| format!("Failed opening device: {e:?}"))?;
235        let interface = dev
236            .claim_interface(interface_id as u8)
237            .map_err(|e| format!("Failed claiming interface: {e:?}"))?;
238
239        let mut mps: Option<usize> = None;
240        let mut ep_in: Option<u8> = None;
241        let mut ep_out: Option<u8> = None;
242        for ias in interface.descriptors() {
243            for ep in ias
244                .endpoints()
245                .filter(|e| e.transfer_type() == EndpointType::Bulk)
246            {
247                match ep.direction() {
248                    Direction::Out => {
249                        mps = Some(match mps.take() {
250                            Some(old) => old.min(ep.max_packet_size()),
251                            None => ep.max_packet_size(),
252                        });
253                        ep_out = Some(ep.address());
254                    }
255                    Direction::In => ep_in = Some(ep.address()),
256                }
257            }
258        }
259
260        if let Some(max_packet_size) = &mps {
261            tracing::debug!(max_packet_size, "Detected max packet size");
262        } else {
263            tracing::warn!("Unable to detect Max Packet Size!");
264        };
265
266        let ep_out = ep_out.ok_or("Failed to find OUT EP")?;
267        tracing::debug!("OUT EP: {ep_out}");
268
269        let ep_in = ep_in.ok_or("Failed to find IN EP")?;
270        tracing::debug!("IN EP: {ep_in}");
271
272        let boq = interface.bulk_out_queue(ep_out);
273        let biq = interface.bulk_in_queue(ep_in);
274
275        Ok(HostClient::new_with_wire(
276            NusbWireTx {
277                boq,
278                max_packet_size: mps,
279            },
280            NusbWireRx {
281                biq,
282                consecutive_errs: 0,
283            },
284            NusbSpawn,
285            seq_no_kind,
286            err_uri_path,
287            outgoing_depth,
288        ))
289    }
290
291    /// Create a new link using [`nusb`] for connectivity
292    ///
293    /// Panics if connection fails. See [`Self::try_new_raw_nusb()`] for more details.
294    ///
295    /// This constructor is available when the `raw-nusb` feature is enabled.
296    ///
297    /// ## Example
298    ///
299    /// ```rust,no_run
300    /// use postcard_rpc::host_client::HostClient;
301    /// use postcard_rpc::header::VarSeqKind;
302    /// use serde::{Serialize, Deserialize};
303    /// use postcard_schema::Schema;
304    ///
305    /// /// A "wire error" type your server can use to respond to any
306    /// /// kind of request, for example if deserializing a request fails
307    /// #[derive(Debug, PartialEq, Schema, Serialize, Deserialize)]
308    /// pub enum Error {
309    ///    SomethingBad
310    /// }
311    ///
312    /// let client = HostClient::<Error>::new_raw_nusb(
313    ///     // Find the first device with the serial 12345678
314    ///     |d| d.serial_number() == Some("12345678"),
315    ///     // the URI/path for `Error` messages
316    ///     "error",
317    ///     // Outgoing queue depth in messages
318    ///     8,
319    ///     // Use one-byte sequence numbers
320    ///     VarSeqKind::Seq1,
321    /// );
322    /// ```
323    pub fn new_raw_nusb<F: FnMut(&DeviceInfo) -> bool>(
324        func: F,
325        err_uri_path: &str,
326        outgoing_depth: usize,
327        seq_no_kind: VarSeqKind,
328    ) -> Self {
329        Self::try_new_raw_nusb(func, err_uri_path, outgoing_depth, seq_no_kind)
330            .expect("should have found nusb device")
331    }
332}
333
334//////////////////////////////////////////////////////////////////////////////
335// Wire Interface Implementation
336//////////////////////////////////////////////////////////////////////////////
337
338/// NUSB Wire Interface Implementor
339///
340/// Uses Tokio for spawning tasks
341struct NusbSpawn;
342
343impl WireSpawn for NusbSpawn {
344    fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
345        // Explicitly drop the joinhandle as it impls Future and this makes
346        // clippy mad if you just let it drop implicitly
347        core::mem::drop(tokio::task::spawn(fut));
348    }
349}
350
351/// NUSB Wire Transmit Interface Implementor
352struct NusbWireTx {
353    boq: Queue<Vec<u8>>,
354    max_packet_size: Option<usize>,
355}
356
357#[derive(thiserror::Error, Debug)]
358enum NusbWireTxError {
359    #[error("Transfer Error on Send")]
360    Transfer(#[from] TransferError),
361}
362
363impl WireTx for NusbWireTx {
364    type Error = NusbWireTxError;
365
366    #[inline]
367    fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>> + Send {
368        self.send_inner(data)
369    }
370}
371
372impl NusbWireTx {
373    async fn send_inner(&mut self, data: Vec<u8>) -> Result<(), NusbWireTxError> {
374        let needs_zlp = if let Some(mps) = self.max_packet_size {
375            (data.len() % mps) == 0
376        } else {
377            true
378        };
379
380        self.boq.submit(data);
381
382        // Append ZLP if we are a multiple of max packet
383        if needs_zlp {
384            self.boq.submit(vec![]);
385        }
386
387        let send_res = self.boq.next_complete().await;
388        if let Err(e) = send_res.status {
389            tracing::error!("Output Queue Error: {e:?}");
390            return Err(e.into());
391        }
392
393        if needs_zlp {
394            let send_res = self.boq.next_complete().await;
395            if let Err(e) = send_res.status {
396                tracing::error!("Output Queue Error: {e:?}");
397                return Err(e.into());
398            }
399        }
400
401        Ok(())
402    }
403}
404
405/// NUSB Wire Receive Interface Implementor
406struct NusbWireRx {
407    biq: Queue<RequestBuffer>,
408    consecutive_errs: usize,
409}
410
411#[derive(thiserror::Error, Debug)]
412enum NusbWireRxError {
413    #[error("Transfer Error on Recv")]
414    Transfer(#[from] TransferError),
415}
416
417impl WireRx for NusbWireRx {
418    type Error = NusbWireRxError;
419
420    #[inline]
421    fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>> + Send {
422        self.recv_inner()
423    }
424}
425
426impl NusbWireRx {
427    async fn recv_inner(&mut self) -> Result<Vec<u8>, NusbWireRxError> {
428        loop {
429            // Rehydrate the queue
430            let pending = self.biq.pending();
431            for _ in 0..(IN_FLIGHT_REQS.saturating_sub(pending)) {
432                self.biq.submit(RequestBuffer::new(MAX_TRANSFER_SIZE));
433            }
434
435            let res = self.biq.next_complete().await;
436
437            if let Err(e) = res.status {
438                self.consecutive_errs += 1;
439
440                tracing::error!(
441                    "In Worker error: {e:?}, consecutive: {}",
442                    self.consecutive_errs
443                );
444
445                // Docs only recommend this for Stall, but it seems to work with
446                // UNKNOWN on MacOS as well, todo: look into why!
447                //
448                // Update: This stall condition seems to have been due to an errata in the
449                // STM32F4 USB hardware. See https://github.com/embassy-rs/embassy/pull/2823
450                //
451                // It is now questionable whether we should be doing this stall recovery at all,
452                // as it likely indicates an issue with the connected USB device
453                let recoverable = match e {
454                    TransferError::Stall | TransferError::Unknown => {
455                        self.consecutive_errs <= MAX_STALL_RETRIES
456                    }
457                    TransferError::Cancelled => false,
458                    TransferError::Disconnected => false,
459                    TransferError::Fault => false,
460                };
461
462                let fatal = if recoverable {
463                    tracing::warn!("Attempting stall recovery!");
464
465                    // Stall recovery shouldn't be used with in-flight requests, so
466                    // cancel them all. They'll still pop out of next_complete.
467                    self.biq.cancel_all();
468                    tracing::info!("Cancelled all in-flight requests");
469
470                    // Now we need to join all in flight requests
471                    for _ in 0..(IN_FLIGHT_REQS - 1) {
472                        let res = self.biq.next_complete().await;
473                        tracing::info!("Drain state: {:?}", res.status);
474                    }
475
476                    // Now we can mark the stall as clear
477                    match self.biq.clear_halt() {
478                        Ok(()) => false,
479                        Err(e) => {
480                            tracing::error!("Failed to clear stall: {e:?}, Fatal.");
481                            true
482                        }
483                    }
484                } else {
485                    tracing::error!(
486                        "Giving up after {} errors in a row, final error: {e:?}",
487                        self.consecutive_errs
488                    );
489                    true
490                };
491
492                if fatal {
493                    tracing::error!("Fatal Error, exiting");
494                    // When we close the channel, all pending receivers and subscribers
495                    // will be notified
496                    return Err(e.into());
497                } else {
498                    tracing::info!("Potential recovery, resuming NusbWireRx::recv_inner");
499                    continue;
500                }
501            }
502
503            // If we get a good decode, clear the error flag
504            if self.consecutive_errs != 0 {
505                tracing::info!("Clearing consecutive error counter after good header decode");
506                self.consecutive_errs = 0;
507            }
508
509            return Ok(res.data);
510        }
511    }
512}