#![allow(clippy::type_complexity)]
use crate::service::handler::BodyWriteAborted;
use crate::service::{Layer, Service};
use futures_util::FutureExt;
use http::{HeaderMap, Response, StatusCode};
use http_body::{Body, SizeHint};
use pin_project::pin_project;
use std::panic::{self, AssertUnwindSafe};
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct CatchUnwindLayer;
impl<S> Layer<S> for CatchUnwindLayer {
type Service = CatchUnwindService<S>;
fn layer(self, inner: S) -> Self::Service {
CatchUnwindService { inner }
}
}
pub struct CatchUnwindService<S> {
inner: S,
}
impl<S, R, B> Service<R> for CatchUnwindService<S>
where
S: Service<R, Response = Response<B>> + Sync,
R: Send,
{
type Response = Response<CatchUnwindBody<B>>;
async fn call(&self, req: R) -> Self::Response {
let r = match panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req))) {
Ok(future) => AssertUnwindSafe(future).catch_unwind().await,
Err(e) => Err(e),
};
match r {
Ok(response) => response.map(|inner| CatchUnwindBody { inner: Some(inner) }),
Err(_) => panic_response(),
}
}
}
fn panic_response<B>() -> Response<CatchUnwindBody<B>> {
let mut response = Response::new(CatchUnwindBody { inner: None });
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
response
}
#[pin_project]
pub struct CatchUnwindBody<B> {
#[pin]
inner: Option<B>,
}
impl<B> Body for CatchUnwindBody<B>
where
B: Body<Error = BodyWriteAborted>,
{
type Data = B::Data;
type Error = B::Error;
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let mut this = self.project();
match this.inner.as_mut().as_pin_mut() {
Some(inner) => match panic::catch_unwind(AssertUnwindSafe(|| inner.poll_data(cx))) {
Ok(poll) => poll,
Err(_) => {
this.inner.set(None);
Poll::Ready(Some(Err(BodyWriteAborted)))
}
},
None => Poll::Ready(None),
}
}
fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
let mut this = self.project();
match this.inner.as_mut().as_pin_mut() {
Some(inner) => {
match panic::catch_unwind(AssertUnwindSafe(|| inner.poll_trailers(cx))) {
Ok(poll) => poll,
Err(_) => {
this.inner.set(None);
Poll::Ready(Err(BodyWriteAborted))
}
}
}
None => Poll::Ready(Ok(None)),
}
}
fn is_end_stream(&self) -> bool {
self.inner.as_ref().map_or(true, Body::is_end_stream)
}
fn size_hint(&self) -> SizeHint {
self.inner
.as_ref()
.map_or_else(|| SizeHint::with_exact(0), Body::size_hint)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::service::test_util::service_fn;
use bytes::Bytes;
use futures::future;
#[tokio::test]
async fn service_panic() {
fn handle() -> future::Ready<Response<()>> {
panic!()
}
let service = CatchUnwindLayer.layer(service_fn(|_| handle()));
let response = service.call(()).await;
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn service_async_panic() {
fn handle() -> Response<()> {
panic!()
}
let service = CatchUnwindLayer.layer(service_fn(|_| async { handle() }));
let response = service.call(()).await;
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn body_panic() {
struct TestBody;
impl Body for TestBody {
type Data = Bytes;
type Error = BodyWriteAborted;
fn poll_data(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
panic!()
}
fn poll_trailers(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
unimplemented!()
}
}
let service = CatchUnwindLayer.layer(service_fn(|_| async { Response::new(TestBody) }));
let response = service.call(()).await;
assert_eq!(response.status(), StatusCode::OK);
assert!(matches!(
response.into_body().data().await,
Some(Err(BodyWriteAborted))
));
}
}