use std::{
future::Future,
panic::{AssertUnwindSafe, catch_unwind},
pin::Pin,
task::{Context, Poll},
};
use axum::{body::Body, extract::Request};
use futures_util::FutureExt;
use http::{Response, StatusCode};
use tower::{Layer, Service};
pub fn catch_panic_layer() -> CatchPanicLayer {
CatchPanicLayer
}
#[derive(Clone, Copy, Debug, Default)]
pub struct CatchPanicLayer;
impl<S> Layer<S> for CatchPanicLayer {
type Service = CatchPanicService<S>;
fn layer(&self, inner: S) -> Self::Service {
CatchPanicService { inner }
}
}
#[derive(Clone, Debug)]
pub struct CatchPanicService<S> {
inner: S,
}
impl<S> Service<Request> for CatchPanicService<S>
where
S: Service<Request, Response = Response<Body>> + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Response<Body>, S::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request) -> Self::Future {
match catch_unwind(AssertUnwindSafe(|| self.inner.call(request))) {
Ok(future) => Box::pin(async move {
match AssertUnwindSafe(future).catch_unwind().await {
Ok(result) => result,
Err(payload) => {
log_panic(&payload);
Ok(internal_server_error())
}
}
}),
Err(payload) => Box::pin(async move {
log_panic(&payload);
Ok(internal_server_error())
}),
}
}
}
fn log_panic(payload: &Box<dyn std::any::Any + Send + 'static>) {
if let Some(message) = payload.downcast_ref::<String>() {
tracing::error!(http.status = 500, panic.message = %message, "request handler panicked");
} else if let Some(message) = payload.downcast_ref::<&'static str>() {
tracing::error!(http.status = 500, panic.message = %message, "request handler panicked");
} else {
tracing::error!(
http.status = 500,
"request handler panicked with non-string payload"
);
}
}
fn internal_server_error() -> Response<Body> {
let mut response = Response::new(Body::empty());
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
response
}