ntex_service/
middleware.rs

1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use crate::{IntoServiceFactory, Service, ServiceFactory};
4
5/// Apply middleware to a service.
6pub fn apply<T, S, R, C, U>(t: T, factory: U) -> ApplyMiddleware<T, S, C>
7where
8    S: ServiceFactory<R, C>,
9    T: Middleware<S::Service, C>,
10    U: IntoServiceFactory<S, R, C>,
11{
12    ApplyMiddleware::new(t, factory.into_factory())
13}
14
15/// The `Middleware` trait defines the interface of a service factory that wraps inner service
16/// during construction.
17///
18/// Middleware wraps inner service and runs during
19/// inbound and/or outbound processing in the request/response lifecycle.
20/// It may modify request and/or response.
21///
22/// For example, timeout middleware:
23///
24/// ```rust
25/// use ntex_service::{Service, ServiceCtx};
26/// use ntex_util::{time::sleep, future::Either, future::select};
27///
28/// pub struct Timeout<S> {
29///     service: S,
30///     timeout: std::time::Duration,
31/// }
32///
33/// pub enum TimeoutError<E> {
34///    Service(E),
35///    Timeout,
36/// }
37///
38/// impl<S, R> Service<R> for Timeout<S>
39/// where
40///     S: Service<R>,
41/// {
42///     type Response = S::Response;
43///     type Error = TimeoutError<S::Error>;
44///
45///     async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
46///         ctx.ready(&self.service).await.map_err(TimeoutError::Service)
47///     }
48///
49///     async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
50///         match select(sleep(self.timeout), ctx.call(&self.service, req)).await {
51///             Either::Left(_) => Err(TimeoutError::Timeout),
52///             Either::Right(res) => res.map_err(TimeoutError::Service),
53///         }
54///     }
55/// }
56/// ```
57///
58/// Timeout service in above example is decoupled from underlying service implementation
59/// and could be applied to any service.
60///
61/// The `Middleware` trait defines the interface of a middleware factory, defining how to
62/// construct a middleware Service. A Service that is constructed by the factory takes
63/// the Service that follows it during execution as a parameter, assuming
64/// ownership of the next Service.
65///
66/// Factory for `Timeout` middleware from the above example could look like this:
67///
68/// ```rust,ignore
69/// pub struct TimeoutMiddleware {
70///     timeout: std::time::Duration,
71/// }
72///
73/// impl<S> Middleware<S> for TimeoutMiddleware
74/// {
75///     type Service = Timeout<S>;
76///
77///     fn create(&self, service: S) -> Self::Service {
78///         Timeout {
79///             service,
80///             timeout: self.timeout,
81///         }
82///     }
83/// }
84/// ```
85pub trait Middleware<S, Cfg = ()> {
86    /// The middleware `Service` value created by this factory
87    type Service;
88
89    /// Creates and returns a new middleware Service
90    fn create(&self, service: S, cfg: Cfg) -> Self::Service;
91}
92
93impl<T, S, C> Middleware<S, C> for Rc<T>
94where
95    T: Middleware<S, C>,
96{
97    type Service = T::Service;
98
99    fn create(&self, service: S, cfg: C) -> T::Service {
100        self.as_ref().create(service, cfg)
101    }
102}
103
104/// `Apply` middleware to a service factory.
105pub struct ApplyMiddleware<T, S, C>(Rc<(T, S)>, PhantomData<C>);
106
107impl<T, S, C> ApplyMiddleware<T, S, C> {
108    /// Create new `ApplyMiddleware` service factory instance
109    pub(crate) fn new(mw: T, svc: S) -> Self {
110        Self(Rc::new((mw, svc)), PhantomData)
111    }
112}
113
114impl<T, S, C> Clone for ApplyMiddleware<T, S, C> {
115    fn clone(&self) -> Self {
116        Self(self.0.clone(), PhantomData)
117    }
118}
119
120impl<T, S, C> fmt::Debug for ApplyMiddleware<T, S, C>
121where
122    T: fmt::Debug,
123    S: fmt::Debug,
124{
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        f.debug_struct("ApplyMiddleware")
127            .field("service", &self.0.1)
128            .field("middleware", &self.0.0)
129            .finish()
130    }
131}
132
133impl<T, S, R, C> ServiceFactory<R, C> for ApplyMiddleware<T, S, C>
134where
135    S: ServiceFactory<R, C>,
136    T: Middleware<S::Service, C>,
137    T::Service: Service<R>,
138    C: Clone,
139{
140    type Response = <T::Service as Service<R>>::Response;
141    type Error = <T::Service as Service<R>>::Error;
142
143    type Service = T::Service;
144    type InitError = S::InitError;
145
146    #[inline]
147    async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
148        Ok(self.0.0.create(self.0.1.create(cfg.clone()).await?, cfg))
149    }
150}
151
152/// Identity is a middleware.
153///
154/// It returns service without modifications.
155#[derive(Debug, Clone, Copy)]
156pub struct Identity;
157
158impl<S, Cfg> Middleware<S, Cfg> for Identity {
159    type Service = S;
160
161    #[inline]
162    fn create(&self, service: S, _: Cfg) -> Self::Service {
163        service
164    }
165}
166
167/// Stack of middlewares.
168#[derive(Debug, Clone)]
169pub struct Stack<Inner, Outer> {
170    inner: Inner,
171    outer: Outer,
172}
173
174impl<Inner, Outer> Stack<Inner, Outer> {
175    pub fn new(inner: Inner, outer: Outer) -> Self {
176        Stack { inner, outer }
177    }
178}
179
180impl<S, Inner, Outer, C> Middleware<S, C> for Stack<Inner, Outer>
181where
182    Inner: Middleware<S, C>,
183    Outer: Middleware<Inner::Service, C>,
184    C: Clone,
185{
186    type Service = Outer::Service;
187
188    fn create(&self, service: S, cfg: C) -> Self::Service {
189        self.outer
190            .create(self.inner.create(service, cfg.clone()), cfg)
191    }
192}
193
194#[cfg(test)]
195#[allow(clippy::redundant_clone)]
196mod tests {
197    use std::{cell::Cell, rc::Rc};
198
199    use super::*;
200    use crate::{Pipeline, ServiceCtx, fn_service};
201
202    #[derive(Debug, Clone)]
203    struct Mw<R>(PhantomData<R>, Rc<Cell<usize>>);
204
205    impl<S, R, C> Middleware<S, C> for Mw<R> {
206        type Service = Srv<S, R>;
207
208        fn create(&self, service: S, _: C) -> Self::Service {
209            self.1.set(self.1.get() + 1);
210            Srv(service, PhantomData, self.1.clone())
211        }
212    }
213
214    #[derive(Debug, Clone)]
215    struct Srv<S, R>(S, PhantomData<R>, Rc<Cell<usize>>);
216
217    impl<S: Service<R>, R> Service<R> for Srv<S, R> {
218        type Response = S::Response;
219        type Error = S::Error;
220
221        async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
222            ctx.ready(&self.0).await
223        }
224
225        async fn call(
226            &self,
227            req: R,
228            ctx: ServiceCtx<'_, Self>,
229        ) -> Result<S::Response, S::Error> {
230            ctx.call(&self.0, req).await
231        }
232
233        async fn shutdown(&self) {
234            self.2.set(self.2.get() + 1);
235        }
236    }
237
238    #[ntex::test]
239    async fn middleware() {
240        let cnt_sht = Rc::new(Cell::new(0));
241        let factory = apply(
242            Rc::new(Mw(PhantomData, cnt_sht.clone()).clone()),
243            fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
244        )
245        .clone();
246
247        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
248        let res = srv.call(10).await;
249        assert!(res.is_ok());
250        assert_eq!(res.unwrap(), 20);
251        let _ = format!("{factory:?} {srv:?}");
252
253        assert_eq!(srv.ready().await, Ok(()));
254        srv.shutdown().await;
255        assert_eq!(cnt_sht.get(), 2);
256
257        let factory =
258            crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
259                .apply(Rc::new(Mw(PhantomData, Rc::new(Cell::new(0))).clone()))
260                .clone();
261
262        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
263        let res = srv.call(10).await;
264        assert!(res.is_ok());
265        assert_eq!(res.unwrap(), 20);
266        let _ = format!("{factory:?} {srv:?}");
267
268        assert_eq!(srv.ready().await, Ok(()));
269    }
270
271    #[ntex::test]
272    async fn middleware_chain() {
273        let cnt_sht = Rc::new(Cell::new(0));
274        let factory =
275            crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
276                .apply(Mw(PhantomData, cnt_sht.clone()).clone());
277
278        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
279        let res = srv.call(10).await;
280        assert!(res.is_ok());
281        assert_eq!(res.unwrap(), 20);
282        let _ = format!("{factory:?} {srv:?}");
283
284        assert_eq!(srv.ready().await, Ok(()));
285        srv.shutdown().await;
286        assert_eq!(cnt_sht.get(), 2);
287    }
288
289    #[ntex::test]
290    async fn stack() {
291        let cnt_sht = Rc::new(Cell::new(0));
292        let mw = Stack::new(Identity, Mw(PhantomData, cnt_sht.clone()));
293        let _ = format!("{mw:?}");
294
295        let pl = Pipeline::new(Middleware::create(
296            &mw,
297            fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
298            (),
299        ));
300        let res = pl.call(10).await;
301        assert!(res.is_ok());
302        assert_eq!(res.unwrap(), 20);
303        assert_eq!(pl.ready().await, Ok(()));
304        pl.shutdown().await;
305        assert_eq!(cnt_sht.get(), 2);
306    }
307}