use std::{fmt::Debug, sync::Arc};
use maitake_sync::WaitQueue;
use postcard_schema::Schema;
use serde::de::DeserializeOwned;
use tokio::{
select,
sync::{
mpsc::{error::TrySendError, Receiver, Sender},
Mutex,
},
};
use tracing::{debug, trace, warn};
use crate::{
header::{VarHeader, VarKey, VarSeqKind},
host_client::{
HostClient, HostContext, ProcessError, RpcFrame, WireContext, WireRx, WireSpawn, WireTx,
},
Key,
};
#[derive(Default, Debug)]
pub(crate) struct Subscriptions {
pub(crate) list: Vec<(Key, Sender<RpcFrame>)>,
pub(crate) stopped: bool,
}
#[derive(Clone)]
pub struct Stopper {
inner: Arc<WaitQueue>,
}
impl Stopper {
pub fn new() -> Self {
Self {
inner: Arc::new(WaitQueue::new()),
}
}
pub async fn wait_stopped(&self) {
let _ = self.inner.wait().await;
}
pub fn is_stopped(&self) -> bool {
self.inner.is_closed()
}
pub fn stop(&self) {
self.inner.close();
}
}
impl<WireErr> HostClient<WireErr>
where
WireErr: DeserializeOwned + Schema,
{
pub fn new_with_wire<WTX, WRX, WSP>(
tx: WTX,
rx: WRX,
mut sp: WSP,
seq_kind: VarSeqKind,
err_uri_path: &str,
outgoing_depth: usize,
) -> Self
where
WTX: WireTx,
WRX: WireRx,
WSP: WireSpawn,
{
let (me, wire_ctx) = Self::new_manual_priv(err_uri_path, outgoing_depth, seq_kind);
let WireContext { outgoing, incoming } = wire_ctx;
sp.spawn(out_worker(tx, outgoing, me.stopper.clone()));
sp.spawn(in_worker(
rx,
incoming,
me.subscriptions.clone(),
me.stopper.clone(),
));
me
}
}
async fn out_worker<W>(wire: W, rec: Receiver<RpcFrame>, stop: Stopper)
where
W: WireTx,
W::Error: Debug,
{
let cancel_fut = stop.wait_stopped();
let operate_fut = out_worker_inner(wire, rec);
select! {
_ = cancel_fut => {},
_ = operate_fut => {
stop.stop();
},
}
}
async fn out_worker_inner<W>(mut wire: W, mut rec: Receiver<RpcFrame>)
where
W: WireTx,
W::Error: Debug,
{
loop {
let Some(msg) = rec.recv().await else {
tracing::warn!("Receiver Closed, this could be bad");
return;
};
if let Err(e) = wire.send(msg.to_bytes()).await {
tracing::error!("Output Queue Error: {e:?}, exiting");
return;
}
}
}
async fn in_worker<W>(
wire: W,
host_ctx: Arc<HostContext>,
subscriptions: Arc<Mutex<Subscriptions>>,
stop: Stopper,
) where
W: WireRx,
W::Error: Debug,
{
let cancel_fut = stop.wait_stopped();
let operate_fut = in_worker_inner(wire, host_ctx, subscriptions.clone());
select! {
_ = cancel_fut => {},
_ = operate_fut => {
stop.stop();
},
}
let mut guard = subscriptions.lock().await;
guard.stopped = true;
guard.list.clear();
}
async fn in_worker_inner<W>(
mut wire: W,
host_ctx: Arc<HostContext>,
subscriptions: Arc<Mutex<Subscriptions>>,
) where
W: WireRx,
W::Error: Debug,
{
loop {
let Ok(res) = wire.receive().await else {
warn!("in_worker: wire receive error, exiting");
return;
};
let Some((hdr, body)) = VarHeader::take_from_slice(&res) else {
warn!("Header decode error!");
continue;
};
trace!("in_worker received {hdr:?}");
let mut handled = false;
{
let mut subs_guard = subscriptions.lock().await;
let key = hdr.key;
let remove_sub = if let Some((_h, m)) = subs_guard
.list
.iter()
.find(|(k, _)| VarKey::Key8(*k) == key)
{
handled = true;
let frame = RpcFrame {
header: hdr,
body: body.to_vec(),
};
let res = m.try_send(frame);
match res {
Ok(()) => {
trace!("Handled message via subscription");
false
}
Err(TrySendError::Full(_)) => {
tracing::error!("Subscription channel full! Message dropped.");
false
}
Err(TrySendError::Closed(_)) => true,
}
} else {
false
};
if remove_sub {
debug!("Dropping subscription");
subs_guard.list.retain(|(k, _)| VarKey::Key8(*k) != key);
}
}
if handled {
continue;
}
let frame = RpcFrame {
header: hdr,
body: body.to_vec(),
};
match host_ctx.process_did_wake(frame) {
Ok(true) => debug!("Handled message via map"),
Ok(false) => debug!("Message not handled"),
Err(ProcessError::Closed) => {
warn!("Got process error, quitting");
return;
}
}
}
}