actori_utils/
timeout.rs

1//! Service that applies a timeout to requests.
2//!
3//! If the response does not complete within the specified timeout, the response
4//! will be aborted.
5use std::future::Future;
6use std::marker::PhantomData;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::{fmt, time};
10
11use actori_rt::time::{delay_for, Delay};
12use actori_service::{IntoService, Service, Transform};
13use futures::future::{ok, Ready};
14
15/// Applies a timeout to requests.
16#[derive(Debug)]
17pub struct Timeout<E = ()> {
18    timeout: time::Duration,
19    _t: PhantomData<E>,
20}
21
22/// Timeout error
23pub enum TimeoutError<E> {
24    /// Service error
25    Service(E),
26    /// Service call timeout
27    Timeout,
28}
29
30impl<E> From<E> for TimeoutError<E> {
31    fn from(err: E) -> Self {
32        TimeoutError::Service(err)
33    }
34}
35
36impl<E: fmt::Debug> fmt::Debug for TimeoutError<E> {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            TimeoutError::Service(e) => write!(f, "TimeoutError::Service({:?})", e),
40            TimeoutError::Timeout => write!(f, "TimeoutError::Timeout"),
41        }
42    }
43}
44
45impl<E: fmt::Display> fmt::Display for TimeoutError<E> {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            TimeoutError::Service(e) => e.fmt(f),
49            TimeoutError::Timeout => write!(f, "Service call timeout"),
50        }
51    }
52}
53
54impl<E: PartialEq> PartialEq for TimeoutError<E> {
55    fn eq(&self, other: &TimeoutError<E>) -> bool {
56        match self {
57            TimeoutError::Service(e1) => match other {
58                TimeoutError::Service(e2) => e1 == e2,
59                TimeoutError::Timeout => false,
60            },
61            TimeoutError::Timeout => match other {
62                TimeoutError::Service(_) => false,
63                TimeoutError::Timeout => true,
64            },
65        }
66    }
67}
68
69impl<E> Timeout<E> {
70    pub fn new(timeout: time::Duration) -> Self {
71        Timeout {
72            timeout,
73            _t: PhantomData,
74        }
75    }
76}
77
78impl<E> Clone for Timeout<E> {
79    fn clone(&self) -> Self {
80        Timeout::new(self.timeout)
81    }
82}
83
84impl<S, E> Transform<S> for Timeout<E>
85where
86    S: Service,
87{
88    type Request = S::Request;
89    type Response = S::Response;
90    type Error = TimeoutError<S::Error>;
91    type InitError = E;
92    type Transform = TimeoutService<S>;
93    type Future = Ready<Result<Self::Transform, Self::InitError>>;
94
95    fn new_transform(&self, service: S) -> Self::Future {
96        ok(TimeoutService {
97            service,
98            timeout: self.timeout,
99        })
100    }
101}
102
103/// Applies a timeout to requests.
104#[derive(Debug, Clone)]
105pub struct TimeoutService<S> {
106    service: S,
107    timeout: time::Duration,
108}
109
110impl<S> TimeoutService<S>
111where
112    S: Service,
113{
114    pub fn new<U>(timeout: time::Duration, service: U) -> Self
115    where
116        U: IntoService<S>,
117    {
118        TimeoutService {
119            timeout,
120            service: service.into_service(),
121        }
122    }
123}
124
125impl<S> Service for TimeoutService<S>
126where
127    S: Service,
128{
129    type Request = S::Request;
130    type Response = S::Response;
131    type Error = TimeoutError<S::Error>;
132    type Future = TimeoutServiceResponse<S>;
133
134    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
135        self.service.poll_ready(cx).map_err(TimeoutError::Service)
136    }
137
138    fn call(&mut self, request: S::Request) -> Self::Future {
139        TimeoutServiceResponse {
140            fut: self.service.call(request),
141            sleep: delay_for(self.timeout),
142        }
143    }
144}
145
146/// `TimeoutService` response future
147#[pin_project::pin_project]
148#[derive(Debug)]
149pub struct TimeoutServiceResponse<T: Service> {
150    #[pin]
151    fut: T::Future,
152    sleep: Delay,
153}
154
155impl<T> Future for TimeoutServiceResponse<T>
156where
157    T: Service,
158{
159    type Output = Result<T::Response, TimeoutError<T::Error>>;
160
161    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
162        let mut this = self.project();
163
164        // First, try polling the future
165        match this.fut.poll(cx) {
166            Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)),
167            Poll::Ready(Err(e)) => return Poll::Ready(Err(TimeoutError::Service(e))),
168            Poll::Pending => {}
169        }
170
171        // Now check the sleep
172        match Pin::new(&mut this.sleep).poll(cx) {
173            Poll::Pending => Poll::Pending,
174            Poll::Ready(_) => Poll::Ready(Err(TimeoutError::Timeout)),
175        }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use std::task::{Context, Poll};
182    use std::time::Duration;
183
184    use super::*;
185    use actori_service::{apply, fn_factory, Service, ServiceFactory};
186    use futures::future::{ok, FutureExt, LocalBoxFuture};
187
188    struct SleepService(Duration);
189
190    impl Service for SleepService {
191        type Request = ();
192        type Response = ();
193        type Error = ();
194        type Future = LocalBoxFuture<'static, Result<(), ()>>;
195
196        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
197            Poll::Ready(Ok(()))
198        }
199
200        fn call(&mut self, _: ()) -> Self::Future {
201            actori_rt::time::delay_for(self.0)
202                .then(|_| ok::<_, ()>(()))
203                .boxed_local()
204        }
205    }
206
207    #[actori_rt::test]
208    async fn test_success() {
209        let resolution = Duration::from_millis(100);
210        let wait_time = Duration::from_millis(50);
211
212        let mut timeout = TimeoutService::new(resolution, SleepService(wait_time));
213        assert_eq!(timeout.call(()).await, Ok(()));
214    }
215
216    #[actori_rt::test]
217    async fn test_timeout() {
218        let resolution = Duration::from_millis(100);
219        let wait_time = Duration::from_millis(500);
220
221        let mut timeout = TimeoutService::new(resolution, SleepService(wait_time));
222        assert_eq!(timeout.call(()).await, Err(TimeoutError::Timeout));
223    }
224
225    #[actori_rt::test]
226    async fn test_timeout_newservice() {
227        let resolution = Duration::from_millis(100);
228        let wait_time = Duration::from_millis(500);
229
230        let timeout = apply(
231            Timeout::new(resolution),
232            fn_factory(|| ok::<_, ()>(SleepService(wait_time))),
233        );
234        let mut srv = timeout.new_service(&()).await.unwrap();
235
236        assert_eq!(srv.call(()).await, Err(TimeoutError::Timeout));
237    }
238}