Skip to main content

ntex_service/
apply.rs

1#![allow(clippy::type_complexity)]
2use std::{fmt, marker};
3
4use crate::dev::{ServiceChain, ServiceChainFactory};
5use crate::{
6    IntoService, IntoServiceFactory, Pipeline, Service, ServiceCtx, ServiceFactory,
7};
8
9/// Apply transform function to a service.
10pub fn apply_fn<T, Req, F, In, Out, Err, U>(
11    service: U,
12    f: F,
13) -> ServiceChain<Apply<T, Req, F, In, Out, Err>, In>
14where
15    T: Service<Req>,
16    F: AsyncFn(In, &Pipeline<T>) -> Result<Out, Err>,
17    U: IntoService<T, Req>,
18    Err: From<T::Error>,
19{
20    crate::chain(Apply::new(service.into_service(), f))
21}
22
23/// Service factory that produces `apply_fn` service.
24pub fn apply_fn_factory<T, Req, Cfg, F, In, Out, Err, U>(
25    service: U,
26    f: F,
27) -> ServiceChainFactory<ApplyFactory<T, Req, Cfg, F, In, Out, Err>, In, Cfg>
28where
29    T: ServiceFactory<Req, Cfg>,
30    F: AsyncFn(In, &Pipeline<T::Service>) -> Result<Out, Err> + Clone,
31    U: IntoServiceFactory<T, Req, Cfg>,
32    Err: From<T::Error>,
33{
34    crate::chain_factory(ApplyFactory::new(service.into_factory(), f))
35}
36
37/// `Apply` service combinator
38pub struct Apply<T, Req, F, In, Out, Err>
39where
40    T: Service<Req>,
41{
42    service: Pipeline<T>,
43    f: F,
44    r: marker::PhantomData<fn(Req) -> (In, Out, Err)>,
45}
46
47impl<T, Req, F, In, Out, Err> Apply<T, Req, F, In, Out, Err>
48where
49    T: Service<Req>,
50    F: AsyncFn(In, &Pipeline<T>) -> Result<Out, Err>,
51    Err: From<T::Error>,
52{
53    pub(crate) fn new(service: T, f: F) -> Self {
54        Apply {
55            f,
56            service: Pipeline::new(service),
57            r: marker::PhantomData,
58        }
59    }
60}
61
62impl<T, Req, F, In, Out, Err> Clone for Apply<T, Req, F, In, Out, Err>
63where
64    T: Service<Req> + Clone,
65    F: AsyncFn(In, &Pipeline<T>) -> Result<Out, Err> + Clone,
66    Err: From<T::Error>,
67{
68    fn clone(&self) -> Self {
69        Apply {
70            service: self.service.clone(),
71            f: self.f.clone(),
72            r: marker::PhantomData,
73        }
74    }
75}
76
77impl<T, Req, F, In, Out, Err> fmt::Debug for Apply<T, Req, F, In, Out, Err>
78where
79    T: Service<Req> + fmt::Debug,
80{
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        f.debug_struct("Apply")
83            .field("service", &self.service)
84            .field("map", &std::any::type_name::<F>())
85            .finish()
86    }
87}
88
89impl<T, Req, F, In, Out, Err> Service<In> for Apply<T, Req, F, In, Out, Err>
90where
91    T: Service<Req>,
92    F: AsyncFn(In, &Pipeline<T>) -> Result<Out, Err>,
93    Err: From<T::Error>,
94{
95    type Response = Out;
96    type Error = Err;
97
98    #[inline]
99    async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Err> {
100        self.service.ready().await.map_err(From::from)
101    }
102
103    #[inline]
104    async fn call(
105        &self,
106        req: In,
107        _: ServiceCtx<'_, Self>,
108    ) -> Result<Self::Response, Self::Error> {
109        (self.f)(req, &self.service).await
110    }
111
112    crate::forward_poll!(service);
113    crate::forward_shutdown!(service);
114}
115
116/// `apply()` service factory
117pub struct ApplyFactory<T, Req, Cfg, F, In, Out, Err>
118where
119    T: ServiceFactory<Req, Cfg>,
120    F: AsyncFn(In, &Pipeline<T::Service>) -> Result<Out, Err> + Clone,
121{
122    service: T,
123    f: F,
124    r: marker::PhantomData<fn(Req, Cfg) -> (In, Out)>,
125}
126
127impl<T, Req, Cfg, F, In, Out, Err> ApplyFactory<T, Req, Cfg, F, In, Out, Err>
128where
129    T: ServiceFactory<Req, Cfg>,
130    F: AsyncFn(In, &Pipeline<T::Service>) -> Result<Out, Err> + Clone,
131    Err: From<T::Error>,
132{
133    /// Create new `ApplyNewService` new service instance
134    pub(crate) fn new(service: T, f: F) -> Self {
135        Self {
136            f,
137            service,
138            r: marker::PhantomData,
139        }
140    }
141}
142
143impl<T, Req, Cfg, F, In, Out, Err> Clone for ApplyFactory<T, Req, Cfg, F, In, Out, Err>
144where
145    T: ServiceFactory<Req, Cfg> + Clone,
146    F: AsyncFn(In, &Pipeline<T::Service>) -> Result<Out, Err> + Clone,
147    Err: From<T::Error>,
148{
149    fn clone(&self) -> Self {
150        Self {
151            service: self.service.clone(),
152            f: self.f.clone(),
153            r: marker::PhantomData,
154        }
155    }
156}
157
158impl<T, Req, Cfg, F, In, Out, Err> fmt::Debug for ApplyFactory<T, Req, Cfg, F, In, Out, Err>
159where
160    T: ServiceFactory<Req, Cfg> + fmt::Debug,
161    F: AsyncFn(In, &Pipeline<T::Service>) -> Result<Out, Err> + Clone,
162    Err: From<T::Error>,
163{
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        f.debug_struct("ApplyFactory")
166            .field("factory", &self.service)
167            .field("map", &std::any::type_name::<F>())
168            .finish()
169    }
170}
171
172impl<T, Req, Cfg, F, In, Out, Err> ServiceFactory<In, Cfg>
173    for ApplyFactory<T, Req, Cfg, F, In, Out, Err>
174where
175    T: ServiceFactory<Req, Cfg>,
176    F: AsyncFn(In, &Pipeline<T::Service>) -> Result<Out, Err> + Clone,
177    Err: From<T::Error>,
178{
179    type Response = Out;
180    type Error = Err;
181
182    type Service = Apply<T::Service, Req, F, In, Out, Err>;
183    type InitError = T::InitError;
184
185    #[inline]
186    async fn create(&self, cfg: Cfg) -> Result<Self::Service, Self::InitError> {
187        self.service.create(cfg).await.map(|svc| Apply {
188            service: svc.into(),
189            f: self.f.clone(),
190            r: marker::PhantomData,
191        })
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use ntex::util::lazy;
198    use std::{cell::Cell, rc::Rc, task::Context};
199
200    use super::*;
201    use crate::{chain, chain_factory, fn_factory};
202
203    #[derive(Debug, Default, Clone)]
204    struct Srv(Rc<Cell<usize>>);
205
206    impl Service<()> for Srv {
207        type Response = ();
208        type Error = ();
209
210        async fn call(&self, _r: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
211            Ok(())
212        }
213
214        fn poll(&self, _: &mut Context<'_>) -> Result<(), Self::Error> {
215            self.0.set(self.0.get() + 1);
216            Ok(())
217        }
218
219        async fn shutdown(&self) {
220            self.0.set(self.0.get() + 1);
221        }
222    }
223
224    #[derive(Debug, PartialEq, Eq)]
225    struct Err;
226
227    impl From<()> for Err {
228        fn from(_e: ()) -> Self {
229            Err
230        }
231    }
232
233    #[ntex::test]
234    async fn test_call() {
235        let cnt_sht = Rc::new(Cell::new(0));
236        let srv = chain(
237            apply_fn(Srv(cnt_sht.clone()), async move |req: &'static str, svc| {
238                svc.call(()).await.unwrap();
239                Ok((req, ()))
240            })
241            .clone(),
242        )
243        .into_pipeline();
244
245        assert_eq!(srv.ready().await, Ok::<_, Err>(()));
246
247        lazy(|cx| srv.poll(cx)).await.unwrap();
248        assert_eq!(cnt_sht.get(), 1);
249
250        srv.shutdown().await;
251        assert_eq!(cnt_sht.get(), 2);
252
253        let res = srv.call("srv").await;
254        assert!(res.is_ok());
255        assert_eq!(res.unwrap(), ("srv", ()));
256    }
257
258    #[ntex::test]
259    async fn test_call_chain() {
260        let cnt_sht = Rc::new(Cell::new(0));
261        let srv = chain(Srv(cnt_sht.clone()))
262            .apply_fn(async move |req: &'static str, svc| {
263                svc.call(()).await.unwrap();
264                Ok((req, ()))
265            })
266            .clone()
267            .into_pipeline();
268
269        assert_eq!(srv.ready().await, Ok::<_, Err>(()));
270
271        srv.shutdown().await;
272        assert_eq!(cnt_sht.get(), 1);
273
274        let res = srv.call("srv").await;
275        assert!(res.is_ok());
276        assert_eq!(res.unwrap(), ("srv", ()));
277        let _ = format!("{srv:?}");
278    }
279
280    #[ntex::test]
281    async fn test_create() {
282        let new_srv = chain_factory(
283            apply_fn_factory(
284                fn_factory(|| async { Ok::<_, ()>(Srv::default()) }),
285                async move |req: &'static str, srv| {
286                    srv.call(()).await.unwrap();
287                    Ok((req, ()))
288                },
289            )
290            .clone(),
291        );
292
293        let srv = new_srv.pipeline(&()).await.unwrap();
294
295        assert_eq!(srv.ready().await, Ok::<_, Err>(()));
296
297        let res = srv.call("srv").await;
298        assert!(res.is_ok());
299        assert_eq!(res.unwrap(), ("srv", ()));
300        let _ = format!("{new_srv:?}");
301
302        assert!(Err == Err::from(()));
303    }
304
305    #[ntex::test]
306    async fn test_create_chain() {
307        let new_srv = chain_factory(fn_factory(|| async { Ok::<_, ()>(Srv::default()) }))
308            .apply_fn(async move |req: &'static str, srv| {
309                srv.call(()).await.unwrap();
310                Ok((req, ()))
311            })
312            .clone();
313
314        let srv = new_srv.pipeline(&()).await.unwrap();
315
316        assert_eq!(srv.ready().await, Ok::<_, Err>(()));
317
318        let res = srv.call("srv").await;
319        assert!(res.is_ok());
320        assert_eq!(res.unwrap(), ("srv", ()));
321        let _ = format!("{new_srv:?}");
322    }
323}