use std::future::Future;
use nusb::{
transfer::{Queue, RequestBuffer, TransferError},
DeviceInfo,
};
use postcard_schema::Schema;
use serde::de::DeserializeOwned;
use crate::{
header::VarSeqKind,
host_client::{HostClient, WireRx, WireSpawn, WireTx},
};
pub(crate) const BULK_OUT_EP: u8 = 0x01;
pub(crate) const BULK_IN_EP: u8 = 0x81;
pub(crate) const MAX_TRANSFER_SIZE: usize = 1024;
pub(crate) const IN_FLIGHT_REQS: usize = 4;
pub(crate) const MAX_STALL_RETRIES: usize = 10;
impl<WireErr> HostClient<WireErr>
where
WireErr: DeserializeOwned + Schema,
{
pub fn try_new_raw_nusb<F: FnMut(&DeviceInfo) -> bool>(
func: F,
err_uri_path: &str,
outgoing_depth: usize,
seq_no_kind: VarSeqKind,
) -> Result<Self, String> {
let x = nusb::list_devices()
.map_err(|e| format!("Error listing devices: {e:?}"))?
.find(func)
.ok_or_else(|| String::from("Failed to find matching nusb device!"))?;
let dev = x
.open()
.map_err(|e| format!("Failed opening device: {e:?}"))?;
let interface = dev
.claim_interface(0)
.map_err(|e| format!("Failed claiming interface: {e:?}"))?;
let boq = interface.bulk_out_queue(BULK_OUT_EP);
let biq = interface.bulk_in_queue(BULK_IN_EP);
Ok(HostClient::new_with_wire(
NusbWireTx { boq },
NusbWireRx {
biq,
consecutive_errs: 0,
},
NusbSpawn,
seq_no_kind,
err_uri_path,
outgoing_depth,
))
}
pub fn new_raw_nusb<F: FnMut(&DeviceInfo) -> bool>(
func: F,
err_uri_path: &str,
outgoing_depth: usize,
seq_no_kind: VarSeqKind,
) -> Self {
Self::try_new_raw_nusb(func, err_uri_path, outgoing_depth, seq_no_kind)
.expect("should have found nusb device")
}
}
struct NusbSpawn;
impl WireSpawn for NusbSpawn {
fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
core::mem::drop(tokio::task::spawn(fut));
}
}
struct NusbWireTx {
boq: Queue<Vec<u8>>,
}
#[derive(thiserror::Error, Debug)]
enum NusbWireTxError {
#[error("Transfer Error on Send")]
Transfer(#[from] TransferError),
}
impl WireTx for NusbWireTx {
type Error = NusbWireTxError;
#[inline]
fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>> + Send {
self.send_inner(data)
}
}
impl NusbWireTx {
async fn send_inner(&mut self, data: Vec<u8>) -> Result<(), NusbWireTxError> {
self.boq.submit(data);
let send_res = self.boq.next_complete().await;
if let Err(e) = send_res.status {
tracing::error!("Output Queue Error: {e:?}");
return Err(e.into());
}
Ok(())
}
}
struct NusbWireRx {
biq: Queue<RequestBuffer>,
consecutive_errs: usize,
}
#[derive(thiserror::Error, Debug)]
enum NusbWireRxError {
#[error("Transfer Error on Recv")]
Transfer(#[from] TransferError),
}
impl WireRx for NusbWireRx {
type Error = NusbWireRxError;
#[inline]
fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>> + Send {
self.recv_inner()
}
}
impl NusbWireRx {
async fn recv_inner(&mut self) -> Result<Vec<u8>, NusbWireRxError> {
loop {
let pending = self.biq.pending();
for _ in 0..(IN_FLIGHT_REQS.saturating_sub(pending)) {
self.biq.submit(RequestBuffer::new(MAX_TRANSFER_SIZE));
}
let res = self.biq.next_complete().await;
if let Err(e) = res.status {
self.consecutive_errs += 1;
tracing::error!(
"In Worker error: {e:?}, consecutive: {}",
self.consecutive_errs
);
let recoverable = match e {
TransferError::Stall | TransferError::Unknown => {
self.consecutive_errs <= MAX_STALL_RETRIES
}
TransferError::Cancelled => false,
TransferError::Disconnected => false,
TransferError::Fault => false,
};
let fatal = if recoverable {
tracing::warn!("Attempting stall recovery!");
self.biq.cancel_all();
tracing::info!("Cancelled all in-flight requests");
for _ in 0..(IN_FLIGHT_REQS - 1) {
let res = self.biq.next_complete().await;
tracing::info!("Drain state: {:?}", res.status);
}
match self.biq.clear_halt() {
Ok(()) => false,
Err(e) => {
tracing::error!("Failed to clear stall: {e:?}, Fatal.");
true
}
}
} else {
tracing::error!(
"Giving up after {} errors in a row, final error: {e:?}",
self.consecutive_errs
);
true
};
if fatal {
tracing::error!("Fatal Error, exiting");
return Err(e.into());
} else {
tracing::info!("Potential recovery, resuming NusbWireRx::recv_inner");
continue;
}
}
if self.consecutive_errs != 0 {
tracing::info!("Clearing consecutive error counter after good header decode");
self.consecutive_errs = 0;
}
return Ok(res.data);
}
}
}