use std::error::Error;
use std::pin::Pin;
use std::sync::Arc;
use futures::Sink;
use futures::channel::mpsc;
use derive_deftly::{Deftly, define_derive_deftly};
use thiserror::Error;
pub trait SinkTrySend<T>: Sink<T> {
type Error: SinkTrySendError;
fn try_send(self: Pin<&mut Self>, item: T) -> Result<(), <Self as SinkTrySend<T>>::Error> {
self.try_send_or_return(item)
.map_err(|(error, _item)| error)
}
fn try_send_or_return(
self: Pin<&mut Self>,
item: T,
) -> Result<(), (<Self as SinkTrySend<T>>::Error, T)>;
}
pub trait SinkTrySendError: Error + 'static {
fn is_full(&self) -> bool;
fn is_disconnected(&self) -> bool;
}
#[rustfmt::skip] define_derive_deftly! {
ErasedSinkTrySendError expect items:
${defcond PREDICATE vmeta(predicate)}
${define PREDICATE { $<is_ ${snake_case $vname}> }}
impl SinkTrySendError for ErasedSinkTrySendError {
$(
${when PREDICATE}
fn $PREDICATE(&self) -> bool {
matches!(self, $vtype)
}
)
}
impl ErasedSinkTrySendError {
pub fn from<E>(e: E) -> ErasedSinkTrySendError
where E: SinkTrySendError + Send + Sync
{
$(
${when PREDICATE}
if e.$PREDICATE() {
$vtype
} else
)
{
let e = Arc::new(e);
let e2 = e.clone();
match Arc::downcast(e2) {
Ok::<Arc<ErasedSinkTrySendError>, _>(y2) => {
drop(e); let inner: ErasedSinkTrySendError =
Arc::into_inner(y2).expect(
"somehow we weren't the only owner, despite us just having made an Arc!"
);
return inner;
}
Err(other_e2) => {
drop(other_e2);
ErasedSinkTrySendError::Other(e)
},
}
}
}
}
fn handle_mpsc_error<T>(me: mpsc::TrySendError<T>) -> (ErasedSinkTrySendError, T) {
let error = $(
${when PREDICATE}
if me.$PREDICATE() {
$vtype
} else
)
{
$ttype::Other(Arc::new(MpscOtherSinkTrySendError {}))
};
(error, me.into_inner())
}
}
#[derive(Debug, Error, Clone, Deftly)]
#[derive_deftly(ErasedSinkTrySendError)]
#[allow(clippy::exhaustive_enums)] pub enum ErasedSinkTrySendError {
#[error("stream full (backpressure)")]
#[deftly(predicate)]
Full,
#[error("stream disconnected")]
#[deftly(predicate)]
Disconnected,
#[error("failed to convey data")]
Other(#[source] Arc<dyn Error + Send + Sync + 'static>),
}
#[derive(Debug, Error)]
#[error("mpsc::Sender::try_send returned an error which is neither .full() nor .disconnected()")]
#[non_exhaustive]
pub struct MpscOtherSinkTrySendError {}
impl<T> SinkTrySend<T> for mpsc::Sender<T> {
type Error = ErasedSinkTrySendError;
fn try_send_or_return(
self: Pin<&mut Self>,
item: T,
) -> Result<(), (ErasedSinkTrySendError, T)> {
let self_: &mut Self = Pin::into_inner(self);
mpsc::Sender::try_send(self_, item).map_err(handle_mpsc_error)
}
}
impl<T> SinkTrySend<T> for mpsc::UnboundedSender<T> {
type Error = ErasedSinkTrySendError;
fn try_send_or_return(
self: Pin<&mut Self>,
item: T,
) -> Result<(), (ErasedSinkTrySendError, T)> {
let self_: &mut Self = Pin::into_inner(self);
mpsc::UnboundedSender::unbounded_send(self_, item).map_err(handle_mpsc_error)
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
#![allow(clippy::arithmetic_side_effects)] #![allow(clippy::useless_format)]
use super::*;
use derive_deftly::derive_deftly_adhoc;
use tor_error::ErrorReport as _;
#[test]
fn chk_erased_sink() {
#[derive(Error, Clone, Debug, Deftly)]
#[error("concrete {is_full} {is_disconnected}")]
#[derive_deftly_adhoc]
struct Concrete {
is_full: bool,
is_disconnected: bool,
}
derive_deftly_adhoc! {
Concrete:
impl SinkTrySendError for Concrete { $(
fn $fname(&self) -> bool { self.$fname }
) }
}
for is_full in [false, true] {
for is_disconnected in [false, true] {
let c = Concrete {
is_full,
is_disconnected,
};
let e = ErasedSinkTrySendError::from(c.clone());
let e2 = ErasedSinkTrySendError::from(e.clone());
let cs = format!("concrete {is_full} {is_disconnected}");
let es = if is_full {
format!("stream full (backpressure)")
} else if is_disconnected {
format!("stream disconnected")
} else {
format!("failed to convey data: {cs}")
};
assert_eq!(c.report().to_string(), format!("error: {cs}"));
assert_eq!(e.report().to_string(), format!("error: {es}"));
assert_eq!(e2.report().to_string(), format!("error: {es}"));
}
}
}
}