1use crate::worker::TaskExecError;
2use futures::future::BoxFuture;
3use futures::FutureExt;
4use std::future::Future;
5use std::pin::Pin;
6use std::task::Context;
7use std::task::Poll;
8
9pub(crate) struct CatchUnwindFuture<F: Future + Send + 'static> {
10 inner: BoxFuture<'static, F::Output>,
11}
12
13impl<F: Future + Send + 'static> CatchUnwindFuture<F> {
14 pub fn create(f: F) -> CatchUnwindFuture<F> {
15 Self { inner: f.boxed() }
16 }
17}
18
19impl<F: Future + Send + 'static> Future for CatchUnwindFuture<F> {
20 type Output = Result<F::Output, TaskExecError>;
21
22 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
23 let inner = &mut self.inner;
24
25 match catch_unwind(move || inner.poll_unpin(cx)) {
26 Ok(Poll::Pending) => Poll::Pending,
27 Ok(Poll::Ready(value)) => Poll::Ready(Ok(value)),
28 Err(cause) => Poll::Ready(Err(cause)),
29 }
30 }
31}
32
33fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> Result<R, TaskExecError> {
34 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
35 Ok(res) => Ok(res),
36 Err(cause) => match cause.downcast_ref::<&'static str>() {
37 Some(message) => Err(TaskExecError::Panicked(message.to_string())),
38 None => match cause.downcast_ref::<String>() {
39 Some(message) => Err(TaskExecError::Panicked(message.to_string())),
40 None => Err(TaskExecError::Panicked(
41 "Sorry, unknown panic message".to_string(),
42 )),
43 },
44 },
45 }
46}