use crate::{Body, HeaderValue, Request, Response, StatusCode};
use futures_lite::future::FutureExt;
use rama_core::{Context, Layer, Service};
use rama_utils::macros::define_inner_service_accessors;
use std::fmt;
use std::{any::Any, panic::AssertUnwindSafe};
pub struct CatchPanicLayer<T> {
panic_handler: T,
}
impl<T: fmt::Debug> fmt::Debug for CatchPanicLayer<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("CatchPanicLayer")
.field("panic_handler", &self.panic_handler)
.finish()
}
}
impl<T: Clone> Clone for CatchPanicLayer<T> {
fn clone(&self) -> Self {
Self {
panic_handler: self.panic_handler.clone(),
}
}
}
impl Default for CatchPanicLayer<DefaultResponseForPanic> {
fn default() -> Self {
Self::new()
}
}
impl CatchPanicLayer<DefaultResponseForPanic> {
pub const fn new() -> Self {
CatchPanicLayer {
panic_handler: DefaultResponseForPanic,
}
}
}
impl<T> CatchPanicLayer<T> {
pub fn custom(panic_handler: T) -> Self
where
T: ResponseForPanic,
{
Self { panic_handler }
}
}
impl<T, S> Layer<S> for CatchPanicLayer<T>
where
T: Clone,
{
type Service = CatchPanic<S, T>;
fn layer(&self, inner: S) -> Self::Service {
CatchPanic {
inner,
panic_handler: self.panic_handler.clone(),
}
}
fn into_layer(self, inner: S) -> Self::Service {
CatchPanic {
inner,
panic_handler: self.panic_handler,
}
}
}
pub struct CatchPanic<S, T> {
inner: S,
panic_handler: T,
}
impl<S> CatchPanic<S, DefaultResponseForPanic> {
pub const fn new(inner: S) -> Self {
Self {
inner,
panic_handler: DefaultResponseForPanic,
}
}
}
impl<S, T> CatchPanic<S, T> {
define_inner_service_accessors!();
pub const fn custom(inner: S, panic_handler: T) -> Self
where
T: ResponseForPanic,
{
Self {
inner,
panic_handler,
}
}
}
impl<S: fmt::Debug, T: fmt::Debug> fmt::Debug for CatchPanic<S, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CatchPanic")
.field("inner", &self.inner)
.field("panic_handler", &self.panic_handler)
.finish()
}
}
impl<S: Clone, T: Clone> Clone for CatchPanic<S, T> {
fn clone(&self) -> Self {
CatchPanic {
inner: self.inner.clone(),
panic_handler: self.panic_handler.clone(),
}
}
}
impl<State, S, T, ReqBody, ResBody> Service<State, Request<ReqBody>> for CatchPanic<S, T>
where
S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
ResBody: Into<Body> + Send + 'static,
T: ResponseForPanic + Clone + Send + Sync + 'static,
ReqBody: Send + 'static,
ResBody: Send + 'static,
State: Clone + Send + Sync + 'static,
{
type Response = Response;
type Error = S::Error;
async fn serve(
&self,
ctx: Context<State>,
req: Request<ReqBody>,
) -> Result<Self::Response, Self::Error> {
let future = match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.serve(ctx, req)))
{
Ok(future) => future,
Err(panic_err) => return Ok(self.panic_handler.response_for_panic(panic_err)),
};
match AssertUnwindSafe(future).catch_unwind().await {
Ok(res) => match res {
Ok(res) => Ok(res.map(Into::into)),
Err(err) => Err(err),
},
Err(panic_err) => Ok(self.panic_handler.response_for_panic(panic_err)),
}
}
}
pub trait ResponseForPanic: Clone {
fn response_for_panic(&self, err: Box<dyn Any + Send + 'static>) -> Response<Body>;
}
impl<F> ResponseForPanic for F
where
F: Fn(Box<dyn Any + Send + 'static>) -> Response + Clone,
{
fn response_for_panic(&self, err: Box<dyn Any + Send + 'static>) -> Response {
self(err)
}
}
#[derive(Debug, Default, Clone)]
#[non_exhaustive]
pub struct DefaultResponseForPanic;
impl ResponseForPanic for DefaultResponseForPanic {
fn response_for_panic(&self, err: Box<dyn Any + Send + 'static>) -> Response {
if let Some(s) = err.downcast_ref::<String>() {
tracing::error!("Service panicked: {}", s);
} else if let Some(s) = err.downcast_ref::<&str>() {
tracing::error!("Service panicked: {}", s);
} else {
tracing::error!(
"Service panicked but `CatchPanic` was unable to downcast the panic info"
);
};
let mut res = Response::new(Body::from("Service panicked"));
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
#[allow(clippy::declare_interior_mutable_const)]
const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
res.headers_mut()
.insert(rama_http_types::header::CONTENT_TYPE, TEXT_PLAIN);
res
}
}
#[cfg(test)]
mod tests {
#![allow(unreachable_code)]
use super::*;
use crate::dep::http_body_util::BodyExt;
use crate::{Body, Response};
use rama_core::service::service_fn;
use rama_core::{Context, Service};
use std::convert::Infallible;
#[tokio::test]
async fn panic_before_returning_future() {
let svc = CatchPanicLayer::new().into_layer(service_fn(|_: Request| {
panic!("service panic");
async { Ok::<_, Infallible>(Response::new(Body::empty())) }
}));
let req = Request::new(Body::empty());
let res = svc.serve(Context::default(), req).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = res.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"Service panicked");
}
#[tokio::test]
async fn panic_in_future() {
let svc = CatchPanicLayer::new().into_layer(service_fn(async |_: Request<Body>| {
panic!("future panic");
Ok::<_, Infallible>(Response::new(Body::empty()))
}));
let req = Request::new(Body::empty());
let res = svc.serve(Context::default(), req).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = res.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"Service panicked");
}
}