ntex_util/services/
timeout.rs

1//! Service that applies a timeout to requests.
2//!
3//! If the response does not complete within the specified timeout, the response
4//! will be aborted.
5use std::{fmt, marker};
6
7use ntex_service::{Middleware, Service, ServiceCtx};
8
9use crate::future::{select, Either};
10use crate::time::{sleep, Millis};
11
12/// Applies a timeout to requests.
13///
14/// Timeout transform is disabled if timeout is set to 0
15#[derive(Debug)]
16pub struct Timeout<E = ()> {
17    timeout: Millis,
18    _t: marker::PhantomData<E>,
19}
20
21/// Timeout error
22pub enum TimeoutError<E> {
23    /// Service error
24    Service(E),
25    /// Service call timeout
26    Timeout,
27}
28
29impl<E> From<E> for TimeoutError<E> {
30    fn from(err: E) -> Self {
31        TimeoutError::Service(err)
32    }
33}
34
35impl<E: fmt::Debug> fmt::Debug for TimeoutError<E> {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        match self {
38            TimeoutError::Service(e) => write!(f, "TimeoutError::Service({e:?})"),
39            TimeoutError::Timeout => write!(f, "TimeoutError::Timeout"),
40        }
41    }
42}
43
44impl<E: fmt::Display> fmt::Display for TimeoutError<E> {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        match self {
47            TimeoutError::Service(e) => e.fmt(f),
48            TimeoutError::Timeout => write!(f, "Service call timeout"),
49        }
50    }
51}
52
53impl<E: fmt::Display + fmt::Debug> std::error::Error for TimeoutError<E> {}
54
55impl<E: PartialEq> PartialEq for TimeoutError<E> {
56    fn eq(&self, other: &TimeoutError<E>) -> bool {
57        match self {
58            TimeoutError::Service(e1) => match other {
59                TimeoutError::Service(e2) => e1 == e2,
60                TimeoutError::Timeout => false,
61            },
62            TimeoutError::Timeout => match other {
63                TimeoutError::Service(_) => false,
64                TimeoutError::Timeout => true,
65            },
66        }
67    }
68}
69
70impl Timeout {
71    pub fn new<T: Into<Millis>>(timeout: T) -> Self {
72        Timeout {
73            timeout: timeout.into(),
74            _t: marker::PhantomData,
75        }
76    }
77}
78
79impl Clone for Timeout {
80    fn clone(&self) -> Self {
81        Timeout {
82            timeout: self.timeout,
83            _t: marker::PhantomData,
84        }
85    }
86}
87
88impl<S> Middleware<S> for Timeout {
89    type Service = TimeoutService<S>;
90
91    fn create(&self, service: S) -> Self::Service {
92        TimeoutService {
93            service,
94            timeout: self.timeout,
95        }
96    }
97}
98
99/// Applies a timeout to requests.
100#[derive(Debug, Clone)]
101pub struct TimeoutService<S> {
102    service: S,
103    timeout: Millis,
104}
105
106impl<S> TimeoutService<S> {
107    pub fn new<T, R>(timeout: T, service: S) -> Self
108    where
109        T: Into<Millis>,
110        S: Service<R>,
111    {
112        TimeoutService {
113            service,
114            timeout: timeout.into(),
115        }
116    }
117}
118
119impl<S, R> Service<R> for TimeoutService<S>
120where
121    S: Service<R>,
122{
123    type Response = S::Response;
124    type Error = TimeoutError<S::Error>;
125
126    async fn call(
127        &self,
128        request: R,
129        ctx: ServiceCtx<'_, Self>,
130    ) -> Result<Self::Response, Self::Error> {
131        if self.timeout.is_zero() {
132            ctx.call(&self.service, request)
133                .await
134                .map_err(TimeoutError::Service)
135        } else {
136            match select(sleep(self.timeout), ctx.call(&self.service, request)).await {
137                Either::Left(_) => Err(TimeoutError::Timeout),
138                Either::Right(res) => res.map_err(TimeoutError::Service),
139            }
140        }
141    }
142
143    ntex_service::forward_poll!(service, TimeoutError::Service);
144    ntex_service::forward_ready!(service, TimeoutError::Service);
145    ntex_service::forward_shutdown!(service);
146}
147
148#[cfg(test)]
149mod tests {
150    use std::time::Duration;
151
152    use ntex_service::{apply, fn_factory, Pipeline, ServiceFactory};
153
154    use super::*;
155
156    #[derive(Clone, Debug, PartialEq)]
157    struct SleepService(Duration);
158
159    #[derive(Clone, Debug, PartialEq)]
160    struct SrvError;
161
162    impl fmt::Display for SrvError {
163        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164            write!(f, "SrvError")
165        }
166    }
167
168    impl Service<()> for SleepService {
169        type Response = ();
170        type Error = SrvError;
171
172        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), SrvError> {
173            crate::time::sleep(self.0).await;
174            Ok::<_, SrvError>(())
175        }
176    }
177
178    #[ntex_macros::rt_test2]
179    async fn test_success() {
180        let resolution = Duration::from_millis(100);
181        let wait_time = Duration::from_millis(50);
182
183        let timeout =
184            Pipeline::new(TimeoutService::new(resolution, SleepService(wait_time)).clone());
185        assert_eq!(timeout.call(()).await, Ok(()));
186        assert_eq!(timeout.ready().await, Ok(()));
187        timeout.shutdown().await;
188    }
189
190    #[ntex_macros::rt_test2]
191    async fn test_zero() {
192        let wait_time = Duration::from_millis(50);
193        let resolution = Duration::from_millis(0);
194
195        let timeout =
196            Pipeline::new(TimeoutService::new(resolution, SleepService(wait_time)));
197        assert_eq!(timeout.call(()).await, Ok(()));
198        assert_eq!(timeout.ready().await, Ok(()));
199    }
200
201    #[ntex_macros::rt_test2]
202    async fn test_timeout() {
203        let resolution = Duration::from_millis(100);
204        let wait_time = Duration::from_millis(500);
205
206        let timeout =
207            Pipeline::new(TimeoutService::new(resolution, SleepService(wait_time)));
208        assert_eq!(timeout.call(()).await, Err(TimeoutError::Timeout));
209    }
210
211    #[ntex_macros::rt_test2]
212    #[allow(clippy::redundant_clone)]
213    async fn test_timeout_middleware() {
214        let resolution = Duration::from_millis(100);
215        let wait_time = Duration::from_millis(500);
216
217        let timeout = apply(
218            Timeout::new(resolution).clone(),
219            fn_factory(|| async { Ok::<_, ()>(SleepService(wait_time)) }),
220        );
221        let srv = timeout.pipeline(&()).await.unwrap();
222
223        let res = srv.call(()).await.unwrap_err();
224        assert_eq!(res, TimeoutError::Timeout);
225    }
226
227    #[test]
228    fn test_error() {
229        let err1 = TimeoutError::<SrvError>::Timeout;
230        assert!(format!("{err1:?}").contains("TimeoutError::Timeout"));
231        assert!(format!("{err1}").contains("Service call timeout"));
232
233        let err2: TimeoutError<_> = SrvError.into();
234        assert!(format!("{err2:?}").contains("TimeoutError::Service"));
235        assert!(format!("{err2}").contains("SrvError"));
236    }
237}