1use std::{
2 task::{Context, Poll},
3 time::Duration,
4};
5
6use tower::{BoxError, Service};
7
8use super::{future::ResponseFuture, jittered_duration};
9
10#[derive(Debug, Clone)]
12pub struct Delay<S> {
13 inner: S,
14 delay: Duration,
15}
16
17#[derive(Clone, Debug)]
22pub struct DelayWith<S, P> {
23 inner: Delay<S>,
24 predicate: P,
25}
26
27#[derive(Clone, Debug)]
32pub struct JitterDelay<S> {
33 inner: S,
34 base: Duration,
35 pct: f64,
36}
37
38#[derive(Clone, Debug)]
43pub struct JitterDelayWith<S, P> {
44 inner: JitterDelay<S>,
45 predicate: P,
46}
47
48impl<S> Delay<S> {
51 #[inline]
53 pub fn new(inner: S, delay: Duration) -> Self {
54 Delay { inner, delay }
55 }
56}
57
58impl<S, Request> Service<Request> for Delay<S>
59where
60 S: Service<Request> + Clone,
61 S::Error: Into<BoxError>,
62{
63 type Response = S::Response;
64 type Error = BoxError;
65 type Future = ResponseFuture<S, Request>;
66
67 #[inline]
68 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
69 self.inner.poll_ready(cx).map_err(Into::into)
70 }
71
72 fn call(&mut self, req: Request) -> Self::Future {
73 let sleep = tokio::time::sleep(self.delay);
74 ResponseFuture::new(self.inner.clone(), req, sleep)
75 }
76}
77
78impl<S, P> DelayWith<S, P> {
81 #[inline]
83 pub fn new(inner: S, delay: Duration, predicate: P) -> Self {
84 Self {
85 inner: Delay::new(inner, delay),
86 predicate,
87 }
88 }
89}
90
91impl<S, Req, P> Service<Req> for DelayWith<S, P>
92where
93 S: Service<Req> + Clone,
94 S::Error: Into<BoxError>,
95 P: Fn(&Req) -> bool,
96{
97 type Response = S::Response;
98 type Error = BoxError;
99 type Future = ResponseFuture<S, Req>;
100
101 #[inline]
102 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
103 self.inner.poll_ready(cx).map_err(Into::into)
104 }
105
106 fn call(&mut self, req: Req) -> Self::Future {
107 let delay = if (self.predicate)(&req) {
108 self.inner.delay
109 } else {
110 Duration::ZERO
111 };
112 ResponseFuture::new(self.inner.inner.clone(), req, tokio::time::sleep(delay))
113 }
114}
115
116impl<S> JitterDelay<S> {
119 #[inline]
121 pub fn new(inner: S, base: Duration, pct: f64) -> Self {
122 Self {
123 inner,
124 base,
125 pct: pct.clamp(0.0, 1.0),
126 }
127 }
128}
129
130impl<S, Req> Service<Req> for JitterDelay<S>
131where
132 S: Service<Req> + Clone,
133 S::Error: Into<BoxError>,
134{
135 type Response = S::Response;
136 type Error = BoxError;
137 type Future = ResponseFuture<S, Req>;
138
139 #[inline]
140 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
141 self.inner.poll_ready(cx).map_err(Into::into)
142 }
143
144 fn call(&mut self, req: Req) -> Self::Future {
145 let delay = jittered_duration(self.base, self.pct);
146 let sleep = tokio::time::sleep(delay);
147 ResponseFuture::new(self.inner.clone(), req, sleep)
148 }
149}
150
151impl<S, P> JitterDelayWith<S, P> {
154 #[inline]
156 pub fn new(inner: S, base: Duration, pct: f64, predicate: P) -> Self {
157 Self {
158 inner: JitterDelay::new(inner, base, pct),
159 predicate,
160 }
161 }
162}
163
164impl<S, Req, P> Service<Req> for JitterDelayWith<S, P>
165where
166 S: Service<Req> + Clone,
167 S::Error: Into<BoxError>,
168 P: Fn(&Req) -> bool,
169{
170 type Response = S::Response;
171 type Error = BoxError;
172 type Future = ResponseFuture<S, Req>;
173
174 #[inline]
175 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
176 self.inner.poll_ready(cx).map_err(Into::into)
177 }
178
179 fn call(&mut self, req: Req) -> Self::Future {
180 let delay = if (self.predicate)(&req) {
181 jittered_duration(self.inner.base, self.inner.pct)
182 } else {
183 Duration::ZERO
184 };
185
186 ResponseFuture::new(self.inner.inner.clone(), req, tokio::time::sleep(delay))
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use std::{
193 convert::Infallible,
194 sync::{
195 Arc,
196 atomic::{AtomicUsize, Ordering},
197 },
198 task::{Context, Poll},
199 time::Duration,
200 };
201
202 use tower::Service;
203
204 use super::Delay;
205
206 #[derive(Clone)]
207 struct SideEffectService {
208 calls: Arc<AtomicUsize>,
209 }
210
211 impl Service<()> for SideEffectService {
212 type Response = ();
213 type Error = Infallible;
214 type Future = std::future::Ready<Result<(), Infallible>>;
215
216 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
217 Poll::Ready(Ok(()))
218 }
219
220 fn call(&mut self, _req: ()) -> Self::Future {
221 self.calls.fetch_add(1, Ordering::SeqCst);
222 std::future::ready(Ok(()))
223 }
224 }
225
226 #[tokio::test]
227 async fn test_delay_invokes_inner_service_after_sleep() {
228 let calls = Arc::new(AtomicUsize::new(0));
229 let inner = SideEffectService {
230 calls: Arc::clone(&calls),
231 };
232 let mut delayed = Delay::new(inner, Duration::from_millis(25));
233 let started = tokio::time::Instant::now();
234
235 let fut = delayed.call(());
236 tokio::pin!(fut);
237 assert_eq!(calls.load(Ordering::SeqCst), 0);
238
239 assert!(matches!(futures_util::poll!(fut.as_mut()), Poll::Pending));
241 assert_eq!(calls.load(Ordering::SeqCst), 0);
242
243 let _ = fut.await.unwrap();
244 assert_eq!(calls.load(Ordering::SeqCst), 1);
245 assert!(started.elapsed() >= Duration::from_millis(25));
246 }
247}