1use std::{
4 fmt,
5 pin::Pin,
6 rc::Rc,
7 task::{Context, Poll, ready},
8};
9
10use actix_service::{Service, Transform};
11use actix_web::{
12 Error, Result,
13 body::EitherBody,
14 dev::{ServiceRequest, ServiceResponse},
15 http::StatusCode,
16};
17use ahash::AHashMap;
18use futures_core::future::LocalBoxFuture;
19use pin_project_lite::pin_project;
20
21type ErrorHandlerRes<B> = Result<ServiceResponse<EitherBody<B>>>;
22type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> LocalBoxFuture<'static, ErrorHandlerRes<B>>;
23type Handlers<B> = Rc<AHashMap<StatusCode, Box<ErrorHandler<B>>>>;
24
25pub struct ErrorHandlers<B> {
56 handlers: Handlers<B>,
57}
58
59impl<B> fmt::Debug for ErrorHandlers<B> {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 f.debug_struct("ErrorHandlers")
62 .field(
63 "handlers",
64 &format_args!("[<{} items>]", self.handlers.len()),
65 )
66 .finish()
67 }
68}
69
70impl<B> Default for ErrorHandlers<B> {
71 fn default() -> Self {
72 ErrorHandlers {
73 handlers: Default::default(),
74 }
75 }
76}
77
78impl<B> ErrorHandlers<B> {
79 pub fn new() -> Self {
81 ErrorHandlers::default()
82 }
83
84 pub fn handler<F, Fut>(mut self, status: StatusCode, handler: F) -> Self
86 where
87 F: Fn(ServiceResponse<B>) -> Fut + 'static,
88 Fut: Future<Output = ErrorHandlerRes<B>> + 'static,
89 {
90 Rc::get_mut(&mut self.handlers)
91 .unwrap()
92 .insert(status, Box::new(move |res| Box::pin((handler)(res))));
93 self
94 }
95}
96
97impl<S, B> Transform<S, ServiceRequest> for ErrorHandlers<B>
98where
99 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
100 S::Future: 'static,
101 B: 'static,
102{
103 type Response = ServiceResponse<EitherBody<B>>;
104 type Error = Error;
105 type Transform = ErrorHandlersMiddleware<S, B>;
106 type InitError = ();
107 type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>;
108
109 fn new_transform(&self, service: S) -> Self::Future {
110 let handlers = self.handlers.clone();
111 Box::pin(async move { Ok(ErrorHandlersMiddleware { service, handlers }) })
112 }
113}
114
115#[doc(hidden)]
119#[allow(missing_debug_implementations)]
120pub struct ErrorHandlersMiddleware<S, B> {
121 service: S,
122 handlers: Handlers<B>,
123}
124
125impl<S, B> Service<ServiceRequest> for ErrorHandlersMiddleware<S, B>
126where
127 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
128 S::Future: 'static,
129 B: 'static,
130{
131 type Response = ServiceResponse<EitherBody<B>>;
132 type Error = Error;
133 type Future = ErrorHandlersFuture<S::Future, B>;
134
135 actix_service::forward_ready!(service);
136
137 fn call(&self, req: ServiceRequest) -> Self::Future {
138 let handlers = self.handlers.clone();
139 let fut = self.service.call(req);
140 ErrorHandlersFuture::ServiceFuture { fut, handlers }
141 }
142}
143
144pin_project! {
145 #[project = ErrorHandlersProj]
146 pub enum ErrorHandlersFuture<Fut, B>
147 where
148 Fut: Future,
149 {
150 ServiceFuture {
151 #[pin]
152 fut: Fut,
153 handlers: Handlers<B>,
154 },
155 ErrorHandlerFuture {
156 fut: LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>,
157 },
158 }
159}
160
161impl<Fut, B> Future for ErrorHandlersFuture<Fut, B>
162where
163 Fut: Future<Output = Result<ServiceResponse<B>, Error>>,
164{
165 type Output = Result<ServiceResponse<EitherBody<B>>, Error>;
166
167 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
168 match self.as_mut().project() {
169 ErrorHandlersProj::ServiceFuture { fut, handlers } => {
170 let res = ready!(fut.poll(cx))?;
171
172 match handlers.get(&res.status()) {
173 Some(handler) => {
174 let fut = handler(res);
175
176 self.as_mut()
177 .set(ErrorHandlersFuture::ErrorHandlerFuture { fut });
178
179 self.poll(cx)
180 }
181
182 None => Poll::Ready(Ok(res.map_into_left_body())),
183 }
184 }
185
186 ErrorHandlersProj::ErrorHandlerFuture { fut } => fut.as_mut().poll(cx),
187 }
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use actix_service::IntoService;
194 use actix_web::{
195 body,
196 http::{
197 StatusCode,
198 header::{CONTENT_TYPE, HeaderValue},
199 },
200 test::{self, TestRequest},
201 };
202 use bytes::Bytes;
203
204 use super::*;
205
206 #[actix_web::test]
207 async fn add_header_error_handler() {
208 #[allow(clippy::unnecessary_wraps)]
209 async fn error_handler<B>(
210 mut res: ServiceResponse<B>,
211 ) -> Result<ServiceResponse<EitherBody<B>>> {
212 res.response_mut()
213 .headers_mut()
214 .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
215
216 Ok(res.map_into_left_body())
217 }
218
219 let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
220
221 let mw = ErrorHandlers::new()
222 .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
223 .new_transform(srv.into_service())
224 .await
225 .unwrap();
226
227 let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
228 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
229 }
230
231 #[actix_web::test]
232 async fn add_header_error_handler_async() {
233 #[allow(clippy::unnecessary_wraps)]
234 async fn error_handler<B: 'static>(
235 mut res: ServiceResponse<B>,
236 ) -> Result<ServiceResponse<EitherBody<B>>> {
237 res.response_mut()
238 .headers_mut()
239 .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
240
241 Ok(res.map_into_left_body())
242 }
243
244 let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
245
246 let mw = ErrorHandlers::new()
247 .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
248 .new_transform(srv.into_service())
249 .await
250 .unwrap();
251
252 let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
253 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
254 }
255
256 #[actix_web::test]
257 async fn changes_body_type() {
258 #[allow(clippy::unnecessary_wraps)]
259 async fn error_handler<B>(
260 res: ServiceResponse<B>,
261 ) -> Result<ServiceResponse<EitherBody<B>>> {
262 let (req, res) = res.into_parts();
263 let res = res.set_body(Bytes::from("sorry, that's no bueno"));
264
265 let res = ServiceResponse::new(req, res)
266 .map_into_boxed_body()
267 .map_into_right_body();
268
269 Ok(res)
270 }
271
272 let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
273
274 let mw = ErrorHandlers::new()
275 .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
276 .new_transform(srv.into_service())
277 .await
278 .unwrap();
279
280 let res = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
281 assert_eq!(test::read_body(res).await, "sorry, that's no bueno");
282 }
283
284 #[actix_web::test]
285 async fn error_thrown() {
286 #[allow(clippy::unnecessary_wraps)]
287 async fn error_handler<B>(
288 _res: ServiceResponse<B>,
289 ) -> Result<ServiceResponse<EitherBody<B>>> {
290 Err(actix_web::error::ErrorInternalServerError(
291 "error in error handler",
292 ))
293 }
294
295 let srv = test::status_service(StatusCode::BAD_REQUEST);
296
297 let mw = ErrorHandlers::new()
298 .handler(StatusCode::BAD_REQUEST, error_handler)
299 .new_transform(srv.into_service())
300 .await
301 .unwrap();
302
303 let err = mw
304 .call(TestRequest::default().to_srv_request())
305 .await
306 .unwrap_err();
307 let res = err.error_response();
308
309 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
310 assert_eq!(
311 body::to_bytes(res.into_body()).await.unwrap(),
312 "error in error handler"
313 );
314 }
315}