#![cfg(feature = "tokio")]
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use futures::{Sink, SinkExt, StreamExt};
use samod::{AcceptorHandle, BackoffConfig, Dialer, PeerId, Repo, Transport};
use tokio_stream::wrappers::ReceiverStream;
use url::Url;
fn init_logging() {
let _ = tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init();
}
#[derive(Debug)]
struct FaultyError(String);
impl std::fmt::Display for FaultyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for FaultyError {}
struct FailOnSendAndCloseSink {
send_errored: bool,
}
impl FailOnSendAndCloseSink {
fn new() -> Self {
Self {
send_errored: false,
}
}
}
impl Unpin for FailOnSendAndCloseSink {}
impl Sink<Vec<u8>> for FailOnSendAndCloseSink {
type Error = FaultyError;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, _item: Vec<u8>) -> Result<(), Self::Error> {
Ok(())
}
fn poll_flush(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
if !self.send_errored {
self.send_errored = true;
Poll::Ready(Err(FaultyError("Connection timed out".into())))
} else {
Poll::Ready(Ok(()))
}
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Err(FaultyError("broken pipe on close".into())))
}
}
struct FaultySinkDialer {
url: Url,
acceptor: AcceptorHandle,
}
impl Dialer for FaultySinkDialer {
fn url(&self) -> Url {
self.url.clone()
}
fn connect(
&self,
) -> Pin<
Box<
dyn std::future::Future<
Output = Result<Transport, Box<dyn std::error::Error + Send + Sync + 'static>>,
> + Send,
>,
> {
let acceptor = self.acceptor.clone();
Box::pin(async move {
let (acc_tx, dialer_rx) = tokio::sync::mpsc::channel::<Vec<u8>>(16);
let (dialer_tx, acc_rx) = tokio::sync::mpsc::channel::<Vec<u8>>(16);
let acc_stream = ReceiverStream::new(acc_rx).map(Ok::<_, FaultyError>);
let acc_sink = tokio_util::sync::PollSender::new(acc_tx)
.sink_map_err(|e| FaultyError(format!("send error: {e:?}")));
acceptor
.accept(Transport::new(acc_stream, acc_sink))
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync + 'static>)?;
drop(dialer_tx);
let dialer_stream = ReceiverStream::new(dialer_rx).map(Ok::<_, FaultyError>);
let faulty_sink = FailOnSendAndCloseSink::new();
Ok(Transport::new(dialer_stream, faulty_sink))
})
}
}
#[tokio::test]
async fn sink_send_error_then_close_error_does_not_panic() {
init_logging();
let panicked: Arc<std::sync::Mutex<Option<String>>> = Arc::new(std::sync::Mutex::new(None));
let panicked_hook = panicked.clone();
let prev_hook = std::panic::take_hook();
std::panic::set_hook(Box::new(move |info| {
let msg = if let Some(s) = info.payload().downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = info.payload().downcast_ref::<String>() {
s.clone()
} else {
info.to_string()
};
*panicked_hook.lock().unwrap() = Some(msg);
}));
let alice = Repo::build_tokio()
.with_peer_id(PeerId::from("alice"))
.load()
.await;
let bob = Repo::build_tokio()
.with_peer_id(PeerId::from("bob"))
.load()
.await;
let url = Url::parse("ws://test-faulty-sink:0").unwrap();
let acceptor = bob.make_acceptor(url.clone()).unwrap();
let dialer = FaultySinkDialer { url, acceptor };
let _handle = alice
.dial(BackoffConfig::default(), Arc::new(dialer))
.unwrap();
tokio::time::sleep(Duration::from_millis(500)).await;
std::panic::set_hook(prev_hook);
if let Some(msg) = panicked.lock().unwrap().take() {
panic!("spawned task panicked: {msg}");
}
}