rama_http/layer/
catch_panic.rs

1//! Convert panics into responses.
2//!
3//! Note that using panics for error handling is _not_ recommended. Prefer instead to use `Result`
4//! whenever possible.
5//!
6//! # Example
7//!
8//! ```rust
9//! use std::convert::Infallible;
10//!
11//! use rama_http::{Request, Response, Body, header::HeaderName};
12//! use rama_http::layer::catch_panic::CatchPanicLayer;
13//! use rama_core::service::service_fn;
14//! use rama_core::{Context, Service, Layer};
15//! use rama_core::error::BoxError;
16//!
17//! # #[tokio::main]
18//! # async fn main() -> Result<(), BoxError> {
19//! async fn handle(req: Request) -> Result<Response, Infallible> {
20//!     panic!("something went wrong...")
21//! }
22//!
23//! let mut svc = (
24//!     // Catch panics and convert them into responses.
25//!     CatchPanicLayer::new(),
26//! ).layer(service_fn(handle));
27//!
28//! // Call the service.
29//! let request = Request::new(Body::default());
30//!
31//! let response = svc.serve(Context::default(), request).await?;
32//!
33//! assert_eq!(response.status(), 500);
34//! #
35//! # Ok(())
36//! # }
37//! ```
38//!
39//! Using a custom panic handler:
40//!
41//! ```rust
42//! use std::{any::Any, convert::Infallible};
43//!
44//! use rama_http::{Body, Request, StatusCode, Response, header::{self, HeaderName}};
45//! use rama_http::layer::catch_panic::CatchPanicLayer;
46//! use rama_core::service::{Service, service_fn};
47//! use rama_core::Layer;
48//! use rama_core::error::BoxError;
49//!
50//! # #[tokio::main]
51//! # async fn main() -> Result<(), BoxError> {
52//! async fn handle(req: Request) -> Result<Response, Infallible> {
53//!     panic!("something went wrong...")
54//! }
55//!
56//! fn handle_panic(err: Box<dyn Any + Send + 'static>) -> Response {
57//!     let details = if let Some(s) = err.downcast_ref::<String>() {
58//!         s.clone()
59//!     } else if let Some(s) = err.downcast_ref::<&str>() {
60//!         s.to_string()
61//!     } else {
62//!         "Unknown panic message".to_string()
63//!     };
64//!
65//!     let body = serde_json::json!({
66//!         "error": {
67//!             "kind": "panic",
68//!             "details": details,
69//!         }
70//!     });
71//!     let body = serde_json::to_string(&body).unwrap();
72//!
73//!     Response::builder()
74//!         .status(StatusCode::INTERNAL_SERVER_ERROR)
75//!         .header(header::CONTENT_TYPE, "application/json")
76//!         .body(Body::from(body))
77//!         .unwrap()
78//! }
79//!
80//! let svc = (
81//!     // Use `handle_panic` to create the response.
82//!     CatchPanicLayer::custom(handle_panic),
83//! ).layer(service_fn(handle));
84//! #
85//! # Ok(())
86//! # }
87//! ```
88
89use crate::{Body, HeaderValue, Request, Response, StatusCode};
90use futures_lite::future::FutureExt;
91use rama_core::{Context, Layer, Service};
92use rama_utils::macros::define_inner_service_accessors;
93use std::fmt;
94use std::{any::Any, panic::AssertUnwindSafe};
95
96/// Layer that applies the [`CatchPanic`] middleware that catches panics and converts them into
97/// `500 Internal Server` responses.
98///
99/// See the [module docs](self) for an example.
100pub struct CatchPanicLayer<T> {
101    panic_handler: T,
102}
103
104impl<T: fmt::Debug> fmt::Debug for CatchPanicLayer<T> {
105    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
106        f.debug_struct("CatchPanicLayer")
107            .field("panic_handler", &self.panic_handler)
108            .finish()
109    }
110}
111
112impl<T: Clone> Clone for CatchPanicLayer<T> {
113    fn clone(&self) -> Self {
114        Self {
115            panic_handler: self.panic_handler.clone(),
116        }
117    }
118}
119
120impl Default for CatchPanicLayer<DefaultResponseForPanic> {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl CatchPanicLayer<DefaultResponseForPanic> {
127    /// Create a new `CatchPanicLayer` with the [`Default`]] panic handler.
128    pub const fn new() -> Self {
129        CatchPanicLayer {
130            panic_handler: DefaultResponseForPanic,
131        }
132    }
133}
134
135impl<T> CatchPanicLayer<T> {
136    /// Create a new `CatchPanicLayer` with a custom panic handler.
137    pub fn custom(panic_handler: T) -> Self
138    where
139        T: ResponseForPanic,
140    {
141        Self { panic_handler }
142    }
143}
144
145impl<T, S> Layer<S> for CatchPanicLayer<T>
146where
147    T: Clone,
148{
149    type Service = CatchPanic<S, T>;
150
151    fn layer(&self, inner: S) -> Self::Service {
152        CatchPanic {
153            inner,
154            panic_handler: self.panic_handler.clone(),
155        }
156    }
157}
158
159/// Middleware that catches panics and converts them into `500 Internal Server` responses.
160///
161/// See the [module docs](self) for an example.
162pub struct CatchPanic<S, T> {
163    inner: S,
164    panic_handler: T,
165}
166
167impl<S> CatchPanic<S, DefaultResponseForPanic> {
168    /// Create a new `CatchPanic` with the default panic handler.
169    pub const fn new(inner: S) -> Self {
170        Self {
171            inner,
172            panic_handler: DefaultResponseForPanic,
173        }
174    }
175}
176
177impl<S, T> CatchPanic<S, T> {
178    define_inner_service_accessors!();
179
180    /// Create a new `CatchPanic` with a custom panic handler.
181    pub const fn custom(inner: S, panic_handler: T) -> Self
182    where
183        T: ResponseForPanic,
184    {
185        Self {
186            inner,
187            panic_handler,
188        }
189    }
190}
191
192impl<S: fmt::Debug, T: fmt::Debug> fmt::Debug for CatchPanic<S, T> {
193    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194        f.debug_struct("CatchPanic")
195            .field("inner", &self.inner)
196            .field("panic_handler", &self.panic_handler)
197            .finish()
198    }
199}
200
201impl<S: Clone, T: Clone> Clone for CatchPanic<S, T> {
202    fn clone(&self) -> Self {
203        CatchPanic {
204            inner: self.inner.clone(),
205            panic_handler: self.panic_handler.clone(),
206        }
207    }
208}
209
210impl<State, S, T, ReqBody, ResBody> Service<State, Request<ReqBody>> for CatchPanic<S, T>
211where
212    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
213    ResBody: Into<Body> + Send + 'static,
214    T: ResponseForPanic + Clone + Send + Sync + 'static,
215    ReqBody: Send + 'static,
216    ResBody: Send + 'static,
217    State: Clone + Send + Sync + 'static,
218{
219    type Response = Response;
220    type Error = S::Error;
221
222    async fn serve(
223        &self,
224        ctx: Context<State>,
225        req: Request<ReqBody>,
226    ) -> Result<Self::Response, Self::Error> {
227        let future = match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.serve(ctx, req)))
228        {
229            Ok(future) => future,
230            Err(panic_err) => return Ok(self.panic_handler.response_for_panic(panic_err)),
231        };
232        match AssertUnwindSafe(future).catch_unwind().await {
233            Ok(res) => match res {
234                Ok(res) => Ok(res.map(Into::into)),
235                Err(err) => Err(err),
236            },
237            Err(panic_err) => Ok(self.panic_handler.response_for_panic(panic_err)),
238        }
239    }
240}
241
242/// Trait for creating responses from panics.
243pub trait ResponseForPanic: Clone {
244    /// Create a response from the panic error.
245    fn response_for_panic(&self, err: Box<dyn Any + Send + 'static>) -> Response<Body>;
246}
247
248impl<F> ResponseForPanic for F
249where
250    F: Fn(Box<dyn Any + Send + 'static>) -> Response + Clone,
251{
252    fn response_for_panic(&self, err: Box<dyn Any + Send + 'static>) -> Response {
253        self(err)
254    }
255}
256
257/// The default `ResponseForPanic` used by `CatchPanic`.
258///
259/// It will log the panic message and return a `500 Internal Server` error response with an empty
260/// body.
261#[derive(Debug, Default, Clone)]
262#[non_exhaustive]
263pub struct DefaultResponseForPanic;
264
265impl ResponseForPanic for DefaultResponseForPanic {
266    fn response_for_panic(&self, err: Box<dyn Any + Send + 'static>) -> Response {
267        if let Some(s) = err.downcast_ref::<String>() {
268            tracing::error!("Service panicked: {}", s);
269        } else if let Some(s) = err.downcast_ref::<&str>() {
270            tracing::error!("Service panicked: {}", s);
271        } else {
272            tracing::error!(
273                "Service panicked but `CatchPanic` was unable to downcast the panic info"
274            );
275        };
276
277        let mut res = Response::new(Body::from("Service panicked"));
278        *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
279
280        #[allow(clippy::declare_interior_mutable_const)]
281        const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
282        res.headers_mut()
283            .insert(http::header::CONTENT_TYPE, TEXT_PLAIN);
284
285        res
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    #![allow(unreachable_code)]
292
293    use super::*;
294
295    use crate::dep::http_body_util::BodyExt;
296    use crate::{Body, Response};
297    use rama_core::service::service_fn;
298    use rama_core::{Context, Service};
299    use std::convert::Infallible;
300
301    #[tokio::test]
302    async fn panic_before_returning_future() {
303        let svc = CatchPanicLayer::new().layer(service_fn(|_: Request| {
304            panic!("service panic");
305            async { Ok::<_, Infallible>(Response::new(Body::empty())) }
306        }));
307
308        let req = Request::new(Body::empty());
309
310        let res = svc.serve(Context::default(), req).await.unwrap();
311
312        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
313        let body = res.into_body().collect().await.unwrap().to_bytes();
314        assert_eq!(&body[..], b"Service panicked");
315    }
316
317    #[tokio::test]
318    async fn panic_in_future() {
319        let svc = CatchPanicLayer::new().layer(service_fn(|_: Request<Body>| async {
320            panic!("future panic");
321            Ok::<_, Infallible>(Response::new(Body::empty()))
322        }));
323
324        let req = Request::new(Body::empty());
325
326        let res = svc.serve(Context::default(), req).await.unwrap();
327
328        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
329        let body = res.into_body().collect().await.unwrap().to_bytes();
330        assert_eq!(&body[..], b"Service panicked");
331    }
332}