use std::future::Future;
use nusb::{
transfer::{Direction, EndpointType, Queue, RequestBuffer, TransferError},
DeviceInfo,
};
use postcard_schema::Schema;
use serde::de::DeserializeOwned;
use crate::{
header::VarSeqKind,
host_client::{HostClient, WireRx, WireSpawn, WireTx},
};
pub(crate) const MAX_TRANSFER_SIZE: usize = 1024;
pub(crate) const IN_FLIGHT_REQS: usize = 4;
pub(crate) const MAX_STALL_RETRIES: usize = 10;
impl<WireErr> HostClient<WireErr>
where
WireErr: DeserializeOwned + Schema,
{
pub fn try_new_raw_nusb<F: FnMut(&DeviceInfo) -> bool>(
func: F,
err_uri_path: &str,
outgoing_depth: usize,
seq_no_kind: VarSeqKind,
) -> Result<Self, String> {
let x = nusb::list_devices()
.map_err(|e| format!("Error listing devices: {e:?}"))?
.find(func)
.ok_or_else(|| String::from("Failed to find matching nusb device!"))?;
#[cfg(not(target_os = "windows"))]
let interface_id = x
.interfaces()
.position(|i| i.class() == 0xFF)
.ok_or_else(|| String::from("Failed to find matching interface!!"))?;
#[cfg(target_os = "windows")]
let interface_id = 0;
Self::try_from_nusb_and_interface(
&x,
interface_id,
err_uri_path,
outgoing_depth,
seq_no_kind,
)
}
#[cfg(not(target_os = "windows"))]
pub fn try_new_raw_nusb_with_interface<
F1: FnMut(&DeviceInfo) -> bool,
F2: FnMut(&nusb::InterfaceInfo) -> bool,
>(
device_func: F1,
interface_func: F2,
err_uri_path: &str,
outgoing_depth: usize,
seq_no_kind: VarSeqKind,
) -> Result<Self, String> {
let x = nusb::list_devices()
.map_err(|e| format!("Error listing devices: {e:?}"))?
.find(device_func)
.ok_or_else(|| String::from("Failed to find matching nusb device!"))?;
let interface_id = x
.interfaces()
.position(interface_func)
.ok_or_else(|| String::from("Failed to find matching interface!!"))?;
Self::try_from_nusb_and_interface(
&x,
interface_id,
err_uri_path,
outgoing_depth,
seq_no_kind,
)
}
pub fn try_from_nusb_and_interface(
dev: &DeviceInfo,
interface_id: usize,
err_uri_path: &str,
outgoing_depth: usize,
seq_no_kind: VarSeqKind,
) -> Result<Self, String> {
let dev = dev
.open()
.map_err(|e| format!("Failed opening device: {e:?}"))?;
let interface = dev
.claim_interface(interface_id as u8)
.map_err(|e| format!("Failed claiming interface: {e:?}"))?;
let mut mps: Option<usize> = None;
let mut ep_in: Option<u8> = None;
let mut ep_out: Option<u8> = None;
for ias in interface.descriptors() {
for ep in ias
.endpoints()
.filter(|e| e.transfer_type() == EndpointType::Bulk)
{
match ep.direction() {
Direction::Out => {
mps = Some(match mps.take() {
Some(old) => old.min(ep.max_packet_size()),
None => ep.max_packet_size(),
});
ep_out = Some(ep.address());
}
Direction::In => ep_in = Some(ep.address()),
}
}
}
if let Some(max_packet_size) = &mps {
tracing::debug!(max_packet_size, "Detected max packet size");
} else {
tracing::warn!("Unable to detect Max Packet Size!");
};
let ep_out = ep_out.ok_or("Failed to find OUT EP")?;
tracing::debug!("OUT EP: {ep_out}");
let ep_in = ep_in.ok_or("Failed to find IN EP")?;
tracing::debug!("IN EP: {ep_in}");
let boq = interface.bulk_out_queue(ep_out);
let biq = interface.bulk_in_queue(ep_in);
Ok(HostClient::new_with_wire(
NusbWireTx {
boq,
max_packet_size: mps,
},
NusbWireRx {
biq,
consecutive_errs: 0,
},
NusbSpawn,
seq_no_kind,
err_uri_path,
outgoing_depth,
))
}
pub fn new_raw_nusb<F: FnMut(&DeviceInfo) -> bool>(
func: F,
err_uri_path: &str,
outgoing_depth: usize,
seq_no_kind: VarSeqKind,
) -> Self {
Self::try_new_raw_nusb(func, err_uri_path, outgoing_depth, seq_no_kind)
.expect("should have found nusb device")
}
}
struct NusbSpawn;
impl WireSpawn for NusbSpawn {
fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
core::mem::drop(tokio::task::spawn(fut));
}
}
struct NusbWireTx {
boq: Queue<Vec<u8>>,
max_packet_size: Option<usize>,
}
#[derive(thiserror::Error, Debug)]
enum NusbWireTxError {
#[error("Transfer Error on Send")]
Transfer(#[from] TransferError),
}
impl WireTx for NusbWireTx {
type Error = NusbWireTxError;
#[inline]
fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>> + Send {
self.send_inner(data)
}
}
impl NusbWireTx {
async fn send_inner(&mut self, data: Vec<u8>) -> Result<(), NusbWireTxError> {
let needs_zlp = if let Some(mps) = self.max_packet_size {
(data.len() % mps) == 0
} else {
true
};
self.boq.submit(data);
if needs_zlp {
self.boq.submit(vec![]);
}
let send_res = self.boq.next_complete().await;
if let Err(e) = send_res.status {
tracing::error!("Output Queue Error: {e:?}");
return Err(e.into());
}
if needs_zlp {
let send_res = self.boq.next_complete().await;
if let Err(e) = send_res.status {
tracing::error!("Output Queue Error: {e:?}");
return Err(e.into());
}
}
Ok(())
}
}
struct NusbWireRx {
biq: Queue<RequestBuffer>,
consecutive_errs: usize,
}
#[derive(thiserror::Error, Debug)]
enum NusbWireRxError {
#[error("Transfer Error on Recv")]
Transfer(#[from] TransferError),
}
impl WireRx for NusbWireRx {
type Error = NusbWireRxError;
#[inline]
fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>> + Send {
self.recv_inner()
}
}
impl NusbWireRx {
async fn recv_inner(&mut self) -> Result<Vec<u8>, NusbWireRxError> {
loop {
let pending = self.biq.pending();
for _ in 0..(IN_FLIGHT_REQS.saturating_sub(pending)) {
self.biq.submit(RequestBuffer::new(MAX_TRANSFER_SIZE));
}
let res = self.biq.next_complete().await;
if let Err(e) = res.status {
self.consecutive_errs += 1;
tracing::error!(
"In Worker error: {e:?}, consecutive: {}",
self.consecutive_errs
);
let recoverable = match e {
TransferError::Stall | TransferError::Unknown => {
self.consecutive_errs <= MAX_STALL_RETRIES
}
TransferError::Cancelled => false,
TransferError::Disconnected => false,
TransferError::Fault => false,
};
let fatal = if recoverable {
tracing::warn!("Attempting stall recovery!");
self.biq.cancel_all();
tracing::info!("Cancelled all in-flight requests");
for _ in 0..(IN_FLIGHT_REQS - 1) {
let res = self.biq.next_complete().await;
tracing::info!("Drain state: {:?}", res.status);
}
match self.biq.clear_halt() {
Ok(()) => false,
Err(e) => {
tracing::error!("Failed to clear stall: {e:?}, Fatal.");
true
}
}
} else {
tracing::error!(
"Giving up after {} errors in a row, final error: {e:?}",
self.consecutive_errs
);
true
};
if fatal {
tracing::error!("Fatal Error, exiting");
return Err(e.into());
} else {
tracing::info!("Potential recovery, resuming NusbWireRx::recv_inner");
continue;
}
}
if self.consecutive_errs != 0 {
tracing::info!("Clearing consecutive error counter after good header decode");
self.consecutive_errs = 0;
}
return Ok(res.data);
}
}
}