use core::time::Duration;
use std::{
collections::HashSet,
future::Future,
marker::PhantomData,
sync::{
atomic::{AtomicU32, Ordering},
Arc, RwLock,
},
};
use maitake_sync::{
wait_map::{WaitError, WakeOutcome},
WaitMap,
};
use postcard_schema::{schema::owned::OwnedNamedType, Schema};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::{
select,
sync::{
mpsc::{Receiver, Sender},
Mutex,
},
};
use util::Subscriptions;
use crate::{
header::{VarHeader, VarKey, VarKeyKind, VarSeq, VarSeqKind},
standard_icd::{GetAllSchemaDataTopic, GetAllSchemasEndpoint, OwnedSchemaData},
Endpoint, Key, Topic, TopicDirection,
};
use self::util::Stopper;
#[cfg(all(feature = "raw-nusb", not(target_family = "wasm")))]
mod raw_nusb;
#[cfg(all(feature = "cobs-serial", not(target_family = "wasm")))]
mod serial;
#[cfg(all(feature = "webusb", target_family = "wasm"))]
pub mod webusb;
pub(crate) mod util;
#[cfg(feature = "test-utils")]
pub mod test_channels;
#[derive(Debug, PartialEq)]
pub enum HostErr<WireErr> {
Wire(WireErr),
BadResponse,
Postcard(postcard::Error),
Closed,
}
impl<T> From<postcard::Error> for HostErr<T> {
fn from(value: postcard::Error) -> Self {
Self::Postcard(value)
}
}
impl<T> From<WaitError> for HostErr<T> {
fn from(_: WaitError) -> Self {
Self::Closed
}
}
#[cfg(target_family = "wasm")]
pub trait WireTx: 'static {
type Error: std::error::Error;
fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>>;
}
#[cfg(target_family = "wasm")]
pub trait WireRx: 'static {
type Error: std::error::Error; fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>>;
}
#[cfg(target_family = "wasm")]
pub trait WireSpawn: 'static {
fn spawn(&mut self, fut: impl Future<Output = ()> + 'static);
}
#[cfg(not(target_family = "wasm"))]
pub trait WireTx: Send + 'static {
type Error: std::error::Error;
fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>> + Send;
}
#[cfg(not(target_family = "wasm"))]
pub trait WireRx: Send + 'static {
type Error: std::error::Error;
fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>> + Send;
}
#[cfg(not(target_family = "wasm"))]
pub trait WireSpawn: 'static {
fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static);
}
pub struct HostClient<WireErr> {
ctx: Arc<HostContext>,
out: Sender<RpcFrame>,
subscriptions: Arc<Mutex<Subscriptions>>,
err_key: Key,
stopper: Stopper,
seq_kind: VarSeqKind,
_pd: PhantomData<fn() -> WireErr>,
}
impl<WireErr> HostClient<WireErr>
where
WireErr: DeserializeOwned + Schema,
{
pub(crate) fn new_manual_priv(
err_uri_path: &str,
outgoing_depth: usize,
seq_kind: VarSeqKind,
) -> (Self, WireContext) {
let (tx_pc, rx_pc) = tokio::sync::mpsc::channel(outgoing_depth);
let ctx = Arc::new(HostContext {
kkind: RwLock::new(VarKeyKind::Key8),
map: WaitMap::new(),
seq: AtomicU32::new(0),
});
let err_key = Key::for_path::<WireErr>(err_uri_path);
let me = HostClient {
ctx: ctx.clone(),
out: tx_pc,
err_key,
_pd: PhantomData,
subscriptions: Arc::new(Mutex::new(Subscriptions::default())),
stopper: Stopper::new(),
seq_kind,
};
let wire = WireContext {
outgoing: rx_pc,
incoming: ctx,
};
(me, wire)
}
}
#[derive(Debug)]
pub enum SchemaError<WireErr> {
Comms(HostErr<WireErr>),
TaskError,
InvalidReportData,
LostData,
}
impl<WireErr> From<UnableToFindType> for SchemaError<WireErr> {
fn from(_: UnableToFindType) -> Self {
Self::InvalidReportData
}
}
impl<WireErr> HostClient<WireErr>
where
WireErr: DeserializeOwned + Schema,
{
pub async fn get_schema_report(&self) -> Result<SchemaReport, SchemaError<WireErr>> {
let Ok(mut sub) = self.subscribe::<GetAllSchemaDataTopic>(64).await else {
return Err(SchemaError::Comms(HostErr::Closed));
};
let collect_task = tokio::task::spawn({
async move {
let mut got = vec![];
while let Ok(Some(val)) =
tokio::time::timeout(Duration::from_millis(100), sub.recv()).await
{
got.push(val);
}
got
}
});
let trigger_task = self.send_resp::<GetAllSchemasEndpoint>(&()).await;
let data = collect_task.await;
let (resp, data) = match (trigger_task, data) {
(Ok(a), Ok(b)) => (a, b),
(Ok(_), Err(_)) => return Err(SchemaError::TaskError),
(Err(e), Ok(_)) => return Err(SchemaError::Comms(e)),
(Err(e1), Err(_e2)) => return Err(SchemaError::Comms(e1)),
};
let mut rpt = SchemaReport::default();
let mut e_and_t = vec![];
for d in data {
match d {
OwnedSchemaData::Type(d) => {
rpt.add_type(d);
}
e @ OwnedSchemaData::Endpoint { .. } => e_and_t.push(e),
t @ OwnedSchemaData::Topic { .. } => e_and_t.push(t),
}
}
for e in e_and_t {
match e {
OwnedSchemaData::Type(_) => unreachable!(),
OwnedSchemaData::Endpoint {
path,
request_key,
response_key,
} => {
rpt.add_endpoint(path, request_key, response_key)?;
}
OwnedSchemaData::Topic {
path,
key,
direction,
} => match direction {
TopicDirection::ToServer => rpt.add_topic_in(path, key)?,
TopicDirection::ToClient => rpt.add_topic_out(path, key)?,
},
}
}
let mut data_matches = true;
data_matches &= resp.endpoints_sent as usize == rpt.endpoints.len();
data_matches &= resp.topics_in_sent as usize == rpt.topics_in.len();
data_matches &= resp.topics_out_sent as usize == rpt.topics_out.len();
data_matches &= resp.errors == 0;
if data_matches {
Ok(rpt)
} else {
Err(SchemaError::LostData)
}
}
pub async fn send_resp<E: Endpoint>(
&self,
t: &E::Request,
) -> Result<E::Response, HostErr<WireErr>>
where
E::Request: Serialize + Schema,
E::Response: DeserializeOwned + Schema,
{
let seq_no = self.ctx.seq.fetch_add(1, Ordering::Relaxed);
let msg = postcard::to_stdvec(&t).expect("Allocations should not ever fail");
let frame = RpcFrame {
header: VarHeader {
key: VarKey::Key8(E::REQ_KEY),
seq_no: VarSeq::Seq4(seq_no),
},
body: msg,
};
let frame = self.send_resp_raw(frame, E::RESP_KEY).await?;
let r = postcard::from_bytes::<E::Response>(&frame.body)?;
Ok(r)
}
pub async fn send_resp_raw(
&self,
mut rqst: RpcFrame,
resp_key: Key,
) -> Result<RpcFrame, HostErr<WireErr>> {
let cancel_fut = self.stopper.wait_stopped();
let kkind: VarKeyKind = *self.ctx.kkind.read().unwrap();
rqst.header.key.shrink_to(kkind);
rqst.header.seq_no.resize(self.seq_kind);
let mut resp_key = VarKey::Key8(resp_key);
let mut err_key = VarKey::Key8(self.err_key);
resp_key.shrink_to(kkind);
err_key.shrink_to(kkind);
let ok_resp = self.ctx.map.wait(VarHeader {
seq_no: rqst.header.seq_no,
key: resp_key,
});
let err_resp = self.ctx.map.wait(VarHeader {
seq_no: rqst.header.seq_no,
key: err_key,
});
self.out.send(rqst).await.map_err(|_| HostErr::Closed)?;
select! {
_c = cancel_fut => Err(HostErr::Closed),
o = ok_resp => {
let (hdr, resp) = o?;
if hdr.key.kind() != kkind {
*self.ctx.kkind.write().unwrap() = hdr.key.kind();
}
Ok(RpcFrame { header: hdr, body: resp })
},
e = err_resp => {
let (hdr, resp) = e?;
if hdr.key.kind() != kkind {
*self.ctx.kkind.write().unwrap() = hdr.key.kind();
}
let r = postcard::from_bytes::<WireErr>(&resp)?;
Err(HostErr::Wire(r))
},
}
}
pub async fn publish<T: Topic>(&self, seq_no: VarSeq, msg: &T::Message) -> Result<(), IoClosed>
where
T::Message: Serialize,
{
let smsg = postcard::to_stdvec(msg).expect("alloc should never fail");
let frame = RpcFrame {
header: VarHeader {
key: VarKey::Key8(T::TOPIC_KEY),
seq_no,
},
body: smsg,
};
self.publish_raw(frame).await
}
pub async fn publish_raw(&self, mut frame: RpcFrame) -> Result<(), IoClosed> {
let kkind: VarKeyKind = *self.ctx.kkind.read().unwrap();
frame.header.key.shrink_to(kkind);
frame.header.seq_no.resize(self.seq_kind);
let cancel_fut = self.stopper.wait_stopped();
let operate_fut = self.out.send(frame);
select! {
_ = cancel_fut => Err(IoClosed),
res = operate_fut => res.map_err(|_| IoClosed),
}
}
pub async fn subscribe<T: Topic>(
&self,
depth: usize,
) -> Result<Subscription<T::Message>, IoClosed>
where
T::Message: DeserializeOwned,
{
let cancel_fut = self.stopper.wait_stopped();
let operate_fut = self.subscribe_inner::<T>(depth);
select! {
_ = cancel_fut => Err(IoClosed),
res = operate_fut => res,
}
}
async fn subscribe_inner<T: Topic>(
&self,
depth: usize,
) -> Result<Subscription<T::Message>, IoClosed>
where
T::Message: DeserializeOwned,
{
let (tx, rx) = tokio::sync::mpsc::channel(depth);
{
let mut guard = self.subscriptions.lock().await;
if guard.stopped {
return Err(IoClosed);
}
if let Some(entry) = guard.list.iter_mut().find(|(k, _)| *k == T::TOPIC_KEY) {
tracing::warn!("replacing subscription for topic path '{}'", T::PATH);
entry.1 = tx;
} else {
guard.list.push((T::TOPIC_KEY, tx));
}
}
Ok(Subscription {
rx,
_pd: PhantomData,
})
}
pub async fn subscribe_raw(&self, key: Key, depth: usize) -> Result<RawSubscription, IoClosed> {
let cancel_fut = self.stopper.wait_stopped();
let operate_fut = self.subscribe_inner_raw(key, depth);
select! {
_ = cancel_fut => Err(IoClosed),
res = operate_fut => res,
}
}
async fn subscribe_inner_raw(
&self,
key: Key,
depth: usize,
) -> Result<RawSubscription, IoClosed> {
let (tx, rx) = tokio::sync::mpsc::channel(depth);
{
let mut guard = self.subscriptions.lock().await;
if guard.stopped {
return Err(IoClosed);
}
if let Some(entry) = guard.list.iter_mut().find(|(k, _)| *k == key) {
tracing::warn!("replacing subscription for raw topic key '{:?}'", key);
entry.1 = tx;
} else {
guard.list.push((key, tx));
}
}
Ok(RawSubscription { rx })
}
pub fn close(&self) {
self.stopper.stop()
}
pub fn is_closed(&self) -> bool {
self.stopper.is_stopped()
}
pub async fn wait_closed(&self) {
self.stopper.wait_stopped().await;
}
}
pub struct RawSubscription {
rx: Receiver<RpcFrame>,
}
impl RawSubscription {
pub async fn recv(&mut self) -> Option<RpcFrame> {
self.rx.recv().await
}
}
pub struct Subscription<M> {
rx: Receiver<RpcFrame>,
_pd: PhantomData<M>,
}
impl<M> Subscription<M>
where
M: DeserializeOwned,
{
pub async fn recv(&mut self) -> Option<M> {
loop {
let frame = self.rx.recv().await?;
if let Ok(m) = postcard::from_bytes(&frame.body) {
return Some(m);
}
}
}
}
impl<WireErr> Clone for HostClient<WireErr> {
fn clone(&self) -> Self {
Self {
ctx: self.ctx.clone(),
out: self.out.clone(),
err_key: self.err_key,
_pd: PhantomData,
subscriptions: self.subscriptions.clone(),
stopper: self.stopper.clone(),
seq_kind: self.seq_kind,
}
}
}
pub struct WireContext {
pub outgoing: Receiver<RpcFrame>,
pub incoming: Arc<HostContext>,
}
pub struct RpcFrame {
pub header: VarHeader,
pub body: Vec<u8>,
}
impl RpcFrame {
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = self.header.write_to_vec();
out.extend_from_slice(&self.body);
out
}
}
pub struct HostContext {
kkind: RwLock<VarKeyKind>,
map: WaitMap<VarHeader, (VarHeader, Vec<u8>)>,
seq: AtomicU32,
}
#[derive(Debug)]
pub struct IoClosed;
#[derive(Debug, PartialEq)]
pub enum ProcessError {
Closed,
}
impl HostContext {
pub fn process_did_wake(&self, frame: RpcFrame) -> Result<bool, ProcessError> {
match self.map.wake(&frame.header, (frame.header, frame.body)) {
WakeOutcome::Woke => Ok(true),
WakeOutcome::NoMatch(_) => Ok(false),
WakeOutcome::Closed(_) => Err(ProcessError::Closed),
}
}
pub fn process(&self, frame: RpcFrame) -> Result<(), ProcessError> {
if let WakeOutcome::Closed(_) = self.map.wake(&frame.header, (frame.header, frame.body)) {
Err(ProcessError::Closed)
} else {
Ok(())
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Schema)]
pub struct SchemaReport {
pub types: HashSet<OwnedNamedType>,
pub topics_in: Vec<TopicReport>,
pub topics_out: Vec<TopicReport>,
pub endpoints: Vec<EndpointReport>,
}
impl Default for SchemaReport {
fn default() -> Self {
let mut me = Self {
types: Default::default(),
topics_in: Default::default(),
topics_out: Default::default(),
endpoints: Default::default(),
};
me.add_type(OwnedNamedType::from(<bool as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<i8 as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<u8 as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<i16 as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<i32 as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<i64 as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<i128 as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<u16 as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<u32 as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<u64 as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<u128 as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<f32 as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<f64 as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<char as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<String as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<Vec<u8> as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<() as Schema>::SCHEMA));
me.add_type(OwnedNamedType::from(<OwnedNamedType as Schema>::SCHEMA));
me
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Schema)]
pub struct TopicReport {
pub path: String,
pub key: Key,
pub ty: OwnedNamedType,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Schema)]
pub struct EndpointReport {
pub path: String,
pub req_key: Key,
pub req_ty: OwnedNamedType,
pub resp_key: Key,
pub resp_ty: OwnedNamedType,
}
#[derive(Debug)]
pub struct UnableToFindType;
impl SchemaReport {
pub fn add_type(&mut self, t: OwnedNamedType) {
self.types.insert(t);
}
pub fn add_topic_in(&mut self, path: String, key: Key) -> Result<(), UnableToFindType> {
for ty in self.types.iter() {
let calc_key = Key::for_owned_schema_path(&path, ty);
if calc_key == key {
self.topics_in.push(TopicReport {
path,
key,
ty: ty.clone(),
});
return Ok(());
}
}
Err(UnableToFindType)
}
pub fn add_topic_out(&mut self, path: String, key: Key) -> Result<(), UnableToFindType> {
for ty in self.types.iter() {
let calc_key = Key::for_owned_schema_path(&path, ty);
if calc_key == key {
self.topics_out.push(TopicReport {
path,
key,
ty: ty.clone(),
});
return Ok(());
}
}
Err(UnableToFindType)
}
pub fn add_endpoint(
&mut self,
path: String,
req_key: Key,
resp_key: Key,
) -> Result<(), UnableToFindType> {
let mut req_ty = None;
for ty in self.types.iter() {
let calc_key = Key::for_owned_schema_path(&path, ty);
if calc_key == req_key {
req_ty = Some(ty.clone());
break;
}
}
let Some(req_ty) = req_ty else {
return Err(UnableToFindType);
};
let mut resp_ty = None;
for ty in self.types.iter() {
let calc_key = Key::for_owned_schema_path(&path, ty);
if calc_key == resp_key {
resp_ty = Some(ty.clone());
break;
}
}
let Some(resp_ty) = resp_ty else {
return Err(UnableToFindType);
};
self.endpoints.push(EndpointReport {
path,
req_key,
req_ty,
resp_key,
resp_ty,
});
Ok(())
}
}