1use std::{
2 future::{Future, Ready, ready},
3 marker::PhantomData,
4 pin::Pin,
5 rc::Rc,
6 task::{Context, Poll},
7};
8
9use actix_service::{Service, Transform, forward_ready};
10use actix_web::{
11 Error, HttpRequest, HttpResponse,
12 body::MessageBody,
13 dev::{ServiceRequest, ServiceResponse},
14};
15use futures_core::ready;
16use pin_project_lite::pin_project;
17
18pub fn map_response_body<F>(mapper_fn: F) -> MapResBodyMiddleware<F> {
59 MapResBodyMiddleware {
60 mw_fn: Rc::new(mapper_fn),
61 }
62}
63
64#[allow(missing_debug_implementations)]
66pub struct MapResBodyMiddleware<F> {
67 mw_fn: Rc<F>,
68}
69
70impl<S, F, Fut, B, B2> Transform<S, ServiceRequest> for MapResBodyMiddleware<F>
71where
72 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
73 F: Fn(HttpRequest, B) -> Fut,
74 Fut: Future<Output = Result<B2, Error>>,
75 B2: MessageBody,
76{
77 type Response = ServiceResponse<B2>;
78 type Error = Error;
79 type Transform = MapResBodyService<S, F, B>;
80 type InitError = ();
81 type Future = Ready<Result<Self::Transform, Self::InitError>>;
82
83 fn new_transform(&self, service: S) -> Self::Future {
84 ready(Ok(MapResBodyService {
85 service,
86 mw_fn: Rc::clone(&self.mw_fn),
87 _phantom: PhantomData,
88 }))
89 }
90}
91
92#[allow(missing_debug_implementations)]
94pub struct MapResBodyService<S, F, B> {
95 service: S,
96 mw_fn: Rc<F>,
97 _phantom: PhantomData<(B,)>,
98}
99
100impl<S, F, Fut, B, B2> Service<ServiceRequest> for MapResBodyService<S, F, B>
101where
102 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
103 F: Fn(HttpRequest, B) -> Fut,
104 Fut: Future<Output = Result<B2, Error>>,
105 B2: MessageBody,
106{
107 type Response = ServiceResponse<B2>;
108 type Error = Error;
109 type Future = MapResBodyFut<S::Future, F, Fut>;
110
111 forward_ready!(service);
112
113 fn call(&self, req: ServiceRequest) -> Self::Future {
114 let mw_fn = Rc::clone(&self.mw_fn);
115 let fut = self.service.call(req);
116
117 MapResBodyFut {
118 mw_fn,
119 state: MapResBodyFutState::Svc { fut },
120 }
121 }
122}
123
124pin_project! {
125 pub struct MapResBodyFut<SvcFut, F, FnFut> {
126 mw_fn: Rc<F>,
127 #[pin]
128 state: MapResBodyFutState<SvcFut, FnFut>,
129 }
130}
131
132pin_project! {
133 #[project = MapResBodyFutStateProj]
134 enum MapResBodyFutState<SvcFut, FnFut> {
135 Svc { #[pin] fut: SvcFut },
136
137 Fn {
138 #[pin]
139 fut: FnFut,
140
141 req: Option<HttpRequest>,
142 res: Option<HttpResponse<()>>
143 },
144 }
145}
146
147impl<SvcFut, B, F, FnFut, B2> Future for MapResBodyFut<SvcFut, F, FnFut>
148where
149 SvcFut: Future<Output = Result<ServiceResponse<B>, Error>>,
150 F: Fn(HttpRequest, B) -> FnFut,
151 FnFut: Future<Output = Result<B2, Error>>,
152{
153 type Output = Result<ServiceResponse<B2>, Error>;
154
155 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
156 let mut this = self.as_mut().project();
157
158 match this.state.as_mut().project() {
159 MapResBodyFutStateProj::Svc { fut } => {
160 let res = ready!(fut.poll(cx))?;
161
162 let (req, res) = res.into_parts();
163 let (res, body) = res.into_parts();
164
165 let fut = (this.mw_fn)(req.clone(), body);
166 this.state.set(MapResBodyFutState::Fn {
167 fut,
168 req: Some(req),
169 res: Some(res),
170 });
171
172 self.poll(cx)
173 }
174
175 MapResBodyFutStateProj::Fn { fut, req, res } => {
176 let body = ready!(fut.poll(cx))?;
177
178 let req = req.take().unwrap();
179 let res = res.take().unwrap();
180
181 let res = res.set_body(body);
182 let res = ServiceResponse::new(req, res);
183
184 Poll::Ready(Ok(res))
185 }
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use actix_web::{
193 App, HttpResponse,
194 middleware::{Compat, Logger},
195 test, web,
196 };
197
198 use super::*;
199
200 async fn noop(_req: HttpRequest, body: impl MessageBody) -> Result<impl MessageBody, Error> {
201 Ok(body)
202 }
203
204 async fn mutate_body_type(
205 _req: HttpRequest,
206 _body: impl MessageBody + 'static,
207 ) -> Result<impl MessageBody, Error> {
208 Ok("foo".to_owned())
209 }
210
211 #[actix_web::test]
212 async fn compat_compat() {
213 let _ = App::new().wrap(Compat::new(map_response_body(noop)));
214 let _ = App::new().wrap(Compat::new(map_response_body(mutate_body_type)));
215 }
216
217 #[actix_web::test]
218 async fn feels_good() {
219 let app = test::init_service(
220 App::new()
221 .default_service(web::to(HttpResponse::Ok))
222 .wrap(map_response_body(|_req, body| async move { Ok(body) }))
223 .wrap(map_response_body(noop))
224 .wrap(Logger::default())
225 .wrap(map_response_body(mutate_body_type)),
226 )
227 .await;
228
229 let req = test::TestRequest::default().to_request();
230 let body = test::call_and_read_body(&app, req).await;
231 assert_eq!(body, "foo");
232 }
233}