Skip to main content

hpx_util/tower/delay/
service.rs

1use std::{
2    task::{Context, Poll},
3    time::Duration,
4};
5
6use tower::{BoxError, Service};
7
8use super::{future::ResponseFuture, jittered_duration};
9
10/// A Tower [`Service`] that introduces a fixed delay before each request.
11#[derive(Debug, Clone)]
12pub struct Delay<S> {
13    inner: S,
14    delay: Duration,
15}
16
17/// A Tower [`Service`] that conditionally applies fixed delay based on a predicate.
18///
19/// Requests that match the predicate will have the delay applied;
20/// other requests pass through immediately.
21#[derive(Clone, Debug)]
22pub struct DelayWith<S, P> {
23    inner: Delay<S>,
24    predicate: P,
25}
26
27/// A Tower [`Service`] that applies jittered delay to requests.
28///
29/// This service wraps an inner service and introduces a random delay
30/// (within a configured range) before each request.
31#[derive(Clone, Debug)]
32pub struct JitterDelay<S> {
33    inner: S,
34    base: Duration,
35    pct: f64,
36}
37
38/// A Tower [`Service`] that conditionally applies jittered delay based on a predicate.
39///
40/// Requests that match the predicate will have a jittered delay applied;
41/// other requests pass through immediately.
42#[derive(Clone, Debug)]
43pub struct JitterDelayWith<S, P> {
44    inner: JitterDelay<S>,
45    predicate: P,
46}
47
48// ===== impl Delay =====
49
50impl<S> Delay<S> {
51    /// Create a new [`Delay`] service wrapping the given inner service
52    #[inline]
53    pub fn new(inner: S, delay: Duration) -> Self {
54        Delay { inner, delay }
55    }
56}
57
58impl<S, Request> Service<Request> for Delay<S>
59where
60    S: Service<Request> + Clone,
61    S::Error: Into<BoxError>,
62{
63    type Response = S::Response;
64    type Error = BoxError;
65    type Future = ResponseFuture<S, Request>;
66
67    #[inline]
68    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
69        self.inner.poll_ready(cx).map_err(Into::into)
70    }
71
72    fn call(&mut self, req: Request) -> Self::Future {
73        let sleep = tokio::time::sleep(self.delay);
74        ResponseFuture::new(self.inner.clone(), req, sleep)
75    }
76}
77
78// ===== impl DelayWith =====
79
80impl<S, P> DelayWith<S, P> {
81    /// Creates a new [`DelayWith`].
82    #[inline]
83    pub fn new(inner: S, delay: Duration, predicate: P) -> Self {
84        Self {
85            inner: Delay::new(inner, delay),
86            predicate,
87        }
88    }
89}
90
91impl<S, Req, P> Service<Req> for DelayWith<S, P>
92where
93    S: Service<Req> + Clone,
94    S::Error: Into<BoxError>,
95    P: Fn(&Req) -> bool,
96{
97    type Response = S::Response;
98    type Error = BoxError;
99    type Future = ResponseFuture<S, Req>;
100
101    #[inline]
102    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
103        self.inner.poll_ready(cx).map_err(Into::into)
104    }
105
106    fn call(&mut self, req: Req) -> Self::Future {
107        let delay = if (self.predicate)(&req) {
108            self.inner.delay
109        } else {
110            Duration::ZERO
111        };
112        ResponseFuture::new(self.inner.inner.clone(), req, tokio::time::sleep(delay))
113    }
114}
115
116// ===== impl JitterDelay =====
117
118impl<S> JitterDelay<S> {
119    /// Creates a new [`JitterDelay`].
120    #[inline]
121    pub fn new(inner: S, base: Duration, pct: f64) -> Self {
122        Self {
123            inner,
124            base,
125            pct: pct.clamp(0.0, 1.0),
126        }
127    }
128}
129
130impl<S, Req> Service<Req> for JitterDelay<S>
131where
132    S: Service<Req> + Clone,
133    S::Error: Into<BoxError>,
134{
135    type Response = S::Response;
136    type Error = BoxError;
137    type Future = ResponseFuture<S, Req>;
138
139    #[inline]
140    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
141        self.inner.poll_ready(cx).map_err(Into::into)
142    }
143
144    fn call(&mut self, req: Req) -> Self::Future {
145        let delay = jittered_duration(self.base, self.pct);
146        let sleep = tokio::time::sleep(delay);
147        ResponseFuture::new(self.inner.clone(), req, sleep)
148    }
149}
150
151// ===== impl JitterDelayWith =====
152
153impl<S, P> JitterDelayWith<S, P> {
154    /// Creates a new [`JitterDelayWith`].
155    #[inline]
156    pub fn new(inner: S, base: Duration, pct: f64, predicate: P) -> Self {
157        Self {
158            inner: JitterDelay::new(inner, base, pct),
159            predicate,
160        }
161    }
162}
163
164impl<S, Req, P> Service<Req> for JitterDelayWith<S, P>
165where
166    S: Service<Req> + Clone,
167    S::Error: Into<BoxError>,
168    P: Fn(&Req) -> bool,
169{
170    type Response = S::Response;
171    type Error = BoxError;
172    type Future = ResponseFuture<S, Req>;
173
174    #[inline]
175    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
176        self.inner.poll_ready(cx).map_err(Into::into)
177    }
178
179    fn call(&mut self, req: Req) -> Self::Future {
180        let delay = if (self.predicate)(&req) {
181            jittered_duration(self.inner.base, self.inner.pct)
182        } else {
183            Duration::ZERO
184        };
185
186        ResponseFuture::new(self.inner.inner.clone(), req, tokio::time::sleep(delay))
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use std::{
193        convert::Infallible,
194        sync::{
195            Arc,
196            atomic::{AtomicUsize, Ordering},
197        },
198        task::{Context, Poll},
199        time::Duration,
200    };
201
202    use tower::Service;
203
204    use super::Delay;
205
206    #[derive(Clone)]
207    struct SideEffectService {
208        calls: Arc<AtomicUsize>,
209    }
210
211    impl Service<()> for SideEffectService {
212        type Response = ();
213        type Error = Infallible;
214        type Future = std::future::Ready<Result<(), Infallible>>;
215
216        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
217            Poll::Ready(Ok(()))
218        }
219
220        fn call(&mut self, _req: ()) -> Self::Future {
221            self.calls.fetch_add(1, Ordering::SeqCst);
222            std::future::ready(Ok(()))
223        }
224    }
225
226    #[tokio::test]
227    async fn test_delay_invokes_inner_service_after_sleep() {
228        let calls = Arc::new(AtomicUsize::new(0));
229        let inner = SideEffectService {
230            calls: Arc::clone(&calls),
231        };
232        let mut delayed = Delay::new(inner, Duration::from_millis(25));
233        let started = tokio::time::Instant::now();
234
235        let fut = delayed.call(());
236        tokio::pin!(fut);
237        assert_eq!(calls.load(Ordering::SeqCst), 0);
238
239        // Initial poll should not invoke the inner service yet.
240        assert!(matches!(futures_util::poll!(fut.as_mut()), Poll::Pending));
241        assert_eq!(calls.load(Ordering::SeqCst), 0);
242
243        let _ = fut.await.unwrap();
244        assert_eq!(calls.load(Ordering::SeqCst), 1);
245        assert!(started.elapsed() >= Duration::from_millis(25));
246    }
247}