Skip to main content

ntex_service/
middleware.rs

1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use crate::{IntoServiceFactory, Service, ServiceFactory, dev::ServiceChainFactory};
4
5/// Apply middleware to a service.
6pub fn apply<M, S, R, C, U>(
7    mw: M,
8    factory: U,
9) -> ServiceChainFactory<ApplyMiddleware<M, S, C>, R, C>
10where
11    S: ServiceFactory<R, C>,
12    M: Middleware<S::Service, C>,
13    U: IntoServiceFactory<S, R, C>,
14{
15    ServiceChainFactory {
16        factory: ApplyMiddleware::new(mw, factory.into_factory()),
17        _t: PhantomData,
18    }
19}
20
21/// The `Middleware` trait defines the interface for a service factory
22/// that wraps an inner service during construction.
23///
24/// Middleware runs during inbound and/or outbound processing in the
25/// request/response lifecycle, and may modify the request and/or response.
26///
27/// For example, timeout middleware:
28///
29/// ```rust
30/// use ntex_service::{Service, ServiceCtx};
31/// use ntex::{time::sleep, util::Either, util::select};
32///
33/// pub struct Timeout<S> {
34///     service: S,
35///     timeout: std::time::Duration,
36/// }
37///
38/// pub enum TimeoutError<E> {
39///    Service(E),
40///    Timeout,
41/// }
42///
43/// impl<S, R> Service<R> for Timeout<S>
44/// where
45///     S: Service<R>,
46/// {
47///     type Response = S::Response;
48///     type Error = TimeoutError<S::Error>;
49///
50///     async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
51///         ctx.ready(&self.service).await.map_err(TimeoutError::Service)
52///     }
53///
54///     async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
55///         match select(sleep(self.timeout), ctx.call(&self.service, req)).await {
56///             Either::Left(_) => Err(TimeoutError::Timeout),
57///             Either::Right(res) => res.map_err(TimeoutError::Service),
58///         }
59///     }
60/// }
61/// ```
62///
63/// The timeout service in the example above is decoupled from the underlying
64/// service implementation and can be applied to any service.
65///
66/// The `Middleware` trait defines the interface for a middleware factory,
67/// specifying how to construct a middleware `Service`. A service constructed
68/// by the factory takes the following service in the execution chain as a
69/// parameter, assuming ownership of that service.
70///
71/// Factory for `Timeout` middleware from the above example could look like this:
72///
73/// ```rust,ignore
74/// pub struct TimeoutMiddleware {
75///     timeout: std::time::Duration,
76/// }
77///
78/// impl<S> Middleware<S> for TimeoutMiddleware
79/// {
80///     type Service = Timeout<S>;
81///
82///     fn create(&self, service: S) -> Self::Service {
83///         Timeout {
84///             service,
85///             timeout: self.timeout,
86///         }
87///     }
88/// }
89/// ```
90pub trait Middleware<Svc, Cfg = ()> {
91    /// The middleware `Service` value created by this factory
92    type Service;
93
94    /// Creates and returns a new middleware service.
95    fn create(&self, service: Svc, cfg: Cfg) -> Self::Service;
96
97    /// Creates a service factory that instantiates a service and applies
98    /// the current middleware to it.
99    ///
100    /// This is equivalent to `apply(self, factory)`.
101    fn apply<Fac, Req>(
102        self,
103        factory: Fac,
104    ) -> ServiceChainFactory<ApplyMiddleware<Self, Fac, Cfg>, Req, Cfg>
105    where
106        Fac: ServiceFactory<Req, Cfg, Service = Svc>,
107        Cfg: Clone,
108        Self: Sized,
109        Self::Service: Service<Req>,
110    {
111        crate::chain_factory(ApplyMiddleware::new(self, factory))
112    }
113}
114
115impl<M, Svc, Cfg> Middleware<Svc, Cfg> for Rc<M>
116where
117    M: Middleware<Svc, Cfg>,
118{
119    type Service = M::Service;
120
121    fn create(&self, service: Svc, cfg: Cfg) -> M::Service {
122        self.as_ref().create(service, cfg)
123    }
124}
125
126/// `Apply` middleware to a service factory.
127pub struct ApplyMiddleware<M, Fac, Cfg>(Rc<(M, Fac)>, PhantomData<Cfg>);
128
129impl<M, Fac, Cfg> ApplyMiddleware<M, Fac, Cfg> {
130    /// Create new `ApplyMiddleware` service factory instance
131    pub(crate) fn new(mw: M, fac: Fac) -> Self {
132        Self(Rc::new((mw, fac)), PhantomData)
133    }
134}
135
136impl<M, Fac, Cfg> Clone for ApplyMiddleware<M, Fac, Cfg> {
137    fn clone(&self) -> Self {
138        Self(self.0.clone(), PhantomData)
139    }
140}
141
142impl<M, Fac, Cfg> fmt::Debug for ApplyMiddleware<M, Fac, Cfg>
143where
144    M: fmt::Debug,
145    Fac: fmt::Debug,
146{
147    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148        f.debug_struct("ApplyMiddleware")
149            .field("factory", &self.0.1)
150            .field("middleware", &self.0.0)
151            .finish()
152    }
153}
154
155impl<M, Fac, Req, Cfg> ServiceFactory<Req, Cfg> for ApplyMiddleware<M, Fac, Cfg>
156where
157    Fac: ServiceFactory<Req, Cfg>,
158    M: Middleware<Fac::Service, Cfg>,
159    M::Service: Service<Req>,
160    Cfg: Clone,
161{
162    type Response = <M::Service as Service<Req>>::Response;
163    type Error = <M::Service as Service<Req>>::Error;
164
165    type Service = M::Service;
166    type InitError = Fac::InitError;
167
168    #[inline]
169    async fn create(&self, cfg: Cfg) -> Result<Self::Service, Self::InitError> {
170        Ok(self.0.0.create(self.0.1.create(cfg.clone()).await?, cfg))
171    }
172}
173
174/// Identity is a middleware.
175///
176/// It returns service without modifications.
177#[derive(Debug, Clone, Copy)]
178pub struct Identity;
179
180impl<S, Cfg> Middleware<S, Cfg> for Identity {
181    type Service = S;
182
183    #[inline]
184    fn create(&self, service: S, _: Cfg) -> Self::Service {
185        service
186    }
187}
188
189/// Stack of middlewares.
190#[derive(Debug, Clone)]
191pub struct Stack<Inner, Outer> {
192    inner: Inner,
193    outer: Outer,
194}
195
196impl<Inner, Outer> Stack<Inner, Outer> {
197    pub fn new(inner: Inner, outer: Outer) -> Self {
198        Stack { inner, outer }
199    }
200}
201
202impl<S, Inner, Outer, C> Middleware<S, C> for Stack<Inner, Outer>
203where
204    Inner: Middleware<S, C>,
205    Outer: Middleware<Inner::Service, C>,
206    C: Clone,
207{
208    type Service = Outer::Service;
209
210    fn create(&self, service: S, cfg: C) -> Self::Service {
211        self.outer
212            .create(self.inner.create(service, cfg.clone()), cfg)
213    }
214}
215
216#[cfg(test)]
217#[allow(clippy::redundant_clone)]
218mod tests {
219    use std::{cell::Cell, rc::Rc};
220
221    use super::*;
222    use crate::{Pipeline, ServiceCtx, fn_service};
223
224    #[derive(Debug, Clone)]
225    struct Mw<R>(PhantomData<R>, Rc<Cell<usize>>);
226
227    impl<S, R, C> Middleware<S, C> for Mw<R> {
228        type Service = Srv<S, R>;
229
230        fn create(&self, service: S, _: C) -> Self::Service {
231            self.1.set(self.1.get() + 1);
232            Srv(service, PhantomData, self.1.clone())
233        }
234    }
235
236    #[derive(Debug, Clone)]
237    struct Srv<S, R>(S, PhantomData<R>, Rc<Cell<usize>>);
238
239    impl<S: Service<R>, R> Service<R> for Srv<S, R> {
240        type Response = S::Response;
241        type Error = S::Error;
242
243        async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
244            ctx.ready(&self.0).await
245        }
246
247        async fn call(
248            &self,
249            req: R,
250            ctx: ServiceCtx<'_, Self>,
251        ) -> Result<S::Response, S::Error> {
252            ctx.call(&self.0, req).await
253        }
254
255        async fn shutdown(&self) {
256            self.2.set(self.2.get() + 1);
257        }
258    }
259
260    #[ntex::test]
261    async fn middleware() {
262        let cnt_sht = Rc::new(Cell::new(0));
263        let factory = apply(
264            Rc::new(Mw(PhantomData, cnt_sht.clone()).clone()),
265            fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
266        )
267        .clone();
268
269        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
270        let res = srv.call(10).await;
271        assert!(res.is_ok());
272        assert_eq!(res.unwrap(), 20);
273        let _ = format!("{factory:?} {srv:?}");
274
275        assert_eq!(srv.ready().await, Ok(()));
276        srv.shutdown().await;
277        assert_eq!(cnt_sht.get(), 2);
278
279        let factory =
280            crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
281                .apply(Rc::new(Mw(PhantomData, Rc::new(Cell::new(0))).clone()))
282                .clone();
283
284        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
285        let res = srv.call(10).await;
286        assert!(res.is_ok());
287        assert_eq!(res.unwrap(), 20);
288        let _ = format!("{factory:?} {srv:?}");
289
290        assert_eq!(srv.ready().await, Ok(()));
291    }
292
293    #[ntex::test]
294    async fn middleware_apply() {
295        let cnt_sht = Rc::new(Cell::new(0));
296        let factory = Mw(PhantomData, cnt_sht.clone())
297            .apply(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
298            .boxed();
299
300        let srv = factory.pipeline(&()).await.unwrap();
301        let res = srv.call(10).await;
302        assert!(res.is_ok());
303        assert_eq!(res.unwrap(), 20);
304        let _ = format!("{factory:?} {srv:?}");
305
306        assert_eq!(srv.ready().await, Ok(()));
307        srv.shutdown().await;
308        assert_eq!(cnt_sht.get(), 2);
309    }
310
311    #[ntex::test]
312    async fn middleware_chain() {
313        let cnt_sht = Rc::new(Cell::new(0));
314        let factory =
315            crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
316                .apply(Mw(PhantomData, cnt_sht.clone()).clone());
317
318        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
319        let res = srv.call(10).await;
320        assert!(res.is_ok());
321        assert_eq!(res.unwrap(), 20);
322        let _ = format!("{factory:?} {srv:?}");
323
324        assert_eq!(srv.ready().await, Ok(()));
325        srv.shutdown().await;
326        assert_eq!(cnt_sht.get(), 2);
327    }
328
329    #[ntex::test]
330    async fn stack() {
331        let cnt_sht = Rc::new(Cell::new(0));
332        let mw = Stack::new(Identity, Mw(PhantomData, cnt_sht.clone()));
333        let _ = format!("{mw:?}");
334
335        let pl = Pipeline::new(Middleware::create(
336            &mw,
337            fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
338            (),
339        ));
340        let res = pl.call(10).await;
341        assert!(res.is_ok());
342        assert_eq!(res.unwrap(), 20);
343        assert_eq!(pl.ready().await, Ok(()));
344        pl.shutdown().await;
345        assert_eq!(cnt_sht.get(), 2);
346    }
347}