async_callback_manager/
panicking_receiver_stream.rs1use std::pin::Pin;
2use std::task::{Context, Poll};
3use tokio::sync::mpsc::Receiver;
4use tokio::task::JoinHandle;
5use tokio_stream::Stream;
6use tokio_stream::wrappers::ReceiverStream; pub struct PanickingReceiverStream<T> {
12 pub inner: ReceiverStream<T>,
13 pub handle: JoinHandle<()>,
14}
15
16impl<T> PanickingReceiverStream<T> {
17 pub fn new(recv: Receiver<T>, join_handle: JoinHandle<()>) -> Self {
18 Self {
19 inner: ReceiverStream::new(recv),
20 handle: join_handle,
21 }
22 }
23}
24
25impl<T> Stream for PanickingReceiverStream<T> {
26 type Item = T;
27
28 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
29 match Pin::new(&mut self.inner).poll_next(cx) {
30 Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
31 Poll::Ready(None) => {
32 match Pin::new(&mut self.handle).poll(cx) {
33 Poll::Pending => Poll::Pending,
35 Poll::Ready(Err(e)) if e.is_panic() => {
37 std::panic::resume_unwind(e.into_panic());
38 }
39 _ => Poll::Ready(None),
41 }
42 }
43 Poll::Pending => Poll::Pending,
44 }
45 }
46}
47
48#[cfg(test)]
49mod tests {
50 use crate::PanickingReceiverStream;
51 use futures::StreamExt;
52 use tokio_stream::wrappers::ReceiverStream;
53
54 #[tokio::test]
55 async fn assert_tokio_receiver_stream_does_not_panic_if_task_panics() {
56 let (tx, rx) = tokio::sync::mpsc::channel(30);
57 tokio::spawn(async move {
58 for i in 0..=10 {
59 if i == 6 {
60 panic!();
61 }
62 tx.send(i).await.unwrap();
63 }
64 });
65 let stream = ReceiverStream::new(rx);
66 let output: Vec<_> = stream.collect().await;
67 assert_eq!(output, vec![0, 1, 2, 3, 4, 5]);
68 }
69
70 #[tokio::test]
71 #[should_panic]
72 async fn panicking_receiver_stream_should_panic_if_task_panics() {
73 let (tx, rx) = tokio::sync::mpsc::channel(30);
74 let handle = tokio::spawn(async move {
75 for i in 0..=10 {
76 if i == 6 {
77 panic!();
78 }
79 tx.send(i).await.unwrap();
80 }
81 });
82 let stream = PanickingReceiverStream::new(rx, handle);
83 let output: Vec<_> = stream.collect().await;
84 assert_eq!(output, vec![0, 1, 2, 3, 4, 5]);
85 }
86}