ntex_service/
middleware.rs

1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use crate::{IntoServiceFactory, Service, ServiceFactory};
4
5/// Apply middleware to a service.
6pub fn apply<T, S, R, C, U>(t: T, factory: U) -> ApplyMiddleware<T, S, C>
7where
8    S: ServiceFactory<R, C>,
9    T: Middleware<S::Service>,
10    U: IntoServiceFactory<S, R, C>,
11{
12    ApplyMiddleware::new(t, factory.into_factory())
13}
14
15/// Apply middleware to a service.
16pub fn apply2<T, S, R, C, U>(t: T, factory: U) -> ApplyMiddleware2<T, S, C>
17where
18    S: ServiceFactory<R, C>,
19    T: Middleware2<S::Service, C>,
20    U: IntoServiceFactory<S, R, C>,
21{
22    ApplyMiddleware2::new(t, factory.into_factory())
23}
24
25/// The `Middleware` trait defines the interface of a service factory that wraps inner service
26/// during construction.
27///
28/// Middleware wraps inner service and runs during
29/// inbound and/or outbound processing in the request/response lifecycle.
30/// It may modify request and/or response.
31///
32/// For example, timeout middleware:
33///
34/// ```rust
35/// use ntex_service::{Service, ServiceCtx};
36/// use ntex_util::{time::sleep, future::Either, future::select};
37///
38/// pub struct Timeout<S> {
39///     service: S,
40///     timeout: std::time::Duration,
41/// }
42///
43/// pub enum TimeoutError<E> {
44///    Service(E),
45///    Timeout,
46/// }
47///
48/// impl<S, R> Service<R> for Timeout<S>
49/// where
50///     S: Service<R>,
51/// {
52///     type Response = S::Response;
53///     type Error = TimeoutError<S::Error>;
54///
55///     async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
56///         ctx.ready(&self.service).await.map_err(TimeoutError::Service)
57///     }
58///
59///     async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
60///         match select(sleep(self.timeout), ctx.call(&self.service, req)).await {
61///             Either::Left(_) => Err(TimeoutError::Timeout),
62///             Either::Right(res) => res.map_err(TimeoutError::Service),
63///         }
64///     }
65/// }
66/// ```
67///
68/// Timeout service in above example is decoupled from underlying service implementation
69/// and could be applied to any service.
70///
71/// The `Middleware` trait defines the interface of a middleware factory, defining how to
72/// construct a middleware Service. A Service that is constructed by the factory takes
73/// the Service that follows it during execution as a parameter, assuming
74/// ownership of the next Service.
75///
76/// Factory for `Timeout` middleware from the above example could look like this:
77///
78/// ```rust,ignore
79/// pub struct TimeoutMiddleware {
80///     timeout: std::time::Duration,
81/// }
82///
83/// impl<S> Middleware<S> for TimeoutMiddleware
84/// {
85///     type Service = Timeout<S>;
86///
87///     fn create(&self, service: S) -> Self::Service {
88///         Timeout {
89///             service,
90///             timeout: self.timeout,
91///         }
92///     }
93/// }
94/// ```
95pub trait Middleware<S> {
96    /// The middleware `Service` value created by this factory
97    type Service;
98
99    /// Creates and returns a new middleware Service
100    fn create(&self, service: S) -> Self::Service;
101}
102
103/// The `Middleware` trait defines the interface of a service factory that wraps inner service
104/// during construction.
105///
106/// Middleware wraps inner service and runs during
107/// inbound and/or outbound processing in the request/response lifecycle.
108/// It may modify request and/or response.
109///
110/// For example, timeout middleware:
111///
112/// ```rust
113/// use ntex_service::{Service, ServiceCtx};
114/// use ntex_util::{time::sleep, future::Either, future::select};
115///
116/// pub struct Timeout<S> {
117///     service: S,
118///     timeout: std::time::Duration,
119/// }
120///
121/// pub enum TimeoutError<E> {
122///    Service(E),
123///    Timeout,
124/// }
125///
126/// impl<S, R> Service<R> for Timeout<S>
127/// where
128///     S: Service<R>,
129/// {
130///     type Response = S::Response;
131///     type Error = TimeoutError<S::Error>;
132///
133///     async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
134///         ctx.ready(&self.service).await.map_err(TimeoutError::Service)
135///     }
136///
137///     async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
138///         match select(sleep(self.timeout), ctx.call(&self.service, req)).await {
139///             Either::Left(_) => Err(TimeoutError::Timeout),
140///             Either::Right(res) => res.map_err(TimeoutError::Service),
141///         }
142///     }
143/// }
144/// ```
145///
146/// Timeout service in above example is decoupled from underlying service implementation
147/// and could be applied to any service.
148///
149/// The `Middleware` trait defines the interface of a middleware factory, defining how to
150/// construct a middleware Service. A Service that is constructed by the factory takes
151/// the Service that follows it during execution as a parameter, assuming
152/// ownership of the next Service.
153///
154/// Factory for `Timeout` middleware from the above example could look like this:
155///
156/// ```rust,ignore
157/// pub struct TimeoutMiddleware {
158///     timeout: std::time::Duration,
159/// }
160///
161/// impl<S> Middleware<S> for TimeoutMiddleware
162/// {
163///     type Service = Timeout<S>;
164///
165///     fn create(&self, service: S) -> Self::Service {
166///         Timeout {
167///             service,
168///             timeout: self.timeout,
169///         }
170///     }
171/// }
172/// ```
173pub trait Middleware2<S, Cfg = ()> {
174    /// The middleware `Service` value created by this factory
175    type Service;
176
177    /// Creates and returns a new middleware Service
178    fn create(&self, service: S, cfg: Cfg) -> Self::Service;
179}
180
181impl<T, S> Middleware<S> for Rc<T>
182where
183    T: Middleware<S>,
184{
185    type Service = T::Service;
186
187    fn create(&self, service: S) -> T::Service {
188        self.as_ref().create(service)
189    }
190}
191
192impl<T, S, C> Middleware2<S, C> for Rc<T>
193where
194    T: Middleware2<S, C>,
195{
196    type Service = T::Service;
197
198    fn create(&self, service: S, cfg: C) -> T::Service {
199        self.as_ref().create(service, cfg)
200    }
201}
202
203/// `Apply` middleware to a service factory.
204pub struct ApplyMiddleware<T, S, C>(Rc<(T, S)>, PhantomData<C>);
205
206impl<T, S, C> ApplyMiddleware<T, S, C> {
207    /// Create new `ApplyMiddleware` service factory instance
208    pub(crate) fn new(mw: T, svc: S) -> Self {
209        Self(Rc::new((mw, svc)), PhantomData)
210    }
211}
212
213impl<T, S, C> Clone for ApplyMiddleware<T, S, C> {
214    fn clone(&self) -> Self {
215        Self(self.0.clone(), PhantomData)
216    }
217}
218
219impl<T, S, C> fmt::Debug for ApplyMiddleware<T, S, C>
220where
221    T: fmt::Debug,
222    S: fmt::Debug,
223{
224    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225        f.debug_struct("ApplyMiddleware")
226            .field("service", &self.0.1)
227            .field("middleware", &self.0.0)
228            .finish()
229    }
230}
231
232impl<T, S, R, C> ServiceFactory<R, C> for ApplyMiddleware<T, S, C>
233where
234    S: ServiceFactory<R, C>,
235    T: Middleware<S::Service>,
236    T::Service: Service<R>,
237{
238    type Response = <T::Service as Service<R>>::Response;
239    type Error = <T::Service as Service<R>>::Error;
240
241    type Service = T::Service;
242    type InitError = S::InitError;
243
244    #[inline]
245    async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
246        Ok(self.0.0.create(self.0.1.create(cfg).await?))
247    }
248}
249
250/// `Apply` middleware to a service factory.
251pub struct ApplyMiddleware2<T, S, C>(Rc<(T, S)>, PhantomData<C>);
252
253impl<T, S, C> ApplyMiddleware2<T, S, C> {
254    /// Create new `ApplyMiddleware` service factory instance
255    pub(crate) fn new(mw: T, svc: S) -> Self {
256        Self(Rc::new((mw, svc)), PhantomData)
257    }
258}
259
260impl<T, S, C> Clone for ApplyMiddleware2<T, S, C> {
261    fn clone(&self) -> Self {
262        Self(self.0.clone(), PhantomData)
263    }
264}
265
266impl<T, S, C> fmt::Debug for ApplyMiddleware2<T, S, C>
267where
268    T: fmt::Debug,
269    S: fmt::Debug,
270{
271    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272        f.debug_struct("ApplyMiddleware")
273            .field("service", &self.0.1)
274            .field("middleware", &self.0.0)
275            .finish()
276    }
277}
278
279impl<T, S, R, C> ServiceFactory<R, C> for ApplyMiddleware2<T, S, C>
280where
281    S: ServiceFactory<R, C>,
282    T: Middleware2<S::Service, C>,
283    T::Service: Service<R>,
284    C: Clone,
285{
286    type Response = <T::Service as Service<R>>::Response;
287    type Error = <T::Service as Service<R>>::Error;
288
289    type Service = T::Service;
290    type InitError = S::InitError;
291
292    #[inline]
293    async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
294        Ok(self.0.0.create(self.0.1.create(cfg.clone()).await?, cfg))
295    }
296}
297
298/// Identity is a middleware.
299///
300/// It returns service without modifications.
301#[derive(Debug, Clone, Copy)]
302pub struct Identity;
303
304impl<S> Middleware<S> for Identity {
305    type Service = S;
306
307    #[inline]
308    fn create(&self, service: S) -> Self::Service {
309        service
310    }
311}
312
313impl<S, Cfg> Middleware2<S, Cfg> for Identity {
314    type Service = S;
315
316    #[inline]
317    fn create(&self, service: S, _: Cfg) -> Self::Service {
318        service
319    }
320}
321
322/// Stack of middlewares.
323#[derive(Debug, Clone)]
324pub struct Stack<Inner, Outer> {
325    inner: Inner,
326    outer: Outer,
327}
328
329impl<Inner, Outer> Stack<Inner, Outer> {
330    pub fn new(inner: Inner, outer: Outer) -> Self {
331        Stack { inner, outer }
332    }
333}
334
335impl<S, Inner, Outer> Middleware<S> for Stack<Inner, Outer>
336where
337    Inner: Middleware<S>,
338    Outer: Middleware<Inner::Service>,
339{
340    type Service = Outer::Service;
341
342    fn create(&self, service: S) -> Self::Service {
343        self.outer.create(self.inner.create(service))
344    }
345}
346
347impl<S, Inner, Outer, C> Middleware2<S, C> for Stack<Inner, Outer>
348where
349    Inner: Middleware2<S, C>,
350    Outer: Middleware2<Inner::Service, C>,
351    C: Clone,
352{
353    type Service = Outer::Service;
354
355    fn create(&self, service: S, cfg: C) -> Self::Service {
356        self.outer
357            .create(self.inner.create(service, cfg.clone()), cfg)
358    }
359}
360
361#[cfg(test)]
362#[allow(clippy::redundant_clone)]
363mod tests {
364    use std::{cell::Cell, rc::Rc};
365
366    use super::*;
367    use crate::{Pipeline, ServiceCtx, fn_service};
368
369    #[derive(Debug, Clone)]
370    struct Mw<R>(PhantomData<R>, Rc<Cell<usize>>);
371
372    impl<S, R> Middleware<S> for Mw<R> {
373        type Service = Srv<S, R>;
374
375        fn create(&self, service: S) -> Self::Service {
376            Srv(service, PhantomData, self.1.clone())
377        }
378    }
379
380    impl<S, R, C> Middleware2<S, C> for Mw<R> {
381        type Service = Srv<S, R>;
382
383        fn create(&self, service: S, _: C) -> Self::Service {
384            self.1.set(self.1.get() + 1);
385            Srv(service, PhantomData, self.1.clone())
386        }
387    }
388
389    #[derive(Debug, Clone)]
390    struct Srv<S, R>(S, PhantomData<R>, Rc<Cell<usize>>);
391
392    impl<S: Service<R>, R> Service<R> for Srv<S, R> {
393        type Response = S::Response;
394        type Error = S::Error;
395
396        async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
397            ctx.ready(&self.0).await
398        }
399
400        async fn call(
401            &self,
402            req: R,
403            ctx: ServiceCtx<'_, Self>,
404        ) -> Result<S::Response, S::Error> {
405            ctx.call(&self.0, req).await
406        }
407
408        async fn shutdown(&self) {
409            self.2.set(self.2.get() + 1);
410        }
411    }
412
413    #[ntex::test]
414    async fn middleware() {
415        let cnt_sht = Rc::new(Cell::new(0));
416        let factory = apply(
417            Rc::new(Mw(PhantomData, cnt_sht.clone()).clone()),
418            fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
419        )
420        .clone();
421
422        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
423        let res = srv.call(10).await;
424        assert!(res.is_ok());
425        assert_eq!(res.unwrap(), 20);
426        let _ = format!("{factory:?} {srv:?}");
427
428        assert_eq!(srv.ready().await, Ok(()));
429        srv.shutdown().await;
430        assert_eq!(cnt_sht.get(), 1);
431
432        let factory =
433            crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
434                .apply(Rc::new(Mw(PhantomData, Rc::new(Cell::new(0))).clone()))
435                .clone();
436
437        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
438        let res = srv.call(10).await;
439        assert!(res.is_ok());
440        assert_eq!(res.unwrap(), 20);
441        let _ = format!("{factory:?} {srv:?}");
442
443        assert_eq!(srv.ready().await, Ok(()));
444    }
445
446    #[ntex::test]
447    async fn middleware2() {
448        let cnt_sht = Rc::new(Cell::new(0));
449        let factory = apply2(
450            Rc::new(Mw(PhantomData, cnt_sht.clone()).clone()),
451            fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
452        )
453        .clone();
454
455        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
456        let res = srv.call(10).await;
457        assert!(res.is_ok());
458        assert_eq!(res.unwrap(), 20);
459        let _ = format!("{factory:?} {srv:?}");
460
461        assert_eq!(srv.ready().await, Ok(()));
462        srv.shutdown().await;
463        assert_eq!(cnt_sht.get(), 2);
464
465        let factory =
466            crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
467                .apply(Rc::new(Mw(PhantomData, Rc::new(Cell::new(0))).clone()))
468                .clone();
469
470        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
471        let res = srv.call(10).await;
472        assert!(res.is_ok());
473        assert_eq!(res.unwrap(), 20);
474        let _ = format!("{factory:?} {srv:?}");
475
476        assert_eq!(srv.ready().await, Ok(()));
477    }
478
479    #[ntex::test]
480    async fn middleware2_chain() {
481        let cnt_sht = Rc::new(Cell::new(0));
482        let factory =
483            crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
484                .apply2(Mw(PhantomData, cnt_sht.clone()).clone());
485
486        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
487        let res = srv.call(10).await;
488        assert!(res.is_ok());
489        assert_eq!(res.unwrap(), 20);
490        let _ = format!("{factory:?} {srv:?}");
491
492        assert_eq!(srv.ready().await, Ok(()));
493        srv.shutdown().await;
494        assert_eq!(cnt_sht.get(), 2);
495    }
496}