use bytes::Buf;
use futures::FutureExt;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{marker::PhantomData, sync::Mutex};
use super::{
super::{
mpsc::{BACKCHANNEL_MSG_CLOSE, BACKCHANNEL_MSG_ERROR},
remote::{self, PortDeserializer, PortSerializer},
},
RemoteSendError,
};
use crate::{chmux, codec::CodecT};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ReceiveError {
RemoteReceive(remote::ReceiveError),
RemoteConnect(chmux::ConnectError),
RemoteListen(chmux::ListenerError),
}
pub struct Receiver<T, Codec, const BUFFER: usize> {
inner: Option<ReceiverInner<T, Codec, BUFFER>>,
#[allow(clippy::type_complexity)]
successor_tx: Mutex<Option<tokio::sync::oneshot::Sender<ReceiverInner<T, Codec, BUFFER>>>>,
}
pub(crate) struct ReceiverInner<T, Codec, const BUFFER: usize> {
rx: tokio::sync::mpsc::Receiver<Result<T, ReceiveError>>,
closed_tx: tokio::sync::watch::Sender<bool>,
remote_send_err_tx: tokio::sync::watch::Sender<Option<RemoteSendError>>,
_codec: PhantomData<Codec>,
}
#[derive(Serialize, Deserialize)]
pub struct TransportedReceiver<T, Codec> {
port: u32,
data: PhantomData<T>,
codec: PhantomData<Codec>,
}
impl<T, Codec, const BUFFER: usize> Receiver<T, Codec, BUFFER> {
pub(crate) fn new(
rx: tokio::sync::mpsc::Receiver<Result<T, ReceiveError>>, closed_tx: tokio::sync::watch::Sender<bool>,
remote_send_err_tx: tokio::sync::watch::Sender<Option<RemoteSendError>>,
) -> Self {
Self {
inner: Some(ReceiverInner { rx, closed_tx, remote_send_err_tx, _codec: PhantomData }),
successor_tx: Mutex::new(None),
}
}
pub async fn recv(&mut self) -> Result<Option<T>, ReceiveError> {
match self.inner.as_mut().unwrap().rx.recv().await {
Some(Ok(value_opt)) => Ok(Some(value_opt)),
Some(Err(err)) => Err(err),
None => Ok(None),
}
}
pub fn close(&mut self) {
let _ = self.inner.as_mut().unwrap().closed_tx.send(true);
}
}
impl<T, Codec, const BUFFER: usize> Drop for Receiver<T, Codec, BUFFER> {
fn drop(&mut self) {
let mut successor_tx = self.successor_tx.lock().unwrap();
if let Some(successor_tx) = successor_tx.take() {
let _ = successor_tx.send(self.inner.take().unwrap());
}
}
}
impl<T, Codec, const BUFFER: usize> Serialize for Receiver<T, Codec, BUFFER>
where
T: Serialize + DeserializeOwned + Send + 'static,
Codec: CodecT,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let (successor_tx, successor_rx) = tokio::sync::oneshot::channel();
*self.successor_tx.lock().unwrap() = Some(successor_tx);
let port = PortSerializer::connect(|connect, allocator| {
tokio::spawn(async move {
let ReceiverInner { mut rx, closed_tx, remote_send_err_tx, _codec } = match successor_rx.await {
Ok(inner) => inner,
Err(_) => return,
};
let (raw_tx, mut raw_rx) = match connect.await {
Ok(tx_rx) => tx_rx,
Err(err) => {
let _ = remote_send_err_tx.send(Some(RemoteSendError::Connect(err)));
return;
}
};
let mut remote_tx = remote::Sender::<Result<T, ReceiveError>, Codec>::new(raw_tx, allocator);
let mut backchannel_active = true;
loop {
tokio::select! {
biased;
backchannel_msg = raw_rx.recv(), if backchannel_active => {
match backchannel_msg {
Ok(Some(mut msg)) if msg.remaining() >= 1 => {
match msg.get_u8() {
BACKCHANNEL_MSG_CLOSE => {
let _ = closed_tx.send(true);
}
BACKCHANNEL_MSG_ERROR => {
let _ = remote_send_err_tx.send(Some(RemoteSendError::Forward));
}
_ => (),
}
},
_ => backchannel_active = false,
}
}
res_opt = rx.recv() => {
let res = match res_opt {
Some(res) => res,
None => break,
};
if let Err(err) = remote_tx.send(res).await {
let _ = remote_send_err_tx.send(Some(RemoteSendError::Send(err.kind)));
}
}
}
}
})
.map(|_| ())
.boxed()
})?;
let transported = TransportedReceiver::<T, Codec> { port, data: PhantomData, codec: PhantomData };
transported.serialize(serializer)
}
}
impl<'de, T, Codec, const BUFFER: usize> Deserialize<'de> for Receiver<T, Codec, BUFFER>
where
T: Serialize + DeserializeOwned + Send + 'static,
Codec: CodecT,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
assert!(BUFFER > 0, "BUFFER must not be zero");
let TransportedReceiver { port, .. } = TransportedReceiver::<T, Codec>::deserialize(deserializer)?;
let (tx, rx) = tokio::sync::mpsc::channel(BUFFER);
let (closed_tx, mut closed_rx) = tokio::sync::watch::channel(false);
let (remote_send_err_tx, mut remote_send_err_rx) = tokio::sync::watch::channel(None);
PortDeserializer::accept(port, |local_port, request, allocator| {
tokio::spawn(async move {
let (mut raw_tx, raw_rx) = match request.accept_from(local_port).await {
Ok(tx_rx) => tx_rx,
Err(err) => {
let _ = tx.send(Err(ReceiveError::RemoteListen(err))).await;
return;
}
};
let mut remote_rx = remote::Receiver::<Result<T, ReceiveError>, Codec>::new(raw_rx, allocator);
let mut close_sent = false;
loop {
tokio::select! {
biased;
res = closed_rx.changed() => {
match res {
Ok(()) if *closed_rx.borrow() && !close_sent => {
let _ = raw_tx.send(vec![BACKCHANNEL_MSG_CLOSE].into()).await;
close_sent = true;
}
Ok(()) => (),
Err(_) => break,
}
}
Ok(()) = remote_send_err_rx.changed() => {
if remote_send_err_rx.borrow().as_ref().is_some() {
let _ = raw_tx.send(vec![BACKCHANNEL_MSG_ERROR].into()).await;
}
}
res = remote_rx.recv() => {
let value = match res {
Ok(Some(value)) => value,
Ok(None) => break,
Err(err) => Err(ReceiveError::RemoteReceive(err)),
};
if tx.send(value).await.is_err() {
break;
}
}
}
}
})
.map(|_| ())
.boxed()
})?;
Ok(Self::new(rx, closed_tx, remote_send_err_tx))
}
}