1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
use std::{any::Any, panic::AssertUnwindSafe};
use futures_util::FutureExt;
use http::StatusCode;
use crate::{Endpoint, IntoResponse, Middleware, Request, Response, Result};
/// Panics handler
pub trait PanicHandler: Clone + Sync + Send + 'static {
/// Response type
type Response: IntoResponse;
/// Call this method to create a response when a panic occurs.
fn get_response(&self, err: Box<dyn Any + Send + 'static>) -> Self::Response;
}
impl PanicHandler for () {
type Response = (StatusCode, &'static str);
fn get_response(&self, _err: Box<dyn Any + Send + 'static>) -> Self::Response {
(StatusCode::INTERNAL_SERVER_ERROR, "internal server error")
}
}
impl<F, R> PanicHandler for F
where
F: Fn(Box<dyn Any + Send + 'static>) -> R + Send + Sync + Clone + 'static,
R: IntoResponse,
{
type Response = R;
fn get_response(&self, err: Box<dyn Any + Send + 'static>) -> Self::Response {
(self)(err)
}
}
/// Middleware for catches panics and converts them into `500 INTERNAL SERVER
/// ERROR` responses.
///
/// # Example
///
/// ```rust
/// use http::StatusCode;
/// use poem::{handler, middleware::CatchPanic, test::TestClient, EndpointExt, Route};
///
/// #[handler]
/// async fn index() {
/// panic!()
/// }
///
/// let app = Route::new().at("/", index).with(CatchPanic::new());
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let cli = TestClient::new(app);
/// let resp = cli.get("/").send().await;
/// resp.assert_status(StatusCode::INTERNAL_SERVER_ERROR);
/// # });
/// ```
pub struct CatchPanic<H> {
panic_handler: H,
}
impl CatchPanic<()> {
/// Create new `CatchPanic` middleware.
#[inline]
pub fn new() -> Self {
CatchPanic { panic_handler: () }
}
}
impl Default for CatchPanic<()> {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl<H> CatchPanic<H> {
/// Specifies a panic handler to be used to create a custom response when
/// a panic occurs.
///
/// # Example
///
/// ```rust
/// use http::StatusCode;
/// use poem::{
/// handler, middleware::CatchPanic, test::TestClient, EndpointExt, IntoResponse, Route,
/// };
///
/// #[handler]
/// async fn index() {
/// panic!()
/// }
///
/// let app = Route::new().at("/", index).with(
/// CatchPanic::new().with_handler(|_| "error!".with_status(StatusCode::INTERNAL_SERVER_ERROR)),
/// );
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let cli = TestClient::new(app);
/// let resp = cli.get("/").send().await;
/// resp.assert_status(StatusCode::INTERNAL_SERVER_ERROR);
/// resp.assert_text("error!").await;
/// # });
/// ```
#[inline]
pub fn with_handler<T: PanicHandler>(self, handler: T) -> CatchPanic<T> {
CatchPanic {
panic_handler: handler,
}
}
}
impl<E: Endpoint, H: PanicHandler> Middleware<E> for CatchPanic<H> {
type Output = CatchPanicEndpoint<E, H>;
fn transform(&self, ep: E) -> Self::Output {
CatchPanicEndpoint {
inner: ep,
panic_handler: self.panic_handler.clone(),
}
}
}
/// Endpoint for `PanicHandler` middleware.
pub struct CatchPanicEndpoint<E, H> {
inner: E,
panic_handler: H,
}
impl<E: Endpoint, H: PanicHandler> Endpoint for CatchPanicEndpoint<E, H> {
type Output = Response;
async fn call(&self, req: Request) -> Result<Self::Output> {
match AssertUnwindSafe(self.inner.call(req)).catch_unwind().await {
Ok(resp) => resp.map(IntoResponse::into_response),
Err(err) => Ok(self.panic_handler.get_response(err).into_response()),
}
}
}