poem/middleware/
catch_panic.rs1use std::{any::Any, panic::AssertUnwindSafe};
2
3use futures_util::FutureExt;
4use http::StatusCode;
5
6use crate::{Endpoint, IntoResponse, Middleware, Request, Response, Result};
7
8pub trait PanicHandler: Clone + Sync + Send + 'static {
10 type Response: IntoResponse;
12
13 fn get_response(&self, err: Box<dyn Any + Send + 'static>) -> Self::Response;
15}
16
17impl PanicHandler for () {
18 type Response = (StatusCode, &'static str);
19
20 fn get_response(&self, _err: Box<dyn Any + Send + 'static>) -> Self::Response {
21 (StatusCode::INTERNAL_SERVER_ERROR, "internal server error")
22 }
23}
24
25impl<F, R> PanicHandler for F
26where
27 F: Fn(Box<dyn Any + Send + 'static>) -> R + Send + Sync + Clone + 'static,
28 R: IntoResponse,
29{
30 type Response = R;
31
32 fn get_response(&self, err: Box<dyn Any + Send + 'static>) -> Self::Response {
33 (self)(err)
34 }
35}
36
37pub struct CatchPanic<H> {
60 panic_handler: H,
61}
62
63impl CatchPanic<()> {
64 #[inline]
66 pub fn new() -> Self {
67 CatchPanic { panic_handler: () }
68 }
69}
70
71impl Default for CatchPanic<()> {
72 #[inline]
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78impl<H> CatchPanic<H> {
79 #[inline]
107 pub fn with_handler<T: PanicHandler>(self, handler: T) -> CatchPanic<T> {
108 CatchPanic {
109 panic_handler: handler,
110 }
111 }
112}
113
114impl<E: Endpoint, H: PanicHandler> Middleware<E> for CatchPanic<H> {
115 type Output = CatchPanicEndpoint<E, H>;
116
117 fn transform(&self, ep: E) -> Self::Output {
118 CatchPanicEndpoint {
119 inner: ep,
120 panic_handler: self.panic_handler.clone(),
121 }
122 }
123}
124
125pub struct CatchPanicEndpoint<E, H> {
127 inner: E,
128 panic_handler: H,
129}
130
131impl<E: Endpoint, H: PanicHandler> Endpoint for CatchPanicEndpoint<E, H> {
132 type Output = Response;
133
134 async fn call(&self, req: Request) -> Result<Self::Output> {
135 match AssertUnwindSafe(self.inner.call(req)).catch_unwind().await {
136 Ok(resp) => resp.map(IntoResponse::into_response),
137 Err(err) => Ok(self.panic_handler.get_response(err).into_response()),
138 }
139 }
140}