actix_web_extras/middleware/
condition.rs1use actix_web::{
6 body::EitherBody,
7 dev::{Service, ServiceResponse, Transform},
8};
9use futures_core::{future::LocalBoxFuture, ready};
10use futures_util::future::FutureExt as _;
11use pin_project_lite::pin_project;
12use std::{
13 future::Future,
14 pin::Pin,
15 task::{Context, Poll},
16};
17
18pub struct Condition<T> {
40 transformer: Option<T>,
41}
42
43impl<T> Condition<T> {
44 pub fn new(enable: bool, transformer: T) -> Self {
45 Self {
46 transformer: match enable {
47 true => Some(transformer),
48 false => None,
49 },
50 }
51 }
52
53 pub fn from_option(transformer: Option<T>) -> Self {
54 Self { transformer }
55 }
56}
57
58impl<S, T, Req, BE, BD, Err> Transform<S, Req> for Condition<T>
59where
60 S: Service<Req, Response = ServiceResponse<BD>, Error = Err> + 'static,
61 T: Transform<S, Req, Response = ServiceResponse<BE>, Error = Err>,
62 T::Future: 'static,
63 T::InitError: 'static,
64 T::Transform: 'static,
65{
66 type Response = ServiceResponse<EitherBody<BE, BD>>;
67 type Error = Err;
68 type Transform = ConditionMiddleware<T::Transform, S>;
69 type InitError = T::InitError;
70 type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>;
71
72 fn new_transform(&self, service: S) -> Self::Future {
73 if let Some(transformer) = &self.transformer {
74 let fut = transformer.new_transform(service);
75 async move {
76 let wrapped_svc = fut.await?;
77 Ok(ConditionMiddleware::Enable(wrapped_svc))
78 }
79 .boxed_local()
80 } else {
81 async move { Ok(ConditionMiddleware::Disable(service)) }.boxed_local()
82 }
83 }
84}
85
86pub enum ConditionMiddleware<E, D> {
87 Enable(E),
88 Disable(D),
89}
90
91impl<E, D, Req, BE, BD, Err> Service<Req> for ConditionMiddleware<E, D>
92where
93 E: Service<Req, Response = ServiceResponse<BE>, Error = Err>,
94 D: Service<Req, Response = ServiceResponse<BD>, Error = Err>,
95{
96 type Response = ServiceResponse<EitherBody<BE, BD>>;
97 type Error = Err;
98 type Future = ConditionMiddlewareFuture<E::Future, D::Future>;
99
100 fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
101 match self {
102 ConditionMiddleware::Enable(service) => service.poll_ready(cx),
103 ConditionMiddleware::Disable(service) => service.poll_ready(cx),
104 }
105 }
106
107 fn call(&self, req: Req) -> Self::Future {
108 match self {
109 ConditionMiddleware::Enable(service) => ConditionMiddlewareFuture::Enabled {
110 fut: service.call(req),
111 },
112 ConditionMiddleware::Disable(service) => ConditionMiddlewareFuture::Disabled {
113 fut: service.call(req),
114 },
115 }
116 }
117}
118
119pin_project! {
120 #[doc(hidden)]
121 #[project = ConditionProj]
122 pub enum ConditionMiddlewareFuture<E, D> {
123 Enabled { #[pin] fut: E, },
124 Disabled { #[pin] fut: D, },
125 }
126}
127
128impl<E, D, BE, BD, Err> Future for ConditionMiddlewareFuture<E, D>
129where
130 E: Future<Output = Result<ServiceResponse<BE>, Err>>,
131 D: Future<Output = Result<ServiceResponse<BD>, Err>>,
132{
133 type Output = Result<ServiceResponse<EitherBody<BE, BD>>, Err>;
134
135 #[inline]
136 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137 let res = match self.project() {
138 ConditionProj::Enabled { fut } => ready!(fut.poll(cx))?.map_into_left_body(),
139 ConditionProj::Disabled { fut } => ready!(fut.poll(cx))?.map_into_right_body(),
140 };
141
142 Poll::Ready(Ok(res))
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use actix_service::IntoService as _;
149 use futures_util::future::ok;
150
151 use super::*;
152 use actix_web::{
153 body::BoxBody,
154 dev::{ServiceRequest, ServiceResponse},
155 error::Result,
156 http::{
157 header::{HeaderValue, CONTENT_TYPE},
158 StatusCode,
159 },
160 middleware::{self, ErrorHandlerResponse, ErrorHandlers},
161 test::{self, TestRequest},
162 web::Bytes,
163 HttpResponse,
164 };
165
166 #[allow(clippy::unnecessary_wraps)]
167 fn render_500<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
168 res.response_mut()
169 .headers_mut()
170 .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
171
172 Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
173 }
174
175 #[test]
176 fn compat_with_builtin_middleware() {
177 let _ = Condition::new(true, middleware::Logger::default());
178 let _ = Condition::new(true, middleware::Compress::default());
179 let _ = Condition::new(true, middleware::NormalizePath::trim());
180 let _ = Condition::new(true, middleware::DefaultHeaders::new());
181 let _ = Condition::new(true, middleware::ErrorHandlers::<BoxBody>::new());
182 let _ = Condition::new(true, middleware::ErrorHandlers::<Bytes>::new());
183 }
184
185 fn create_optional_mw<B>(enabled: bool) -> Option<ErrorHandlers<B>>
186 where
187 B: 'static,
188 {
189 match enabled {
190 true => {
191 Some(ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500))
192 }
193 false => None,
194 }
195 }
196
197 #[actix_rt::test]
198 async fn test_handler_enabled() {
199 let srv = |req: ServiceRequest| async move {
200 let resp = HttpResponse::InternalServerError().message_body(String::new())?;
201 Ok(req.into_response(resp))
202 };
203
204 let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
205
206 let mw = Condition::new(true, mw)
207 .new_transform(srv.into_service())
208 .await
209 .unwrap();
210
211 let resp: ServiceResponse<EitherBody<EitherBody<_, _>, String>> =
212 test::call_service(&mw, TestRequest::default().to_srv_request()).await;
213 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
214 }
215
216 #[actix_rt::test]
217 async fn test_handler_disabled() {
218 let srv = |req: ServiceRequest| async move {
219 let resp = HttpResponse::InternalServerError().message_body(String::new())?;
220 Ok(req.into_response(resp))
221 };
222
223 let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
224
225 let mw = Condition::new(false, mw)
226 .new_transform(srv.into_service())
227 .await
228 .unwrap();
229
230 let resp: ServiceResponse<EitherBody<EitherBody<_, _>, String>> =
231 test::call_service(&mw, TestRequest::default().to_srv_request()).await;
232 assert_eq!(resp.headers().get(CONTENT_TYPE), None);
233 }
234
235 #[actix_rt::test]
236 async fn test_handler_some() {
237 let srv = |req: ServiceRequest| {
238 ok(req.into_response(HttpResponse::InternalServerError().finish()))
239 };
240
241 let mw = Condition::from_option(create_optional_mw(true))
242 .new_transform(srv.into_service())
243 .await
244 .unwrap();
245 let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
246 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
247 }
248
249 #[actix_rt::test]
250 async fn test_handler_none() {
251 let srv = |req: ServiceRequest| {
252 ok(req.into_response(HttpResponse::InternalServerError().finish()))
253 };
254
255 let mw = Condition::from_option(create_optional_mw(false))
256 .new_transform(srv.into_service())
257 .await
258 .unwrap();
259
260 let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
261 assert_eq!(resp.headers().get(CONTENT_TYPE), None);
262 }
263}