backie/
catch_unwind.rs

1use crate::worker::TaskExecError;
2use futures::future::BoxFuture;
3use futures::FutureExt;
4use std::future::Future;
5use std::pin::Pin;
6use std::task::Context;
7use std::task::Poll;
8
9pub(crate) struct CatchUnwindFuture<F: Future + Send + 'static> {
10    inner: BoxFuture<'static, F::Output>,
11}
12
13impl<F: Future + Send + 'static> CatchUnwindFuture<F> {
14    pub fn create(f: F) -> CatchUnwindFuture<F> {
15        Self { inner: f.boxed() }
16    }
17}
18
19impl<F: Future + Send + 'static> Future for CatchUnwindFuture<F> {
20    type Output = Result<F::Output, TaskExecError>;
21
22    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
23        let inner = &mut self.inner;
24
25        match catch_unwind(move || inner.poll_unpin(cx)) {
26            Ok(Poll::Pending) => Poll::Pending,
27            Ok(Poll::Ready(value)) => Poll::Ready(Ok(value)),
28            Err(cause) => Poll::Ready(Err(cause)),
29        }
30    }
31}
32
33fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> Result<R, TaskExecError> {
34    match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
35        Ok(res) => Ok(res),
36        Err(cause) => match cause.downcast_ref::<&'static str>() {
37            Some(message) => Err(TaskExecError::Panicked(message.to_string())),
38            None => match cause.downcast_ref::<String>() {
39                Some(message) => Err(TaskExecError::Panicked(message.to_string())),
40                None => Err(TaskExecError::Panicked(
41                    "Sorry, unknown panic message".to_string(),
42                )),
43            },
44        },
45    }
46}