actix_web_lab/
err_handler.rs

1//! For middleware documentation, see [`ErrorHandlers`].
2
3use std::{
4    fmt,
5    pin::Pin,
6    rc::Rc,
7    task::{Context, Poll, ready},
8};
9
10use actix_service::{Service, Transform};
11use actix_web::{
12    Error, Result,
13    body::EitherBody,
14    dev::{ServiceRequest, ServiceResponse},
15    http::StatusCode,
16};
17use ahash::AHashMap;
18use futures_core::future::LocalBoxFuture;
19use pin_project_lite::pin_project;
20
21type ErrorHandlerRes<B> = Result<ServiceResponse<EitherBody<B>>>;
22type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> LocalBoxFuture<'static, ErrorHandlerRes<B>>;
23type Handlers<B> = Rc<AHashMap<StatusCode, Box<ErrorHandler<B>>>>;
24
25/// Middleware for registering custom status code based error handlers.
26///
27/// Register handlers with the `ErrorHandlers::handler()` method to register a custom error handler
28/// for a given status code. Handlers can modify existing responses or create completely new ones.
29///
30/// # Examples
31/// ```
32/// use actix_web::{
33///     App, HttpResponse, Result,
34///     body::EitherBody,
35///     dev::ServiceResponse,
36///     http::{StatusCode, header},
37///     web,
38/// };
39/// use actix_web_lab::middleware::ErrorHandlers;
40///
41/// async fn add_error_header<B>(
42///     mut res: ServiceResponse<B>,
43/// ) -> Result<ServiceResponse<EitherBody<B>>> {
44///     res.response_mut().headers_mut().insert(
45///         header::CONTENT_TYPE,
46///         header::HeaderValue::from_static("Error"),
47///     );
48///     Ok(res.map_into_left_body())
49/// }
50///
51/// let app = App::new()
52///     .wrap(ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, add_error_header))
53///     .service(web::resource("/").route(web::get().to(HttpResponse::InternalServerError)));
54/// ```
55pub struct ErrorHandlers<B> {
56    handlers: Handlers<B>,
57}
58
59impl<B> fmt::Debug for ErrorHandlers<B> {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        f.debug_struct("ErrorHandlers")
62            .field(
63                "handlers",
64                &format_args!("[<{} items>]", self.handlers.len()),
65            )
66            .finish()
67    }
68}
69
70impl<B> Default for ErrorHandlers<B> {
71    fn default() -> Self {
72        ErrorHandlers {
73            handlers: Default::default(),
74        }
75    }
76}
77
78impl<B> ErrorHandlers<B> {
79    /// Construct new `ErrorHandlers` instance.
80    pub fn new() -> Self {
81        ErrorHandlers::default()
82    }
83
84    /// Register error handler for specified status code.
85    pub fn handler<F, Fut>(mut self, status: StatusCode, handler: F) -> Self
86    where
87        F: Fn(ServiceResponse<B>) -> Fut + 'static,
88        Fut: Future<Output = ErrorHandlerRes<B>> + 'static,
89    {
90        Rc::get_mut(&mut self.handlers)
91            .unwrap()
92            .insert(status, Box::new(move |res| Box::pin((handler)(res))));
93        self
94    }
95}
96
97impl<S, B> Transform<S, ServiceRequest> for ErrorHandlers<B>
98where
99    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
100    S::Future: 'static,
101    B: 'static,
102{
103    type Response = ServiceResponse<EitherBody<B>>;
104    type Error = Error;
105    type Transform = ErrorHandlersMiddleware<S, B>;
106    type InitError = ();
107    type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>;
108
109    fn new_transform(&self, service: S) -> Self::Future {
110        let handlers = self.handlers.clone();
111        Box::pin(async move { Ok(ErrorHandlersMiddleware { service, handlers }) })
112    }
113}
114
115/// Middleware for registering custom status code based error handlers.
116///
117/// See [`ErrorHandlers`].
118#[doc(hidden)]
119#[allow(missing_debug_implementations)]
120pub struct ErrorHandlersMiddleware<S, B> {
121    service: S,
122    handlers: Handlers<B>,
123}
124
125impl<S, B> Service<ServiceRequest> for ErrorHandlersMiddleware<S, B>
126where
127    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
128    S::Future: 'static,
129    B: 'static,
130{
131    type Response = ServiceResponse<EitherBody<B>>;
132    type Error = Error;
133    type Future = ErrorHandlersFuture<S::Future, B>;
134
135    actix_service::forward_ready!(service);
136
137    fn call(&self, req: ServiceRequest) -> Self::Future {
138        let handlers = self.handlers.clone();
139        let fut = self.service.call(req);
140        ErrorHandlersFuture::ServiceFuture { fut, handlers }
141    }
142}
143
144pin_project! {
145    #[project = ErrorHandlersProj]
146    pub enum ErrorHandlersFuture<Fut, B>
147    where
148        Fut: Future,
149    {
150        ServiceFuture {
151            #[pin]
152            fut: Fut,
153            handlers: Handlers<B>,
154        },
155        ErrorHandlerFuture {
156            fut: LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>,
157        },
158    }
159}
160
161impl<Fut, B> Future for ErrorHandlersFuture<Fut, B>
162where
163    Fut: Future<Output = Result<ServiceResponse<B>, Error>>,
164{
165    type Output = Result<ServiceResponse<EitherBody<B>>, Error>;
166
167    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
168        match self.as_mut().project() {
169            ErrorHandlersProj::ServiceFuture { fut, handlers } => {
170                let res = ready!(fut.poll(cx))?;
171
172                match handlers.get(&res.status()) {
173                    Some(handler) => {
174                        let fut = handler(res);
175
176                        self.as_mut()
177                            .set(ErrorHandlersFuture::ErrorHandlerFuture { fut });
178
179                        self.poll(cx)
180                    }
181
182                    None => Poll::Ready(Ok(res.map_into_left_body())),
183                }
184            }
185
186            ErrorHandlersProj::ErrorHandlerFuture { fut } => fut.as_mut().poll(cx),
187        }
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use actix_service::IntoService;
194    use actix_web::{
195        body,
196        http::{
197            StatusCode,
198            header::{CONTENT_TYPE, HeaderValue},
199        },
200        test::{self, TestRequest},
201    };
202    use bytes::Bytes;
203
204    use super::*;
205
206    #[actix_web::test]
207    async fn add_header_error_handler() {
208        #[allow(clippy::unnecessary_wraps)]
209        async fn error_handler<B>(
210            mut res: ServiceResponse<B>,
211        ) -> Result<ServiceResponse<EitherBody<B>>> {
212            res.response_mut()
213                .headers_mut()
214                .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
215
216            Ok(res.map_into_left_body())
217        }
218
219        let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
220
221        let mw = ErrorHandlers::new()
222            .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
223            .new_transform(srv.into_service())
224            .await
225            .unwrap();
226
227        let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
228        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
229    }
230
231    #[actix_web::test]
232    async fn add_header_error_handler_async() {
233        #[allow(clippy::unnecessary_wraps)]
234        async fn error_handler<B: 'static>(
235            mut res: ServiceResponse<B>,
236        ) -> Result<ServiceResponse<EitherBody<B>>> {
237            res.response_mut()
238                .headers_mut()
239                .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
240
241            Ok(res.map_into_left_body())
242        }
243
244        let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
245
246        let mw = ErrorHandlers::new()
247            .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
248            .new_transform(srv.into_service())
249            .await
250            .unwrap();
251
252        let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
253        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
254    }
255
256    #[actix_web::test]
257    async fn changes_body_type() {
258        #[allow(clippy::unnecessary_wraps)]
259        async fn error_handler<B>(
260            res: ServiceResponse<B>,
261        ) -> Result<ServiceResponse<EitherBody<B>>> {
262            let (req, res) = res.into_parts();
263            let res = res.set_body(Bytes::from("sorry, that's no bueno"));
264
265            let res = ServiceResponse::new(req, res)
266                .map_into_boxed_body()
267                .map_into_right_body();
268
269            Ok(res)
270        }
271
272        let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
273
274        let mw = ErrorHandlers::new()
275            .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
276            .new_transform(srv.into_service())
277            .await
278            .unwrap();
279
280        let res = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
281        assert_eq!(test::read_body(res).await, "sorry, that's no bueno");
282    }
283
284    #[actix_web::test]
285    async fn error_thrown() {
286        #[allow(clippy::unnecessary_wraps)]
287        async fn error_handler<B>(
288            _res: ServiceResponse<B>,
289        ) -> Result<ServiceResponse<EitherBody<B>>> {
290            Err(actix_web::error::ErrorInternalServerError(
291                "error in error handler",
292            ))
293        }
294
295        let srv = test::status_service(StatusCode::BAD_REQUEST);
296
297        let mw = ErrorHandlers::new()
298            .handler(StatusCode::BAD_REQUEST, error_handler)
299            .new_transform(srv.into_service())
300            .await
301            .unwrap();
302
303        let err = mw
304            .call(TestRequest::default().to_srv_request())
305            .await
306            .unwrap_err();
307        let res = err.error_response();
308
309        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
310        assert_eq!(
311            body::to_bytes(res.into_body()).await.unwrap(),
312            "error in error handler"
313        );
314    }
315}