use futures::FutureExt;
use serde::{Deserialize, Serialize};
use std::{error::Error, fmt, marker::PhantomData, sync::Mutex};
use super::{
super::{
RemoteSendError, SendErrorExt,
base::{self, PortDeserializer, PortSerializer},
},
Receiver, Ref,
receiver::RecvError,
};
use crate::{RemoteSend, chmux, codec};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum SendError {
Closed,
RemoteSend(base::SendErrorKind),
RemoteConnect(chmux::ConnectError),
RemoteListen(chmux::ListenerError),
RemoteForward,
}
impl SendError {
pub fn is_closed(&self) -> bool {
matches!(self, Self::Closed)
}
pub fn is_disconnected(&self) -> bool {
!matches!(self, Self::RemoteSend(base::SendErrorKind::Serialize(_)))
}
pub fn is_final(&self) -> bool {
match self {
Self::RemoteSend(err) => err.is_final(),
Self::Closed | Self::RemoteConnect(_) | Self::RemoteListen(_) | Self::RemoteForward => true,
}
}
pub fn is_item_specific(&self) -> bool {
matches!(self, Self::RemoteSend(err) if err.is_item_specific())
}
}
impl SendErrorExt for SendError {
fn is_closed(&self) -> bool {
self.is_closed()
}
fn is_disconnected(&self) -> bool {
self.is_disconnected()
}
fn is_final(&self) -> bool {
self.is_final()
}
fn is_item_specific(&self) -> bool {
self.is_item_specific()
}
}
impl fmt::Display for SendError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Closed => write!(f, "channel is closed"),
Self::RemoteSend(err) => write!(f, "send error: {err}"),
Self::RemoteConnect(err) => write!(f, "connect error: {err}"),
Self::RemoteListen(err) => write!(f, "listen error: {err}"),
Self::RemoteForward => write!(f, "forwarding error"),
}
}
}
impl Error for SendError {}
impl From<RemoteSendError> for SendError {
fn from(err: RemoteSendError) -> Self {
match err {
RemoteSendError::Send(err) => Self::RemoteSend(err),
RemoteSendError::Connect(err) => Self::RemoteConnect(err),
RemoteSendError::Listen(err) => Self::RemoteListen(err),
RemoteSendError::Forward => Self::RemoteForward,
RemoteSendError::Closed => Self::Closed,
}
}
}
pub struct Sender<T, Codec = codec::Default> {
inner: Option<SenderInner<T, Codec>>,
successor_tx: Mutex<Option<tokio::sync::oneshot::Sender<SenderInner<T, Codec>>>>,
}
impl<T, Codec> fmt::Debug for Sender<T, Codec> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Sender").finish()
}
}
pub(crate) struct SenderInner<T, Codec> {
tx: tokio::sync::watch::Sender<Result<T, RecvError>>,
remote_send_err_tx: tokio::sync::mpsc::UnboundedSender<RemoteSendError>,
remote_send_err_rx: Mutex<tokio::sync::mpsc::UnboundedReceiver<RemoteSendError>>,
current_err: Mutex<Option<RemoteSendError>>,
max_item_size: usize,
_codec: PhantomData<Codec>,
}
#[derive(Serialize, Deserialize)]
pub(crate) struct TransportedSender<T, Codec> {
port: u32,
data: Result<T, RecvError>,
codec: PhantomData<Codec>,
#[serde(default = "default_max_item_size")]
max_item_size: u64,
}
const fn default_max_item_size() -> u64 {
u64::MAX
}
impl<T, Codec> Sender<T, Codec>
where
T: Send + 'static,
{
pub(crate) fn new(
tx: tokio::sync::watch::Sender<Result<T, RecvError>>,
remote_send_err_tx: tokio::sync::mpsc::UnboundedSender<RemoteSendError>,
remote_send_err_rx: tokio::sync::mpsc::UnboundedReceiver<RemoteSendError>, max_item_size: usize,
) -> Self {
let inner = SenderInner {
tx,
remote_send_err_tx,
remote_send_err_rx: Mutex::new(remote_send_err_rx),
current_err: Mutex::new(None),
max_item_size,
_codec: PhantomData,
};
Self { inner: Some(inner), successor_tx: Mutex::new(None) }
}
pub fn send(&self, value: T) -> Result<(), SendError> {
match self.inner.as_ref().unwrap().tx.send(Ok(value)) {
Ok(()) => Ok(()),
Err(_) => match self.error() {
Some(err) => Err(err),
None => Err(SendError::Closed),
},
}
}
pub fn send_modify<F>(&self, func: F)
where
F: FnOnce(&mut T),
{
self.inner.as_ref().unwrap().tx.send_modify(move |v| func(v.as_mut().unwrap()))
}
pub fn send_replace(&self, value: T) -> T {
self.inner.as_ref().unwrap().tx.send_replace(Ok(value)).unwrap()
}
pub fn borrow(&self) -> Ref<'_, T> {
Ref(self.inner.as_ref().unwrap().tx.borrow())
}
pub async fn closed(&self) {
self.inner.as_ref().unwrap().tx.closed().await
}
pub fn is_closed(&self) -> bool {
self.inner.as_ref().unwrap().tx.is_closed()
}
pub fn subscribe(&self) -> Receiver<T, Codec> {
let inner = self.inner.as_ref().unwrap();
Receiver::new(inner.tx.subscribe(), inner.remote_send_err_tx.clone(), None)
}
fn update_error(&self) {
let inner = self.inner.as_ref().unwrap();
let mut current_err = inner.current_err.lock().unwrap();
if current_err.is_some() {
return;
}
let mut remote_send_err_rx = inner.remote_send_err_rx.lock().unwrap();
if let Ok(err) = remote_send_err_rx.try_recv() {
*current_err = Some(err);
}
}
pub fn error(&self) -> Option<SendError> {
self.update_error();
let inner = self.inner.as_ref().unwrap();
let current_err = inner.current_err.lock().unwrap();
current_err.clone().map(|err| err.into())
}
pub fn clear_error(&mut self) {
self.update_error();
let inner = self.inner.as_ref().unwrap();
let mut current_err = inner.current_err.lock().unwrap();
*current_err = None;
}
pub fn check(&mut self) -> Result<(), SendError> {
while let Some(err) = self.error() {
if err.is_item_specific() {
return Err(err);
}
self.clear_error();
}
Ok(())
}
pub fn max_item_size(&self) -> usize {
self.inner.as_ref().unwrap().max_item_size
}
pub fn set_max_item_size(&mut self, max_item_size: usize) {
self.inner.as_mut().unwrap().max_item_size = max_item_size;
}
}
impl<T, Codec> Drop for Sender<T, Codec> {
fn drop(&mut self) {
if let Some(successor_tx) = self.successor_tx.lock().unwrap().take() {
let _ = successor_tx.send(self.inner.take().unwrap());
}
}
}
impl<T, Codec> Serialize for Sender<T, Codec>
where
T: RemoteSend + Sync + Clone,
Codec: codec::Codec,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let max_item_size = self.max_item_size();
let (successor_tx, successor_rx) = tokio::sync::oneshot::channel();
*self.successor_tx.lock().unwrap() = Some(successor_tx);
let port = PortSerializer::connect(move |connect| {
async move {
let SenderInner { tx, remote_send_err_rx, current_err, .. } = match successor_rx.await {
Ok(inner) => inner,
Err(_) => return,
};
let remote_send_err_rx = remote_send_err_rx.into_inner().unwrap();
let current_err = current_err.into_inner().unwrap();
let (raw_tx, raw_rx) = match connect.await {
Ok(tx_rx) => tx_rx,
Err(err) => {
let _ = tx.send(Err(RecvError::RemoteConnect(err)));
return;
}
};
super::recv_impl::<T, Codec>(tx, raw_tx, raw_rx, remote_send_err_rx, current_err, max_item_size)
.await;
}
.boxed()
})?;
let data = self.inner.as_ref().unwrap().tx.borrow().clone();
let transported = TransportedSender::<T, Codec> {
port,
data,
max_item_size: max_item_size.try_into().unwrap_or(u64::MAX),
codec: PhantomData,
};
transported.serialize(serializer)
}
}
impl<'de, T, Codec> Deserialize<'de> for Sender<T, Codec>
where
T: RemoteSend + Sync + Clone,
Codec: codec::Codec,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let TransportedSender { port, data, max_item_size, .. } =
TransportedSender::<T, Codec>::deserialize(deserializer)?;
let max_item_size = usize::try_from(max_item_size).unwrap_or(usize::MAX);
if data.is_err() {
return Err(serde::de::Error::custom("received watch data with error"));
}
let (tx, rx) = tokio::sync::watch::channel(data);
let (remote_send_err_tx, remote_send_err_rx) = tokio::sync::mpsc::unbounded_channel();
let remote_send_err_tx2 = remote_send_err_tx.clone();
PortDeserializer::accept(port, move |local_port, request| {
async move {
let (raw_tx, raw_rx) = match request.accept_from(local_port).await {
Ok(tx_rx) => tx_rx,
Err(err) => {
let _ = remote_send_err_tx.send(RemoteSendError::Listen(err));
return;
}
};
super::send_impl::<T, Codec>(rx, raw_tx, raw_rx, remote_send_err_tx, max_item_size).await;
}
.boxed()
})?;
Ok(Self::new(tx, remote_send_err_tx2, remote_send_err_rx, max_item_size))
}
}