1use std::{fmt, marker};
6
7use ntex_service::{Middleware, Middleware2, Service, ServiceCtx};
8
9use crate::future::{Either, select};
10use crate::time::{Millis, sleep};
11
12#[derive(Debug)]
16pub struct Timeout<E = ()> {
17 timeout: Millis,
18 _t: marker::PhantomData<E>,
19}
20
21pub enum TimeoutError<E> {
23 Service(E),
25 Timeout,
27}
28
29impl<E> From<E> for TimeoutError<E> {
30 fn from(err: E) -> Self {
31 TimeoutError::Service(err)
32 }
33}
34
35impl<E: fmt::Debug> fmt::Debug for TimeoutError<E> {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 match self {
38 TimeoutError::Service(e) => write!(f, "TimeoutError::Service({e:?})"),
39 TimeoutError::Timeout => write!(f, "TimeoutError::Timeout"),
40 }
41 }
42}
43
44impl<E: fmt::Display> fmt::Display for TimeoutError<E> {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 match self {
47 TimeoutError::Service(e) => e.fmt(f),
48 TimeoutError::Timeout => write!(f, "Service call timeout"),
49 }
50 }
51}
52
53impl<E: fmt::Display + fmt::Debug> std::error::Error for TimeoutError<E> {}
54
55impl<E: PartialEq> PartialEq for TimeoutError<E> {
56 fn eq(&self, other: &TimeoutError<E>) -> bool {
57 match self {
58 TimeoutError::Service(e1) => match other {
59 TimeoutError::Service(e2) => e1 == e2,
60 TimeoutError::Timeout => false,
61 },
62 TimeoutError::Timeout => match other {
63 TimeoutError::Service(_) => false,
64 TimeoutError::Timeout => true,
65 },
66 }
67 }
68}
69
70impl Timeout {
71 pub fn new<T: Into<Millis>>(timeout: T) -> Self {
72 Timeout {
73 timeout: timeout.into(),
74 _t: marker::PhantomData,
75 }
76 }
77}
78
79impl Clone for Timeout {
80 fn clone(&self) -> Self {
81 Timeout {
82 timeout: self.timeout,
83 _t: marker::PhantomData,
84 }
85 }
86}
87
88impl<S> Middleware<S> for Timeout {
89 type Service = TimeoutService<S>;
90
91 fn create(&self, service: S) -> Self::Service {
92 TimeoutService {
93 service,
94 timeout: self.timeout,
95 }
96 }
97}
98
99impl<S, C> Middleware2<S, C> for Timeout {
100 type Service = TimeoutService<S>;
101
102 fn create(&self, service: S, _: C) -> Self::Service {
103 TimeoutService {
104 service,
105 timeout: self.timeout,
106 }
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct TimeoutService<S> {
113 service: S,
114 timeout: Millis,
115}
116
117impl<S> TimeoutService<S> {
118 pub fn new<T, R>(timeout: T, service: S) -> Self
119 where
120 T: Into<Millis>,
121 S: Service<R>,
122 {
123 TimeoutService {
124 service,
125 timeout: timeout.into(),
126 }
127 }
128}
129
130impl<S, R> Service<R> for TimeoutService<S>
131where
132 S: Service<R>,
133{
134 type Response = S::Response;
135 type Error = TimeoutError<S::Error>;
136
137 async fn call(
138 &self,
139 request: R,
140 ctx: ServiceCtx<'_, Self>,
141 ) -> Result<Self::Response, Self::Error> {
142 if self.timeout.is_zero() {
143 ctx.call(&self.service, request)
144 .await
145 .map_err(TimeoutError::Service)
146 } else {
147 match select(sleep(self.timeout), ctx.call(&self.service, request)).await {
148 Either::Left(_) => Err(TimeoutError::Timeout),
149 Either::Right(res) => res.map_err(TimeoutError::Service),
150 }
151 }
152 }
153
154 ntex_service::forward_poll!(service, TimeoutError::Service);
155 ntex_service::forward_ready!(service, TimeoutError::Service);
156 ntex_service::forward_shutdown!(service);
157}
158
159#[cfg(test)]
160mod tests {
161 use std::time::Duration;
162
163 use ntex_service::{Pipeline, ServiceFactory, apply, apply2, fn_factory};
164
165 use super::*;
166
167 #[derive(Clone, Debug, PartialEq)]
168 struct SleepService(Duration);
169
170 #[derive(Clone, Debug, PartialEq)]
171 struct SrvError;
172
173 impl fmt::Display for SrvError {
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175 write!(f, "SrvError")
176 }
177 }
178
179 impl Service<()> for SleepService {
180 type Response = ();
181 type Error = SrvError;
182
183 async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), SrvError> {
184 crate::time::sleep(self.0).await;
185 Ok::<_, SrvError>(())
186 }
187 }
188
189 #[ntex::test]
190 async fn test_success() {
191 let resolution = Duration::from_millis(100);
192 let wait_time = Duration::from_millis(50);
193
194 let timeout =
195 Pipeline::new(TimeoutService::new(resolution, SleepService(wait_time)).clone());
196 assert_eq!(timeout.call(()).await, Ok(()));
197 assert_eq!(timeout.ready().await, Ok(()));
198 timeout.shutdown().await;
199 }
200
201 #[ntex::test]
202 async fn test_zero() {
203 let wait_time = Duration::from_millis(50);
204 let resolution = Duration::from_millis(0);
205
206 let timeout =
207 Pipeline::new(TimeoutService::new(resolution, SleepService(wait_time)));
208 assert_eq!(timeout.call(()).await, Ok(()));
209 assert_eq!(timeout.ready().await, Ok(()));
210 }
211
212 #[ntex::test]
213 async fn test_timeout() {
214 let resolution = Duration::from_millis(100);
215 let wait_time = Duration::from_millis(500);
216
217 let timeout =
218 Pipeline::new(TimeoutService::new(resolution, SleepService(wait_time)));
219 assert_eq!(timeout.call(()).await, Err(TimeoutError::Timeout));
220 }
221
222 #[ntex::test]
223 #[allow(clippy::redundant_clone)]
224 async fn test_timeout_middleware() {
225 let resolution = Duration::from_millis(100);
226 let wait_time = Duration::from_millis(500);
227
228 let timeout = apply(
229 Timeout::new(resolution).clone(),
230 fn_factory(|| async { Ok::<_, ()>(SleepService(wait_time)) }),
231 );
232 let srv = timeout.pipeline(&()).await.unwrap();
233
234 let res = srv.call(()).await.unwrap_err();
235 assert_eq!(res, TimeoutError::Timeout);
236 }
237
238 #[ntex::test]
239 #[allow(clippy::redundant_clone)]
240 async fn test_timeout_middleware2() {
241 let resolution = Duration::from_millis(100);
242 let wait_time = Duration::from_millis(500);
243
244 let timeout = apply2(
245 Timeout::new(resolution).clone(),
246 fn_factory(|| async { Ok::<_, ()>(SleepService(wait_time)) }),
247 );
248 let srv = timeout.pipeline(&()).await.unwrap();
249
250 let res = srv.call(()).await.unwrap_err();
251 assert_eq!(res, TimeoutError::Timeout);
252 }
253
254 #[test]
255 fn test_error() {
256 let err1 = TimeoutError::<SrvError>::Timeout;
257 assert!(format!("{err1:?}").contains("TimeoutError::Timeout"));
258 assert!(format!("{err1}").contains("Service call timeout"));
259
260 let err2: TimeoutError<_> = SrvError.into();
261 assert!(format!("{err2:?}").contains("TimeoutError::Service"));
262 assert!(format!("{err2}").contains("SrvError"));
263 }
264}