use std::{
marker::PhantomData,
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
};
use crate::{
accumulator::raw::{CobsAccumulator, FeedResult},
headered::{extract_header_from_bytes, to_stdvec},
Key,
};
use cobs::encode_vec;
use maitake_sync::{
wait_map::{WaitError, WakeOutcome},
WaitMap,
};
use postcard::experimental::schema::Schema;
use serde::{de::DeserializeOwned, Serialize};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
select,
sync::mpsc::{Receiver, Sender},
};
use tokio_serial::{SerialPortBuilderExt, SerialStream};
#[derive(Debug, PartialEq)]
pub enum HostErr<WireErr> {
Wire(WireErr),
BadResponse,
Postcard(postcard::Error),
Closed,
}
impl<T> From<postcard::Error> for HostErr<T> {
fn from(value: postcard::Error) -> Self {
Self::Postcard(value)
}
}
impl<T> From<WaitError> for HostErr<T> {
fn from(_: WaitError) -> Self {
Self::Closed
}
}
async fn wire_worker(
mut port: SerialStream,
mut outgoing: Receiver<Vec<u8>>,
ctx: Arc<HostContext>,
) {
let mut buf = [0u8; 1024];
let mut acc = CobsAccumulator::<1024>::new();
loop {
select! {
out = outgoing.recv() => {
let Some(msg) = out else {
return;
};
let mut msg = encode_vec(&msg);
msg.push(0);
if port.write_all(&msg).await.is_err() {
return;
}
}
inc = port.read(&mut buf) => {
let Ok(used) = inc else {
return;
};
let mut window = &buf[..used];
'cobs: while !window.is_empty() {
window = match acc.feed(window) {
FeedResult::Consumed => break 'cobs,
FeedResult::OverFull(new_wind) => new_wind,
FeedResult::DeserError(new_wind) => new_wind,
FeedResult::Success { data, remaining } => {
if let Ok((hdr, _body)) = extract_header_from_bytes(data) {
if let WakeOutcome::Closed(_) = ctx.map.wake(&hdr.seq_no, data.to_vec()) {
return;
}
}
remaining
}
};
}
}
}
}
}
pub struct HostClient<WireErr> {
ctx: Arc<HostContext>,
out: Sender<Vec<u8>>,
err_key: Key,
_pd: PhantomData<fn() -> WireErr>,
}
impl<WireErr> HostClient<WireErr>
where
WireErr: DeserializeOwned + Schema,
{
pub fn new(serial_path: &str, err_uri_path: &str) -> Self {
let (tx_pc, rx_pc) = tokio::sync::mpsc::channel(8);
let port = tokio_serial::new(serial_path, 115_200)
.open_native_async()
.unwrap();
let ctx = Arc::new(HostContext {
map: WaitMap::new(),
seq: AtomicU32::new(0),
});
tokio::task::spawn({
let ctx = ctx.clone();
async move { wire_worker(port, rx_pc, ctx).await }
});
let err_key = Key::for_path::<WireErr>(err_uri_path);
HostClient {
ctx,
out: tx_pc,
err_key,
_pd: PhantomData,
}
}
pub async fn send_resp<TX, RX>(&self, path: &str, t: TX) -> Result<RX, HostErr<WireErr>>
where
TX: Serialize + Schema,
RX: DeserializeOwned + Schema,
{
let seq_no = self.ctx.seq.fetch_add(1, Ordering::Relaxed);
let msg = to_stdvec(seq_no, path, &t).expect("Allocations should not ever fail");
self.out.send(msg).await.map_err(|_| HostErr::Closed)?;
let resp = self.ctx.map.wait(seq_no).await?;
let (hdr, body) = extract_header_from_bytes(&resp)?;
if hdr.key == Key::for_path::<RX>(path) {
let r = postcard::from_bytes::<RX>(body)?;
Ok(r)
} else if hdr.key == self.err_key {
let r = postcard::from_bytes::<WireErr>(body)?;
Err(HostErr::Wire(r))
} else {
Err(HostErr::BadResponse)
}
}
}
impl<WireErr> Clone for HostClient<WireErr> {
fn clone(&self) -> Self {
Self {
ctx: self.ctx.clone(),
out: self.out.clone(),
err_key: self.err_key,
_pd: PhantomData,
}
}
}
struct HostContext {
map: WaitMap<u32, Vec<u8>>,
seq: AtomicU32,
}