actix_web_lab/
middleware_map_response_body.rs

1use std::{
2    future::{Future, Ready, ready},
3    marker::PhantomData,
4    pin::Pin,
5    rc::Rc,
6    task::{Context, Poll},
7};
8
9use actix_service::{Service, Transform, forward_ready};
10use actix_web::{
11    Error, HttpRequest, HttpResponse,
12    body::MessageBody,
13    dev::{ServiceRequest, ServiceResponse},
14};
15use futures_core::ready;
16use pin_project_lite::pin_project;
17
18/// Creates a middleware from an async function that is used as a mapping function for an
19/// [`impl MessageBody`][MessageBody].
20///
21/// # Examples
22/// Completely replaces the body:
23/// ```
24/// # use actix_web_lab::middleware::map_response_body;
25/// use actix_web::{HttpRequest, body::MessageBody};
26///
27/// async fn replace_body(
28///     _req: HttpRequest,
29///     _: impl MessageBody,
30/// ) -> actix_web::Result<impl MessageBody> {
31///     Ok("foo".to_owned())
32/// }
33/// # actix_web::App::new().wrap(map_response_body(replace_body));
34/// ```
35///
36/// Appends some bytes to the body:
37/// ```
38/// # use actix_web_lab::middleware::map_response_body;
39/// use actix_web::{
40///     HttpRequest,
41///     body::{self, MessageBody},
42///     web::{BufMut as _, BytesMut},
43/// };
44///
45/// async fn append_bytes(
46///     _req: HttpRequest,
47///     body: impl MessageBody,
48/// ) -> actix_web::Result<impl MessageBody> {
49///     let buf = body::to_bytes(body).await.ok().unwrap();
50///
51///     let mut body = BytesMut::from(&buf[..]);
52///     body.put_slice(b" - hope you like things ruining your payload format");
53///
54///     Ok(body)
55/// }
56/// # actix_web::App::new().wrap(map_response_body(append_bytes));
57/// ```
58pub fn map_response_body<F>(mapper_fn: F) -> MapResBodyMiddleware<F> {
59    MapResBodyMiddleware {
60        mw_fn: Rc::new(mapper_fn),
61    }
62}
63
64/// Middleware transform for [`map_response_body`].
65#[allow(missing_debug_implementations)]
66pub struct MapResBodyMiddleware<F> {
67    mw_fn: Rc<F>,
68}
69
70impl<S, F, Fut, B, B2> Transform<S, ServiceRequest> for MapResBodyMiddleware<F>
71where
72    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
73    F: Fn(HttpRequest, B) -> Fut,
74    Fut: Future<Output = Result<B2, Error>>,
75    B2: MessageBody,
76{
77    type Response = ServiceResponse<B2>;
78    type Error = Error;
79    type Transform = MapResBodyService<S, F, B>;
80    type InitError = ();
81    type Future = Ready<Result<Self::Transform, Self::InitError>>;
82
83    fn new_transform(&self, service: S) -> Self::Future {
84        ready(Ok(MapResBodyService {
85            service,
86            mw_fn: Rc::clone(&self.mw_fn),
87            _phantom: PhantomData,
88        }))
89    }
90}
91
92/// Middleware service for [`from_fn`].
93#[allow(missing_debug_implementations)]
94pub struct MapResBodyService<S, F, B> {
95    service: S,
96    mw_fn: Rc<F>,
97    _phantom: PhantomData<(B,)>,
98}
99
100impl<S, F, Fut, B, B2> Service<ServiceRequest> for MapResBodyService<S, F, B>
101where
102    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
103    F: Fn(HttpRequest, B) -> Fut,
104    Fut: Future<Output = Result<B2, Error>>,
105    B2: MessageBody,
106{
107    type Response = ServiceResponse<B2>;
108    type Error = Error;
109    type Future = MapResBodyFut<S::Future, F, Fut>;
110
111    forward_ready!(service);
112
113    fn call(&self, req: ServiceRequest) -> Self::Future {
114        let mw_fn = Rc::clone(&self.mw_fn);
115        let fut = self.service.call(req);
116
117        MapResBodyFut {
118            mw_fn,
119            state: MapResBodyFutState::Svc { fut },
120        }
121    }
122}
123
124pin_project! {
125    pub struct MapResBodyFut<SvcFut, F, FnFut> {
126        mw_fn: Rc<F>,
127        #[pin]
128        state: MapResBodyFutState<SvcFut, FnFut>,
129    }
130}
131
132pin_project! {
133    #[project = MapResBodyFutStateProj]
134    enum MapResBodyFutState<SvcFut, FnFut> {
135        Svc { #[pin] fut: SvcFut },
136
137        Fn {
138            #[pin]
139            fut: FnFut,
140
141            req: Option<HttpRequest>,
142            res: Option<HttpResponse<()>>
143        },
144    }
145}
146
147impl<SvcFut, B, F, FnFut, B2> Future for MapResBodyFut<SvcFut, F, FnFut>
148where
149    SvcFut: Future<Output = Result<ServiceResponse<B>, Error>>,
150    F: Fn(HttpRequest, B) -> FnFut,
151    FnFut: Future<Output = Result<B2, Error>>,
152{
153    type Output = Result<ServiceResponse<B2>, Error>;
154
155    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
156        let mut this = self.as_mut().project();
157
158        match this.state.as_mut().project() {
159            MapResBodyFutStateProj::Svc { fut } => {
160                let res = ready!(fut.poll(cx))?;
161
162                let (req, res) = res.into_parts();
163                let (res, body) = res.into_parts();
164
165                let fut = (this.mw_fn)(req.clone(), body);
166                this.state.set(MapResBodyFutState::Fn {
167                    fut,
168                    req: Some(req),
169                    res: Some(res),
170                });
171
172                self.poll(cx)
173            }
174
175            MapResBodyFutStateProj::Fn { fut, req, res } => {
176                let body = ready!(fut.poll(cx))?;
177
178                let req = req.take().unwrap();
179                let res = res.take().unwrap();
180
181                let res = res.set_body(body);
182                let res = ServiceResponse::new(req, res);
183
184                Poll::Ready(Ok(res))
185            }
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use actix_web::{
193        App, HttpResponse,
194        middleware::{Compat, Logger},
195        test, web,
196    };
197
198    use super::*;
199
200    async fn noop(_req: HttpRequest, body: impl MessageBody) -> Result<impl MessageBody, Error> {
201        Ok(body)
202    }
203
204    async fn mutate_body_type(
205        _req: HttpRequest,
206        _body: impl MessageBody + 'static,
207    ) -> Result<impl MessageBody, Error> {
208        Ok("foo".to_owned())
209    }
210
211    #[actix_web::test]
212    async fn compat_compat() {
213        let _ = App::new().wrap(Compat::new(map_response_body(noop)));
214        let _ = App::new().wrap(Compat::new(map_response_body(mutate_body_type)));
215    }
216
217    #[actix_web::test]
218    async fn feels_good() {
219        let app = test::init_service(
220            App::new()
221                .default_service(web::to(HttpResponse::Ok))
222                .wrap(map_response_body(|_req, body| async move { Ok(body) }))
223                .wrap(map_response_body(noop))
224                .wrap(Logger::default())
225                .wrap(map_response_body(mutate_body_type)),
226        )
227        .await;
228
229        let req = test::TestRequest::default().to_request();
230        let body = test::call_and_read_body(&app, req).await;
231        assert_eq!(body, "foo");
232    }
233}