Skip to main content

ntex_service/
apply.rs

1#![allow(clippy::type_complexity)]
2use std::{fmt, marker};
3
4use crate::ctx::WaitersRef;
5use crate::dev::{ServiceChain, ServiceChainFactory};
6use crate::{IntoService, IntoServiceFactory, Service, ServiceCtx, ServiceFactory};
7
8/// Apply transform function to a service.
9pub fn apply_fn<T, Req, F, In, Out, Err, U>(
10    service: U,
11    f: F,
12) -> ServiceChain<Apply<T, Req, F, In, Out, Err>, In>
13where
14    T: Service<Req>,
15    F: AsyncFn(In, &ApplyCtx<'_, T>) -> Result<Out, Err>,
16    U: IntoService<T, Req>,
17    Err: From<T::Error>,
18{
19    crate::chain(Apply::new(service.into_service(), f))
20}
21
22/// Service factory that produces `apply_fn` service.
23pub fn apply_fn_factory<T, Req, Cfg, F, In, Out, Err, U>(
24    service: U,
25    f: F,
26) -> ServiceChainFactory<ApplyFactory<T, Req, Cfg, F, In, Out, Err>, In, Cfg>
27where
28    T: ServiceFactory<Req, Cfg>,
29    F: AsyncFn(In, &ApplyCtx<'_, T::Service>) -> Result<Out, Err> + Clone,
30    U: IntoServiceFactory<T, Req, Cfg>,
31    Err: From<T::Error>,
32{
33    crate::chain_factory(ApplyFactory::new(service.into_factory(), f))
34}
35
36#[derive(Debug)]
37pub struct ApplyCtx<'a, S> {
38    idx: u32,
39    waiters: &'a WaitersRef,
40    service: &'a S,
41}
42
43impl<'a, S> ApplyCtx<'a, S> {
44    #[inline]
45    /// Wait for service readiness and then call service.
46    pub async fn call<R>(&self, req: R) -> Result<S::Response, S::Error>
47    where
48        S: Service<R>,
49        R: 'a,
50    {
51        let ctx = ServiceCtx::new(self.idx, self.waiters);
52
53        self.service.ready(ctx).await?;
54        self.service.call(req, ctx).await
55    }
56}
57
58/// `Apply` service combinator
59pub struct Apply<T, Req, F, In, Out, Err> {
60    service: T,
61    f: F,
62    r: marker::PhantomData<fn(Req) -> (In, Out, Err)>,
63}
64
65impl<T, Req, F, In, Out, Err> Apply<T, Req, F, In, Out, Err>
66where
67    T: Service<Req>,
68    F: AsyncFn(In, &ApplyCtx<'_, T>) -> Result<Out, Err>,
69    Err: From<T::Error>,
70{
71    pub(crate) fn new(service: T, f: F) -> Self {
72        Apply {
73            f,
74            service,
75            r: marker::PhantomData,
76        }
77    }
78}
79
80impl<T, Req, F, In, Out, Err> Clone for Apply<T, Req, F, In, Out, Err>
81where
82    T: Clone,
83    F: Clone,
84{
85    fn clone(&self) -> Self {
86        Apply {
87            service: self.service.clone(),
88            f: self.f.clone(),
89            r: marker::PhantomData,
90        }
91    }
92}
93
94impl<T, Req, F, In, Out, Err> fmt::Debug for Apply<T, Req, F, In, Out, Err>
95where
96    T: fmt::Debug,
97{
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        f.debug_struct("Apply")
100            .field("service", &self.service)
101            .field("map", &std::any::type_name::<F>())
102            .finish()
103    }
104}
105
106impl<T, Req, F, In, Out, Err> Service<In> for Apply<T, Req, F, In, Out, Err>
107where
108    T: Service<Req>,
109    F: AsyncFn(In, &ApplyCtx<'_, T>) -> Result<Out, Err>,
110    Err: From<T::Error>,
111{
112    type Response = Out;
113    type Error = Err;
114
115    #[inline]
116    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Err> {
117        ctx.ready(&self.service).await.map_err(From::from)
118    }
119
120    #[inline]
121    async fn call(
122        &self,
123        req: In,
124        ctx: ServiceCtx<'_, Self>,
125    ) -> Result<Self::Response, Self::Error> {
126        let (idx, waiters) = ctx.inner();
127
128        let ctx = ApplyCtx {
129            idx,
130            waiters,
131            service: &self.service,
132        };
133        (self.f)(req, &ctx).await
134    }
135
136    crate::forward_poll!(service);
137    crate::forward_shutdown!(service);
138}
139
140/// `apply()` service factory
141pub struct ApplyFactory<T, Req, Cfg, F, In, Out, Err>
142where
143    T: ServiceFactory<Req, Cfg>,
144    F: AsyncFn(In, &ApplyCtx<'_, T::Service>) -> Result<Out, Err> + Clone,
145{
146    service: T,
147    f: F,
148    r: marker::PhantomData<fn(Req, Cfg) -> (In, Out)>,
149}
150
151impl<T, Req, Cfg, F, In, Out, Err> ApplyFactory<T, Req, Cfg, F, In, Out, Err>
152where
153    T: ServiceFactory<Req, Cfg>,
154    F: AsyncFn(In, &ApplyCtx<'_, T::Service>) -> Result<Out, Err> + Clone,
155    Err: From<T::Error>,
156{
157    /// Create new `ApplyNewService` new service instance
158    pub(crate) fn new(service: T, f: F) -> Self {
159        Self {
160            f,
161            service,
162            r: marker::PhantomData,
163        }
164    }
165}
166
167impl<T, Req, Cfg, F, In, Out, Err> Clone for ApplyFactory<T, Req, Cfg, F, In, Out, Err>
168where
169    T: ServiceFactory<Req, Cfg> + Clone,
170    F: AsyncFn(In, &ApplyCtx<'_, T::Service>) -> Result<Out, Err> + Clone,
171    Err: From<T::Error>,
172{
173    fn clone(&self) -> Self {
174        Self {
175            service: self.service.clone(),
176            f: self.f.clone(),
177            r: marker::PhantomData,
178        }
179    }
180}
181
182impl<T, Req, Cfg, F, In, Out, Err> fmt::Debug for ApplyFactory<T, Req, Cfg, F, In, Out, Err>
183where
184    T: ServiceFactory<Req, Cfg> + fmt::Debug,
185    F: AsyncFn(In, &ApplyCtx<'_, T::Service>) -> Result<Out, Err> + Clone,
186    Err: From<T::Error>,
187{
188    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189        f.debug_struct("ApplyFactory")
190            .field("factory", &self.service)
191            .field("map", &std::any::type_name::<F>())
192            .finish()
193    }
194}
195
196impl<T, Req, Cfg, F, In, Out, Err> ServiceFactory<In, Cfg>
197    for ApplyFactory<T, Req, Cfg, F, In, Out, Err>
198where
199    T: ServiceFactory<Req, Cfg>,
200    F: AsyncFn(In, &ApplyCtx<'_, T::Service>) -> Result<Out, Err> + Clone,
201    Err: From<T::Error>,
202{
203    type Response = Out;
204    type Error = Err;
205
206    type Service = Apply<T::Service, Req, F, In, Out, Err>;
207    type InitError = T::InitError;
208
209    #[inline]
210    async fn create(&self, cfg: Cfg) -> Result<Self::Service, Self::InitError> {
211        self.service.create(cfg).await.map(|service| Apply {
212            service,
213            f: self.f.clone(),
214            r: marker::PhantomData,
215        })
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use ntex::util::lazy;
222    use std::{cell::Cell, rc::Rc, task::Context};
223
224    use super::*;
225    use crate::{chain, chain_factory, fn_factory};
226
227    #[derive(Debug, Default, Clone)]
228    struct Srv(Rc<Cell<usize>>);
229
230    impl Service<()> for Srv {
231        type Response = ();
232        type Error = ();
233
234        async fn call(&self, _r: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
235            Ok(())
236        }
237
238        fn poll(&self, _: &mut Context<'_>) -> Result<(), Self::Error> {
239            self.0.set(self.0.get() + 1);
240            Ok(())
241        }
242
243        async fn shutdown(&self) {
244            self.0.set(self.0.get() + 1);
245        }
246    }
247
248    #[derive(Debug, PartialEq, Eq)]
249    struct Err;
250
251    impl From<()> for Err {
252        fn from(_e: ()) -> Self {
253            Err
254        }
255    }
256
257    #[ntex::test]
258    async fn test_call() {
259        let cnt_sht = Rc::new(Cell::new(0));
260        let srv = chain(
261            apply_fn(Srv(cnt_sht.clone()), async move |req: &'static str, svc| {
262                svc.call(()).await.unwrap();
263                Ok((req, ()))
264            })
265            .clone(),
266        )
267        .into_pipeline();
268
269        assert_eq!(srv.ready().await, Ok::<_, Err>(()));
270
271        lazy(|cx| srv.poll(cx)).await.unwrap();
272        assert_eq!(cnt_sht.get(), 1);
273
274        srv.shutdown().await;
275        assert_eq!(cnt_sht.get(), 2);
276
277        let res = srv.call("srv").await;
278        assert!(res.is_ok());
279        assert_eq!(res.unwrap(), ("srv", ()));
280    }
281
282    #[ntex::test]
283    async fn test_call_chain() {
284        let cnt_sht = Rc::new(Cell::new(0));
285        let srv = chain(Srv(cnt_sht.clone()))
286            .apply_fn(async move |req: &'static str, svc| {
287                svc.call(()).await.unwrap();
288                Ok((req, ()))
289            })
290            .clone()
291            .into_pipeline();
292
293        assert_eq!(srv.ready().await, Ok::<_, Err>(()));
294
295        srv.shutdown().await;
296        assert_eq!(cnt_sht.get(), 1);
297
298        let res = srv.call("srv").await;
299        assert!(res.is_ok());
300        assert_eq!(res.unwrap(), ("srv", ()));
301        let _ = format!("{srv:?}");
302    }
303
304    #[ntex::test]
305    async fn test_create() {
306        let new_srv = chain_factory(
307            apply_fn_factory(
308                fn_factory(|| async { Ok::<_, ()>(Srv::default()) }),
309                async move |req: &'static str, srv| {
310                    srv.call(()).await.unwrap();
311                    Ok((req, ()))
312                },
313            )
314            .clone(),
315        );
316
317        let srv = new_srv.pipeline(&()).await.unwrap();
318
319        assert_eq!(srv.ready().await, Ok::<_, Err>(()));
320
321        let res = srv.call("srv").await;
322        assert!(res.is_ok());
323        assert_eq!(res.unwrap(), ("srv", ()));
324        let _ = format!("{new_srv:?}");
325
326        assert!(Err == Err::from(()));
327    }
328
329    #[ntex::test]
330    async fn test_create_chain() {
331        let new_srv = chain_factory(fn_factory(|| async { Ok::<_, ()>(Srv::default()) }))
332            .apply_fn(async move |req: &'static str, srv| {
333                srv.call(()).await.unwrap();
334                Ok((req, ()))
335            })
336            .clone();
337
338        let srv = new_srv.pipeline(&()).await.unwrap();
339
340        assert_eq!(srv.ready().await, Ok::<_, Err>(()));
341
342        let res = srv.call("srv").await;
343        assert!(res.is_ok());
344        assert_eq!(res.unwrap(), ("srv", ()));
345        let _ = format!("{new_srv:?}");
346    }
347}