actix_web_lab/
middleware_map_response.rs

1use std::{
2    future::{Future, Ready, ready},
3    marker::PhantomData,
4    pin::Pin,
5    rc::Rc,
6    task::{Context, Poll},
7};
8
9use actix_service::{Service, Transform, forward_ready};
10use actix_web::{
11    Error,
12    body::MessageBody,
13    dev::{ServiceRequest, ServiceResponse},
14};
15use futures_core::ready;
16use pin_project_lite::pin_project;
17
18/// Creates a middleware from an async function that is used as a mapping function for a
19/// [`ServiceResponse`].
20///
21/// # Examples
22/// Adds header:
23/// ```
24/// # use actix_web_lab::middleware::map_response;
25/// use actix_web::{body::MessageBody, dev::ServiceResponse, http::header};
26///
27/// async fn add_header(
28///     mut res: ServiceResponse<impl MessageBody>,
29/// ) -> actix_web::Result<ServiceResponse<impl MessageBody>> {
30///     res.headers_mut()
31///         .insert(header::WARNING, header::HeaderValue::from_static("42"));
32///
33///     Ok(res)
34/// }
35/// # actix_web::App::new().wrap(map_response(add_header));
36/// ```
37///
38/// Maps body:
39/// ```
40/// # use actix_web_lab::middleware::map_response;
41/// use actix_web::{body::MessageBody, dev::ServiceResponse};
42///
43/// async fn mutate_body_type(
44///     res: ServiceResponse<impl MessageBody + 'static>,
45/// ) -> actix_web::Result<ServiceResponse<impl MessageBody>> {
46///     Ok(res.map_into_left_body::<()>())
47/// }
48/// # actix_web::App::new().wrap(map_response(mutate_body_type));
49/// ```
50pub fn map_response<F>(mapper_fn: F) -> MapResMiddleware<F> {
51    MapResMiddleware {
52        mw_fn: Rc::new(mapper_fn),
53    }
54}
55
56/// Middleware transform for [`map_response`].
57#[allow(missing_debug_implementations)]
58pub struct MapResMiddleware<F> {
59    mw_fn: Rc<F>,
60}
61
62impl<S, F, Fut, B, B2> Transform<S, ServiceRequest> for MapResMiddleware<F>
63where
64    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
65    F: Fn(ServiceResponse<B>) -> Fut,
66    Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
67    B2: MessageBody,
68{
69    type Response = ServiceResponse<B2>;
70    type Error = Error;
71    type Transform = MapResService<S, F, B>;
72    type InitError = ();
73    type Future = Ready<Result<Self::Transform, Self::InitError>>;
74
75    fn new_transform(&self, service: S) -> Self::Future {
76        ready(Ok(MapResService {
77            service,
78            mw_fn: Rc::clone(&self.mw_fn),
79            _phantom: PhantomData,
80        }))
81    }
82}
83
84/// Middleware service for [`from_fn`].
85#[allow(missing_debug_implementations)]
86pub struct MapResService<S, F, B> {
87    service: S,
88    mw_fn: Rc<F>,
89    _phantom: PhantomData<(B,)>,
90}
91
92impl<S, F, Fut, B, B2> Service<ServiceRequest> for MapResService<S, F, B>
93where
94    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
95    F: Fn(ServiceResponse<B>) -> Fut,
96    Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
97    B2: MessageBody,
98{
99    type Response = ServiceResponse<B2>;
100    type Error = Error;
101    type Future = MapResFut<S::Future, F, Fut>;
102
103    forward_ready!(service);
104
105    fn call(&self, req: ServiceRequest) -> Self::Future {
106        let mw_fn = Rc::clone(&self.mw_fn);
107        let fut = self.service.call(req);
108
109        MapResFut {
110            mw_fn,
111            state: MapResFutState::Svc { fut },
112        }
113    }
114}
115
116pin_project! {
117    pub struct MapResFut<SvcFut, F, FnFut> {
118        mw_fn: Rc<F>,
119        #[pin]
120        state: MapResFutState<SvcFut, FnFut>,
121    }
122}
123
124pin_project! {
125    #[project = MapResFutStateProj]
126    enum MapResFutState<SvcFut, FnFut> {
127        Svc { #[pin] fut: SvcFut },
128        Fn { #[pin] fut: FnFut },
129    }
130}
131
132impl<SvcFut, B, F, FnFut, B2> Future for MapResFut<SvcFut, F, FnFut>
133where
134    SvcFut: Future<Output = Result<ServiceResponse<B>, Error>>,
135    F: Fn(ServiceResponse<B>) -> FnFut,
136    FnFut: Future<Output = Result<ServiceResponse<B2>, Error>>,
137{
138    type Output = Result<ServiceResponse<B2>, Error>;
139
140    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141        let mut this = self.as_mut().project();
142
143        match this.state.as_mut().project() {
144            MapResFutStateProj::Svc { fut } => {
145                let res = ready!(fut.poll(cx))?;
146
147                let fut = (this.mw_fn)(res);
148                this.state.set(MapResFutState::Fn { fut });
149                self.poll(cx)
150            }
151
152            MapResFutStateProj::Fn { fut } => fut.poll(cx),
153        }
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use actix_web::{
160        App, HttpResponse,
161        http::header::{self, HeaderValue},
162        middleware::{Compat, Logger},
163        test, web,
164    };
165
166    use super::*;
167
168    async fn noop(
169        res: ServiceResponse<impl MessageBody>,
170    ) -> Result<ServiceResponse<impl MessageBody>, Error> {
171        Ok(res)
172    }
173
174    async fn add_header(
175        mut res: ServiceResponse<impl MessageBody>,
176    ) -> Result<ServiceResponse<impl MessageBody>, Error> {
177        res.headers_mut()
178            .insert(header::WARNING, HeaderValue::from_static("42"));
179
180        Ok(res)
181    }
182
183    async fn mutate_body_type(
184        res: ServiceResponse<impl MessageBody + 'static>,
185    ) -> Result<ServiceResponse<impl MessageBody>, Error> {
186        Ok(res.map_into_left_body::<()>())
187    }
188
189    #[actix_web::test]
190    async fn compat_compat() {
191        let _ = App::new().wrap(Compat::new(map_response(noop)));
192        let _ = App::new().wrap(Compat::new(map_response(mutate_body_type)));
193    }
194
195    #[actix_web::test]
196    async fn feels_good() {
197        let app = test::init_service(
198            App::new()
199                .default_service(web::to(HttpResponse::Ok))
200                .wrap(map_response(|res| async move { Ok(res) }))
201                .wrap(map_response(noop))
202                .wrap(map_response(add_header))
203                .wrap(Logger::default())
204                .wrap(map_response(mutate_body_type)),
205        )
206        .await;
207
208        let req = test::TestRequest::default().to_request();
209        let res = test::call_service(&app, req).await;
210        assert!(res.headers().contains_key(header::WARNING));
211    }
212}