poem/middleware/
catch_panic.rs

1use std::{any::Any, panic::AssertUnwindSafe};
2
3use futures_util::FutureExt;
4use http::StatusCode;
5
6use crate::{Endpoint, IntoResponse, Middleware, Request, Response, Result};
7
8/// Panics handler
9pub trait PanicHandler: Clone + Sync + Send + 'static {
10    /// Response type
11    type Response: IntoResponse;
12
13    /// Call this method to create a response when a panic occurs.
14    fn get_response(&self, err: Box<dyn Any + Send + 'static>) -> Self::Response;
15}
16
17impl PanicHandler for () {
18    type Response = (StatusCode, &'static str);
19
20    fn get_response(&self, _err: Box<dyn Any + Send + 'static>) -> Self::Response {
21        (StatusCode::INTERNAL_SERVER_ERROR, "internal server error")
22    }
23}
24
25impl<F, R> PanicHandler for F
26where
27    F: Fn(Box<dyn Any + Send + 'static>) -> R + Send + Sync + Clone + 'static,
28    R: IntoResponse,
29{
30    type Response = R;
31
32    fn get_response(&self, err: Box<dyn Any + Send + 'static>) -> Self::Response {
33        (self)(err)
34    }
35}
36
37/// Middleware that catches panics and converts them into `500 INTERNAL SERVER
38/// ERROR` responses.
39///
40/// # Example
41///
42/// ```rust
43/// use http::StatusCode;
44/// use poem::{EndpointExt, Route, handler, middleware::CatchPanic, test::TestClient};
45///
46/// #[handler]
47/// async fn index() {
48///     panic!()
49/// }
50///
51/// let app = Route::new().at("/", index).with(CatchPanic::new());
52///
53/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
54/// let cli = TestClient::new(app);
55/// let resp = cli.get("/").send().await;
56/// resp.assert_status(StatusCode::INTERNAL_SERVER_ERROR);
57/// # });
58/// ```
59pub struct CatchPanic<H> {
60    panic_handler: H,
61}
62
63impl CatchPanic<()> {
64    /// Create new `CatchPanic` middleware.
65    #[inline]
66    pub fn new() -> Self {
67        CatchPanic { panic_handler: () }
68    }
69}
70
71impl Default for CatchPanic<()> {
72    #[inline]
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78impl<H> CatchPanic<H> {
79    /// Specifies a panic handler to be used to create a custom response when
80    /// a panic occurs.
81    ///
82    /// # Example
83    ///
84    /// ```rust
85    /// use http::StatusCode;
86    /// use poem::{
87    ///     EndpointExt, IntoResponse, Route, handler, middleware::CatchPanic, test::TestClient,
88    /// };
89    ///
90    /// #[handler]
91    /// async fn index() {
92    ///     panic!()
93    /// }
94    ///
95    /// let app = Route::new().at("/", index).with(
96    ///     CatchPanic::new().with_handler(|_| "error!".with_status(StatusCode::INTERNAL_SERVER_ERROR)),
97    /// );
98    ///
99    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
100    /// let cli = TestClient::new(app);
101    /// let resp = cli.get("/").send().await;
102    /// resp.assert_status(StatusCode::INTERNAL_SERVER_ERROR);
103    /// resp.assert_text("error!").await;
104    /// # });
105    /// ```
106    #[inline]
107    pub fn with_handler<T: PanicHandler>(self, handler: T) -> CatchPanic<T> {
108        CatchPanic {
109            panic_handler: handler,
110        }
111    }
112}
113
114impl<E: Endpoint, H: PanicHandler> Middleware<E> for CatchPanic<H> {
115    type Output = CatchPanicEndpoint<E, H>;
116
117    fn transform(&self, ep: E) -> Self::Output {
118        CatchPanicEndpoint {
119            inner: ep,
120            panic_handler: self.panic_handler.clone(),
121        }
122    }
123}
124
125/// Endpoint for the `PanicHandler` middleware.
126pub struct CatchPanicEndpoint<E, H> {
127    inner: E,
128    panic_handler: H,
129}
130
131impl<E: Endpoint, H: PanicHandler> Endpoint for CatchPanicEndpoint<E, H> {
132    type Output = Response;
133
134    async fn call(&self, req: Request) -> Result<Self::Output> {
135        match AssertUnwindSafe(self.inner.call(req)).catch_unwind().await {
136            Ok(resp) => resp.map(IntoResponse::into_response),
137            Err(err) => Ok(self.panic_handler.get_response(err).into_response()),
138        }
139    }
140}