actix_web_lab/
catch_panic.rs1use std::{
2 future::{Ready, ready},
3 panic::AssertUnwindSafe,
4 rc::Rc,
5};
6
7use actix_web::{
8 dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready},
9 error,
10};
11use futures_core::future::LocalBoxFuture;
12use futures_util::FutureExt as _;
13
14#[derive(Debug, Clone, Default)]
46#[non_exhaustive]
47pub struct CatchPanic;
48
49impl<S, B> Transform<S, ServiceRequest> for CatchPanic
50where
51 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
52{
53 type Response = ServiceResponse<B>;
54 type Error = actix_web::Error;
55 type Transform = CatchPanicMiddleware<S>;
56 type InitError = ();
57 type Future = Ready<Result<Self::Transform, Self::InitError>>;
58
59 fn new_transform(&self, service: S) -> Self::Future {
60 ready(Ok(CatchPanicMiddleware {
61 service: Rc::new(service),
62 }))
63 }
64}
65
66#[doc(hidden)]
70#[allow(missing_debug_implementations)]
71pub struct CatchPanicMiddleware<S> {
72 service: Rc<S>,
73}
74
75impl<S, B> Service<ServiceRequest> for CatchPanicMiddleware<S>
76where
77 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
78{
79 type Response = ServiceResponse<B>;
80 type Error = actix_web::Error;
81 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
82
83 forward_ready!(service);
84
85 fn call(&self, req: ServiceRequest) -> Self::Future {
86 AssertUnwindSafe(self.service.call(req))
87 .catch_unwind()
88 .map(move |res| match res {
89 Ok(Ok(res)) => Ok(res),
90 Ok(Err(svc_err)) => Err(svc_err),
91 Err(_panic_err) => Err(error::ErrorInternalServerError("")),
92 })
93 .boxed_local()
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use actix_web::{
100 App, Error,
101 body::{MessageBody, to_bytes},
102 dev::{Service as _, ServiceFactory},
103 http::StatusCode,
104 test, web,
105 };
106
107 use super::*;
108
109 fn test_app() -> App<
110 impl ServiceFactory<
111 ServiceRequest,
112 Response = ServiceResponse<impl MessageBody>,
113 Config = (),
114 InitError = (),
115 Error = Error,
116 >,
117 > {
118 App::new()
119 .wrap(CatchPanic::default())
120 .route("/", web::get().to(|| async { "content" }))
121 .route(
122 "/disco",
123 #[allow(unreachable_code)]
124 web::get().to(|| async {
125 panic!("the disco");
126 ""
127 }),
128 )
129 }
130
131 #[actix_web::test]
132 async fn pass_through_no_panic() {
133 let app = test::init_service(test_app()).await;
134
135 let req = test::TestRequest::default().to_request();
136 let res = test::call_service(&app, req).await;
137 assert_eq!(res.status(), StatusCode::OK);
138 let body = test::read_body(res).await;
139 assert_eq!(body, "content");
140 }
141
142 #[actix_web::test]
143 async fn catch_panic_return_internal_server_error_response() {
144 let app = test::init_service(test_app()).await;
145
146 let req = test::TestRequest::with_uri("/disco").to_request();
147 let err = match app.call(req).await {
148 Ok(_) => panic!("unexpected Ok response"),
149 Err(err) => err,
150 };
151 let res = err.error_response();
152 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
153 let body = to_bytes(res.into_body()).await.unwrap();
154 assert!(body.is_empty());
155 }
156}