1#![allow(clippy::type_complexity)]
2use std::{fmt, future::Future, marker};
3
4use super::{
5 IntoService, IntoServiceFactory, Pipeline, Service, ServiceCtx, ServiceFactory,
6};
7
8pub fn apply_fn<T, Req, F, R, In, Out, Err, U>(
10 service: U,
11 f: F,
12) -> Apply<T, Req, F, R, In, Out, Err>
13where
14 T: Service<Req>,
15 F: Fn(In, Pipeline<T>) -> R,
16 R: Future<Output = Result<Out, Err>>,
17 U: IntoService<T, Req>,
18 Err: From<T::Error>,
19{
20 Apply::new(service.into_service(), f)
21}
22
23pub fn apply_fn_factory<T, Req, Cfg, F, R, In, Out, Err, U>(
25 service: U,
26 f: F,
27) -> ApplyFactory<T, Req, Cfg, F, R, In, Out, Err>
28where
29 T: ServiceFactory<Req, Cfg>,
30 F: Fn(In, Pipeline<T::Service>) -> R + Clone,
31 R: Future<Output = Result<Out, Err>>,
32 U: IntoServiceFactory<T, Req, Cfg>,
33 Err: From<T::Error>,
34{
35 ApplyFactory::new(service.into_factory(), f)
36}
37
38pub struct Apply<T, Req, F, R, In, Out, Err>
40where
41 T: Service<Req>,
42{
43 service: Pipeline<T>,
44 f: F,
45 r: marker::PhantomData<fn(Req) -> (In, Out, R, Err)>,
46}
47
48impl<T, Req, F, R, In, Out, Err> Apply<T, Req, F, R, In, Out, Err>
49where
50 T: Service<Req>,
51 F: Fn(In, Pipeline<T>) -> R,
52 R: Future<Output = Result<Out, Err>>,
53 Err: From<T::Error>,
54{
55 pub(crate) fn new(service: T, f: F) -> Self {
56 Apply {
57 f,
58 service: Pipeline::new(service),
59 r: marker::PhantomData,
60 }
61 }
62}
63
64impl<T, Req, F, R, In, Out, Err> Clone for Apply<T, Req, F, R, In, Out, Err>
65where
66 T: Service<Req> + Clone,
67 F: Fn(In, Pipeline<T>) -> R + Clone,
68 R: Future<Output = Result<Out, Err>>,
69 Err: From<T::Error>,
70{
71 fn clone(&self) -> Self {
72 Apply {
73 service: self.service.clone(),
74 f: self.f.clone(),
75 r: marker::PhantomData,
76 }
77 }
78}
79
80impl<T, Req, F, R, In, Out, Err> fmt::Debug for Apply<T, Req, F, R, In, Out, Err>
81where
82 T: Service<Req> + fmt::Debug,
83{
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 f.debug_struct("Apply")
86 .field("service", &self.service)
87 .field("map", &std::any::type_name::<F>())
88 .finish()
89 }
90}
91
92impl<T, Req, F, R, In, Out, Err> Service<In> for Apply<T, Req, F, R, In, Out, Err>
93where
94 T: Service<Req>,
95 F: Fn(In, Pipeline<T>) -> R,
96 R: Future<Output = Result<Out, Err>>,
97 Err: From<T::Error>,
98{
99 type Response = Out;
100 type Error = Err;
101
102 #[inline]
103 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Err> {
104 self.service.ready().await.map_err(From::from)
105 }
106
107 #[inline]
108 async fn call(
109 &self,
110 req: In,
111 _: ServiceCtx<'_, Self>,
112 ) -> Result<Self::Response, Self::Error> {
113 (self.f)(req, self.service.clone()).await
114 }
115
116 crate::forward_poll!(service);
117 crate::forward_shutdown!(service);
118}
119
120pub struct ApplyFactory<T, Req, Cfg, F, R, In, Out, Err>
122where
123 T: ServiceFactory<Req, Cfg>,
124 F: Fn(In, Pipeline<T::Service>) -> R + Clone,
125 R: Future<Output = Result<Out, Err>>,
126{
127 service: T,
128 f: F,
129 r: marker::PhantomData<fn(Req, Cfg) -> (R, In, Out)>,
130}
131
132impl<T, Req, Cfg, F, R, In, Out, Err> ApplyFactory<T, Req, Cfg, F, R, In, Out, Err>
133where
134 T: ServiceFactory<Req, Cfg>,
135 F: Fn(In, Pipeline<T::Service>) -> R + Clone,
136 R: Future<Output = Result<Out, Err>>,
137 Err: From<T::Error>,
138{
139 pub(crate) fn new(service: T, f: F) -> Self {
141 Self {
142 f,
143 service,
144 r: marker::PhantomData,
145 }
146 }
147}
148
149impl<T, Req, Cfg, F, R, In, Out, Err> Clone
150 for ApplyFactory<T, Req, Cfg, F, R, In, Out, Err>
151where
152 T: ServiceFactory<Req, Cfg> + Clone,
153 F: Fn(In, Pipeline<T::Service>) -> R + Clone,
154 R: Future<Output = Result<Out, Err>>,
155 Err: From<T::Error>,
156{
157 fn clone(&self) -> Self {
158 Self {
159 service: self.service.clone(),
160 f: self.f.clone(),
161 r: marker::PhantomData,
162 }
163 }
164}
165
166impl<T, Req, Cfg, F, R, In, Out, Err> fmt::Debug
167 for ApplyFactory<T, Req, Cfg, F, R, In, Out, Err>
168where
169 T: ServiceFactory<Req, Cfg> + fmt::Debug,
170 F: Fn(In, Pipeline<T::Service>) -> R + Clone,
171 R: Future<Output = Result<Out, Err>>,
172 Err: From<T::Error>,
173{
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175 f.debug_struct("ApplyFactory")
176 .field("factory", &self.service)
177 .field("map", &std::any::type_name::<F>())
178 .finish()
179 }
180}
181
182impl<T, Req, Cfg, F, R, In, Out, Err> ServiceFactory<In, Cfg>
183 for ApplyFactory<T, Req, Cfg, F, R, In, Out, Err>
184where
185 T: ServiceFactory<Req, Cfg>,
186 F: Fn(In, Pipeline<T::Service>) -> R + Clone,
187 R: Future<Output = Result<Out, Err>>,
188 Err: From<T::Error>,
189{
190 type Response = Out;
191 type Error = Err;
192
193 type Service = Apply<T::Service, Req, F, R, In, Out, Err>;
194 type InitError = T::InitError;
195
196 #[inline]
197 async fn create(&self, cfg: Cfg) -> Result<Self::Service, Self::InitError> {
198 self.service.create(cfg).await.map(|svc| Apply {
199 service: svc.into(),
200 f: self.f.clone(),
201 r: marker::PhantomData,
202 })
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use ntex_util::future::lazy;
209 use std::{cell::Cell, rc::Rc, task::Context};
210
211 use super::*;
212 use crate::{chain, chain_factory, fn_factory};
213
214 #[derive(Debug, Default, Clone)]
215 struct Srv(Rc<Cell<usize>>);
216
217 impl Service<()> for Srv {
218 type Response = ();
219 type Error = ();
220
221 async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
222 Ok(())
223 }
224
225 fn poll(&self, _: &mut Context<'_>) -> Result<(), Self::Error> {
226 self.0.set(self.0.get() + 1);
227 Ok(())
228 }
229
230 async fn shutdown(&self) {
231 self.0.set(self.0.get() + 1);
232 }
233 }
234
235 #[derive(Debug, PartialEq, Eq)]
236 struct Err;
237
238 impl From<()> for Err {
239 fn from(_: ()) -> Self {
240 Err
241 }
242 }
243
244 #[ntex::test]
245 async fn test_call() {
246 let cnt_sht = Rc::new(Cell::new(0));
247 let srv = chain(
248 apply_fn(Srv(cnt_sht.clone()), |req: &'static str, svc| async move {
249 svc.call(()).await.unwrap();
250 Ok((req, ()))
251 })
252 .clone(),
253 )
254 .into_pipeline();
255
256 assert_eq!(srv.ready().await, Ok::<_, Err>(()));
257
258 lazy(|cx| srv.poll(cx)).await.unwrap();
259 assert_eq!(cnt_sht.get(), 1);
260
261 srv.shutdown().await;
262 assert_eq!(cnt_sht.get(), 2);
263
264 let res = srv.call("srv").await;
265 assert!(res.is_ok());
266 assert_eq!(res.unwrap(), ("srv", ()));
267 }
268
269 #[ntex::test]
270 async fn test_call_chain() {
271 let cnt_sht = Rc::new(Cell::new(0));
272 let srv = chain(Srv(cnt_sht.clone()))
273 .apply_fn(|req: &'static str, svc| async move {
274 svc.call(()).await.unwrap();
275 Ok((req, ()))
276 })
277 .clone()
278 .into_pipeline();
279
280 assert_eq!(srv.ready().await, Ok::<_, Err>(()));
281
282 srv.shutdown().await;
283 assert_eq!(cnt_sht.get(), 1);
284
285 let res = srv.call("srv").await;
286 assert!(res.is_ok());
287 assert_eq!(res.unwrap(), ("srv", ()));
288 let _ = format!("{:?}", srv);
289 }
290
291 #[ntex::test]
292 async fn test_create() {
293 let new_srv = chain_factory(
294 apply_fn_factory(
295 fn_factory(|| async { Ok::<_, ()>(Srv::default()) }),
296 |req: &'static str, srv| async move {
297 srv.call(()).await.unwrap();
298 Ok((req, ()))
299 },
300 )
301 .clone(),
302 );
303
304 let srv = new_srv.pipeline(&()).await.unwrap();
305
306 assert_eq!(srv.ready().await, Ok::<_, Err>(()));
307
308 let res = srv.call("srv").await;
309 assert!(res.is_ok());
310 assert_eq!(res.unwrap(), ("srv", ()));
311 let _ = format!("{:?}", new_srv);
312
313 assert!(Err == Err::from(()));
314 }
315
316 #[ntex::test]
317 async fn test_create_chain() {
318 let new_srv = chain_factory(fn_factory(|| async { Ok::<_, ()>(Srv::default()) }))
319 .apply_fn(|req: &'static str, srv| async move {
320 srv.call(()).await.unwrap();
321 Ok((req, ()))
322 })
323 .clone();
324
325 let srv = new_srv.pipeline(&()).await.unwrap();
326
327 assert_eq!(srv.ready().await, Ok::<_, Err>(()));
328
329 let res = srv.call("srv").await;
330 assert!(res.is_ok());
331 assert_eq!(res.unwrap(), ("srv", ()));
332 let _ = format!("{:?}", new_srv);
333 }
334}