use bytes::Bytes;
use futures_core::ready;
use futures_util::future::{CatchUnwind, FutureExt};
use http::{HeaderValue, Request, Response, StatusCode};
use http_body::{combinators::UnsyncBoxBody, Body, Full};
use pin_project_lite::pin_project;
use std::{
any::Any,
future::Future,
panic::AssertUnwindSafe,
pin::Pin,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;
use crate::BoxError;
#[derive(Debug, Clone, Copy, Default)]
pub struct CatchPanicLayer<T> {
panic_handler: T,
}
impl CatchPanicLayer<DefaultResponseForPanic> {
pub fn new() -> Self {
CatchPanicLayer {
panic_handler: DefaultResponseForPanic,
}
}
}
impl<T> CatchPanicLayer<T> {
pub fn custom(panic_handler: T) -> Self
where
T: ResponseForPanic,
{
Self { panic_handler }
}
}
impl<T, S> Layer<S> for CatchPanicLayer<T>
where
T: Clone,
{
type Service = CatchPanic<S, T>;
fn layer(&self, inner: S) -> Self::Service {
CatchPanic {
inner,
panic_handler: self.panic_handler.clone(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct CatchPanic<S, T> {
inner: S,
panic_handler: T,
}
impl<S> CatchPanic<S, DefaultResponseForPanic> {
pub fn new(inner: S) -> Self {
Self {
inner,
panic_handler: DefaultResponseForPanic,
}
}
}
impl<S, T> CatchPanic<S, T> {
define_inner_service_accessors!();
pub fn custom(inner: S, panic_handler: T) -> Self
where
T: ResponseForPanic,
{
Self {
inner,
panic_handler,
}
}
}
impl<S, T, ReqBody, ResBody> Service<Request<ReqBody>> for CatchPanic<S, T>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ResBody: Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
T: ResponseForPanic + Clone,
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
<T::ResponseBody as Body>::Error: Into<BoxError>,
{
type Response = Response<UnsyncBoxBody<Bytes, BoxError>>;
type Error = S::Error;
type Future = ResponseFuture<S::Future, T>;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req))) {
Ok(future) => ResponseFuture {
kind: Kind::Future {
future: AssertUnwindSafe(future).catch_unwind(),
panic_handler: Some(self.panic_handler.clone()),
},
},
Err(panic_err) => ResponseFuture {
kind: Kind::Panicked {
panic_err: Some(panic_err),
panic_handler: Some(self.panic_handler.clone()),
},
},
}
}
}
pin_project! {
pub struct ResponseFuture<F, T> {
#[pin]
kind: Kind<F, T>,
}
}
pin_project! {
#[project = KindProj]
enum Kind<F, T> {
Panicked {
panic_err: Option<Box<dyn Any + Send + 'static>>,
panic_handler: Option<T>,
},
Future {
#[pin]
future: CatchUnwind<AssertUnwindSafe<F>>,
panic_handler: Option<T>,
}
}
}
impl<F, ResBody, E, T> Future for ResponseFuture<F, T>
where
F: Future<Output = Result<Response<ResBody>, E>>,
ResBody: Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
T: ResponseForPanic,
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
<T::ResponseBody as Body>::Error: Into<BoxError>,
{
type Output = Result<Response<UnsyncBoxBody<Bytes, BoxError>>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().kind.project() {
KindProj::Panicked {
panic_err,
panic_handler,
} => {
let panic_handler = panic_handler
.take()
.expect("future polled after completion");
let panic_err = panic_err.take().expect("future polled after completion");
Poll::Ready(Ok(response_for_panic(panic_handler, panic_err)))
}
KindProj::Future {
future,
panic_handler,
} => match ready!(future.poll(cx)) {
Ok(Ok(res)) => {
Poll::Ready(Ok(res.map(|body| body.map_err(Into::into).boxed_unsync())))
}
Ok(Err(svc_err)) => Poll::Ready(Err(svc_err)),
Err(panic_err) => Poll::Ready(Ok(response_for_panic(
panic_handler
.take()
.expect("future polled after completion"),
panic_err,
))),
},
}
}
}
fn response_for_panic<T>(
mut panic_handler: T,
err: Box<dyn Any + Send + 'static>,
) -> Response<UnsyncBoxBody<Bytes, BoxError>>
where
T: ResponseForPanic,
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
<T::ResponseBody as Body>::Error: Into<BoxError>,
{
panic_handler
.response_for_panic(err)
.map(|body| body.map_err(Into::into).boxed_unsync())
}
pub trait ResponseForPanic: Clone {
type ResponseBody;
fn response_for_panic(
&mut self,
err: Box<dyn Any + Send + 'static>,
) -> Response<Self::ResponseBody>;
}
impl<F, B> ResponseForPanic for F
where
F: FnMut(Box<dyn Any + Send + 'static>) -> Response<B> + Clone,
{
type ResponseBody = B;
fn response_for_panic(
&mut self,
err: Box<dyn Any + Send + 'static>,
) -> Response<Self::ResponseBody> {
self(err)
}
}
#[derive(Debug, Default, Clone, Copy)]
#[non_exhaustive]
pub struct DefaultResponseForPanic;
impl ResponseForPanic for DefaultResponseForPanic {
type ResponseBody = Full<Bytes>;
fn response_for_panic(
&mut self,
err: Box<dyn Any + Send + 'static>,
) -> Response<Self::ResponseBody> {
if let Some(s) = err.downcast_ref::<String>() {
tracing::error!("Service panicked: {}", s);
} else if let Some(s) = err.downcast_ref::<&str>() {
tracing::error!("Service panicked: {}", s);
} else {
tracing::error!(
"Service panicked but `CatchPanic` was unable to downcast the panic info"
);
};
let mut res = Response::new(Full::from("Service panicked"));
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
#[allow(clippy::declare_interior_mutable_const)]
const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
res.headers_mut()
.insert(http::header::CONTENT_TYPE, TEXT_PLAIN);
res
}
}
#[cfg(test)]
mod tests {
#![allow(unreachable_code)]
use super::*;
use hyper::{Body, Response};
use std::convert::Infallible;
use tower::{ServiceBuilder, ServiceExt};
#[tokio::test]
async fn panic_before_returning_future() {
let svc = ServiceBuilder::new()
.layer(CatchPanicLayer::new())
.service_fn(|_: Request<Body>| {
panic!("service panic");
async { Ok::<_, Infallible>(Response::new(Body::empty())) }
});
let req = Request::new(Body::empty());
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = hyper::body::to_bytes(res).await.unwrap();
assert_eq!(&body[..], b"Service panicked");
}
#[tokio::test]
async fn panic_in_future() {
let svc = ServiceBuilder::new()
.layer(CatchPanicLayer::new())
.service_fn(|_: Request<Body>| async {
panic!("future panic");
Ok::<_, Infallible>(Response::new(Body::empty()))
});
let req = Request::new(Body::empty());
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = hyper::body::to_bytes(res).await.unwrap();
assert_eq!(&body[..], b"Service panicked");
}
}