1use std::{
2 future::{Future, Ready, ready},
3 marker::PhantomData,
4 pin::Pin,
5 rc::Rc,
6 task::{Context, Poll},
7};
8
9use actix_service::{Service, Transform, forward_ready};
10use actix_web::{
11 Error,
12 body::MessageBody,
13 dev::{ServiceRequest, ServiceResponse},
14};
15use futures_core::ready;
16use pin_project_lite::pin_project;
17
18pub fn map_response<F>(mapper_fn: F) -> MapResMiddleware<F> {
51 MapResMiddleware {
52 mw_fn: Rc::new(mapper_fn),
53 }
54}
55
56#[allow(missing_debug_implementations)]
58pub struct MapResMiddleware<F> {
59 mw_fn: Rc<F>,
60}
61
62impl<S, F, Fut, B, B2> Transform<S, ServiceRequest> for MapResMiddleware<F>
63where
64 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
65 F: Fn(ServiceResponse<B>) -> Fut,
66 Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
67 B2: MessageBody,
68{
69 type Response = ServiceResponse<B2>;
70 type Error = Error;
71 type Transform = MapResService<S, F, B>;
72 type InitError = ();
73 type Future = Ready<Result<Self::Transform, Self::InitError>>;
74
75 fn new_transform(&self, service: S) -> Self::Future {
76 ready(Ok(MapResService {
77 service,
78 mw_fn: Rc::clone(&self.mw_fn),
79 _phantom: PhantomData,
80 }))
81 }
82}
83
84#[allow(missing_debug_implementations)]
86pub struct MapResService<S, F, B> {
87 service: S,
88 mw_fn: Rc<F>,
89 _phantom: PhantomData<(B,)>,
90}
91
92impl<S, F, Fut, B, B2> Service<ServiceRequest> for MapResService<S, F, B>
93where
94 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
95 F: Fn(ServiceResponse<B>) -> Fut,
96 Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
97 B2: MessageBody,
98{
99 type Response = ServiceResponse<B2>;
100 type Error = Error;
101 type Future = MapResFut<S::Future, F, Fut>;
102
103 forward_ready!(service);
104
105 fn call(&self, req: ServiceRequest) -> Self::Future {
106 let mw_fn = Rc::clone(&self.mw_fn);
107 let fut = self.service.call(req);
108
109 MapResFut {
110 mw_fn,
111 state: MapResFutState::Svc { fut },
112 }
113 }
114}
115
116pin_project! {
117 pub struct MapResFut<SvcFut, F, FnFut> {
118 mw_fn: Rc<F>,
119 #[pin]
120 state: MapResFutState<SvcFut, FnFut>,
121 }
122}
123
124pin_project! {
125 #[project = MapResFutStateProj]
126 enum MapResFutState<SvcFut, FnFut> {
127 Svc { #[pin] fut: SvcFut },
128 Fn { #[pin] fut: FnFut },
129 }
130}
131
132impl<SvcFut, B, F, FnFut, B2> Future for MapResFut<SvcFut, F, FnFut>
133where
134 SvcFut: Future<Output = Result<ServiceResponse<B>, Error>>,
135 F: Fn(ServiceResponse<B>) -> FnFut,
136 FnFut: Future<Output = Result<ServiceResponse<B2>, Error>>,
137{
138 type Output = Result<ServiceResponse<B2>, Error>;
139
140 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141 let mut this = self.as_mut().project();
142
143 match this.state.as_mut().project() {
144 MapResFutStateProj::Svc { fut } => {
145 let res = ready!(fut.poll(cx))?;
146
147 let fut = (this.mw_fn)(res);
148 this.state.set(MapResFutState::Fn { fut });
149 self.poll(cx)
150 }
151
152 MapResFutStateProj::Fn { fut } => fut.poll(cx),
153 }
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use actix_web::{
160 App, HttpResponse,
161 http::header::{self, HeaderValue},
162 middleware::{Compat, Logger},
163 test, web,
164 };
165
166 use super::*;
167
168 async fn noop(
169 res: ServiceResponse<impl MessageBody>,
170 ) -> Result<ServiceResponse<impl MessageBody>, Error> {
171 Ok(res)
172 }
173
174 async fn add_header(
175 mut res: ServiceResponse<impl MessageBody>,
176 ) -> Result<ServiceResponse<impl MessageBody>, Error> {
177 res.headers_mut()
178 .insert(header::WARNING, HeaderValue::from_static("42"));
179
180 Ok(res)
181 }
182
183 async fn mutate_body_type(
184 res: ServiceResponse<impl MessageBody + 'static>,
185 ) -> Result<ServiceResponse<impl MessageBody>, Error> {
186 Ok(res.map_into_left_body::<()>())
187 }
188
189 #[actix_web::test]
190 async fn compat_compat() {
191 let _ = App::new().wrap(Compat::new(map_response(noop)));
192 let _ = App::new().wrap(Compat::new(map_response(mutate_body_type)));
193 }
194
195 #[actix_web::test]
196 async fn feels_good() {
197 let app = test::init_service(
198 App::new()
199 .default_service(web::to(HttpResponse::Ok))
200 .wrap(map_response(|res| async move { Ok(res) }))
201 .wrap(map_response(noop))
202 .wrap(map_response(add_header))
203 .wrap(Logger::default())
204 .wrap(map_response(mutate_body_type)),
205 )
206 .await;
207
208 let req = test::TestRequest::default().to_request();
209 let res = test::call_service(&app, req).await;
210 assert!(res.headers().contains_key(header::WARNING));
211 }
212}