1use pin_project_lite::pin_project;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{ready, Context, Poll};
5use std::time::Duration;
6use tower::retry::future::ResponseFuture;
7use tower::retry::{Policy, Retry, RetryLayer};
8use tower::{Layer, Service};
9
10pub struct BackoffLayer<P, B> {
14 retry: RetryLayer<BackoffPolicy<P>>,
15 backoff: B,
16}
17
18impl<P, B> BackoffLayer<P, B> {
19 pub fn new(policy: P, backoff_strategy: B) -> Self {
20 BackoffLayer {
21 retry: RetryLayer::new(BackoffPolicy::new(policy)),
22 backoff: backoff_strategy,
23 }
24 }
25}
26
27impl<S, P, B> Layer<S> for BackoffLayer<P, B>
28where
29 P: Clone,
30 B: Clone,
31{
32 type Service = BackoffService<P, S, B>;
33
34 fn layer(&self, inner: S) -> Self::Service {
35 BackoffService::new_from_retry(self.retry.layer(BackoffInnerService {
36 inner,
37 backoff: self.backoff.clone(),
38 }))
39 }
40}
41
42#[derive(Clone)]
48pub struct BackoffService<P, S, B> {
49 backoff_retry: Retry<BackoffPolicy<P>, BackoffInnerService<S, B>>,
50}
51
52impl<P, S, B> BackoffService<P, S, B> {
53 pub fn new(policy: P, inner: S, backoff: B) -> Self {
54 BackoffService::new_from_retry(Retry::new(
55 BackoffPolicy::new(policy),
56 BackoffInnerService::new(inner, backoff),
57 ))
58 }
59
60 fn new_from_retry(retry: Retry<BackoffPolicy<P>, BackoffInnerService<S, B>>) -> Self {
61 BackoffService {
62 backoff_retry: retry,
63 }
64 }
65}
66
67impl<P, S, B, Req> Service<Req> for BackoffService<P, S, B>
68where
69 P: Policy<Req, S::Response, S::Error> + Clone,
70 B: BackoffStrategy,
71 S: Service<Req> + Clone,
72{
73 type Response = S::Response;
74 type Error = S::Error;
75 type Future = ResponseFuture<BackoffPolicy<P>, BackoffInnerService<S, B>, BackoffRequest<Req>>;
76
77 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78 self.backoff_retry.poll_ready(cx)
79 }
80
81 fn call(&mut self, req: Req) -> Self::Future {
82 self.backoff_retry.call(BackoffRequest::new(req))
83 }
84}
85
86#[derive(Debug, Clone)]
91pub struct BackoffInnerService<S, B> {
92 inner: S,
93 backoff: B,
94}
95
96impl<S, B> BackoffInnerService<S, B> {
97 fn new(inner: S, backoff: B) -> Self {
98 BackoffInnerService { inner, backoff }
99 }
100}
101
102impl<S, B, Req> Service<BackoffRequest<Req>> for BackoffInnerService<S, B>
103where
104 S: Service<Req>,
105 B: BackoffStrategy,
106{
107 type Response = S::Response;
108 type Error = S::Error;
109 type Future = BackoffFut<S::Future>;
110
111 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112 self.inner.poll_ready(cx)
113 }
114
115 fn call(&mut self, req: BackoffRequest<Req>) -> Self::Future {
116 let BackoffRequest { calls, req } = req;
117 let backoff = self.backoff.backoff_duration(calls);
118 let is_first_call = calls == 0;
119 BackoffFut::new(is_first_call, backoff, self.inner.call(req))
120 }
121}
122
123#[cfg(feature = "tokio")]
124pin_project! {
125 pub struct BackoffFut<F> {
127 slept: bool,
128 #[pin]
129 sleep: tokio::time::Sleep,
130 #[pin]
131 fut: F,
132 }
133}
134
135#[cfg(feature = "async_std")]
136pin_project! {
137 pub struct BackoffFut<F> {
139 slept: bool,
140 #[pin]
141 sleep: async_io::Timer,
142 #[pin]
143 fut: F,
144 }
145}
146
147impl<F> BackoffFut<F> {
148 fn new(slept: bool, duration: Duration, fut: F) -> Self {
149 BackoffFut {
150 slept,
151 #[cfg(feature = "tokio")]
152 sleep: tokio::time::sleep(duration),
153 #[cfg(feature = "async_std")]
154 sleep: async_io::Timer::after(duration),
155 fut,
156 }
157 }
158}
159
160impl<F> Future for BackoffFut<F>
161where
162 F: Future,
163{
164 type Output = F::Output;
165
166 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
167 let this = self.project();
168
169 if !*this.slept {
170 ready!(this.sleep.poll(cx));
171 *this.slept = true;
172 }
173
174 this.fut.poll(cx)
175 }
176}
177
178#[derive(Debug, Clone)]
180pub struct BackoffPolicy<P> {
181 inner: P,
182}
183
184impl<P> BackoffPolicy<P> {
185 fn new(policy: P) -> Self {
186 Self { inner: policy }
187 }
188}
189
190pin_project! {
191 pub struct IntoBackoffPolicyFut<F> {
192 #[pin]
193 fut: F
194 }
195}
196
197impl<F> IntoBackoffPolicyFut<F> {
198 fn new(fut: F) -> Self {
199 IntoBackoffPolicyFut { fut }
200 }
201}
202
203impl<F> Future for IntoBackoffPolicyFut<F>
204where
205 F: Future,
206{
207 type Output = BackoffPolicy<F::Output>;
208
209 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
210 let this = self.project();
211 let res = ready!(this.fut.poll(cx));
212 Poll::Ready(BackoffPolicy::new(res))
213 }
214}
215
216impl<P, Req, Res, Err> Policy<BackoffRequest<Req>, Res, Err> for BackoffPolicy<P>
219where
220 P: Policy<Req, Res, Err>,
221{
222 type Future = IntoBackoffPolicyFut<P::Future>;
223
224 fn retry(&self, req: &BackoffRequest<Req>, result: Result<&Res, &Err>) -> Option<Self::Future> {
225 let BackoffRequest { req, .. } = req;
226 self.inner.retry(req, result).map(IntoBackoffPolicyFut::new)
227 }
228
229 fn clone_request(&self, req: &BackoffRequest<Req>) -> Option<BackoffRequest<Req>> {
230 let BackoffRequest { calls, req } = req;
231 self.inner
232 .clone_request(req)
233 .map(|req| BackoffRequest::new_with_calls(req, calls + 1))
234 }
235}
236
237pub struct BackoffRequest<R> {
239 calls: u32,
241 req: R,
242}
243
244impl<R> BackoffRequest<R> {
245 fn new(req: R) -> Self {
246 BackoffRequest { calls: 0, req }
247 }
248
249 fn new_with_calls(req: R, calls: u32) -> Self {
250 BackoffRequest { calls, req }
251 }
252}
253
254pub trait BackoffStrategy: Clone {
256 fn backoff_duration(&self, repeats: u32) -> Duration;
257}
258
259pub mod backoff_strategies {
260 use crate::BackoffStrategy;
261 use std::time::Duration;
262
263 #[derive(Debug, Clone)]
265 pub struct ExponentialBackoffStrategy;
266
267 impl BackoffStrategy for ExponentialBackoffStrategy {
268 fn backoff_duration(&self, repeats: u32) -> Duration {
269 Duration::from_millis(1 << repeats)
270 }
271 }
272
273 #[derive(Debug, Clone)]
275 pub struct FibonacciBackoffStrategy;
276
277 impl BackoffStrategy for FibonacciBackoffStrategy {
278 fn backoff_duration(&self, repeats: u32) -> Duration {
279 let mut a = 0;
280 let mut b = 1;
281 for _ in 0..repeats {
282 let c = a + b;
283 a = b;
284 b = c;
285 }
286 Duration::from_millis(a)
287 }
288 }
289
290 #[derive(Debug, Clone)]
292 pub struct LinearBackoffStrategy {
293 duration_multiple: Duration,
294 }
295
296 impl LinearBackoffStrategy {
297 pub fn new(duration_multiple: Duration) -> Self {
298 Self { duration_multiple }
299 }
300 }
301
302 impl BackoffStrategy for LinearBackoffStrategy {
303 fn backoff_duration(&self, repeats: u32) -> Duration {
304 self.duration_multiple * repeats
305 }
306 }
307
308 #[derive(Debug, Clone)]
310 pub struct ConstantBackoffStrategy {
311 duration: Duration,
312 }
313
314 impl ConstantBackoffStrategy {
315 pub fn new(duration: Duration) -> Self {
316 Self { duration }
317 }
318 }
319
320 impl BackoffStrategy for ConstantBackoffStrategy {
321 fn backoff_duration(&self, _repeats: u32) -> Duration {
322 self.duration
323 }
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use crate::backoff_layer::{BackoffInnerService, BackoffRequest};
330 use crate::backoff_strategies::ExponentialBackoffStrategy;
331 use crate::BackoffLayer;
332 use std::error::Error;
333 use std::future::{ready, Ready};
334 use tokio::select;
335 use tower::retry::Policy;
336 use tower::{Service, ServiceBuilder};
337
338 #[derive(Clone)]
339 struct MyPolicy {
340 attempts_left: usize,
341 }
342
343 impl Policy<usize, usize, &'static str> for MyPolicy {
344 type Future = Ready<Self>;
345
346 fn retry(
347 &self,
348 _req: &usize,
349 result: Result<&usize, &&'static str>,
350 ) -> Option<Self::Future> {
351 if self.attempts_left == 0 {
352 return None;
353 }
354
355 match result {
356 Ok(_) => None,
357 Err(_) => Some(ready(MyPolicy {
358 attempts_left: self.attempts_left - 1,
359 })),
360 }
361 }
362
363 fn clone_request(&self, req: &usize) -> Option<usize> {
364 Some(req + 1)
365 }
366 }
367
368 #[tokio::test]
369 async fn retries_work() -> Result<(), Box<dyn Error>> {
370 let mut service = ServiceBuilder::new()
371 .layer(BackoffLayer::new(
372 MyPolicy { attempts_left: 4 },
373 ExponentialBackoffStrategy,
374 ))
375 .service_fn(|x: usize| async move {
376 if x % 10 == 0 {
377 Ok(x / 10)
378 } else {
379 Err("bad input")
380 }
381 });
382
383 assert_eq!(
384 Ok(6),
385 service.call(60).await,
386 "should be the next multiple of 10 divided by 10"
387 );
388 assert_eq!(
389 Ok(6),
390 service.call(59).await,
391 "should be the next multiple of 10 divided by 10"
392 );
393 assert_eq!(
394 Ok(6),
395 service.call(58).await,
396 "should be the next multiple of 10 divided by 10"
397 );
398 assert_eq!(
399 Ok(6),
400 service.call(57).await,
401 "should be the next multiple of 10 divided by 10"
402 );
403 assert_eq!(
404 Ok(6),
405 service.call(56).await,
406 "should be the next multiple of 10 divided by 10"
407 );
408 assert_eq!(
409 Err("bad input"),
410 service.call(55).await,
411 "should error as ran out of retries"
412 );
413
414 Ok(())
415 }
416
417 #[tokio::test]
418 async fn subsequent_retires_have_different_wait_periods() -> Result<(), Box<dyn Error>> {
419 let mut backoff_inner_svc = BackoffInnerService::new(
420 tower::service_fn(|x: usize| async move {
421 if x % 10 == 0 {
422 Ok(x / 10)
423 } else {
424 Err("bad input")
425 }
426 }),
427 ExponentialBackoffStrategy,
428 );
429
430 assert_eq!(6, backoff_inner_svc.call(BackoffRequest::new(60)).await?);
431
432 let a = backoff_inner_svc.call(BackoffRequest::new(60));
433 let b = backoff_inner_svc.call(BackoffRequest::new_with_calls(60, 1));
434 let c = backoff_inner_svc.call(BackoffRequest::new_with_calls(60, 2));
435
436 assert!(a.slept, "0 calls should have no backoff");
437 assert!(!b.slept, "1 or more calls should have backoffs");
438 assert!(!c.slept, "1 or more calls should have backoffs");
439
440 #[cfg(feature = "tokio")]
441 assert!(b.sleep.deadline() < c.sleep.deadline());
442
443 select! {
444 _ = b => {}
445 _ = c => {
446 panic!("call b should respond first due to a smaller backoff")
447 }
448 }
449
450 Ok(())
451 }
452}