postcard_rpc/host_client/raw_nusb.rs
1//! Implementation of transport using nusb
2
3use std::future::Future;
4
5use nusb::{
6 transfer::{Direction, EndpointType, Queue, RequestBuffer, TransferError},
7 DeviceInfo,
8};
9use postcard_schema::Schema;
10use serde::de::DeserializeOwned;
11
12use crate::{
13 header::VarSeqKind,
14 host_client::{HostClient, WireRx, WireSpawn, WireTx},
15};
16
17// TODO: These should all be configurable, PRs welcome
18
19/// The size in bytes of the largest possible IN transfer
20pub(crate) const MAX_TRANSFER_SIZE: usize = 1024;
21/// How many in-flight requests at once - allows nusb to keep pulling frames
22/// even if we haven't processed them host-side yet.
23pub(crate) const IN_FLIGHT_REQS: usize = 4;
24/// How many consecutive IN errors will we try to recover from before giving up?
25pub(crate) const MAX_STALL_RETRIES: usize = 10;
26
27/// # `nusb` Constructor Methods
28///
29/// These methods are used to create a new [HostClient] instance for use with `nusb` and
30/// USB bulk transfer encoding.
31///
32/// **Requires feature**: `raw-nusb`
33impl<WireErr> HostClient<WireErr>
34where
35 WireErr: DeserializeOwned + Schema,
36{
37 /// Try to create a new link using [`nusb`] for connectivity
38 ///
39 /// The provided function will be used to find a matching device. The first
40 /// matching device will be connected to. `err_uri_path` is
41 /// the path associated with the `WireErr` message type.
42 ///
43 /// Returns an error if no device could be found, or if there was an error
44 /// connecting to the device.
45 ///
46 /// This constructor is available when the `raw-nusb` feature is enabled.
47 ///
48 /// ## Platform specific support
49 ///
50 /// When using Windows, the WinUSB driver does not allow enumerating interfaces.
51 /// When on windows, this method will ALWAYS try to connect to interface zero.
52 /// This limitation may be removed in the future, and if so, will be changed to
53 /// look for the first interface with the class of 0xFF.
54 ///
55 /// ## Example
56 ///
57 /// ```rust,no_run
58 /// use postcard_rpc::host_client::HostClient;
59 /// use postcard_rpc::header::VarSeqKind;
60 /// use serde::{Serialize, Deserialize};
61 /// use postcard_schema::Schema;
62 ///
63 /// /// A "wire error" type your server can use to respond to any
64 /// /// kind of request, for example if deserializing a request fails
65 /// #[derive(Debug, PartialEq, Schema, Serialize, Deserialize)]
66 /// pub enum Error {
67 /// SomethingBad
68 /// }
69 ///
70 /// let client = HostClient::<Error>::try_new_raw_nusb(
71 /// // Find the first device with the serial 12345678
72 /// |d| d.serial_number() == Some("12345678"),
73 /// // the URI/path for `Error` messages
74 /// "error",
75 /// // Outgoing queue depth in messages
76 /// 8,
77 /// // Use one-byte sequence numbers
78 /// VarSeqKind::Seq1,
79 /// ).unwrap();
80 /// ```
81 pub fn try_new_raw_nusb<F: FnMut(&DeviceInfo) -> bool>(
82 func: F,
83 err_uri_path: &str,
84 outgoing_depth: usize,
85 seq_no_kind: VarSeqKind,
86 ) -> Result<Self, String> {
87 let x = nusb::list_devices()
88 .map_err(|e| format!("Error listing devices: {e:?}"))?
89 .find(func)
90 .ok_or_else(|| String::from("Failed to find matching nusb device!"))?;
91
92 // NOTE: We can't enumerate interfaces on Windows. For now, just use
93 // a hardcoded interface of zero instead of trying to find the right one
94 #[cfg(not(target_os = "windows"))]
95 let interface_id = x
96 .interfaces()
97 .position(|i| i.class() == 0xFF)
98 .ok_or_else(|| String::from("Failed to find matching interface!!"))?;
99
100 #[cfg(target_os = "windows")]
101 let interface_id = 0;
102
103 Self::try_from_nusb_and_interface(
104 &x,
105 interface_id,
106 err_uri_path,
107 outgoing_depth,
108 seq_no_kind,
109 )
110 }
111
112 /// Try to create a new link using [`nusb`] for connectivity
113 ///
114 /// The provided function will be used to find a matching device and interface. The first
115 /// matching device will be connected to. `err_uri_path` is
116 /// the path associated with the `WireErr` message type.
117 ///
118 /// Returns an error if no device or interface could be found, or if there was an error
119 /// connecting to the device or interface.
120 ///
121 /// This constructor is available when the `raw-nusb` feature is enabled.
122 ///
123 /// ## Platform specific support
124 ///
125 /// When using Windows, the WinUSB driver does not allow enumerating interfaces.
126 /// Therefore, this constructor is not available on windows. This limitation may
127 /// be removed in the future.
128 ///
129 /// ## Example
130 ///
131 /// ```rust,no_run
132 /// use postcard_rpc::host_client::HostClient;
133 /// use postcard_rpc::header::VarSeqKind;
134 /// use serde::{Serialize, Deserialize};
135 /// use postcard_schema::Schema;
136 ///
137 /// /// A "wire error" type your server can use to respond to any
138 /// /// kind of request, for example if deserializing a request fails
139 /// #[derive(Debug, PartialEq, Schema, Serialize, Deserialize)]
140 /// pub enum Error {
141 /// SomethingBad
142 /// }
143 ///
144 /// let client = HostClient::<Error>::try_new_raw_nusb_with_interface(
145 /// // Find the first device with the serial 12345678
146 /// |d| d.serial_number() == Some("12345678"),
147 /// // Find the "Vendor Specific" interface
148 /// |i| i.class() == 0xFF,
149 /// // the URI/path for `Error` messages
150 /// "error",
151 /// // Outgoing queue depth in messages
152 /// 8,
153 /// // Use one-byte sequence numbers
154 /// VarSeqKind::Seq1,
155 /// ).unwrap();
156 /// ```
157 #[cfg(not(target_os = "windows"))]
158 pub fn try_new_raw_nusb_with_interface<
159 F1: FnMut(&DeviceInfo) -> bool,
160 F2: FnMut(&nusb::InterfaceInfo) -> bool,
161 >(
162 device_func: F1,
163 interface_func: F2,
164 err_uri_path: &str,
165 outgoing_depth: usize,
166 seq_no_kind: VarSeqKind,
167 ) -> Result<Self, String> {
168 let x = nusb::list_devices()
169 .map_err(|e| format!("Error listing devices: {e:?}"))?
170 .find(device_func)
171 .ok_or_else(|| String::from("Failed to find matching nusb device!"))?;
172 let interface_id = x
173 .interfaces()
174 .position(interface_func)
175 .ok_or_else(|| String::from("Failed to find matching interface!!"))?;
176
177 Self::try_from_nusb_and_interface(
178 &x,
179 interface_id,
180 err_uri_path,
181 outgoing_depth,
182 seq_no_kind,
183 )
184 }
185
186 /// Try to create a new link using [`nusb`] for connectivity
187 ///
188 /// This will connect to the given device and interface. `err_uri_path` is
189 /// the path associated with the `WireErr` message type.
190 ///
191 /// Returns an error if there was an error connecting to the device or interface.
192 ///
193 /// This constructor is available when the `raw-nusb` feature is enabled.
194 ///
195 /// ## Example
196 ///
197 /// ```rust,no_run
198 /// use postcard_rpc::host_client::HostClient;
199 /// use postcard_rpc::header::VarSeqKind;
200 /// use serde::{Serialize, Deserialize};
201 /// use postcard_schema::Schema;
202 ///
203 /// /// A "wire error" type your server can use to respond to any
204 /// /// kind of request, for example if deserializing a request fails
205 /// #[derive(Debug, PartialEq, Schema, Serialize, Deserialize)]
206 /// pub enum Error {
207 /// SomethingBad
208 /// }
209 ///
210 /// // Assume the first usb device is the one we're interested
211 /// let dev = nusb::list_devices().unwrap().next().unwrap();
212 /// let client = HostClient::<Error>::try_from_nusb_and_interface(
213 /// // Device to open
214 /// &dev,
215 /// // Use the first interface (0)
216 /// 0,
217 /// // the URI/path for `Error` messages
218 /// "error",
219 /// // Outgoing queue depth in messages
220 /// 8,
221 /// // Use one-byte sequence numbers
222 /// VarSeqKind::Seq1,
223 /// ).unwrap();
224 /// ```
225 pub fn try_from_nusb_and_interface(
226 dev: &DeviceInfo,
227 interface_id: usize,
228 err_uri_path: &str,
229 outgoing_depth: usize,
230 seq_no_kind: VarSeqKind,
231 ) -> Result<Self, String> {
232 let dev = dev
233 .open()
234 .map_err(|e| format!("Failed opening device: {e:?}"))?;
235 let interface = dev
236 .claim_interface(interface_id as u8)
237 .map_err(|e| format!("Failed claiming interface: {e:?}"))?;
238
239 let mut mps: Option<usize> = None;
240 let mut ep_in: Option<u8> = None;
241 let mut ep_out: Option<u8> = None;
242 for ias in interface.descriptors() {
243 for ep in ias
244 .endpoints()
245 .filter(|e| e.transfer_type() == EndpointType::Bulk)
246 {
247 match ep.direction() {
248 Direction::Out => {
249 mps = Some(match mps.take() {
250 Some(old) => old.min(ep.max_packet_size()),
251 None => ep.max_packet_size(),
252 });
253 ep_out = Some(ep.address());
254 }
255 Direction::In => ep_in = Some(ep.address()),
256 }
257 }
258 }
259
260 if let Some(max_packet_size) = &mps {
261 tracing::debug!(max_packet_size, "Detected max packet size");
262 } else {
263 tracing::warn!("Unable to detect Max Packet Size!");
264 };
265
266 let ep_out = ep_out.ok_or("Failed to find OUT EP")?;
267 tracing::debug!("OUT EP: {ep_out}");
268
269 let ep_in = ep_in.ok_or("Failed to find IN EP")?;
270 tracing::debug!("IN EP: {ep_in}");
271
272 let boq = interface.bulk_out_queue(ep_out);
273 let biq = interface.bulk_in_queue(ep_in);
274
275 Ok(HostClient::new_with_wire(
276 NusbWireTx {
277 boq,
278 max_packet_size: mps,
279 },
280 NusbWireRx {
281 biq,
282 consecutive_errs: 0,
283 },
284 NusbSpawn,
285 seq_no_kind,
286 err_uri_path,
287 outgoing_depth,
288 ))
289 }
290
291 /// Create a new link using [`nusb`] for connectivity
292 ///
293 /// Panics if connection fails. See [`Self::try_new_raw_nusb()`] for more details.
294 ///
295 /// This constructor is available when the `raw-nusb` feature is enabled.
296 ///
297 /// ## Example
298 ///
299 /// ```rust,no_run
300 /// use postcard_rpc::host_client::HostClient;
301 /// use postcard_rpc::header::VarSeqKind;
302 /// use serde::{Serialize, Deserialize};
303 /// use postcard_schema::Schema;
304 ///
305 /// /// A "wire error" type your server can use to respond to any
306 /// /// kind of request, for example if deserializing a request fails
307 /// #[derive(Debug, PartialEq, Schema, Serialize, Deserialize)]
308 /// pub enum Error {
309 /// SomethingBad
310 /// }
311 ///
312 /// let client = HostClient::<Error>::new_raw_nusb(
313 /// // Find the first device with the serial 12345678
314 /// |d| d.serial_number() == Some("12345678"),
315 /// // the URI/path for `Error` messages
316 /// "error",
317 /// // Outgoing queue depth in messages
318 /// 8,
319 /// // Use one-byte sequence numbers
320 /// VarSeqKind::Seq1,
321 /// );
322 /// ```
323 pub fn new_raw_nusb<F: FnMut(&DeviceInfo) -> bool>(
324 func: F,
325 err_uri_path: &str,
326 outgoing_depth: usize,
327 seq_no_kind: VarSeqKind,
328 ) -> Self {
329 Self::try_new_raw_nusb(func, err_uri_path, outgoing_depth, seq_no_kind)
330 .expect("should have found nusb device")
331 }
332}
333
334//////////////////////////////////////////////////////////////////////////////
335// Wire Interface Implementation
336//////////////////////////////////////////////////////////////////////////////
337
338/// NUSB Wire Interface Implementor
339///
340/// Uses Tokio for spawning tasks
341struct NusbSpawn;
342
343impl WireSpawn for NusbSpawn {
344 fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
345 // Explicitly drop the joinhandle as it impls Future and this makes
346 // clippy mad if you just let it drop implicitly
347 core::mem::drop(tokio::task::spawn(fut));
348 }
349}
350
351/// NUSB Wire Transmit Interface Implementor
352struct NusbWireTx {
353 boq: Queue<Vec<u8>>,
354 max_packet_size: Option<usize>,
355}
356
357#[derive(thiserror::Error, Debug)]
358enum NusbWireTxError {
359 #[error("Transfer Error on Send")]
360 Transfer(#[from] TransferError),
361}
362
363impl WireTx for NusbWireTx {
364 type Error = NusbWireTxError;
365
366 #[inline]
367 fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>> + Send {
368 self.send_inner(data)
369 }
370}
371
372impl NusbWireTx {
373 async fn send_inner(&mut self, data: Vec<u8>) -> Result<(), NusbWireTxError> {
374 let needs_zlp = if let Some(mps) = self.max_packet_size {
375 (data.len() % mps) == 0
376 } else {
377 true
378 };
379
380 self.boq.submit(data);
381
382 // Append ZLP if we are a multiple of max packet
383 if needs_zlp {
384 self.boq.submit(vec![]);
385 }
386
387 let send_res = self.boq.next_complete().await;
388 if let Err(e) = send_res.status {
389 tracing::error!("Output Queue Error: {e:?}");
390 return Err(e.into());
391 }
392
393 if needs_zlp {
394 let send_res = self.boq.next_complete().await;
395 if let Err(e) = send_res.status {
396 tracing::error!("Output Queue Error: {e:?}");
397 return Err(e.into());
398 }
399 }
400
401 Ok(())
402 }
403}
404
405/// NUSB Wire Receive Interface Implementor
406struct NusbWireRx {
407 biq: Queue<RequestBuffer>,
408 consecutive_errs: usize,
409}
410
411#[derive(thiserror::Error, Debug)]
412enum NusbWireRxError {
413 #[error("Transfer Error on Recv")]
414 Transfer(#[from] TransferError),
415}
416
417impl WireRx for NusbWireRx {
418 type Error = NusbWireRxError;
419
420 #[inline]
421 fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>> + Send {
422 self.recv_inner()
423 }
424}
425
426impl NusbWireRx {
427 async fn recv_inner(&mut self) -> Result<Vec<u8>, NusbWireRxError> {
428 loop {
429 // Rehydrate the queue
430 let pending = self.biq.pending();
431 for _ in 0..(IN_FLIGHT_REQS.saturating_sub(pending)) {
432 self.biq.submit(RequestBuffer::new(MAX_TRANSFER_SIZE));
433 }
434
435 let res = self.biq.next_complete().await;
436
437 if let Err(e) = res.status {
438 self.consecutive_errs += 1;
439
440 tracing::error!(
441 "In Worker error: {e:?}, consecutive: {}",
442 self.consecutive_errs
443 );
444
445 // Docs only recommend this for Stall, but it seems to work with
446 // UNKNOWN on MacOS as well, todo: look into why!
447 //
448 // Update: This stall condition seems to have been due to an errata in the
449 // STM32F4 USB hardware. See https://github.com/embassy-rs/embassy/pull/2823
450 //
451 // It is now questionable whether we should be doing this stall recovery at all,
452 // as it likely indicates an issue with the connected USB device
453 let recoverable = match e {
454 TransferError::Stall | TransferError::Unknown => {
455 self.consecutive_errs <= MAX_STALL_RETRIES
456 }
457 TransferError::Cancelled => false,
458 TransferError::Disconnected => false,
459 TransferError::Fault => false,
460 };
461
462 let fatal = if recoverable {
463 tracing::warn!("Attempting stall recovery!");
464
465 // Stall recovery shouldn't be used with in-flight requests, so
466 // cancel them all. They'll still pop out of next_complete.
467 self.biq.cancel_all();
468 tracing::info!("Cancelled all in-flight requests");
469
470 // Now we need to join all in flight requests
471 for _ in 0..(IN_FLIGHT_REQS - 1) {
472 let res = self.biq.next_complete().await;
473 tracing::info!("Drain state: {:?}", res.status);
474 }
475
476 // Now we can mark the stall as clear
477 match self.biq.clear_halt() {
478 Ok(()) => false,
479 Err(e) => {
480 tracing::error!("Failed to clear stall: {e:?}, Fatal.");
481 true
482 }
483 }
484 } else {
485 tracing::error!(
486 "Giving up after {} errors in a row, final error: {e:?}",
487 self.consecutive_errs
488 );
489 true
490 };
491
492 if fatal {
493 tracing::error!("Fatal Error, exiting");
494 // When we close the channel, all pending receivers and subscribers
495 // will be notified
496 return Err(e.into());
497 } else {
498 tracing::info!("Potential recovery, resuming NusbWireRx::recv_inner");
499 continue;
500 }
501 }
502
503 // If we get a good decode, clear the error flag
504 if self.consecutive_errs != 0 {
505 tracing::info!("Clearing consecutive error counter after good header decode");
506 self.consecutive_errs = 0;
507 }
508
509 return Ok(res.data);
510 }
511 }
512}