async_rayon/
async_handle.rs

1use futures::channel::oneshot::Receiver;
2
3use std::future::Future;
4use std::panic::resume_unwind;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use std::thread;
8
9/// Async handle for a blocking task running in a Rayon thread pool.
10///
11/// If the spawned task panics, `poll()` will propagate the panic.
12#[must_use]
13#[derive(Debug)]
14pub struct AsyncRayonHandle<T> {
15    pub(crate) rx: Receiver<thread::Result<T>>,
16}
17
18impl<T> Future for AsyncRayonHandle<T> {
19    type Output = T;
20
21    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
22        let rx = Pin::new(&mut self.rx);
23        rx.poll(cx).map(|result| {
24            result
25                .expect("Unreachable error: Tokio channel closed")
26                .unwrap_or_else(|err| resume_unwind(err))
27        })
28    }
29}
30
31#[cfg(test)]
32mod tests {
33    use futures::channel::oneshot::channel;
34
35    use super::*;
36    use crate::test::init;
37    use std::panic::catch_unwind;
38    use std::thread;
39
40    #[tokio::test]
41    #[should_panic(expected = "Task failed successfully")]
42    async fn test_poll_propagates_panic() {
43        init();
44        let panic_err = catch_unwind(|| {
45            panic!("Task failed successfully");
46        })
47        .unwrap_err();
48
49        let (tx, rx) = channel::<thread::Result<()>>();
50        let handle = AsyncRayonHandle { rx };
51        tx.send(Err(panic_err)).unwrap();
52        handle.await;
53    }
54
55    #[tokio::test]
56    #[should_panic(expected = "Unreachable error: Tokio channel closed")]
57    async fn test_unreachable_channel_closed() {
58        init();
59        let (_, rx) = channel::<thread::Result<()>>();
60        let handle = AsyncRayonHandle { rx };
61        handle.await;
62    }
63}