use crate::{diagnostic::type_name_of_val, Diagnostic, LambdaEvent};
use futures::{future::CatchUnwind, FutureExt};
use pin_project::pin_project;
use std::{any::Any, fmt::Debug, future::Future, marker::PhantomData, panic::AssertUnwindSafe, pin::Pin, task};
use tower::Service;
use tracing::error;
#[derive(Clone)]
pub struct CatchPanicService<'a, S> {
inner: S,
_phantom: PhantomData<&'a ()>,
}
impl<S> CatchPanicService<'_, S> {
pub fn new(inner: S) -> Self {
Self {
inner,
_phantom: PhantomData,
}
}
}
impl<'a, S, Payload> Service<LambdaEvent<Payload>> for CatchPanicService<'a, S>
where
S: Service<LambdaEvent<Payload>>,
S::Future: 'a,
S::Error: Into<Diagnostic> + Debug,
{
type Error = Diagnostic;
type Response = S::Response;
type Future = CatchPanicFuture<'a, S::Future>;
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|err| err.into())
}
fn call(&mut self, req: LambdaEvent<Payload>) -> Self::Future {
let task = std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req)));
match task {
Ok(task) => {
let fut = AssertUnwindSafe(task).catch_unwind();
CatchPanicFuture::Future(fut, PhantomData)
}
Err(error) => {
error!(?error, "user handler panicked");
CatchPanicFuture::Error(error)
}
}
}
}
#[pin_project(project = CatchPanicFutureProj)]
pub enum CatchPanicFuture<'a, F> {
Future(#[pin] CatchUnwind<AssertUnwindSafe<F>>, PhantomData<&'a ()>),
Error(Box<dyn Any + Send + 'static>),
}
impl<F, T, E> Future for CatchPanicFuture<'_, F>
where
F: Future<Output = Result<T, E>>,
E: Into<Diagnostic> + Debug,
{
type Output = Result<T, Diagnostic>;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
use task::Poll;
match self.project() {
CatchPanicFutureProj::Future(fut, _) => match fut.poll(cx) {
Poll::Ready(ready) => match ready {
Ok(Ok(success)) => Poll::Ready(Ok(success)),
Ok(Err(error)) => {
error!("{error:?}");
Poll::Ready(Err(error.into()))
}
Err(error) => {
error!(?error, "user handler panicked");
Poll::Ready(Err(Self::build_panic_diagnostic(&error)))
}
},
Poll::Pending => Poll::Pending,
},
CatchPanicFutureProj::Error(error) => Poll::Ready(Err(Self::build_panic_diagnostic(error))),
}
}
}
impl<F> CatchPanicFuture<'_, F> {
fn build_panic_diagnostic(err: &Box<dyn Any + Send>) -> Diagnostic {
let error_message = if let Some(msg) = err.downcast_ref::<&str>() {
format!("Lambda panicked: {msg}")
} else if let Some(msg) = err.downcast_ref::<String>() {
format!("Lambda panicked: {msg}")
} else {
"Lambda panicked".to_string()
};
Diagnostic {
error_type: type_name_of_val(err),
error_message,
}
}
}