async_callback_manager/
panicking_receiver_stream.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3use tokio::sync::mpsc::Receiver;
4use tokio::task::JoinHandle;
5use tokio_stream::Stream;
6use tokio_stream::wrappers::ReceiverStream; // or std::future::Future
7
8/// A modification to tokio's ReceiverStream that awaits a JoinHandle prior to
9/// reporting closed, rethrowing the panic if there was one. Use this when the
10/// ReceiverStream is driven by a task that may panic.
11pub struct PanickingReceiverStream<T> {
12    pub inner: ReceiverStream<T>,
13    pub handle: JoinHandle<()>,
14}
15
16impl<T> PanickingReceiverStream<T> {
17    pub fn new(recv: Receiver<T>, join_handle: JoinHandle<()>) -> Self {
18        Self {
19            inner: ReceiverStream::new(recv),
20            handle: join_handle,
21        }
22    }
23}
24
25impl<T> Stream for PanickingReceiverStream<T> {
26    type Item = T;
27
28    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
29        match Pin::new(&mut self.inner).poll_next(cx) {
30            Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
31            Poll::Ready(None) => {
32                match Pin::new(&mut self.handle).poll(cx) {
33                    // Task is still tearing down; wait for it to finish to capture the panic.
34                    Poll::Pending => Poll::Pending,
35                    // Task panicked! Rethrow it.
36                    Poll::Ready(Err(e)) if e.is_panic() => {
37                        std::panic::resume_unwind(e.into_panic());
38                    }
39                    // Task finished normally or was cancelled.
40                    _ => Poll::Ready(None),
41                }
42            }
43            Poll::Pending => Poll::Pending,
44        }
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use crate::PanickingReceiverStream;
51    use futures::StreamExt;
52    use tokio_stream::wrappers::ReceiverStream;
53
54    #[tokio::test]
55    async fn assert_tokio_receiver_stream_does_not_panic_if_task_panics() {
56        let (tx, rx) = tokio::sync::mpsc::channel(30);
57        tokio::spawn(async move {
58            for i in 0..=10 {
59                if i == 6 {
60                    panic!();
61                }
62                tx.send(i).await.unwrap();
63            }
64        });
65        let stream = ReceiverStream::new(rx);
66        let output: Vec<_> = stream.collect().await;
67        assert_eq!(output, vec![0, 1, 2, 3, 4, 5]);
68    }
69
70    #[tokio::test]
71    #[should_panic]
72    async fn panicking_receiver_stream_should_panic_if_task_panics() {
73        let (tx, rx) = tokio::sync::mpsc::channel(30);
74        let handle = tokio::spawn(async move {
75            for i in 0..=10 {
76                if i == 6 {
77                    panic!();
78                }
79                tx.send(i).await.unwrap();
80            }
81        });
82        let stream = PanickingReceiverStream::new(rx, handle);
83        let output: Vec<_> = stream.collect().await;
84        assert_eq!(output, vec![0, 1, 2, 3, 4, 5]);
85    }
86}