Skip to main content

nidus_http/middleware/
catch_panic.rs

1//! Panic-catching middleware that preserves the response body type.
2//!
3//! Unlike `tower_http::catch_panic`, this layer keeps the response as
4//! `Response<axum::body::Body>` so it composes with [`crate::error::ErrorEnvelopeLayer`].
5//! Place it inside the envelope (as [`crate::middleware::ApiDefaults::production`] does)
6//! so a handler panic surfaces as a structured `500` envelope with a request id and
7//! metrics, instead of aborting the connection.
8
9use std::{
10    future::Future,
11    panic::{AssertUnwindSafe, catch_unwind},
12    pin::Pin,
13    task::{Context, Poll},
14};
15
16use axum::{body::Body, extract::Request};
17use futures_util::FutureExt;
18use http::{Response, StatusCode};
19use tower::{Layer, Service};
20
21/// Creates a layer that catches panics from the inner service and maps them to a
22/// `500 Internal Server Error` response.
23///
24/// See [`CatchPanicLayer`] for details.
25pub fn catch_panic_layer() -> CatchPanicLayer {
26    CatchPanicLayer
27}
28
29/// Tower layer that catches panics from the inner service.
30///
31/// On a panic the inner service's future is abandoned, the panic payload is
32/// logged via `tracing::error!`, and a bare `500` response is returned. When
33/// layered inside [`crate::error::ErrorEnvelopeLayer`] that `500` is rendered as
34/// the production error envelope.
35#[derive(Clone, Copy, Debug, Default)]
36pub struct CatchPanicLayer;
37
38impl<S> Layer<S> for CatchPanicLayer {
39    type Service = CatchPanicService<S>;
40
41    fn layer(&self, inner: S) -> Self::Service {
42        CatchPanicService { inner }
43    }
44}
45
46/// Service produced by [`CatchPanicLayer`].
47#[derive(Clone, Debug)]
48pub struct CatchPanicService<S> {
49    inner: S,
50}
51
52impl<S> Service<Request> for CatchPanicService<S>
53where
54    S: Service<Request, Response = Response<Body>> + Send + 'static,
55    S::Future: Send + 'static,
56    S::Error: Send + 'static,
57{
58    type Response = Response<Body>;
59    type Error = S::Error;
60    type Future = Pin<Box<dyn Future<Output = Result<Response<Body>, S::Error>> + Send>>;
61
62    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63        self.inner.poll_ready(cx)
64    }
65
66    fn call(&mut self, request: Request) -> Self::Future {
67        // Catch a panic that occurs synchronously while starting the inner
68        // service (e.g. inside `call`), then catch a panic that occurs while the
69        // inner future is polled.
70        match catch_unwind(AssertUnwindSafe(|| self.inner.call(request))) {
71            Ok(future) => Box::pin(async move {
72                match AssertUnwindSafe(future).catch_unwind().await {
73                    Ok(result) => result,
74                    Err(payload) => {
75                        log_panic(&payload);
76                        Ok(internal_server_error())
77                    }
78                }
79            }),
80            Err(payload) => Box::pin(async move {
81                log_panic(&payload);
82                Ok(internal_server_error())
83            }),
84        }
85    }
86}
87
88fn log_panic(payload: &Box<dyn std::any::Any + Send + 'static>) {
89    if let Some(message) = payload.downcast_ref::<String>() {
90        tracing::error!(http.status = 500, panic.message = %message, "request handler panicked");
91    } else if let Some(message) = payload.downcast_ref::<&'static str>() {
92        tracing::error!(http.status = 500, panic.message = %message, "request handler panicked");
93    } else {
94        tracing::error!(
95            http.status = 500,
96            "request handler panicked with non-string payload"
97        );
98    }
99}
100
101fn internal_server_error() -> Response<Body> {
102    let mut response = Response::new(Body::empty());
103    *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
104    response
105}