use bytes::Buf;
use serde::{Serialize, de::DeserializeOwned};
use super::{ClosedReason, RemoteSendError, Sending, base};
use crate::{
RemoteSend, chmux, codec,
rch::{BACKCHANNEL_MSG_CLOSE, BACKCHANNEL_MSG_ERROR},
};
mod distributor;
mod receiver;
mod sender;
pub use distributor::{DistributedReceiverHandle, Distributor};
pub use receiver::{Receiver, RecvError, TryRecvError};
pub use sender::{Permit, SendError, Sender, SenderSink, TrySendError};
pub fn channel<T, Codec>(local_buffer: usize) -> (Sender<T, Codec>, Receiver<T, Codec>)
where
T: RemoteSend,
{
assert!(local_buffer > 0, "local_buffer must not be zero");
let (tx, rx) = tokio::sync::mpsc::channel(local_buffer);
let (closed_tx, closed_rx) = tokio::sync::watch::channel(None);
let (remote_send_err_tx, remote_send_err_rx) = tokio::sync::watch::channel(None);
let sender = Sender::new(tx, closed_rx, remote_send_err_rx);
let receiver = Receiver::new(rx, closed_tx, false, remote_send_err_tx, None);
(sender, receiver)
}
pub trait MpscExt<T, Codec, const BUFFER: usize, const MAX_ITEM_SIZE: usize> {
fn with_buffer<const NEW_BUFFER: usize>(
self,
) -> (Sender<T, Codec, NEW_BUFFER>, Receiver<T, Codec, NEW_BUFFER, MAX_ITEM_SIZE>);
fn with_max_item_size<const NEW_MAX_ITEM_SIZE: usize>(
self,
) -> (Sender<T, Codec, BUFFER>, Receiver<T, Codec, BUFFER, NEW_MAX_ITEM_SIZE>);
}
impl<T, Codec, const BUFFER: usize, const MAX_ITEM_SIZE: usize> MpscExt<T, Codec, BUFFER, MAX_ITEM_SIZE>
for (Sender<T, Codec, BUFFER>, Receiver<T, Codec, BUFFER, MAX_ITEM_SIZE>)
where
T: Send + 'static,
{
fn with_buffer<const NEW_BUFFER: usize>(
self,
) -> (Sender<T, Codec, NEW_BUFFER>, Receiver<T, Codec, NEW_BUFFER, MAX_ITEM_SIZE>) {
let (tx, rx) = self;
let tx = tx.set_buffer();
let rx = rx.set_buffer();
(tx, rx)
}
fn with_max_item_size<const NEW_MAX_ITEM_SIZE: usize>(
self,
) -> (Sender<T, Codec, BUFFER>, Receiver<T, Codec, BUFFER, NEW_MAX_ITEM_SIZE>) {
let (mut tx, rx) = self;
tx.set_max_item_size(NEW_MAX_ITEM_SIZE);
let rx = rx.set_max_item_size();
(tx, rx)
}
}
pub(crate) struct SendReq<T> {
pub value: Result<T, RecvError>,
pub result_tx: tokio::sync::oneshot::Sender<Result<(), base::SendError<T>>>,
}
impl<T> SendReq<T> {
fn new(value: Result<T, RecvError>) -> Self {
Self { value, result_tx: tokio::sync::oneshot::channel().0 }
}
fn ack(self) -> Result<T, RecvError> {
let Self { value, result_tx } = self;
let _ = result_tx.send(Ok(()));
value
}
}
pub(crate) fn send_req<T>(value: Result<T, RecvError>) -> (SendReq<T>, Sending<T>) {
let (result_tx, result_rx) = tokio::sync::oneshot::channel();
let this = SendReq { value, result_tx };
let sent = Sending(result_rx);
(this, sent)
}
async fn send_impl<T, Codec>(
mut rx: tokio::sync::mpsc::Receiver<SendReq<T>>, raw_tx: chmux::Sender, mut raw_rx: chmux::Receiver,
remote_send_err_tx: tokio::sync::watch::Sender<Option<RemoteSendError>>,
closed_tx: tokio::sync::watch::Sender<Option<ClosedReason>>, max_item_size: usize,
) where
T: Serialize + Send + 'static,
Codec: codec::Codec,
{
let mut remote_tx = base::Sender::<Result<T, RecvError>, Codec>::new(raw_tx);
remote_tx.set_max_item_size(max_item_size);
loop {
tokio::select! {
biased;
backchannel_msg = raw_rx.recv() => {
match backchannel_msg {
Ok(Some(mut msg)) if msg.remaining() >= 1 => {
match msg.get_u8() {
BACKCHANNEL_MSG_CLOSE => {
let _ = remote_send_err_tx.send(Some(RemoteSendError::Closed));
let _ = closed_tx.send(Some(ClosedReason::Closed));
break;
}
BACKCHANNEL_MSG_ERROR => {
let _ = remote_send_err_tx.send(Some(RemoteSendError::Forward));
let _ = closed_tx.send(Some(ClosedReason::Failed));
break;
}
_ => (),
}
},
Ok(Some(_)) => (),
Ok(None) => {
let _ = remote_send_err_tx.send(Some(RemoteSendError::Send(
base::SendErrorKind::Send(chmux::SendError::Closed { gracefully: false })
)));
let _ = closed_tx.send(Some(ClosedReason::Dropped));
break;
}
_ => {
let _ = remote_send_err_tx.send(Some(RemoteSendError::Send(
base::SendErrorKind::Send(chmux::SendError::ChMux)
)));
let _ = closed_tx.send(Some(ClosedReason::Failed));
break;
},
}
}
value_opt = rx.recv() => {
match value_opt {
Some(value) => {
let SendReq { value, result_tx } = value;
match remote_tx.send(value).await {
Ok(()) => {
let _ = result_tx.send(Ok(()));
}
Err(err) => {
let _ = remote_send_err_tx.send(Some(RemoteSendError::Send(err.kind.clone())));
let _ = closed_tx.send(Some(ClosedReason::Failed));
if let Ok(item) = err.item
&& let Err(Err(err)) = result_tx.send(Err(base::SendError {
kind: err.kind,
item,
}))
&& err.is_item_specific() {
tracing::warn!(%err, "sending over remote channel failed");
}
}
}
}
None => break,
}
}
}
}
}
async fn recv_impl<T, Codec>(
tx: &tokio::sync::mpsc::Sender<SendReq<T>>, mut raw_tx: chmux::Sender, raw_rx: chmux::Receiver,
mut remote_send_err_rx: tokio::sync::watch::Receiver<Option<RemoteSendError>>,
mut closed_rx: tokio::sync::watch::Receiver<Option<ClosedReason>>, max_item_size: usize,
) where
T: DeserializeOwned + Send + 'static,
Codec: codec::Codec,
{
let mut remote_rx = base::Receiver::<Result<T, RecvError>, Codec>::new(raw_rx);
remote_rx.set_max_item_size(max_item_size);
loop {
tokio::select! {
biased;
res = closed_rx.changed() => {
match res {
Ok(()) => {
let reason = closed_rx.borrow().clone();
match reason {
Some(ClosedReason::Closed) => {
let _ = raw_tx.send(vec![BACKCHANNEL_MSG_CLOSE].into()).await;
}
Some(ClosedReason::Dropped) => break,
Some(ClosedReason::Failed) => {
let _ = raw_tx.send(vec![BACKCHANNEL_MSG_ERROR].into()).await;
}
None => (),
}
},
Err(_) => break,
}
}
Ok(()) = remote_send_err_rx.changed() => {
if remote_send_err_rx.borrow().as_ref().is_some() {
let _ = raw_tx.send(vec![BACKCHANNEL_MSG_ERROR].into()).await;
}
}
res = remote_rx.recv() => {
let mut is_final_err = false;
let value = match res {
Ok(Some(value)) => value,
Ok(None) => break,
Err(err) => {
is_final_err = err.is_final();
Err(RecvError::RemoteReceive(err))
},
};
if tx.send(SendReq::new(value)).await.is_err() {
break;
}
if is_final_err {
break;
}
}
}
}
}