use bytes::Buf;
use futures::FutureExt;
use serde::{Serialize, de::DeserializeOwned};
use std::{
fmt,
future::Future,
ops::Deref,
pin::Pin,
task::{Context, Poll, ready},
};
use super::{DEFAULT_MAX_ITEM_SIZE, RemoteSendError, base};
use crate::{RemoteSend, chmux, codec, exec, rch::BACKCHANNEL_MSG_ERROR};
mod receiver;
mod sender;
pub use receiver::{ChangedError, Receiver, ReceiverStream, RecvError};
pub use sender::{SendError, Sender};
pub struct Ref<'a, T>(tokio::sync::watch::Ref<'a, Result<T, RecvError>>);
impl<T> Deref for Ref<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.0.as_ref().unwrap()
}
}
impl<T> fmt::Debug for Ref<'_, T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", &**self)
}
}
pub fn channel<T, Codec>(init: T) -> (Sender<T, Codec>, Receiver<T, Codec>)
where
T: RemoteSend,
{
let (tx, rx) = tokio::sync::watch::channel(Ok(init));
let (remote_send_err_tx, remote_send_err_rx) = tokio::sync::mpsc::unbounded_channel();
let sender = Sender::new(tx, remote_send_err_tx.clone(), remote_send_err_rx, DEFAULT_MAX_ITEM_SIZE);
let receiver = Receiver::new(rx, remote_send_err_tx, None);
(sender, receiver)
}
pub fn forward<T, Codec>(mut local_rx: tokio::sync::watch::Receiver<T>) -> (Forwarding, Receiver<T, Codec>)
where
T: RemoteSend + Sync + Clone,
Codec: codec::Codec,
{
let init = local_rx.borrow_and_update().clone();
let (mut tx, rx) = channel(init);
let hnd = exec::spawn(async move {
loop {
tokio::select! {
biased;
() = tx.closed() => break,
res = local_rx.changed() => {
match res {
Ok(()) => {
let value = local_rx.borrow_and_update().clone();
match tx.send(value) {
Ok(()) => (),
Err(err) if err.is_closed() => break,
Err(err) => return Err(err),
}
}
Err(_) => break,
}
}
}
}
tx.check()
});
(Forwarding(hnd), rx)
}
pub struct Forwarding(exec::task::JoinHandle<Result<(), SendError>>);
impl fmt::Debug for Forwarding {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Forwarding").finish()
}
}
impl Future for Forwarding {
type Output = Result<(), SendError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match ready!(self.0.poll_unpin(cx)) {
Ok(res) => Poll::Ready(res),
Err(_) => Poll::Ready(Err(SendError::Closed)),
}
}
}
impl Forwarding {
pub fn stop(self) {
self.0.abort();
}
}
pub trait WatchExt<T, Codec, const MAX_ITEM_SIZE: usize> {
fn with_max_item_size<const NEW_MAX_ITEM_SIZE: usize>(
self,
) -> (Sender<T, Codec>, Receiver<T, Codec, NEW_MAX_ITEM_SIZE>);
}
impl<T, Codec, const MAX_ITEM_SIZE: usize> WatchExt<T, Codec, MAX_ITEM_SIZE>
for (Sender<T, Codec>, Receiver<T, Codec, MAX_ITEM_SIZE>)
where
T: Send + 'static,
{
fn with_max_item_size<const NEW_MAX_ITEM_SIZE: usize>(
self,
) -> (Sender<T, Codec>, Receiver<T, Codec, NEW_MAX_ITEM_SIZE>) {
let (mut tx, rx) = self;
tx.set_max_item_size(NEW_MAX_ITEM_SIZE);
let rx = rx.set_max_item_size();
(tx, rx)
}
}
async fn send_impl<T, Codec>(
mut rx: tokio::sync::watch::Receiver<Result<T, RecvError>>, raw_tx: chmux::Sender,
mut raw_rx: chmux::Receiver, remote_send_err_tx: tokio::sync::mpsc::UnboundedSender<RemoteSendError>,
max_item_size: usize,
) where
T: Serialize + Send + Clone + 'static,
Codec: codec::Codec,
{
let mut remote_tx = base::Sender::<Result<T, RecvError>, Codec>::new(raw_tx);
remote_tx.set_max_item_size(max_item_size);
loop {
tokio::select! {
biased;
backchannel_msg = raw_rx.recv() => {
match backchannel_msg {
Ok(Some(mut msg)) if msg.remaining() >= 1 => {
if msg.get_u8() == BACKCHANNEL_MSG_ERROR {
let _ = remote_send_err_tx.send(RemoteSendError::Forward);
}
}
_ => break,
}
}
changed = rx.changed() => {
match changed {
Ok(()) => {
let value = rx.borrow_and_update().clone();
if let Err(err) = remote_tx.send(value).await {
let _ = remote_send_err_tx.send(RemoteSendError::Send(err.kind.clone()));
if err.is_item_specific() {
tracing::warn!(%err, "sending over remote channel failed");
break
}
}
}
Err(_) => break,
}
}
}
}
}
async fn recv_impl<T, Codec>(
tx: tokio::sync::watch::Sender<Result<T, RecvError>>, mut raw_tx: chmux::Sender, raw_rx: chmux::Receiver,
mut remote_send_err_rx: tokio::sync::mpsc::UnboundedReceiver<RemoteSendError>,
mut current_err: Option<RemoteSendError>, max_item_size: usize,
) where
T: DeserializeOwned + Send + 'static,
Codec: codec::Codec,
{
let mut remote_rx = base::Receiver::<Result<T, RecvError>, Codec>::new(raw_rx);
remote_rx.set_max_item_size(max_item_size);
loop {
tokio::select! {
biased;
() = tx.closed() => break,
Some(_) = remote_send_err_rx.recv() => {
let _ = raw_tx.send(vec![BACKCHANNEL_MSG_ERROR].into()).await;
}
() = futures::future::ready(()), if current_err.is_some() => {
let _ = raw_tx.send(vec![BACKCHANNEL_MSG_ERROR].into()).await;
current_err = None;
}
res = remote_rx.recv() => {
let mut is_final_err = false;
let value = match res {
Ok(Some(value)) => value,
Ok(None) => break,
Err(err) => {
is_final_err = err.is_final();
Err(RecvError::RemoteReceive(err))
},
};
if tx.send(value).is_err() {
break;
}
if is_final_err {
break;
}
}
}
}
}