actix_web_lab/
panic_reporter.rs

1//! Panic reporter middleware.
2//!
3//! See [`PanicReporter`] for docs.
4
5use std::{
6    any::Any,
7    future::{Ready, ready},
8    panic::{self, AssertUnwindSafe},
9    rc::Rc,
10};
11
12use actix_web::dev::{Service, Transform, forward_ready};
13use futures_core::future::LocalBoxFuture;
14use futures_util::FutureExt as _;
15
16type PanicCallback = Rc<dyn Fn(&(dyn Any + Send))>;
17
18/// A middleware that triggers a callback when the worker is panicking.
19///
20/// Mostly useful for logging or metrics publishing. The callback received the object with which
21/// panic was originally invoked to allow down-casting.
22///
23/// # Examples
24///
25/// ```no_run
26/// # use actix_web::App;
27/// use actix_web_lab::middleware::PanicReporter;
28/// # mod metrics {
29/// #   macro_rules! increment_counter {
30/// #       ($tt:tt) => {{}};
31/// #   }
32/// #   pub(crate) use increment_counter;
33/// # }
34///
35/// App::new().wrap(PanicReporter::new(|_| metrics::increment_counter!("panic")))
36///     # ;
37/// ```
38#[derive(Clone)]
39pub struct PanicReporter {
40    cb: PanicCallback,
41}
42
43impl PanicReporter {
44    /// Constructs new panic reporter middleware with `callback`.
45    pub fn new(callback: impl Fn(&(dyn Any + Send)) + 'static) -> Self {
46        Self {
47            cb: Rc::new(callback),
48        }
49    }
50}
51
52impl std::fmt::Debug for PanicReporter {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("PanicReporter")
55            .field("cb", &"<callback>")
56            .finish()
57    }
58}
59
60impl<S, Req> Transform<S, Req> for PanicReporter
61where
62    S: Service<Req>,
63    S::Future: 'static,
64{
65    type Response = S::Response;
66    type Error = S::Error;
67    type Transform = PanicReporterMiddleware<S>;
68    type InitError = ();
69    type Future = Ready<Result<Self::Transform, Self::InitError>>;
70
71    fn new_transform(&self, service: S) -> Self::Future {
72        ready(Ok(PanicReporterMiddleware {
73            service: Rc::new(service),
74            cb: Rc::clone(&self.cb),
75        }))
76    }
77}
78
79/// Middleware service implementation for [`PanicReporter`].
80#[doc(hidden)]
81#[allow(missing_debug_implementations)]
82pub struct PanicReporterMiddleware<S> {
83    service: Rc<S>,
84    cb: PanicCallback,
85}
86
87impl<S, Req> Service<Req> for PanicReporterMiddleware<S>
88where
89    S: Service<Req>,
90    S::Future: 'static,
91{
92    type Response = S::Response;
93    type Error = S::Error;
94    type Future = LocalBoxFuture<'static, Result<S::Response, S::Error>>;
95
96    forward_ready!(service);
97
98    fn call(&self, req: Req) -> Self::Future {
99        let cb = Rc::clone(&self.cb);
100
101        // catch panics in service call
102        AssertUnwindSafe(self.service.call(req))
103            .catch_unwind()
104            .map(move |maybe_res| match maybe_res {
105                Ok(res) => res,
106                Err(panic_err) => {
107                    // invoke callback with panic arg
108                    (cb)(&panic_err);
109
110                    // continue unwinding
111                    panic::resume_unwind(panic_err)
112                }
113            })
114            .boxed_local()
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use std::sync::{
121        Arc,
122        atomic::{AtomicBool, Ordering},
123    };
124
125    use actix_web::{
126        App,
127        dev::Service as _,
128        test,
129        web::{self, ServiceConfig},
130    };
131
132    use super::*;
133
134    fn configure_test_app(cfg: &mut ServiceConfig) {
135        cfg.route("/", web::get().to(|| async { "content" })).route(
136            "/disco",
137            #[allow(unreachable_code)]
138            web::get().to(|| async {
139                panic!("the disco");
140                ""
141            }),
142        );
143    }
144
145    #[actix_web::test]
146    async fn report_when_panics_occur() {
147        let triggered = Arc::new(AtomicBool::new(false));
148
149        let app = App::new()
150            .wrap(PanicReporter::new({
151                let triggered = Arc::clone(&triggered);
152                move |_| {
153                    triggered.store(true, Ordering::SeqCst);
154                }
155            }))
156            .configure(configure_test_app);
157
158        let app = test::init_service(app).await;
159
160        let req = test::TestRequest::with_uri("/").to_request();
161        assert!(app.call(req).await.is_ok());
162        assert!(!triggered.load(Ordering::SeqCst));
163
164        let req = test::TestRequest::with_uri("/disco").to_request();
165        assert!(
166            AssertUnwindSafe(app.call(req))
167                .catch_unwind()
168                .await
169                .is_err()
170        );
171        assert!(triggered.load(Ordering::SeqCst));
172    }
173}