1use std::future::Future;
2use std::panic::AssertUnwindSafe;
3use std::panic::UnwindSafe;
4use std::panic::catch_unwind;
5use std::pin::Pin;
6use std::task::Context;
7use std::task::Poll;
8
9use pin_project_lite::pin_project;
10
11pin_project! {
12 pub struct AsyncCatchUnwind<Fut>
14 where
15 Fut: Future,
16 {
17 #[pin]
18 future: Fut,
19 }
20}
21
22impl<Fut> AsyncCatchUnwind<Fut>
23where
24 Fut: Future + UnwindSafe,
25{
26 pub fn new(future: Fut) -> Self {
28 Self { future }
29 }
30}
31
32impl<Fut> Future for AsyncCatchUnwind<Fut>
33where
34 Fut: Future + UnwindSafe,
35{
36 type Output = Result<Fut::Output, String>;
37
38 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
39 let f = self.project().future;
40
41 catch_unwind(AssertUnwindSafe(|| f.poll(cx)))
42 .map_err(|x| {
43 if x.is::<String>() {
44 return *x.downcast::<String>().unwrap();
45 } else if x.is::<&str>() {
46 return x.downcast::<&str>().unwrap().to_string();
47 }
48
49 "Unknown error!".to_string()
50 })?
51 .map(Ok)
52 }
53}