futures_retry_policies/
tokio.rs

1#![cfg(feature = "tokio")]
2#![cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
3//! Retry features for the [tokio runtime](https://tokio.rs)
4
5use std::{future::Future, time::Duration};
6use tokio::time::{sleep, Sleep};
7
8use crate::RetryPolicy;
9
10/// Retry a future using the given [retry policy](`RetryPolicy`) and [tokio's sleep](`sleep`) method.
11///
12/// ```
13/// use futures_retry_policies::{tokio::retry, RetryPolicy};
14/// use std::{ops::ControlFlow, time::Duration};
15///
16/// pub struct Retries(usize);
17/// impl RetryPolicy<Result<(), &'static str>> for Retries {
18///     fn should_retry(&mut self, result: Result<(), &'static str>) -> ControlFlow<Result<(), &'static str>, Duration> {
19///         if self.0 > 0 && result.is_err() {
20///             self.0 -= 1;
21///             // continue to retry on error
22///             ControlFlow::Continue(Duration::from_millis(100))
23///         } else {
24///             // We've got a success, or we've exhausted our retries, so break
25///             ControlFlow::Break(result)
26///         }
27///     }
28/// }
29///
30/// async fn make_request() -> Result<(), &'static str>  {
31///     // make a request
32///     # static COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
33///     # if COUNT.fetch_add(1, std::sync::atomic::Ordering::SeqCst) < 2 { Err("fail") } else { Ok(()) }
34/// }
35///
36/// #[tokio::main]
37/// async fn main() -> Result<(), &'static str> {
38///     retry(Retries(3), make_request).await
39/// }
40/// ```
41pub fn retry<Policy, Futures, Fut>(
42    backoff: Policy,
43    futures: Futures,
44) -> RetryFuture<Policy, Futures, Fut>
45where
46    Policy: RetryPolicy<Fut::Output>,
47    Futures: FnMut() -> Fut,
48    Fut: Future,
49{
50    super::retry(backoff, sleep, futures)
51}
52
53pub type RetryFuture<Policy, Futures, Fut> =
54    crate::RetryFuture<Policy, fn(Duration) -> Sleep, Sleep, Futures, Fut>;
55
56/// Easy helper trait to retry futures
57///
58/// ```
59/// use futures_retry_policies::{tokio::RetryFutureExt, RetryPolicy};
60/// use std::{ops::ControlFlow, time::Duration};
61///
62/// pub struct Attempts(usize);
63/// impl RetryPolicy<Result<(), &'static str>> for Attempts {
64///     fn should_retry(&mut self, result: Result<(), &'static str>) -> ControlFlow<Result<(), &'static str>, Duration> {
65///         self.0 -= 1;
66///         if self.0 > 0 && result.is_err() {
67///             // continue to retry on error
68///             ControlFlow::Continue(Duration::from_millis(100))
69///         } else {
70///             // We've got a success, or we've exhausted our retries, so break
71///             ControlFlow::Break(result)
72///         }
73///     }
74/// }
75///
76/// async fn make_request() -> Result<(), &'static str>  {
77///     // make a request
78///     # static COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
79///     # if COUNT.fetch_add(1, std::sync::atomic::Ordering::SeqCst) < 2 { Err("fail") } else { Ok(()) }
80/// }
81///
82/// #[tokio::main]
83/// async fn main() -> Result<(), &'static str> {
84///     make_request.retry(Attempts(3)).await
85/// }
86/// ```
87pub trait RetryFutureExt<Fut>
88where
89    Fut: Future,
90{
91    fn retry<Policy>(self, policy: Policy) -> RetryFuture<Policy, Self, Fut>
92    where
93        Policy: RetryPolicy<Fut::Output>,
94        Self: FnMut() -> Fut + Sized,
95    {
96        retry(policy, self)
97    }
98}
99
100impl<Futures, Fut> RetryFutureExt<Fut> for Futures
101where
102    Futures: FnMut() -> Fut,
103    Fut: Future,
104{
105}
106
107#[cfg(test)]
108mod tests {
109    use std::time::Duration;
110
111    use futures_retry_policies_core::RetryPolicy;
112    use retry_policies::policies::ExponentialBackoff;
113
114    use crate::{retry_policies::RetryPolicies, tokio::RetryFutureExt, ShouldRetry};
115
116    struct Error(u32);
117    impl ShouldRetry for Error {
118        fn should_retry(&self, attempts: u32) -> bool {
119            attempts <= self.0
120        }
121    }
122
123    /// exponential backoff with no jitter, max 3 attempts.
124    /// 1s * 2^0, 2^1, 2^2
125    fn policy<R: ShouldRetry>() -> impl RetryPolicy<R> {
126        let backoff = ExponentialBackoff::builder()
127            .retry_bounds(Duration::from_secs(1), Duration::from_secs(60))
128            .jitter(retry_policies::Jitter::None)
129            .build_with_max_retries(3);
130        RetryPolicies::new(backoff)
131    }
132
133    #[tokio::test(start_paused = true)]
134    async fn retry_full() {
135        async fn req() -> Result<(), Error> {
136            // always retry
137            Err(Error(u32::MAX))
138        }
139
140        let start = tokio::time::Instant::now();
141        req.retry(policy()).await.unwrap_err();
142        assert_eq!(start.elapsed(), Duration::from_secs(1 + 2 + 4));
143    }
144
145    #[tokio::test(start_paused = true)]
146    async fn retry_none() {
147        async fn req() -> Result<(), Error> {
148            // never retry
149            Err(Error(0))
150        }
151
152        let start = tokio::time::Instant::now();
153        req.retry(policy()).await.unwrap_err();
154        assert_eq!(start.elapsed(), Duration::ZERO);
155    }
156
157    #[tokio::test(start_paused = true)]
158    async fn retry_twice() {
159        async fn req() -> Result<(), Error> {
160            // retry twice
161            Err(Error(2))
162        }
163
164        let start = tokio::time::Instant::now();
165        req.retry(policy()).await.unwrap_err();
166        assert_eq!(start.elapsed(), Duration::from_secs(1 + 2));
167    }
168}