lambda_runtime 1.1.3

AWS Lambda Runtime
Documentation
use crate::{diagnostic::type_name_of_val, Diagnostic, LambdaEvent};
use futures::{future::CatchUnwind, FutureExt};
use pin_project::pin_project;
use std::{any::Any, fmt::Debug, future::Future, marker::PhantomData, panic::AssertUnwindSafe, pin::Pin, task};
use tower::Service;
use tracing::error;

/// Tower service that transforms panics into an error. Panics are converted to errors both when
/// constructed in [tower::Service::call] and when constructed in the returned
/// [tower::Service::Future].
///
/// This type is only meant for internal use in the Lambda runtime crate. It neither augments the
/// inner service's request type, nor its response type. It merely transforms the error type
/// from `Into<Diagnostic<'_> + Debug` into `Diagnostic<'a>` to turn panics into diagnostics.
#[derive(Clone)]
pub struct CatchPanicService<'a, S> {
    inner: S,
    _phantom: PhantomData<&'a ()>,
}

impl<S> CatchPanicService<'_, S> {
    pub fn new(inner: S) -> Self {
        Self {
            inner,
            _phantom: PhantomData,
        }
    }
}

impl<'a, S, Payload> Service<LambdaEvent<Payload>> for CatchPanicService<'a, S>
where
    S: Service<LambdaEvent<Payload>>,
    S::Future: 'a,
    S::Error: Into<Diagnostic> + Debug,
{
    type Error = Diagnostic;
    type Response = S::Response;
    type Future = CatchPanicFuture<'a, S::Future>;

    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx).map_err(|err| err.into())
    }

    fn call(&mut self, req: LambdaEvent<Payload>) -> Self::Future {
        // Catch panics that result from calling `call` on the service
        let task = std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req)));

        // Catch panics that result from polling the future returned from `call`
        match task {
            Ok(task) => {
                let fut = AssertUnwindSafe(task).catch_unwind();
                CatchPanicFuture::Future(fut, PhantomData)
            }
            Err(error) => {
                error!(?error, "user handler panicked");
                CatchPanicFuture::Error(error)
            }
        }
    }
}

/// Future returned by [CatchPanicService].
#[pin_project(project = CatchPanicFutureProj)]
pub enum CatchPanicFuture<'a, F> {
    Future(#[pin] CatchUnwind<AssertUnwindSafe<F>>, PhantomData<&'a ()>),
    Error(Box<dyn Any + Send + 'static>),
}

impl<F, T, E> Future for CatchPanicFuture<'_, F>
where
    F: Future<Output = Result<T, E>>,
    E: Into<Diagnostic> + Debug,
{
    type Output = Result<T, Diagnostic>;

    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
        use task::Poll;
        match self.project() {
            CatchPanicFutureProj::Future(fut, _) => match fut.poll(cx) {
                Poll::Ready(ready) => match ready {
                    Ok(Ok(success)) => Poll::Ready(Ok(success)),
                    Ok(Err(error)) => {
                        error!("{error:?}");
                        Poll::Ready(Err(error.into()))
                    }
                    Err(error) => {
                        error!(?error, "user handler panicked");
                        Poll::Ready(Err(Self::build_panic_diagnostic(&error)))
                    }
                },
                Poll::Pending => Poll::Pending,
            },
            CatchPanicFutureProj::Error(error) => Poll::Ready(Err(Self::build_panic_diagnostic(error))),
        }
    }
}

impl<F> CatchPanicFuture<'_, F> {
    fn build_panic_diagnostic(err: &Box<dyn Any + Send>) -> Diagnostic {
        let error_message = if let Some(msg) = err.downcast_ref::<&str>() {
            format!("Lambda panicked: {msg}")
        } else if let Some(msg) = err.downcast_ref::<String>() {
            format!("Lambda panicked: {msg}")
        } else {
            "Lambda panicked".to_string()
        };
        Diagnostic {
            error_type: type_name_of_val(err),
            error_message,
        }
    }
}