backoff_tower/
backoff_layer.rs

1use pin_project_lite::pin_project;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{ready, Context, Poll};
5use std::time::Duration;
6use tower::retry::future::ResponseFuture;
7use tower::retry::{Policy, Retry, RetryLayer};
8use tower::{Layer, Service};
9
10/// A layer that creates a service that will attempt & reattempt to perform a service call based on a policy.
11///
12/// Each subsequent call will have a backoff period as defined by the passed strategy
13pub struct BackoffLayer<P, B> {
14    retry: RetryLayer<BackoffPolicy<P>>,
15    backoff: B,
16}
17
18impl<P, B> BackoffLayer<P, B> {
19    pub fn new(policy: P, backoff_strategy: B) -> Self {
20        BackoffLayer {
21            retry: RetryLayer::new(BackoffPolicy::new(policy)),
22            backoff: backoff_strategy,
23        }
24    }
25}
26
27impl<S, P, B> Layer<S> for BackoffLayer<P, B>
28where
29    P: Clone,
30    B: Clone,
31{
32    type Service = BackoffService<P, S, B>;
33
34    fn layer(&self, inner: S) -> Self::Service {
35        BackoffService::new_from_retry(self.retry.layer(BackoffInnerService {
36            inner,
37            backoff: self.backoff.clone(),
38        }))
39    }
40}
41
42/// A service for the retrying of a call with back offs.
43///
44/// This service adds the backoff wrapper to the request
45/// so that the inner service can choose an appropriate
46/// backoff period before reattempting its service call
47#[derive(Clone)]
48pub struct BackoffService<P, S, B> {
49    backoff_retry: Retry<BackoffPolicy<P>, BackoffInnerService<S, B>>,
50}
51
52impl<P, S, B> BackoffService<P, S, B> {
53    pub fn new(policy: P, inner: S, backoff: B) -> Self {
54        BackoffService::new_from_retry(Retry::new(
55            BackoffPolicy::new(policy),
56            BackoffInnerService::new(inner, backoff),
57        ))
58    }
59
60    fn new_from_retry(retry: Retry<BackoffPolicy<P>, BackoffInnerService<S, B>>) -> Self {
61        BackoffService {
62            backoff_retry: retry,
63        }
64    }
65}
66
67impl<P, S, B, Req> Service<Req> for BackoffService<P, S, B>
68where
69    P: Policy<Req, S::Response, S::Error> + Clone,
70    B: BackoffStrategy,
71    S: Service<Req> + Clone,
72{
73    type Response = S::Response;
74    type Error = S::Error;
75    type Future = ResponseFuture<BackoffPolicy<P>, BackoffInnerService<S, B>, BackoffRequest<Req>>;
76
77    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78        self.backoff_retry.poll_ready(cx)
79    }
80
81    fn call(&mut self, req: Req) -> Self::Future {
82        self.backoff_retry.call(BackoffRequest::new(req))
83    }
84}
85
86/// The inner service which performs the backed off request
87///
88/// Unwraps the request from the backoff wrapper & applies
89/// a backoff period to the future as necessary
90#[derive(Debug, Clone)]
91pub struct BackoffInnerService<S, B> {
92    inner: S,
93    backoff: B,
94}
95
96impl<S, B> BackoffInnerService<S, B> {
97    fn new(inner: S, backoff: B) -> Self {
98        BackoffInnerService { inner, backoff }
99    }
100}
101
102impl<S, B, Req> Service<BackoffRequest<Req>> for BackoffInnerService<S, B>
103where
104    S: Service<Req>,
105    B: BackoffStrategy,
106{
107    type Response = S::Response;
108    type Error = S::Error;
109    type Future = BackoffFut<S::Future>;
110
111    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112        self.inner.poll_ready(cx)
113    }
114
115    fn call(&mut self, req: BackoffRequest<Req>) -> Self::Future {
116        let BackoffRequest { calls, req } = req;
117        let backoff = self.backoff.backoff_duration(calls);
118        let is_first_call = calls == 0;
119        BackoffFut::new(is_first_call, backoff, self.inner.call(req))
120    }
121}
122
123#[cfg(feature = "tokio")]
124pin_project! {
125    /// A future with a sleep before it can be polled
126    pub struct BackoffFut<F> {
127        slept: bool,
128        #[pin]
129        sleep: tokio::time::Sleep,
130        #[pin]
131        fut: F,
132    }
133}
134
135#[cfg(feature = "async_std")]
136pin_project! {
137    /// A future with a sleep before it can be polled
138    pub struct BackoffFut<F> {
139        slept: bool,
140        #[pin]
141        sleep: async_io::Timer,
142        #[pin]
143        fut: F,
144    }
145}
146
147impl<F> BackoffFut<F> {
148    fn new(slept: bool, duration: Duration, fut: F) -> Self {
149        BackoffFut {
150            slept,
151            #[cfg(feature = "tokio")]
152            sleep: tokio::time::sleep(duration),
153            #[cfg(feature = "async_std")]
154            sleep: async_io::Timer::after(duration),
155            fut,
156        }
157    }
158}
159
160impl<F> Future for BackoffFut<F>
161where
162    F: Future,
163{
164    type Output = F::Output;
165
166    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
167        let this = self.project();
168
169        if !*this.slept {
170            ready!(this.sleep.poll(cx));
171            *this.slept = true;
172        }
173
174        this.fut.poll(cx)
175    }
176}
177
178/// A policy which wraps a policy for the raw request type
179#[derive(Debug, Clone)]
180pub struct BackoffPolicy<P> {
181    inner: P,
182}
183
184impl<P> BackoffPolicy<P> {
185    fn new(policy: P) -> Self {
186        Self { inner: policy }
187    }
188}
189
190pin_project! {
191    pub struct IntoBackoffPolicyFut<F> {
192        #[pin]
193        fut: F
194    }
195}
196
197impl<F> IntoBackoffPolicyFut<F> {
198    fn new(fut: F) -> Self {
199        IntoBackoffPolicyFut { fut }
200    }
201}
202
203impl<F> Future for IntoBackoffPolicyFut<F>
204where
205    F: Future,
206{
207    type Output = BackoffPolicy<F::Output>;
208
209    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
210        let this = self.project();
211        let res = ready!(this.fut.poll(cx));
212        Poll::Ready(BackoffPolicy::new(res))
213    }
214}
215
216/// Policy for a backed off request defers to the policy for the raw request
217/// and updates the calls count upon clone
218impl<P, Req, Res, Err> Policy<BackoffRequest<Req>, Res, Err> for BackoffPolicy<P>
219where
220    P: Policy<Req, Res, Err>,
221{
222    type Future = IntoBackoffPolicyFut<P::Future>;
223
224    fn retry(&self, req: &BackoffRequest<Req>, result: Result<&Res, &Err>) -> Option<Self::Future> {
225        let BackoffRequest { req, .. } = req;
226        self.inner.retry(req, result).map(IntoBackoffPolicyFut::new)
227    }
228
229    fn clone_request(&self, req: &BackoffRequest<Req>) -> Option<BackoffRequest<Req>> {
230        let BackoffRequest { calls, req } = req;
231        self.inner
232            .clone_request(req)
233            .map(|req| BackoffRequest::new_with_calls(req, calls + 1))
234    }
235}
236
237/// Request wrapper to track the number of retries of the request
238pub struct BackoffRequest<R> {
239    // 4bn is hopefully enough 🤞
240    calls: u32,
241    req: R,
242}
243
244impl<R> BackoffRequest<R> {
245    fn new(req: R) -> Self {
246        BackoffRequest { calls: 0, req }
247    }
248
249    fn new_with_calls(req: R, calls: u32) -> Self {
250        BackoffRequest { calls, req }
251    }
252}
253
254/// A trait describing how long to backoff for each subsequent attempt
255pub trait BackoffStrategy: Clone {
256    fn backoff_duration(&self, repeats: u32) -> Duration;
257}
258
259pub mod backoff_strategies {
260    use crate::BackoffStrategy;
261    use std::time::Duration;
262
263    /// Performs backoffs in millisecond powers of 2
264    #[derive(Debug, Clone)]
265    pub struct ExponentialBackoffStrategy;
266
267    impl BackoffStrategy for ExponentialBackoffStrategy {
268        fn backoff_duration(&self, repeats: u32) -> Duration {
269            Duration::from_millis(1 << repeats)
270        }
271    }
272
273    /// Performs backoffs in fibonacci milliseconds
274    #[derive(Debug, Clone)]
275    pub struct FibonacciBackoffStrategy;
276
277    impl BackoffStrategy for FibonacciBackoffStrategy {
278        fn backoff_duration(&self, repeats: u32) -> Duration {
279            let mut a = 0;
280            let mut b = 1;
281            for _ in 0..repeats {
282                let c = a + b;
283                a = b;
284                b = c;
285            }
286            Duration::from_millis(a)
287        }
288    }
289
290    /// Performs backoffs in multiples of a duration
291    #[derive(Debug, Clone)]
292    pub struct LinearBackoffStrategy {
293        duration_multiple: Duration,
294    }
295
296    impl LinearBackoffStrategy {
297        pub fn new(duration_multiple: Duration) -> Self {
298            Self { duration_multiple }
299        }
300    }
301
302    impl BackoffStrategy for LinearBackoffStrategy {
303        fn backoff_duration(&self, repeats: u32) -> Duration {
304            self.duration_multiple * repeats
305        }
306    }
307
308    /// Backoff is a constant value
309    #[derive(Debug, Clone)]
310    pub struct ConstantBackoffStrategy {
311        duration: Duration,
312    }
313
314    impl ConstantBackoffStrategy {
315        pub fn new(duration: Duration) -> Self {
316            Self { duration }
317        }
318    }
319
320    impl BackoffStrategy for ConstantBackoffStrategy {
321        fn backoff_duration(&self, _repeats: u32) -> Duration {
322            self.duration
323        }
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use crate::backoff_layer::{BackoffInnerService, BackoffRequest};
330    use crate::backoff_strategies::ExponentialBackoffStrategy;
331    use crate::BackoffLayer;
332    use std::error::Error;
333    use std::future::{ready, Ready};
334    use tokio::select;
335    use tower::retry::Policy;
336    use tower::{Service, ServiceBuilder};
337
338    #[derive(Clone)]
339    struct MyPolicy {
340        attempts_left: usize,
341    }
342
343    impl Policy<usize, usize, &'static str> for MyPolicy {
344        type Future = Ready<Self>;
345
346        fn retry(
347            &self,
348            _req: &usize,
349            result: Result<&usize, &&'static str>,
350        ) -> Option<Self::Future> {
351            if self.attempts_left == 0 {
352                return None;
353            }
354
355            match result {
356                Ok(_) => None,
357                Err(_) => Some(ready(MyPolicy {
358                    attempts_left: self.attempts_left - 1,
359                })),
360            }
361        }
362
363        fn clone_request(&self, req: &usize) -> Option<usize> {
364            Some(req + 1)
365        }
366    }
367
368    #[tokio::test]
369    async fn retries_work() -> Result<(), Box<dyn Error>> {
370        let mut service = ServiceBuilder::new()
371            .layer(BackoffLayer::new(
372                MyPolicy { attempts_left: 4 },
373                ExponentialBackoffStrategy,
374            ))
375            .service_fn(|x: usize| async move {
376                if x % 10 == 0 {
377                    Ok(x / 10)
378                } else {
379                    Err("bad input")
380                }
381            });
382
383        assert_eq!(
384            Ok(6),
385            service.call(60).await,
386            "should be the next multiple of 10 divided by 10"
387        );
388        assert_eq!(
389            Ok(6),
390            service.call(59).await,
391            "should be the next multiple of 10 divided by 10"
392        );
393        assert_eq!(
394            Ok(6),
395            service.call(58).await,
396            "should be the next multiple of 10 divided by 10"
397        );
398        assert_eq!(
399            Ok(6),
400            service.call(57).await,
401            "should be the next multiple of 10 divided by 10"
402        );
403        assert_eq!(
404            Ok(6),
405            service.call(56).await,
406            "should be the next multiple of 10 divided by 10"
407        );
408        assert_eq!(
409            Err("bad input"),
410            service.call(55).await,
411            "should error as ran out of retries"
412        );
413
414        Ok(())
415    }
416
417    #[tokio::test]
418    async fn subsequent_retires_have_different_wait_periods() -> Result<(), Box<dyn Error>> {
419        let mut backoff_inner_svc = BackoffInnerService::new(
420            tower::service_fn(|x: usize| async move {
421                if x % 10 == 0 {
422                    Ok(x / 10)
423                } else {
424                    Err("bad input")
425                }
426            }),
427            ExponentialBackoffStrategy,
428        );
429
430        assert_eq!(6, backoff_inner_svc.call(BackoffRequest::new(60)).await?);
431
432        let a = backoff_inner_svc.call(BackoffRequest::new(60));
433        let b = backoff_inner_svc.call(BackoffRequest::new_with_calls(60, 1));
434        let c = backoff_inner_svc.call(BackoffRequest::new_with_calls(60, 2));
435
436        assert!(a.slept, "0 calls should have no backoff");
437        assert!(!b.slept, "1 or more calls should have backoffs");
438        assert!(!c.slept, "1 or more calls should have backoffs");
439
440        #[cfg(feature = "tokio")]
441        assert!(b.sleep.deadline() < c.sleep.deadline());
442
443        select! {
444            _ = b => {}
445            _ = c => {
446                panic!("call b should respond first due to a smaller backoff")
447            }
448        }
449
450        Ok(())
451    }
452}