Skip to main content

ntex_service/
middleware.rs

1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use crate::dev::{Apply, ApplyCtx, ServiceChainFactory};
4use crate::{IntoServiceFactory, Service, ServiceFactory};
5
6/// Apply middleware to a service.
7pub fn apply<M, S, R, C, U>(
8    mw: M,
9    factory: U,
10) -> ServiceChainFactory<ApplyMiddleware<M, S, C>, R, C>
11where
12    S: ServiceFactory<R, C>,
13    M: Middleware<S::Service, C>,
14    U: IntoServiceFactory<S, R, C>,
15{
16    ServiceChainFactory {
17        factory: ApplyMiddleware::new(mw, factory.into_factory()),
18        _t: PhantomData,
19    }
20}
21
22/// The `Middleware` trait defines the interface for a service factory
23/// that wraps an inner service during construction.
24///
25/// Middleware runs during inbound and/or outbound processing in the
26/// request/response lifecycle, and may modify the request and/or response.
27///
28/// For example, timeout middleware:
29///
30/// ```rust
31/// use ntex_service::{Service, ServiceCtx};
32/// use ntex::{time::sleep, util::Either, util::select};
33///
34/// pub struct Timeout<S> {
35///     service: S,
36///     timeout: std::time::Duration,
37/// }
38///
39/// pub enum TimeoutError<E> {
40///    Service(E),
41///    Timeout,
42/// }
43///
44/// impl<S, R> Service<R> for Timeout<S>
45/// where
46///     S: Service<R>,
47/// {
48///     type Response = S::Response;
49///     type Error = TimeoutError<S::Error>;
50///
51///     async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
52///         ctx.ready(&self.service).await.map_err(TimeoutError::Service)
53///     }
54///
55///     async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
56///         match select(sleep(self.timeout), ctx.call(&self.service, req)).await {
57///             Either::Left(_) => Err(TimeoutError::Timeout),
58///             Either::Right(res) => res.map_err(TimeoutError::Service),
59///         }
60///     }
61/// }
62/// ```
63///
64/// The timeout service in the example above is decoupled from the underlying
65/// service implementation and can be applied to any service.
66///
67/// The `Middleware` trait defines the interface for a middleware factory,
68/// specifying how to construct a middleware `Service`. A service constructed
69/// by the factory takes the following service in the execution chain as a
70/// parameter, assuming ownership of that service.
71///
72/// Factory for `Timeout` middleware from the above example could look like this:
73///
74/// ```rust,ignore
75/// pub struct TimeoutMiddleware {
76///     timeout: std::time::Duration,
77/// }
78///
79/// impl<S> Middleware<S> for TimeoutMiddleware
80/// {
81///     type Service = Timeout<S>;
82///
83///     fn create(&self, service: S) -> Self::Service {
84///         Timeout {
85///             service,
86///             timeout: self.timeout,
87///         }
88///     }
89/// }
90/// ```
91pub trait Middleware<Svc, Cfg = ()> {
92    /// The middleware `Service` value created by this factory
93    type Service;
94
95    /// Creates and returns a new middleware service.
96    fn create(&self, service: Svc, cfg: Cfg) -> Self::Service;
97
98    /// Creates a service factory that instantiates a service and applies
99    /// the current middleware to it.
100    ///
101    /// This is equivalent to `apply(self, factory)`.
102    fn apply<Fac, Req>(
103        self,
104        factory: Fac,
105    ) -> ServiceChainFactory<ApplyMiddleware<Self, Fac, Cfg>, Req, Cfg>
106    where
107        Fac: ServiceFactory<Req, Cfg, Service = Svc>,
108        Cfg: Clone,
109        Self: Sized,
110        Self::Service: Service<Req>,
111    {
112        crate::chain_factory(ApplyMiddleware::new(self, factory))
113    }
114}
115
116impl<M, Svc, Cfg> Middleware<Svc, Cfg> for Rc<M>
117where
118    M: Middleware<Svc, Cfg>,
119{
120    type Service = M::Service;
121
122    fn create(&self, service: Svc, cfg: Cfg) -> M::Service {
123        self.as_ref().create(service, cfg)
124    }
125}
126
127/// `Apply` middleware to a service factory.
128pub struct ApplyMiddleware<M, Fac, Cfg>(Rc<(M, Fac)>, PhantomData<Cfg>);
129
130impl<M, Fac, Cfg> ApplyMiddleware<M, Fac, Cfg> {
131    /// Create new `ApplyMiddleware` service factory instance
132    pub(crate) fn new(mw: M, fac: Fac) -> Self {
133        Self(Rc::new((mw, fac)), PhantomData)
134    }
135}
136
137impl<M, Fac, Cfg> Clone for ApplyMiddleware<M, Fac, Cfg> {
138    fn clone(&self) -> Self {
139        Self(self.0.clone(), PhantomData)
140    }
141}
142
143impl<M, Fac, Cfg> fmt::Debug for ApplyMiddleware<M, Fac, Cfg>
144where
145    M: fmt::Debug,
146    Fac: fmt::Debug,
147{
148    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149        f.debug_struct("ApplyMiddleware")
150            .field("factory", &self.0.1)
151            .field("middleware", &self.0.0)
152            .finish()
153    }
154}
155
156impl<M, Fac, Req, Cfg> ServiceFactory<Req, Cfg> for ApplyMiddleware<M, Fac, Cfg>
157where
158    Fac: ServiceFactory<Req, Cfg>,
159    M: Middleware<Fac::Service, Cfg>,
160    M::Service: Service<Req>,
161    Cfg: Clone,
162{
163    type Response = <M::Service as Service<Req>>::Response;
164    type Error = <M::Service as Service<Req>>::Error;
165
166    type Service = M::Service;
167    type InitError = Fac::InitError;
168
169    #[inline]
170    async fn create(&self, cfg: Cfg) -> Result<Self::Service, Self::InitError> {
171        Ok(self.0.0.create(self.0.1.create(cfg.clone()).await?, cfg))
172    }
173}
174
175/// Identity is a middleware.
176///
177/// It returns service without modifications.
178#[derive(Debug, Clone, Copy)]
179pub struct Identity;
180
181impl<S, Cfg> Middleware<S, Cfg> for Identity {
182    type Service = S;
183
184    #[inline]
185    fn create(&self, service: S, _: Cfg) -> Self::Service {
186        service
187    }
188}
189
190/// Stack of middlewares.
191#[derive(Debug, Clone)]
192pub struct Stack<Inner, Outer> {
193    inner: Inner,
194    outer: Outer,
195}
196
197impl<Inner, Outer> Stack<Inner, Outer> {
198    pub fn new(inner: Inner, outer: Outer) -> Self {
199        Stack { inner, outer }
200    }
201}
202
203impl<S, Inner, Outer, C> Middleware<S, C> for Stack<Inner, Outer>
204where
205    Inner: Middleware<S, C>,
206    Outer: Middleware<Inner::Service, C>,
207    C: Clone,
208{
209    type Service = Outer::Service;
210
211    fn create(&self, service: S, cfg: C) -> Self::Service {
212        self.outer
213            .create(self.inner.create(service, cfg.clone()), cfg)
214    }
215}
216
217#[doc(hidden)]
218/// Service factory that produces `middleware` from `Fn`.
219pub fn fn_layer<T, Req, F, In, Out, Err>(f: F) -> FnMiddleware<T, Req, F, In, Out, Err>
220where
221    F: AsyncFn(In, &ApplyCtx<'_, T>) -> Result<Out, Err> + Clone,
222{
223    FnMiddleware { f, r: PhantomData }
224}
225
226#[allow(clippy::type_complexity)]
227/// `FnMiddleware` service combinator
228pub struct FnMiddleware<T, Req, F, In, Out, Err> {
229    f: F,
230    r: PhantomData<fn(T, Req) -> (In, Out, Err)>,
231}
232
233impl<T, Req, F, In, Out, Err> Clone for FnMiddleware<T, Req, F, In, Out, Err>
234where
235    F: Clone,
236{
237    fn clone(&self) -> Self {
238        FnMiddleware {
239            f: self.f.clone(),
240            r: PhantomData,
241        }
242    }
243}
244
245impl<T, Req, F, In, Out, Err> fmt::Debug for FnMiddleware<T, Req, F, In, Out, Err> {
246    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247        f.debug_struct("FnMiddleware")
248            .field("layer", &std::any::type_name::<F>())
249            .finish()
250    }
251}
252
253impl<T, C, R, F, In, Out, Err> Middleware<T, C> for FnMiddleware<T, R, F, In, Out, Err>
254where
255    T: Service<R>,
256    F: AsyncFn(In, &ApplyCtx<'_, T>) -> Result<Out, Err> + Clone,
257    Err: From<T::Error>,
258{
259    type Service = Apply<T, R, F, In, Out, Err>;
260
261    fn create(&self, service: T, _: C) -> Self::Service {
262        Apply::new(service, self.f.clone())
263    }
264}
265
266#[cfg(test)]
267#[allow(clippy::redundant_clone)]
268mod tests {
269    use std::{cell::Cell, rc::Rc};
270
271    use super::*;
272    use crate::{Pipeline, ServiceCtx, fn_service};
273
274    #[derive(Debug, Clone)]
275    struct Mw<R>(PhantomData<R>, Rc<Cell<usize>>);
276
277    impl<S, R, C> Middleware<S, C> for Mw<R> {
278        type Service = Srv<S, R>;
279
280        fn create(&self, service: S, _: C) -> Self::Service {
281            self.1.set(self.1.get() + 1);
282            Srv(service, PhantomData, self.1.clone())
283        }
284    }
285
286    #[derive(Debug, Clone)]
287    struct Srv<S, R>(S, PhantomData<R>, Rc<Cell<usize>>);
288
289    impl<S: Service<R>, R> Service<R> for Srv<S, R> {
290        type Response = S::Response;
291        type Error = S::Error;
292
293        async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
294            ctx.ready(&self.0).await
295        }
296
297        async fn call(
298            &self,
299            req: R,
300            ctx: ServiceCtx<'_, Self>,
301        ) -> Result<S::Response, S::Error> {
302            ctx.call(&self.0, req).await
303        }
304
305        async fn shutdown(&self) {
306            self.2.set(self.2.get() + 1);
307        }
308    }
309
310    #[ntex::test]
311    async fn middleware() {
312        let cnt_sht = Rc::new(Cell::new(0));
313        let factory = apply(
314            Rc::new(Mw(PhantomData, cnt_sht.clone()).clone()),
315            fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
316        )
317        .clone();
318
319        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
320        let res = srv.call(10).await;
321        assert!(res.is_ok());
322        assert_eq!(res.unwrap(), 20);
323        let _ = format!("{factory:?} {srv:?}");
324
325        assert_eq!(srv.ready().await, Ok(()));
326        srv.shutdown().await;
327        assert_eq!(cnt_sht.get(), 2);
328
329        let factory =
330            crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
331                .apply(Rc::new(Mw(PhantomData, Rc::new(Cell::new(0))).clone()))
332                .clone();
333
334        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
335        let res = srv.call(10).await;
336        assert!(res.is_ok());
337        assert_eq!(res.unwrap(), 20);
338        let _ = format!("{factory:?} {srv:?}");
339
340        assert_eq!(srv.ready().await, Ok(()));
341    }
342
343    #[ntex::test]
344    async fn middleware_apply() {
345        let cnt_sht = Rc::new(Cell::new(0));
346        let factory = Mw(PhantomData, cnt_sht.clone())
347            .apply(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
348            .boxed();
349
350        let srv = factory.pipeline(&()).await.unwrap();
351        let res = srv.call(10).await;
352        assert!(res.is_ok());
353        assert_eq!(res.unwrap(), 20);
354        let _ = format!("{factory:?} {srv:?}");
355
356        assert_eq!(srv.ready().await, Ok(()));
357        srv.shutdown().await;
358        assert_eq!(cnt_sht.get(), 2);
359    }
360
361    #[ntex::test]
362    async fn middleware_chain() {
363        let cnt_sht = Rc::new(Cell::new(0));
364        let factory =
365            crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
366                .apply(Mw(PhantomData, cnt_sht.clone()).clone());
367
368        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
369        let res = srv.call(10).await;
370        assert!(res.is_ok());
371        assert_eq!(res.unwrap(), 20);
372        let _ = format!("{factory:?} {srv:?}");
373
374        assert_eq!(srv.ready().await, Ok(()));
375        srv.shutdown().await;
376        assert_eq!(cnt_sht.get(), 2);
377    }
378
379    #[ntex::test]
380    async fn stack() {
381        let cnt_sht = Rc::new(Cell::new(0));
382        let mw = Stack::new(Identity, Mw(PhantomData, cnt_sht.clone()));
383        let _ = format!("{mw:?}");
384
385        let pl = Pipeline::new(Middleware::create(
386            &mw,
387            fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
388            (),
389        ));
390        let res = pl.call(10).await;
391        assert!(res.is_ok());
392        assert_eq!(res.unwrap(), 20);
393        assert_eq!(pl.ready().await, Ok(()));
394        pl.shutdown().await;
395        assert_eq!(cnt_sht.get(), 2);
396    }
397
398    #[ntex::test]
399    async fn fn_middleware_service() {
400        let cnt_sht = Rc::new(Cell::new(0));
401        let cnt_sht2 = cnt_sht.clone();
402        let mw = fn_layer(async move |req: &'static str, svc| {
403            cnt_sht2.set(cnt_sht2.get() + 1);
404            let result = svc.call(1).await?;
405            Ok::<_, ()>((req, result))
406        })
407        .clone();
408        let _ = format!("{mw:?}");
409
410        let svc = Pipeline::new(
411            mw.create(fn_service(async move |i: usize| Ok::<_, ()>(i * 2)), ()),
412        );
413
414        let res = svc.call("test").await;
415        assert!(res.is_ok());
416        assert_eq!(res.unwrap(), ("test", 2));
417        let _ = format!("{svc:?}");
418
419        assert_eq!(svc.ready().await, Ok(()));
420        svc.shutdown().await;
421        assert_eq!(cnt_sht.get(), 1);
422    }
423}