use std::{
future::Future,
marker::PhantomData,
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
};
use maitake_sync::{
wait_map::{WaitError, WakeOutcome},
WaitMap,
};
use postcard_schema::Schema;
use serde::{de::DeserializeOwned, Serialize};
use tokio::{
select,
sync::mpsc::{Receiver, Sender},
};
use crate::{Endpoint, Key, Topic, WireHeader};
use self::util::Stopper;
#[cfg(all(feature = "raw-nusb", not(target_family = "wasm")))]
mod raw_nusb;
#[cfg(all(feature = "cobs-serial", not(target_family = "wasm")))]
mod serial;
#[cfg(all(feature = "webusb", target_family = "wasm"))]
pub mod webusb;
pub(crate) mod util;
#[derive(Debug, PartialEq)]
pub enum HostErr<WireErr> {
Wire(WireErr),
BadResponse,
Postcard(postcard::Error),
Closed,
}
impl<T> From<postcard::Error> for HostErr<T> {
fn from(value: postcard::Error) -> Self {
Self::Postcard(value)
}
}
impl<T> From<WaitError> for HostErr<T> {
fn from(_: WaitError) -> Self {
Self::Closed
}
}
#[cfg(target_family = "wasm")]
pub trait WireTx: 'static {
type Error: std::error::Error; fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>>;
}
#[cfg(target_family = "wasm")]
pub trait WireRx: 'static {
type Error: std::error::Error; fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>>;
}
#[cfg(target_family = "wasm")]
pub trait WireSpawn: 'static {
fn spawn(&mut self, fut: impl Future<Output = ()> + 'static);
}
#[cfg(not(target_family = "wasm"))]
pub trait WireTx: Send + 'static {
type Error: std::error::Error; fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>> + Send;
}
#[cfg(not(target_family = "wasm"))]
pub trait WireRx: Send + 'static {
type Error: std::error::Error; fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>> + Send;
}
#[cfg(not(target_family = "wasm"))]
pub trait WireSpawn: 'static {
fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static);
}
pub struct HostClient<WireErr> {
ctx: Arc<HostContext>,
out: Sender<RpcFrame>,
subber: Sender<SubInfo>,
err_key: Key,
stopper: Stopper,
_pd: PhantomData<fn() -> WireErr>,
}
impl<WireErr> HostClient<WireErr>
where
WireErr: DeserializeOwned + Schema,
{
#[deprecated = "HostClient::new_manual will become private in the future, use HostClient::new_with_wire instead"]
pub fn new_manual(err_uri_path: &str, outgoing_depth: usize) -> (Self, WireContext) {
Self::new_manual_priv(err_uri_path, outgoing_depth)
}
pub(crate) fn new_manual_priv(
err_uri_path: &str,
outgoing_depth: usize,
) -> (Self, WireContext) {
let (tx_pc, rx_pc) = tokio::sync::mpsc::channel(outgoing_depth);
let (tx_si, rx_si) = tokio::sync::mpsc::channel(outgoing_depth);
let ctx = Arc::new(HostContext {
map: WaitMap::new(),
seq: AtomicU32::new(0),
});
let err_key = Key::for_path::<WireErr>(err_uri_path);
let me = HostClient {
ctx: ctx.clone(),
out: tx_pc,
err_key,
_pd: PhantomData,
subber: tx_si.clone(),
stopper: Stopper::new(),
};
let wire = WireContext {
outgoing: rx_pc,
incoming: ctx,
new_subs: rx_si,
};
(me, wire)
}
}
impl<WireErr> HostClient<WireErr>
where
WireErr: DeserializeOwned + Schema,
{
pub async fn send_resp<E: Endpoint>(
&self,
t: &E::Request,
) -> Result<E::Response, HostErr<WireErr>>
where
E::Request: Serialize + Schema,
E::Response: DeserializeOwned + Schema,
{
let seq_no = self.ctx.seq.fetch_add(1, Ordering::Relaxed);
let msg = postcard::to_stdvec(&t).expect("Allocations should not ever fail");
let frame = RpcFrame {
header: WireHeader {
key: E::REQ_KEY,
seq_no,
},
body: msg,
};
let frame = self.send_resp_raw(frame, E::RESP_KEY).await?;
let r = postcard::from_bytes::<E::Response>(&frame.body)?;
Ok(r)
}
pub async fn send_resp_raw(
&self,
rqst: RpcFrame,
resp_key: Key,
) -> Result<RpcFrame, HostErr<WireErr>> {
let cancel_fut = self.stopper.wait_stopped();
let ok_resp = self.ctx.map.wait(WireHeader {
seq_no: rqst.header.seq_no,
key: resp_key,
});
let err_resp = self.ctx.map.wait(WireHeader {
seq_no: rqst.header.seq_no,
key: self.err_key,
});
let seq_no = rqst.header.seq_no;
self.out.send(rqst).await.map_err(|_| HostErr::Closed)?;
select! {
_c = cancel_fut => Err(HostErr::Closed),
o = ok_resp => {
let resp = o?;
Ok(RpcFrame { header: WireHeader { key: resp_key, seq_no }, body: resp })
},
e = err_resp => {
let resp = e?;
let r = postcard::from_bytes::<WireErr>(&resp)?;
Err(HostErr::Wire(r))
},
}
}
pub async fn publish<T: Topic>(&self, seq_no: u32, msg: &T::Message) -> Result<(), IoClosed>
where
T::Message: Serialize,
{
let smsg = postcard::to_stdvec(msg).expect("alloc should never fail");
let frame = RpcFrame {
header: WireHeader {
key: T::TOPIC_KEY,
seq_no,
},
body: smsg,
};
self.publish_raw(frame).await
}
pub async fn publish_raw(&self, frame: RpcFrame) -> Result<(), IoClosed> {
let cancel_fut = self.stopper.wait_stopped();
let operate_fut = self.out.send(frame);
select! {
_ = cancel_fut => Err(IoClosed),
res = operate_fut => res.map_err(|_| IoClosed),
}
}
pub async fn subscribe<T: Topic>(
&self,
depth: usize,
) -> Result<Subscription<T::Message>, IoClosed>
where
T::Message: DeserializeOwned,
{
let cancel_fut = self.stopper.wait_stopped();
let operate_fut = self.subscribe_inner::<T>(depth);
select! {
_ = cancel_fut => Err(IoClosed),
res = operate_fut => res,
}
}
async fn subscribe_inner<T: Topic>(
&self,
depth: usize,
) -> Result<Subscription<T::Message>, IoClosed>
where
T::Message: DeserializeOwned,
{
let (tx, rx) = tokio::sync::mpsc::channel(depth);
self.subber
.send(SubInfo {
key: T::TOPIC_KEY,
tx,
})
.await
.map_err(|_| IoClosed)?;
Ok(Subscription {
rx,
_pd: PhantomData,
})
}
pub async fn subscribe_raw(&self, key: Key, depth: usize) -> Result<RawSubscription, IoClosed> {
let cancel_fut = self.stopper.wait_stopped();
let operate_fut = self.subscribe_inner_raw(key, depth);
select! {
_ = cancel_fut => Err(IoClosed),
res = operate_fut => res,
}
}
async fn subscribe_inner_raw(
&self,
key: Key,
depth: usize,
) -> Result<RawSubscription, IoClosed> {
let (tx, rx) = tokio::sync::mpsc::channel(depth);
self.subber
.send(SubInfo { key, tx })
.await
.map_err(|_| IoClosed)?;
Ok(RawSubscription { rx })
}
pub fn close(&self) {
self.stopper.stop()
}
pub fn is_closed(&self) -> bool {
self.stopper.is_stopped()
}
pub async fn wait_closed(&self) {
self.stopper.wait_stopped().await;
}
}
pub struct RawSubscription {
rx: Receiver<RpcFrame>,
}
impl RawSubscription {
pub async fn recv(&mut self) -> Option<RpcFrame> {
self.rx.recv().await
}
}
pub struct Subscription<M> {
rx: Receiver<RpcFrame>,
_pd: PhantomData<M>,
}
impl<M> Subscription<M>
where
M: DeserializeOwned,
{
pub async fn recv(&mut self) -> Option<M> {
loop {
let frame = self.rx.recv().await?;
if let Ok(m) = postcard::from_bytes(&frame.body) {
return Some(m);
}
}
}
}
impl<WireErr> Clone for HostClient<WireErr> {
fn clone(&self) -> Self {
Self {
ctx: self.ctx.clone(),
out: self.out.clone(),
err_key: self.err_key,
_pd: PhantomData,
subber: self.subber.clone(),
stopper: self.stopper.clone(),
}
}
}
pub struct SubInfo {
pub key: Key,
pub tx: Sender<RpcFrame>,
}
pub struct WireContext {
pub outgoing: Receiver<RpcFrame>,
pub incoming: Arc<HostContext>,
pub new_subs: Receiver<SubInfo>,
}
pub struct RpcFrame {
pub header: WireHeader,
pub body: Vec<u8>,
}
impl RpcFrame {
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = postcard::to_stdvec(&self.header).expect("Alloc should never fail");
out.extend_from_slice(&self.body);
out
}
}
pub struct HostContext {
map: WaitMap<WireHeader, Vec<u8>>,
seq: AtomicU32,
}
#[derive(Debug)]
pub struct IoClosed;
#[derive(Debug, PartialEq)]
pub enum ProcessError {
Closed,
}
impl HostContext {
pub fn process_did_wake(&self, frame: RpcFrame) -> Result<bool, ProcessError> {
match self.map.wake(&frame.header, frame.body) {
WakeOutcome::Woke => Ok(true),
WakeOutcome::NoMatch(_) => Ok(false),
WakeOutcome::Closed(_) => Err(ProcessError::Closed),
}
}
pub fn process(&self, frame: RpcFrame) -> Result<(), ProcessError> {
if let WakeOutcome::Closed(_) = self.map.wake(&frame.header, frame.body) {
Err(ProcessError::Closed)
} else {
Ok(())
}
}
}