Skip to main content

hydra/
catch_unwind.rs

1use std::future::Future;
2use std::panic::AssertUnwindSafe;
3use std::panic::UnwindSafe;
4use std::panic::catch_unwind;
5use std::pin::Pin;
6use std::task::Context;
7use std::task::Poll;
8
9use pin_project_lite::pin_project;
10
11pin_project! {
12    /// A future that will catch panics and unwind them.
13    pub struct AsyncCatchUnwind<Fut>
14    where
15        Fut: Future,
16    {
17        #[pin]
18        future: Fut,
19    }
20}
21
22impl<Fut> AsyncCatchUnwind<Fut>
23where
24    Fut: Future + UnwindSafe,
25{
26    /// Constructs a new [CatchUnwind] for the given future.
27    pub fn new(future: Fut) -> Self {
28        Self { future }
29    }
30}
31
32impl<Fut> Future for AsyncCatchUnwind<Fut>
33where
34    Fut: Future + UnwindSafe,
35{
36    type Output = Result<Fut::Output, String>;
37
38    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
39        let f = self.project().future;
40
41        catch_unwind(AssertUnwindSafe(|| f.poll(cx)))
42            .map_err(|x| {
43                if x.is::<String>() {
44                    return *x.downcast::<String>().unwrap();
45                } else if x.is::<&str>() {
46                    return x.downcast::<&str>().unwrap().to_string();
47                }
48
49                "Unknown error!".to_string()
50            })?
51            .map(Ok)
52    }
53}