use futures::FutureExt;
use serde::{Deserialize, Serialize, de::DeserializeOwned, ser};
use std::{
error::Error,
fmt,
marker::PhantomData,
sync::{Arc, Mutex},
};
use super::{
super::{
ConnectError, SendErrorExt,
base::{self, PortDeserializer, PortSerializer},
},
Interlock, Location,
};
use crate::{
chmux,
codec::{self, SerializationError},
};
pub use super::super::base::Closed;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SendError<T> {
pub kind: SendErrorKind,
pub item: T,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SendErrorKind {
Serialize(SerializationError),
Send(chmux::SendError),
Connect(ConnectError),
MaxItemSizeExceeded,
}
impl<T> SendError<T> {
pub(crate) fn new(kind: SendErrorKind, item: T) -> Self {
Self { kind, item }
}
pub fn is_closed(&self) -> bool {
matches!(&self.kind, SendErrorKind::Send(err) if err.is_closed())
}
pub fn is_disconnected(&self) -> bool {
matches!(&self.kind, SendErrorKind::Send(_) | SendErrorKind::Connect(_))
}
pub fn is_final(&self) -> bool {
self.is_disconnected()
}
pub fn is_item_specific(&self) -> bool {
matches!(&self.kind, SendErrorKind::Serialize(_) | SendErrorKind::MaxItemSizeExceeded)
}
pub fn without_item(self) -> SendError<()> {
SendError { kind: self.kind, item: () }
}
}
impl<T> SendErrorExt for SendError<T> {
fn is_closed(&self) -> bool {
self.is_closed()
}
fn is_disconnected(&self) -> bool {
self.is_disconnected()
}
fn is_final(&self) -> bool {
self.is_final()
}
fn is_item_specific(&self) -> bool {
self.is_item_specific()
}
}
impl fmt::Display for SendErrorKind {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Serialize(err) => write!(f, "serialization error: {err}"),
Self::Send(err) => write!(f, "send error: {err}"),
Self::Connect(err) => write!(f, "connect error: {err}"),
Self::MaxItemSizeExceeded => write!(f, "maximum item size exceeded"),
}
}
}
impl From<base::SendErrorKind> for SendErrorKind {
fn from(err: base::SendErrorKind) -> Self {
match err {
base::SendErrorKind::Serialize(err) => Self::Serialize(err),
base::SendErrorKind::Send(err) => Self::Send(err),
base::SendErrorKind::MaxItemSizeExceeded => Self::MaxItemSizeExceeded,
}
}
}
impl<T> fmt::Display for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", &self.kind)
}
}
impl<T> From<base::SendError<T>> for SendError<T> {
fn from(err: base::SendError<T>) -> Self {
Self { kind: err.kind.into(), item: err.item }
}
}
impl<T> Error for SendError<T> where T: fmt::Debug {}
pub struct Sender<T, Codec = codec::Default> {
pub(super) sender: Option<Result<base::Sender<T, Codec>, ConnectError>>,
pub(super) sender_rx: tokio::sync::mpsc::UnboundedReceiver<Result<base::Sender<T, Codec>, ConnectError>>,
pub(super) receiver_tx:
Option<tokio::sync::mpsc::UnboundedSender<Result<base::Receiver<T, Codec>, ConnectError>>>,
pub(super) interlock: Arc<Mutex<Interlock>>,
pub(super) max_item_size: usize,
}
impl<T, Codec> fmt::Debug for Sender<T, Codec> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Sender").finish()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct TransportedSender<T, Codec> {
pub port: u32,
data: PhantomData<T>,
codec: PhantomData<Codec>,
#[serde(default = "default_max_item_size")]
max_item_size: u64,
}
const fn default_max_item_size() -> u64 {
u64::MAX
}
impl<T, Codec> Sender<T, Codec>
where
T: Serialize + Send + 'static,
Codec: codec::Codec,
{
async fn get(&mut self) -> Result<&mut base::Sender<T, Codec>, ConnectError> {
if self.sender.is_none() {
self.sender = Some(self.sender_rx.recv().await.unwrap_or(Err(ConnectError::Dropped)));
if let Some(Ok(sender)) = &mut self.sender {
sender.set_max_item_size(self.max_item_size);
}
}
self.sender.as_mut().unwrap().as_mut().map_err(|err| err.clone())
}
pub async fn send(&mut self, item: T) -> Result<(), SendError<T>> {
match self.get().await {
Ok(sender) => Ok(sender.send(item).await?),
Err(err) => Err(SendError::new(SendErrorKind::Connect(err), item)),
}
}
pub async fn is_closed(&mut self) -> Result<bool, ConnectError> {
Ok(self.get().await?.is_closed())
}
pub async fn closed(&mut self) -> Result<Closed, ConnectError> {
Ok(self.get().await?.closed())
}
pub fn max_item_size(&self) -> usize {
self.max_item_size
}
pub fn set_max_item_size(&mut self, max_item_size: usize) {
self.max_item_size = max_item_size;
if let Some(Ok(sender)) = &mut self.sender {
sender.set_max_item_size(self.max_item_size);
}
}
}
impl<T, Codec> Serialize for Sender<T, Codec>
where
T: DeserializeOwned + Send + 'static,
Codec: codec::Codec,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let max_item_size = self.max_item_size;
let receiver_tx =
self.receiver_tx.clone().ok_or_else(|| ser::Error::custom("cannot forward received sender"))?;
let interlock_confirm = {
let mut interlock = self.interlock.lock().unwrap();
if !interlock.receiver.check_local() {
return Err(ser::Error::custom("cannot send sender because receiver has been sent"));
}
interlock.receiver.start_send()
};
let port = PortSerializer::connect(move |connect| {
async move {
let _ = interlock_confirm.send(());
match connect.await {
Ok((_, raw_rx)) => {
let mut rx = base::Receiver::new(raw_rx);
rx.set_max_item_size(max_item_size);
let _ = receiver_tx.send(Ok(rx));
}
Err(err) => {
let _ = receiver_tx.send(Err(ConnectError::Connect(err)));
}
}
}
.boxed()
})?;
TransportedSender::<T, Codec> {
port,
data: PhantomData,
max_item_size: max_item_size.try_into().unwrap_or(u64::MAX),
codec: PhantomData,
}
.serialize(serializer)
}
}
impl<'de, T, Codec> Deserialize<'de> for Sender<T, Codec>
where
T: Serialize + Send + 'static,
Codec: codec::Codec,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let TransportedSender::<T, Codec> { port, max_item_size, .. } =
TransportedSender::deserialize(deserializer)?;
let max_item_size = usize::try_from(max_item_size).unwrap_or(usize::MAX);
let (sender_tx, sender_rx) = tokio::sync::mpsc::unbounded_channel();
PortDeserializer::accept(port, move |local_port, request| {
async move {
match request.accept_from(local_port).await {
Ok((raw_tx, _)) => {
let mut tx = base::Sender::new(raw_tx);
tx.set_max_item_size(max_item_size);
let _ = sender_tx.send(Ok(tx));
}
Err(err) => {
let _ = sender_tx.send(Err(ConnectError::Listen(err)));
}
}
}
.boxed()
})?;
Ok(Self {
sender: None,
sender_rx,
receiver_tx: None,
interlock: Arc::new(Mutex::new(Interlock { sender: Location::Local, receiver: Location::Remote })),
max_item_size,
})
}
}