ntex_service/
apply.rs

1#![allow(clippy::type_complexity)]
2use std::{fmt, future::Future, marker};
3
4use super::{
5    IntoService, IntoServiceFactory, Pipeline, Service, ServiceCtx, ServiceFactory,
6};
7
8/// Apply transform function to a service.
9pub fn apply_fn<T, Req, F, R, In, Out, Err, U>(
10    service: U,
11    f: F,
12) -> Apply<T, Req, F, R, In, Out, Err>
13where
14    T: Service<Req>,
15    F: Fn(In, Pipeline<T>) -> R,
16    R: Future<Output = Result<Out, Err>>,
17    U: IntoService<T, Req>,
18    Err: From<T::Error>,
19{
20    Apply::new(service.into_service(), f)
21}
22
23/// Service factory that produces `apply_fn` service.
24pub fn apply_fn_factory<T, Req, Cfg, F, R, In, Out, Err, U>(
25    service: U,
26    f: F,
27) -> ApplyFactory<T, Req, Cfg, F, R, In, Out, Err>
28where
29    T: ServiceFactory<Req, Cfg>,
30    F: Fn(In, Pipeline<T::Service>) -> R + Clone,
31    R: Future<Output = Result<Out, Err>>,
32    U: IntoServiceFactory<T, Req, Cfg>,
33    Err: From<T::Error>,
34{
35    ApplyFactory::new(service.into_factory(), f)
36}
37
38/// `Apply` service combinator
39pub struct Apply<T, Req, F, R, In, Out, Err>
40where
41    T: Service<Req>,
42{
43    service: Pipeline<T>,
44    f: F,
45    r: marker::PhantomData<fn(Req) -> (In, Out, R, Err)>,
46}
47
48impl<T, Req, F, R, In, Out, Err> Apply<T, Req, F, R, In, Out, Err>
49where
50    T: Service<Req>,
51    F: Fn(In, Pipeline<T>) -> R,
52    R: Future<Output = Result<Out, Err>>,
53    Err: From<T::Error>,
54{
55    pub(crate) fn new(service: T, f: F) -> Self {
56        Apply {
57            f,
58            service: Pipeline::new(service),
59            r: marker::PhantomData,
60        }
61    }
62}
63
64impl<T, Req, F, R, In, Out, Err> Clone for Apply<T, Req, F, R, In, Out, Err>
65where
66    T: Service<Req> + Clone,
67    F: Fn(In, Pipeline<T>) -> R + Clone,
68    R: Future<Output = Result<Out, Err>>,
69    Err: From<T::Error>,
70{
71    fn clone(&self) -> Self {
72        Apply {
73            service: self.service.clone(),
74            f: self.f.clone(),
75            r: marker::PhantomData,
76        }
77    }
78}
79
80impl<T, Req, F, R, In, Out, Err> fmt::Debug for Apply<T, Req, F, R, In, Out, Err>
81where
82    T: Service<Req> + fmt::Debug,
83{
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        f.debug_struct("Apply")
86            .field("service", &self.service)
87            .field("map", &std::any::type_name::<F>())
88            .finish()
89    }
90}
91
92impl<T, Req, F, R, In, Out, Err> Service<In> for Apply<T, Req, F, R, In, Out, Err>
93where
94    T: Service<Req>,
95    F: Fn(In, Pipeline<T>) -> R,
96    R: Future<Output = Result<Out, Err>>,
97    Err: From<T::Error>,
98{
99    type Response = Out;
100    type Error = Err;
101
102    #[inline]
103    async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Err> {
104        self.service.ready().await.map_err(From::from)
105    }
106
107    #[inline]
108    async fn call(
109        &self,
110        req: In,
111        _: ServiceCtx<'_, Self>,
112    ) -> Result<Self::Response, Self::Error> {
113        (self.f)(req, self.service.clone()).await
114    }
115
116    crate::forward_poll!(service);
117    crate::forward_shutdown!(service);
118}
119
120/// `apply()` service factory
121pub struct ApplyFactory<T, Req, Cfg, F, R, In, Out, Err>
122where
123    T: ServiceFactory<Req, Cfg>,
124    F: Fn(In, Pipeline<T::Service>) -> R + Clone,
125    R: Future<Output = Result<Out, Err>>,
126{
127    service: T,
128    f: F,
129    r: marker::PhantomData<fn(Req, Cfg) -> (R, In, Out)>,
130}
131
132impl<T, Req, Cfg, F, R, In, Out, Err> ApplyFactory<T, Req, Cfg, F, R, In, Out, Err>
133where
134    T: ServiceFactory<Req, Cfg>,
135    F: Fn(In, Pipeline<T::Service>) -> R + Clone,
136    R: Future<Output = Result<Out, Err>>,
137    Err: From<T::Error>,
138{
139    /// Create new `ApplyNewService` new service instance
140    pub(crate) fn new(service: T, f: F) -> Self {
141        Self {
142            f,
143            service,
144            r: marker::PhantomData,
145        }
146    }
147}
148
149impl<T, Req, Cfg, F, R, In, Out, Err> Clone
150    for ApplyFactory<T, Req, Cfg, F, R, In, Out, Err>
151where
152    T: ServiceFactory<Req, Cfg> + Clone,
153    F: Fn(In, Pipeline<T::Service>) -> R + Clone,
154    R: Future<Output = Result<Out, Err>>,
155    Err: From<T::Error>,
156{
157    fn clone(&self) -> Self {
158        Self {
159            service: self.service.clone(),
160            f: self.f.clone(),
161            r: marker::PhantomData,
162        }
163    }
164}
165
166impl<T, Req, Cfg, F, R, In, Out, Err> fmt::Debug
167    for ApplyFactory<T, Req, Cfg, F, R, In, Out, Err>
168where
169    T: ServiceFactory<Req, Cfg> + fmt::Debug,
170    F: Fn(In, Pipeline<T::Service>) -> R + Clone,
171    R: Future<Output = Result<Out, Err>>,
172    Err: From<T::Error>,
173{
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        f.debug_struct("ApplyFactory")
176            .field("factory", &self.service)
177            .field("map", &std::any::type_name::<F>())
178            .finish()
179    }
180}
181
182impl<T, Req, Cfg, F, R, In, Out, Err> ServiceFactory<In, Cfg>
183    for ApplyFactory<T, Req, Cfg, F, R, In, Out, Err>
184where
185    T: ServiceFactory<Req, Cfg>,
186    F: Fn(In, Pipeline<T::Service>) -> R + Clone,
187    R: Future<Output = Result<Out, Err>>,
188    Err: From<T::Error>,
189{
190    type Response = Out;
191    type Error = Err;
192
193    type Service = Apply<T::Service, Req, F, R, In, Out, Err>;
194    type InitError = T::InitError;
195
196    #[inline]
197    async fn create(&self, cfg: Cfg) -> Result<Self::Service, Self::InitError> {
198        self.service.create(cfg).await.map(|svc| Apply {
199            service: svc.into(),
200            f: self.f.clone(),
201            r: marker::PhantomData,
202        })
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use ntex_util::future::lazy;
209    use std::{cell::Cell, rc::Rc, task::Context};
210
211    use super::*;
212    use crate::{chain, chain_factory, fn_factory};
213
214    #[derive(Debug, Default, Clone)]
215    struct Srv(Rc<Cell<usize>>);
216
217    impl Service<()> for Srv {
218        type Response = ();
219        type Error = ();
220
221        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
222            Ok(())
223        }
224
225        fn poll(&self, _: &mut Context<'_>) -> Result<(), Self::Error> {
226            self.0.set(self.0.get() + 1);
227            Ok(())
228        }
229
230        async fn shutdown(&self) {
231            self.0.set(self.0.get() + 1);
232        }
233    }
234
235    #[derive(Debug, PartialEq, Eq)]
236    struct Err;
237
238    impl From<()> for Err {
239        fn from(_: ()) -> Self {
240            Err
241        }
242    }
243
244    #[ntex::test]
245    async fn test_call() {
246        let cnt_sht = Rc::new(Cell::new(0));
247        let srv = chain(
248            apply_fn(Srv(cnt_sht.clone()), |req: &'static str, svc| async move {
249                svc.call(()).await.unwrap();
250                Ok((req, ()))
251            })
252            .clone(),
253        )
254        .into_pipeline();
255
256        assert_eq!(srv.ready().await, Ok::<_, Err>(()));
257
258        lazy(|cx| srv.poll(cx)).await.unwrap();
259        assert_eq!(cnt_sht.get(), 1);
260
261        srv.shutdown().await;
262        assert_eq!(cnt_sht.get(), 2);
263
264        let res = srv.call("srv").await;
265        assert!(res.is_ok());
266        assert_eq!(res.unwrap(), ("srv", ()));
267    }
268
269    #[ntex::test]
270    async fn test_call_chain() {
271        let cnt_sht = Rc::new(Cell::new(0));
272        let srv = chain(Srv(cnt_sht.clone()))
273            .apply_fn(|req: &'static str, svc| async move {
274                svc.call(()).await.unwrap();
275                Ok((req, ()))
276            })
277            .clone()
278            .into_pipeline();
279
280        assert_eq!(srv.ready().await, Ok::<_, Err>(()));
281
282        srv.shutdown().await;
283        assert_eq!(cnt_sht.get(), 1);
284
285        let res = srv.call("srv").await;
286        assert!(res.is_ok());
287        assert_eq!(res.unwrap(), ("srv", ()));
288        let _ = format!("{:?}", srv);
289    }
290
291    #[ntex::test]
292    async fn test_create() {
293        let new_srv = chain_factory(
294            apply_fn_factory(
295                fn_factory(|| async { Ok::<_, ()>(Srv::default()) }),
296                |req: &'static str, srv| async move {
297                    srv.call(()).await.unwrap();
298                    Ok((req, ()))
299                },
300            )
301            .clone(),
302        );
303
304        let srv = new_srv.pipeline(&()).await.unwrap();
305
306        assert_eq!(srv.ready().await, Ok::<_, Err>(()));
307
308        let res = srv.call("srv").await;
309        assert!(res.is_ok());
310        assert_eq!(res.unwrap(), ("srv", ()));
311        let _ = format!("{:?}", new_srv);
312
313        assert!(Err == Err::from(()));
314    }
315
316    #[ntex::test]
317    async fn test_create_chain() {
318        let new_srv = chain_factory(fn_factory(|| async { Ok::<_, ()>(Srv::default()) }))
319            .apply_fn(|req: &'static str, srv| async move {
320                srv.call(()).await.unwrap();
321                Ok((req, ()))
322            })
323            .clone();
324
325        let srv = new_srv.pipeline(&()).await.unwrap();
326
327        assert_eq!(srv.ready().await, Ok::<_, Err>(()));
328
329        let res = srv.call("srv").await;
330        assert!(res.is_ok());
331        assert_eq!(res.unwrap(), ("srv", ()));
332        let _ = format!("{:?}", new_srv);
333    }
334}