ntex_util/services/
timeout.rs

1//! Service that applies a timeout to requests.
2//!
3//! If the response does not complete within the specified timeout, the response
4//! will be aborted.
5use std::{fmt, marker};
6
7use ntex_service::{Middleware, Middleware2, Service, ServiceCtx};
8
9use crate::future::{Either, select};
10use crate::time::{Millis, sleep};
11
12/// Applies a timeout to requests.
13///
14/// Timeout transform is disabled if timeout is set to 0
15#[derive(Debug)]
16pub struct Timeout<E = ()> {
17    timeout: Millis,
18    _t: marker::PhantomData<E>,
19}
20
21/// Timeout error
22pub enum TimeoutError<E> {
23    /// Service error
24    Service(E),
25    /// Service call timeout
26    Timeout,
27}
28
29impl<E> From<E> for TimeoutError<E> {
30    fn from(err: E) -> Self {
31        TimeoutError::Service(err)
32    }
33}
34
35impl<E: fmt::Debug> fmt::Debug for TimeoutError<E> {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        match self {
38            TimeoutError::Service(e) => write!(f, "TimeoutError::Service({e:?})"),
39            TimeoutError::Timeout => write!(f, "TimeoutError::Timeout"),
40        }
41    }
42}
43
44impl<E: fmt::Display> fmt::Display for TimeoutError<E> {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        match self {
47            TimeoutError::Service(e) => e.fmt(f),
48            TimeoutError::Timeout => write!(f, "Service call timeout"),
49        }
50    }
51}
52
53impl<E: fmt::Display + fmt::Debug> std::error::Error for TimeoutError<E> {}
54
55impl<E: PartialEq> PartialEq for TimeoutError<E> {
56    fn eq(&self, other: &TimeoutError<E>) -> bool {
57        match self {
58            TimeoutError::Service(e1) => match other {
59                TimeoutError::Service(e2) => e1 == e2,
60                TimeoutError::Timeout => false,
61            },
62            TimeoutError::Timeout => match other {
63                TimeoutError::Service(_) => false,
64                TimeoutError::Timeout => true,
65            },
66        }
67    }
68}
69
70impl Timeout {
71    pub fn new<T: Into<Millis>>(timeout: T) -> Self {
72        Timeout {
73            timeout: timeout.into(),
74            _t: marker::PhantomData,
75        }
76    }
77}
78
79impl Clone for Timeout {
80    fn clone(&self) -> Self {
81        Timeout {
82            timeout: self.timeout,
83            _t: marker::PhantomData,
84        }
85    }
86}
87
88impl<S> Middleware<S> for Timeout {
89    type Service = TimeoutService<S>;
90
91    fn create(&self, service: S) -> Self::Service {
92        TimeoutService {
93            service,
94            timeout: self.timeout,
95        }
96    }
97}
98
99impl<S, C> Middleware2<S, C> for Timeout {
100    type Service = TimeoutService<S>;
101
102    fn create(&self, service: S, _: C) -> Self::Service {
103        TimeoutService {
104            service,
105            timeout: self.timeout,
106        }
107    }
108}
109
110/// Applies a timeout to requests.
111#[derive(Debug, Clone)]
112pub struct TimeoutService<S> {
113    service: S,
114    timeout: Millis,
115}
116
117impl<S> TimeoutService<S> {
118    pub fn new<T, R>(timeout: T, service: S) -> Self
119    where
120        T: Into<Millis>,
121        S: Service<R>,
122    {
123        TimeoutService {
124            service,
125            timeout: timeout.into(),
126        }
127    }
128}
129
130impl<S, R> Service<R> for TimeoutService<S>
131where
132    S: Service<R>,
133{
134    type Response = S::Response;
135    type Error = TimeoutError<S::Error>;
136
137    async fn call(
138        &self,
139        request: R,
140        ctx: ServiceCtx<'_, Self>,
141    ) -> Result<Self::Response, Self::Error> {
142        if self.timeout.is_zero() {
143            ctx.call(&self.service, request)
144                .await
145                .map_err(TimeoutError::Service)
146        } else {
147            match select(sleep(self.timeout), ctx.call(&self.service, request)).await {
148                Either::Left(_) => Err(TimeoutError::Timeout),
149                Either::Right(res) => res.map_err(TimeoutError::Service),
150            }
151        }
152    }
153
154    ntex_service::forward_poll!(service, TimeoutError::Service);
155    ntex_service::forward_ready!(service, TimeoutError::Service);
156    ntex_service::forward_shutdown!(service);
157}
158
159#[cfg(test)]
160mod tests {
161    use std::time::Duration;
162
163    use ntex_service::{Pipeline, ServiceFactory, apply, apply2, fn_factory};
164
165    use super::*;
166
167    #[derive(Clone, Debug, PartialEq)]
168    struct SleepService(Duration);
169
170    #[derive(Clone, Debug, PartialEq)]
171    struct SrvError;
172
173    impl fmt::Display for SrvError {
174        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175            write!(f, "SrvError")
176        }
177    }
178
179    impl Service<()> for SleepService {
180        type Response = ();
181        type Error = SrvError;
182
183        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), SrvError> {
184            crate::time::sleep(self.0).await;
185            Ok::<_, SrvError>(())
186        }
187    }
188
189    #[ntex::test]
190    async fn test_success() {
191        let resolution = Duration::from_millis(100);
192        let wait_time = Duration::from_millis(50);
193
194        let timeout =
195            Pipeline::new(TimeoutService::new(resolution, SleepService(wait_time)).clone());
196        assert_eq!(timeout.call(()).await, Ok(()));
197        assert_eq!(timeout.ready().await, Ok(()));
198        timeout.shutdown().await;
199    }
200
201    #[ntex::test]
202    async fn test_zero() {
203        let wait_time = Duration::from_millis(50);
204        let resolution = Duration::from_millis(0);
205
206        let timeout =
207            Pipeline::new(TimeoutService::new(resolution, SleepService(wait_time)));
208        assert_eq!(timeout.call(()).await, Ok(()));
209        assert_eq!(timeout.ready().await, Ok(()));
210    }
211
212    #[ntex::test]
213    async fn test_timeout() {
214        let resolution = Duration::from_millis(100);
215        let wait_time = Duration::from_millis(500);
216
217        let timeout =
218            Pipeline::new(TimeoutService::new(resolution, SleepService(wait_time)));
219        assert_eq!(timeout.call(()).await, Err(TimeoutError::Timeout));
220    }
221
222    #[ntex::test]
223    #[allow(clippy::redundant_clone)]
224    async fn test_timeout_middleware() {
225        let resolution = Duration::from_millis(100);
226        let wait_time = Duration::from_millis(500);
227
228        let timeout = apply(
229            Timeout::new(resolution).clone(),
230            fn_factory(|| async { Ok::<_, ()>(SleepService(wait_time)) }),
231        );
232        let srv = timeout.pipeline(&()).await.unwrap();
233
234        let res = srv.call(()).await.unwrap_err();
235        assert_eq!(res, TimeoutError::Timeout);
236    }
237
238    #[ntex::test]
239    #[allow(clippy::redundant_clone)]
240    async fn test_timeout_middleware2() {
241        let resolution = Duration::from_millis(100);
242        let wait_time = Duration::from_millis(500);
243
244        let timeout = apply2(
245            Timeout::new(resolution).clone(),
246            fn_factory(|| async { Ok::<_, ()>(SleepService(wait_time)) }),
247        );
248        let srv = timeout.pipeline(&()).await.unwrap();
249
250        let res = srv.call(()).await.unwrap_err();
251        assert_eq!(res, TimeoutError::Timeout);
252    }
253
254    #[test]
255    fn test_error() {
256        let err1 = TimeoutError::<SrvError>::Timeout;
257        assert!(format!("{err1:?}").contains("TimeoutError::Timeout"));
258        assert!(format!("{err1}").contains("Service call timeout"));
259
260        let err2: TimeoutError<_> = SrvError.into();
261        assert!(format!("{err2:?}").contains("TimeoutError::Service"));
262        assert!(format!("{err2}").contains("SrvError"));
263    }
264}