actix_web_extras/middleware/
condition.rs

1//! For middleware documentation, see [`Condition`].
2//!
3//! Pending [PR 2623](https://github.com/actix/actix-web/pull/2623)
4
5use actix_web::{
6    body::EitherBody,
7    dev::{Service, ServiceResponse, Transform},
8};
9use futures_core::{future::LocalBoxFuture, ready};
10use futures_util::future::FutureExt as _;
11use pin_project_lite::pin_project;
12use std::{
13    future::Future,
14    pin::Pin,
15    task::{Context, Poll},
16};
17
18/// Middleware for conditionally enabling other middleware.
19///
20/// # Examples
21/// ```
22/// use actix_web_extras::middleware::Condition;
23/// use actix_web::middleware::NormalizePath;
24/// use actix_web::App;
25///
26/// let enable_normalize = std::env::var("NORMALIZE_PATH").is_ok();
27/// let app = App::new()
28///     .wrap(Condition::new(enable_normalize, NormalizePath::default()));
29/// ```
30/// Or you can use an `Option` to create a new instance:
31/// ```
32/// use actix_web_extras::middleware::Condition;
33/// use actix_web::middleware::NormalizePath;
34/// use actix_web::App;
35///
36/// let app = App::new()
37///     .wrap(Condition::from_option(Some(NormalizePath::default())));
38/// ```
39pub struct Condition<T> {
40    transformer: Option<T>,
41}
42
43impl<T> Condition<T> {
44    pub fn new(enable: bool, transformer: T) -> Self {
45        Self {
46            transformer: match enable {
47                true => Some(transformer),
48                false => None,
49            },
50        }
51    }
52
53    pub fn from_option(transformer: Option<T>) -> Self {
54        Self { transformer }
55    }
56}
57
58impl<S, T, Req, BE, BD, Err> Transform<S, Req> for Condition<T>
59where
60    S: Service<Req, Response = ServiceResponse<BD>, Error = Err> + 'static,
61    T: Transform<S, Req, Response = ServiceResponse<BE>, Error = Err>,
62    T::Future: 'static,
63    T::InitError: 'static,
64    T::Transform: 'static,
65{
66    type Response = ServiceResponse<EitherBody<BE, BD>>;
67    type Error = Err;
68    type Transform = ConditionMiddleware<T::Transform, S>;
69    type InitError = T::InitError;
70    type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>;
71
72    fn new_transform(&self, service: S) -> Self::Future {
73        if let Some(transformer) = &self.transformer {
74            let fut = transformer.new_transform(service);
75            async move {
76                let wrapped_svc = fut.await?;
77                Ok(ConditionMiddleware::Enable(wrapped_svc))
78            }
79            .boxed_local()
80        } else {
81            async move { Ok(ConditionMiddleware::Disable(service)) }.boxed_local()
82        }
83    }
84}
85
86pub enum ConditionMiddleware<E, D> {
87    Enable(E),
88    Disable(D),
89}
90
91impl<E, D, Req, BE, BD, Err> Service<Req> for ConditionMiddleware<E, D>
92where
93    E: Service<Req, Response = ServiceResponse<BE>, Error = Err>,
94    D: Service<Req, Response = ServiceResponse<BD>, Error = Err>,
95{
96    type Response = ServiceResponse<EitherBody<BE, BD>>;
97    type Error = Err;
98    type Future = ConditionMiddlewareFuture<E::Future, D::Future>;
99
100    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
101        match self {
102            ConditionMiddleware::Enable(service) => service.poll_ready(cx),
103            ConditionMiddleware::Disable(service) => service.poll_ready(cx),
104        }
105    }
106
107    fn call(&self, req: Req) -> Self::Future {
108        match self {
109            ConditionMiddleware::Enable(service) => ConditionMiddlewareFuture::Enabled {
110                fut: service.call(req),
111            },
112            ConditionMiddleware::Disable(service) => ConditionMiddlewareFuture::Disabled {
113                fut: service.call(req),
114            },
115        }
116    }
117}
118
119pin_project! {
120    #[doc(hidden)]
121    #[project = ConditionProj]
122    pub enum ConditionMiddlewareFuture<E, D> {
123        Enabled { #[pin] fut: E, },
124        Disabled { #[pin] fut: D, },
125    }
126}
127
128impl<E, D, BE, BD, Err> Future for ConditionMiddlewareFuture<E, D>
129where
130    E: Future<Output = Result<ServiceResponse<BE>, Err>>,
131    D: Future<Output = Result<ServiceResponse<BD>, Err>>,
132{
133    type Output = Result<ServiceResponse<EitherBody<BE, BD>>, Err>;
134
135    #[inline]
136    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137        let res = match self.project() {
138            ConditionProj::Enabled { fut } => ready!(fut.poll(cx))?.map_into_left_body(),
139            ConditionProj::Disabled { fut } => ready!(fut.poll(cx))?.map_into_right_body(),
140        };
141
142        Poll::Ready(Ok(res))
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use actix_service::IntoService as _;
149    use futures_util::future::ok;
150
151    use super::*;
152    use actix_web::{
153        body::BoxBody,
154        dev::{ServiceRequest, ServiceResponse},
155        error::Result,
156        http::{
157            header::{HeaderValue, CONTENT_TYPE},
158            StatusCode,
159        },
160        middleware::{self, ErrorHandlerResponse, ErrorHandlers},
161        test::{self, TestRequest},
162        web::Bytes,
163        HttpResponse,
164    };
165
166    #[allow(clippy::unnecessary_wraps)]
167    fn render_500<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
168        res.response_mut()
169            .headers_mut()
170            .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
171
172        Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
173    }
174
175    #[test]
176    fn compat_with_builtin_middleware() {
177        let _ = Condition::new(true, middleware::Logger::default());
178        let _ = Condition::new(true, middleware::Compress::default());
179        let _ = Condition::new(true, middleware::NormalizePath::trim());
180        let _ = Condition::new(true, middleware::DefaultHeaders::new());
181        let _ = Condition::new(true, middleware::ErrorHandlers::<BoxBody>::new());
182        let _ = Condition::new(true, middleware::ErrorHandlers::<Bytes>::new());
183    }
184
185    fn create_optional_mw<B>(enabled: bool) -> Option<ErrorHandlers<B>>
186    where
187        B: 'static,
188    {
189        match enabled {
190            true => {
191                Some(ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500))
192            }
193            false => None,
194        }
195    }
196
197    #[actix_rt::test]
198    async fn test_handler_enabled() {
199        let srv = |req: ServiceRequest| async move {
200            let resp = HttpResponse::InternalServerError().message_body(String::new())?;
201            Ok(req.into_response(resp))
202        };
203
204        let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
205
206        let mw = Condition::new(true, mw)
207            .new_transform(srv.into_service())
208            .await
209            .unwrap();
210
211        let resp: ServiceResponse<EitherBody<EitherBody<_, _>, String>> =
212            test::call_service(&mw, TestRequest::default().to_srv_request()).await;
213        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
214    }
215
216    #[actix_rt::test]
217    async fn test_handler_disabled() {
218        let srv = |req: ServiceRequest| async move {
219            let resp = HttpResponse::InternalServerError().message_body(String::new())?;
220            Ok(req.into_response(resp))
221        };
222
223        let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
224
225        let mw = Condition::new(false, mw)
226            .new_transform(srv.into_service())
227            .await
228            .unwrap();
229
230        let resp: ServiceResponse<EitherBody<EitherBody<_, _>, String>> =
231            test::call_service(&mw, TestRequest::default().to_srv_request()).await;
232        assert_eq!(resp.headers().get(CONTENT_TYPE), None);
233    }
234
235    #[actix_rt::test]
236    async fn test_handler_some() {
237        let srv = |req: ServiceRequest| {
238            ok(req.into_response(HttpResponse::InternalServerError().finish()))
239        };
240
241        let mw = Condition::from_option(create_optional_mw(true))
242            .new_transform(srv.into_service())
243            .await
244            .unwrap();
245        let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
246        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
247    }
248
249    #[actix_rt::test]
250    async fn test_handler_none() {
251        let srv = |req: ServiceRequest| {
252            ok(req.into_response(HttpResponse::InternalServerError().finish()))
253        };
254
255        let mw = Condition::from_option(create_optional_mw(false))
256            .new_transform(srv.into_service())
257            .await
258            .unwrap();
259
260        let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
261        assert_eq!(resp.headers().get(CONTENT_TYPE), None);
262    }
263}