ntex_service/
middleware.rs

1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use crate::{IntoServiceFactory, Service, ServiceFactory};
4
5/// Apply middleware to a service.
6pub fn apply<T, S, R, C, U>(t: T, factory: U) -> ApplyMiddleware<T, S, C>
7where
8    S: ServiceFactory<R, C>,
9    T: Middleware<S::Service>,
10    U: IntoServiceFactory<S, R, C>,
11{
12    ApplyMiddleware::new(t, factory.into_factory())
13}
14
15/// The `Middleware` trait defines the interface of a service factory that wraps inner service
16/// during construction.
17///
18/// Middleware wraps inner service and runs during
19/// inbound and/or outbound processing in the request/response lifecycle.
20/// It may modify request and/or response.
21///
22/// For example, timeout middleware:
23///
24/// ```rust
25/// use ntex_service::{Service, ServiceCtx};
26/// use ntex_util::{time::sleep, future::Either, future::select};
27///
28/// pub struct Timeout<S> {
29///     service: S,
30///     timeout: std::time::Duration,
31/// }
32///
33/// pub enum TimeoutError<E> {
34///    Service(E),
35///    Timeout,
36/// }
37///
38/// impl<S, R> Service<R> for Timeout<S>
39/// where
40///     S: Service<R>,
41/// {
42///     type Response = S::Response;
43///     type Error = TimeoutError<S::Error>;
44///
45///     async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
46///         ctx.ready(&self.service).await.map_err(TimeoutError::Service)
47///     }
48///
49///     async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
50///         match select(sleep(self.timeout), ctx.call(&self.service, req)).await {
51///             Either::Left(_) => Err(TimeoutError::Timeout),
52///             Either::Right(res) => res.map_err(TimeoutError::Service),
53///         }
54///     }
55/// }
56/// ```
57///
58/// Timeout service in above example is decoupled from underlying service implementation
59/// and could be applied to any service.
60///
61/// The `Middleware` trait defines the interface of a middleware factory, defining how to
62/// construct a middleware Service. A Service that is constructed by the factory takes
63/// the Service that follows it during execution as a parameter, assuming
64/// ownership of the next Service.
65///
66/// Factory for `Timeout` middleware from the above example could look like this:
67///
68/// ```rust,ignore
69/// pub struct TimeoutMiddleware {
70///     timeout: std::time::Duration,
71/// }
72///
73/// impl<S> Middleware<S> for TimeoutMiddleware
74/// {
75///     type Service = Timeout<S>;
76///
77///     fn create(&self, service: S) -> Self::Service {
78///         Timeout {
79///             service,
80///             timeout: self.timeout,
81///         }
82///     }
83/// }
84/// ```
85pub trait Middleware<S> {
86    /// The middleware `Service` value created by this factory
87    type Service;
88
89    /// Creates and returns a new middleware Service
90    fn create(&self, service: S) -> Self::Service;
91}
92
93impl<T, S> Middleware<S> for Rc<T>
94where
95    T: Middleware<S>,
96{
97    type Service = T::Service;
98
99    fn create(&self, service: S) -> T::Service {
100        self.as_ref().create(service)
101    }
102}
103
104/// `Apply` middleware to a service factory.
105pub struct ApplyMiddleware<T, S, C>(Rc<(T, S)>, PhantomData<C>);
106
107impl<T, S, C> ApplyMiddleware<T, S, C> {
108    /// Create new `ApplyMiddleware` service factory instance
109    pub(crate) fn new(mw: T, svc: S) -> Self {
110        Self(Rc::new((mw, svc)), PhantomData)
111    }
112}
113
114impl<T, S, C> Clone for ApplyMiddleware<T, S, C> {
115    fn clone(&self) -> Self {
116        Self(self.0.clone(), PhantomData)
117    }
118}
119
120impl<T, S, C> fmt::Debug for ApplyMiddleware<T, S, C>
121where
122    T: fmt::Debug,
123    S: fmt::Debug,
124{
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        f.debug_struct("ApplyMiddleware")
127            .field("service", &self.0 .1)
128            .field("middleware", &self.0 .0)
129            .finish()
130    }
131}
132
133impl<T, S, R, C> ServiceFactory<R, C> for ApplyMiddleware<T, S, C>
134where
135    S: ServiceFactory<R, C>,
136    T: Middleware<S::Service>,
137    T::Service: Service<R>,
138{
139    type Response = <T::Service as Service<R>>::Response;
140    type Error = <T::Service as Service<R>>::Error;
141
142    type Service = T::Service;
143    type InitError = S::InitError;
144
145    #[inline]
146    async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
147        Ok(self.0 .0.create(self.0 .1.create(cfg).await?))
148    }
149}
150
151/// Identity is a middleware.
152///
153/// It returns service without modifications.
154#[derive(Debug, Clone, Copy)]
155pub struct Identity;
156
157impl<S> Middleware<S> for Identity {
158    type Service = S;
159
160    #[inline]
161    fn create(&self, service: S) -> Self::Service {
162        service
163    }
164}
165
166/// Stack of middlewares.
167#[derive(Debug, Clone)]
168pub struct Stack<Inner, Outer> {
169    inner: Inner,
170    outer: Outer,
171}
172
173impl<Inner, Outer> Stack<Inner, Outer> {
174    pub fn new(inner: Inner, outer: Outer) -> Self {
175        Stack { inner, outer }
176    }
177}
178
179impl<S, Inner, Outer> Middleware<S> for Stack<Inner, Outer>
180where
181    Inner: Middleware<S>,
182    Outer: Middleware<Inner::Service>,
183{
184    type Service = Outer::Service;
185
186    fn create(&self, service: S) -> Self::Service {
187        self.outer.create(self.inner.create(service))
188    }
189}
190
191#[cfg(test)]
192#[allow(clippy::redundant_clone)]
193mod tests {
194    use std::{cell::Cell, rc::Rc};
195
196    use super::*;
197    use crate::{fn_service, Pipeline, ServiceCtx};
198
199    #[derive(Debug, Clone)]
200    struct Tr<R>(PhantomData<R>, Rc<Cell<usize>>);
201
202    impl<S, R> Middleware<S> for Tr<R> {
203        type Service = Srv<S, R>;
204
205        fn create(&self, service: S) -> Self::Service {
206            Srv(service, PhantomData, self.1.clone())
207        }
208    }
209
210    #[derive(Debug, Clone)]
211    struct Srv<S, R>(S, PhantomData<R>, Rc<Cell<usize>>);
212
213    impl<S: Service<R>, R> Service<R> for Srv<S, R> {
214        type Response = S::Response;
215        type Error = S::Error;
216
217        async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
218            ctx.ready(&self.0).await
219        }
220
221        async fn call(
222            &self,
223            req: R,
224            ctx: ServiceCtx<'_, Self>,
225        ) -> Result<S::Response, S::Error> {
226            ctx.call(&self.0, req).await
227        }
228
229        async fn shutdown(&self) {
230            self.2.set(self.2.get() + 1);
231        }
232    }
233
234    #[ntex::test]
235    async fn middleware() {
236        let cnt_sht = Rc::new(Cell::new(0));
237        let factory = apply(
238            Rc::new(Tr(PhantomData, cnt_sht.clone()).clone()),
239            fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }),
240        )
241        .clone();
242
243        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
244        let res = srv.call(10).await;
245        assert!(res.is_ok());
246        assert_eq!(res.unwrap(), 20);
247        let _ = format!("{:?} {:?}", factory, srv);
248
249        assert_eq!(srv.ready().await, Ok(()));
250        srv.shutdown().await;
251        assert_eq!(cnt_sht.get(), 1);
252
253        let factory =
254            crate::chain_factory(fn_service(|i: usize| async move { Ok::<_, ()>(i * 2) }))
255                .apply(Rc::new(Tr(PhantomData, Rc::new(Cell::new(0))).clone()))
256                .clone();
257
258        let srv = Pipeline::new(factory.create(&()).await.unwrap().clone());
259        let res = srv.call(10).await;
260        assert!(res.is_ok());
261        assert_eq!(res.unwrap(), 20);
262        let _ = format!("{:?} {:?}", factory, srv);
263
264        assert_eq!(srv.ready().await, Ok(()));
265    }
266}