fluvio_future/
retry.rs

1use async_trait::async_trait;
2use futures_lite::FutureExt as lite_ext;
3use futures_util::FutureExt;
4use std::error::Error;
5use std::fmt::{Debug, Display, Formatter};
6use std::future::Future;
7use std::time::Duration;
8use tracing::warn;
9
10use crate::timer::sleep;
11pub use delay::ExponentialBackoff;
12pub use delay::FibonacciBackoff;
13pub use delay::FixedDelay;
14
15/// An extension trait for `Future` that provides a convenient methods for retries.
16#[async_trait]
17pub trait RetryExt: Future {
18    /// Transforms the current `Future` to a new one that is time-limited by the given timeout.
19    /// Returns [TimeoutError] if timeout exceeds. Otherwise, it returns the original Future’s result.
20    /// Example:
21    /// ```
22    /// use fluvio_future::timer::sleep;
23    /// use fluvio_future::retry::RetryExt;
24    /// use fluvio_future::retry::TimeoutError;
25    /// use std::time::{Duration, Instant};
26    ///
27    /// fluvio_future::task::run(async {
28    ///     let result = sleep(Duration::from_secs(10)).timeout(Duration::from_secs(1)).await;
29    ///     assert_eq!(result, Err(TimeoutError));
30    /// });
31    /// ```
32    async fn timeout(self, timeout: Duration) -> Result<Self::Output, TimeoutError>;
33}
34
35#[async_trait]
36impl<F: Future + Send> RetryExt for F {
37    async fn timeout(self, timeout: Duration) -> Result<Self::Output, TimeoutError> {
38        self.map(Ok)
39            .or(async move {
40                let _ = sleep(timeout).await;
41                Err(TimeoutError)
42            })
43            .await
44    }
45}
46
47#[derive(Debug, Clone, Eq, PartialEq)]
48pub struct TimeoutError;
49
50impl Display for TimeoutError {
51    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
52        write!(f, "{self:?}")
53    }
54}
55
56impl Error for TimeoutError {}
57
58/// Awaits on `Future`. Of Ok(_) or retry condition passes, returns from the function, otherwise returns
59/// error to the outer scope.
60macro_rules! poll_err {
61    ($function:ident, $condition:ident) => {{
62        match $function().await {
63            Ok(output) => return Ok(output),
64            Err(err) if !$condition(&err) => return Err(err),
65            Err(err) => err,
66        }
67    }};
68}
69
70/// Provides `Future` with specified retries strategy. See [retry_if] for details.
71pub fn retry<I, O, F, E, A>(retries: I, factory: A) -> impl Future<Output = Result<O, E>>
72where
73    I: IntoIterator<Item = Duration>,
74    A: FnMut() -> F,
75    F: Future<Output = Result<O, E>>,
76    E: Debug,
77{
78    retry_if(retries, factory, |_| true)
79}
80
81/// Provides retry functionality in async context. The `Future` that you want to retry needs to be
82/// represented in `FnMut() -> Future` structure. Each retry creates a new instance of `Future` and
83/// awaits it. Iterator `Iterator<Item=Duration>` controls the number of retries and delays between
84/// them. If iterator returns None, retries stop. There are three common implementations of retry
85/// strategies: [FixedDelay], [FibonacciBackoff], and [ExponentialBackoff].
86///
87/// Example:
88/// ```
89/// use std::io::{Error, ErrorKind};
90/// use std::ops::AddAssign;
91/// use std::time::Duration;
92/// use fluvio_future::retry::FixedDelay;
93/// use fluvio_future::retry::retry;
94///
95/// fluvio_future::task::run(async {
96///     let mut attempts = 0u8;
97///     let result = retry(FixedDelay::from_millis(100).take(2), || {
98///         attempts.add_assign(1);
99///         operation()
100///     }).await;
101///     assert!(matches!(result, Err(err) if err.kind() == ErrorKind::NotFound));
102///     assert_eq!(attempts, 3); // first attempt + 2 retries
103/// });
104///
105/// async fn operation() -> Result<(), Error> {
106///     Err(Error::from(ErrorKind::NotFound))
107/// }
108/// ```
109pub async fn retry_if<I, O, F, E, A, P>(retries: I, mut factory: A, condition: P) -> Result<O, E>
110where
111    I: IntoIterator<Item = Duration>,
112    A: FnMut() -> F,
113    F: Future<Output = Result<O, E>>,
114    P: Fn(&E) -> bool,
115    E: Debug,
116{
117    let mut err = poll_err!(factory, condition);
118    for delay_duration in retries.into_iter() {
119        cfg_if::cfg_if! {
120            if #[cfg(target_arch = "wasm32")] {
121                sleep(delay_duration).await.unwrap();
122            } else {
123                sleep(delay_duration).await;
124            }
125        }
126        warn!(?err, "retrying");
127        err = poll_err!(factory, condition);
128    }
129    Err(err)
130}
131
132mod delay {
133    use std::time::Duration;
134
135    /// A retry strategy driven by a fixed interval between retries.
136    /// ```
137    /// use std::io::Error;
138    /// use fluvio_future::retry::{FixedDelay, retry};
139    /// use std::time::{Duration, Instant};
140    ///
141    /// fluvio_future::task::run(async {
142    ///     let _: Result<(), Error> = retry(FixedDelay::from_millis(100).take(4), || async {Ok(())}).await; // 4 retries
143    /// });
144    /// ```
145    #[derive(Default, Clone, Debug, Eq, PartialEq)]
146    pub struct FixedDelay {
147        delay: Duration,
148    }
149
150    impl FixedDelay {
151        pub fn new(delay: Duration) -> Self {
152            Self { delay }
153        }
154
155        pub fn from_millis(millis: u64) -> Self {
156            Self::new(Duration::from_millis(millis))
157        }
158
159        pub fn from_secs(secs: u64) -> Self {
160            Self::new(Duration::from_secs(secs))
161        }
162    }
163
164    impl Iterator for FixedDelay {
165        type Item = Duration;
166
167        fn next(&mut self) -> Option<Duration> {
168            Some(self.delay)
169        }
170    }
171
172    /// A retry strategy driven by the fibonacci series of intervals between retries.
173    /// ```
174    /// use std::io::Error;
175    /// use fluvio_future::retry::{FibonacciBackoff, retry};
176    /// use std::time::{Duration, Instant};
177    ///
178    /// fluvio_future::task::run(async {
179    ///     let _: Result<(), Error> = retry(FibonacciBackoff::from_millis(100).take(4), || async {Ok(())}).await; // 4 retries
180    /// });
181    /// ```
182    #[derive(Default, Clone, Debug, Eq, PartialEq)]
183    pub struct FibonacciBackoff {
184        current: Duration,
185        next: Duration,
186        max_delay: Option<Duration>,
187    }
188
189    impl FibonacciBackoff {
190        pub fn new(initial_delay: Duration) -> Self {
191            Self {
192                current: initial_delay,
193                next: initial_delay,
194                max_delay: None,
195            }
196        }
197
198        pub fn from_millis(millis: u64) -> Self {
199            Self::new(Duration::from_millis(millis))
200        }
201
202        pub fn from_secs(secs: u64) -> Self {
203            Self::new(Duration::from_secs(secs))
204        }
205
206        pub fn max_delay(mut self, max_delay: Duration) -> Self {
207            self.max_delay = Some(max_delay);
208            self
209        }
210    }
211
212    impl Iterator for FibonacciBackoff {
213        type Item = Duration;
214
215        fn next(&mut self) -> Option<Self::Item> {
216            let duration = self.current;
217            if let Some(ref max_delay) = self.max_delay {
218                if duration > *max_delay {
219                    return Some(*max_delay);
220                }
221            };
222            if let Some(next_next) = self.current.checked_add(self.next) {
223                self.current = self.next;
224                self.next = next_next;
225            } else {
226                self.current = self.next;
227                self.next = Duration::MAX;
228            }
229            Some(duration)
230        }
231    }
232
233    /// A retry strategy driven by exponential back-off.
234    /// ```
235    /// use std::io::Error;
236    /// use fluvio_future::retry::{ExponentialBackoff, retry};
237    /// use std::time::{Duration, Instant};
238    ///
239    /// fluvio_future::task::run(async {
240    ///     let _: Result<(), Error> = retry(ExponentialBackoff::from_millis(100).take(4), || async {Ok(())}).await; // 4 retries
241    /// });
242    /// ```
243    #[derive(Default, Clone, Debug, Eq, PartialEq)]
244    pub struct ExponentialBackoff {
245        base_millis: u64,
246        current_millis: u64,
247        max_delay: Option<Duration>,
248    }
249
250    impl ExponentialBackoff {
251        pub fn from_millis(millis: u64) -> Self {
252            Self {
253                base_millis: millis,
254                current_millis: millis,
255                max_delay: None,
256            }
257        }
258
259        pub fn max_delay(mut self, max_delay: Duration) -> Self {
260            self.max_delay = Some(max_delay);
261            self
262        }
263    }
264
265    impl Iterator for ExponentialBackoff {
266        type Item = Duration;
267
268        fn next(&mut self) -> Option<Self::Item> {
269            let duration = Duration::from_millis(self.current_millis);
270            if let Some(ref max_delay) = self.max_delay {
271                if duration > *max_delay {
272                    return Some(*max_delay);
273                }
274            };
275            if let Some(next) = self.current_millis.checked_mul(self.base_millis) {
276                self.current_millis = next;
277            } else {
278                self.current_millis = u64::MAX;
279            }
280            Some(duration)
281        }
282    }
283
284    #[cfg(test)]
285    mod test {
286        use super::*;
287
288        #[test]
289        fn test_fibonacci_series_starting_at_10() {
290            let mut iter = FibonacciBackoff::from_millis(10);
291            assert_eq!(iter.next(), Some(Duration::from_millis(10)));
292            assert_eq!(iter.next(), Some(Duration::from_millis(10)));
293            assert_eq!(iter.next(), Some(Duration::from_millis(20)));
294            assert_eq!(iter.next(), Some(Duration::from_millis(30)));
295            assert_eq!(iter.next(), Some(Duration::from_millis(50)));
296            assert_eq!(iter.next(), Some(Duration::from_millis(80)));
297        }
298
299        #[test]
300        fn test_fibonacci_saturates_at_maximum_value() {
301            let mut iter = FibonacciBackoff::from_millis(u64::MAX);
302            assert_eq!(iter.next(), Some(Duration::from_millis(u64::MAX)));
303            assert_eq!(iter.next(), Some(Duration::from_millis(u64::MAX)));
304        }
305
306        #[test]
307        fn test_fibonacci_stops_increasing_at_max_delay() {
308            let mut iter = FibonacciBackoff::from_millis(10).max_delay(Duration::from_millis(50));
309            assert_eq!(iter.next(), Some(Duration::from_millis(10)));
310            assert_eq!(iter.next(), Some(Duration::from_millis(10)));
311            assert_eq!(iter.next(), Some(Duration::from_millis(20)));
312            assert_eq!(iter.next(), Some(Duration::from_millis(30)));
313            assert_eq!(iter.next(), Some(Duration::from_millis(50)));
314            assert_eq!(iter.next(), Some(Duration::from_millis(50)));
315        }
316
317        #[test]
318        fn test_fibonacci_returns_max_when_max_less_than_base() {
319            let mut iter = FibonacciBackoff::from_secs(20).max_delay(Duration::from_secs(10));
320
321            assert_eq!(iter.next(), Some(Duration::from_secs(10)));
322            assert_eq!(iter.next(), Some(Duration::from_secs(10)));
323        }
324
325        #[test]
326        fn test_exponential_some_exponential_base_10() {
327            let mut s = ExponentialBackoff::from_millis(10);
328
329            assert_eq!(s.next(), Some(Duration::from_millis(10)));
330            assert_eq!(s.next(), Some(Duration::from_millis(100)));
331            assert_eq!(s.next(), Some(Duration::from_millis(1000)));
332        }
333
334        #[test]
335        fn test_exponential_some_exponential_base_2() {
336            let mut s = ExponentialBackoff::from_millis(2);
337
338            assert_eq!(s.next(), Some(Duration::from_millis(2)));
339            assert_eq!(s.next(), Some(Duration::from_millis(4)));
340            assert_eq!(s.next(), Some(Duration::from_millis(8)));
341        }
342
343        #[test]
344        fn test_exponential_saturates_at_maximum_value() {
345            let mut s = ExponentialBackoff::from_millis(u64::MAX - 1);
346
347            assert_eq!(s.next(), Some(Duration::from_millis(u64::MAX - 1)));
348            assert_eq!(s.next(), Some(Duration::from_millis(u64::MAX)));
349            assert_eq!(s.next(), Some(Duration::from_millis(u64::MAX)));
350        }
351
352        #[test]
353        fn test_exponential_stops_increasing_at_max_delay() {
354            let mut s = ExponentialBackoff::from_millis(2).max_delay(Duration::from_millis(4));
355
356            assert_eq!(s.next(), Some(Duration::from_millis(2)));
357            assert_eq!(s.next(), Some(Duration::from_millis(4)));
358            assert_eq!(s.next(), Some(Duration::from_millis(4)));
359        }
360
361        #[test]
362        fn test_exponential_max_when_max_less_than_base() {
363            let mut s = ExponentialBackoff::from_millis(20).max_delay(Duration::from_millis(10));
364
365            assert_eq!(s.next(), Some(Duration::from_millis(10)));
366            assert_eq!(s.next(), Some(Duration::from_millis(10)));
367        }
368    }
369}
370
371#[cfg(test)]
372mod test {
373    use super::*;
374    use std::io::ErrorKind;
375    use std::ops::AddAssign;
376    use std::time::Duration;
377    use tracing::debug;
378
379    #[fluvio_future::test]
380    async fn test_fixed_retries_no_delay() {
381        let mut executed_retries = 0u8;
382        let operation = || {
383            let i = executed_retries;
384            executed_retries.add_assign(1);
385            async move {
386                debug!("called retry#{}", i);
387
388                Result::<usize, std::io::Error>::Err(std::io::Error::from(ErrorKind::NotFound))
389            }
390        };
391        let retry_result = retry(FixedDelay::default().take(2), operation).await;
392        assert!(matches!(retry_result, Err(err) if err.kind() == ErrorKind::NotFound));
393        assert_eq!(executed_retries, 3);
394    }
395
396    #[fluvio_future::test]
397    async fn test_fixed_retries_timeout() {
398        let mut executed_retries = 0u8;
399        let operation = || {
400            let i = executed_retries;
401            executed_retries.add_assign(1);
402            async move {
403                debug!("called retry#{}", i);
404                Result::<usize, std::io::Error>::Err(std::io::Error::from(ErrorKind::NotFound))
405            }
406        };
407        let retry_result = retry(FixedDelay::from_millis(100).take(10), operation)
408            .timeout(Duration::from_millis(300))
409            .await;
410
411        assert!(retry_result.is_err());
412        assert!(executed_retries < 10);
413    }
414
415    #[fluvio_future::test]
416    async fn test_fixed_retries_not_retryable() {
417        let mut executed_retries = 0u8;
418        let operation = || {
419            let i = executed_retries;
420            executed_retries.add_assign(1);
421            async move {
422                debug!("called retry#{}", i);
423                Result::<usize, std::io::Error>::Err(std::io::Error::from(ErrorKind::NotFound))
424            }
425        };
426        let retry_result =
427            retry_if(FixedDelay::from_millis(100).take(10), operation, |_| false).await;
428
429        assert!(matches!(retry_result, Err(err) if err.kind() == ErrorKind::NotFound));
430        assert_eq!(executed_retries, 1);
431    }
432
433    #[fluvio_future::test]
434    async fn test_conditional_retry() {
435        let mut executed_retries = 0u8;
436        let operation = || {
437            executed_retries.add_assign(1);
438            let i = executed_retries;
439            async move {
440                debug!("called retry#{}", i);
441                if i < 2 {
442                    Result::<usize, std::io::Error>::Err(std::io::Error::from(ErrorKind::NotFound))
443                } else {
444                    Result::<usize, std::io::Error>::Err(std::io::Error::from(
445                        ErrorKind::AddrNotAvailable,
446                    ))
447                }
448            }
449        };
450        let condition = |err: &std::io::Error| err.kind() == ErrorKind::NotFound;
451        let retry_result = retry_if(FixedDelay::default().take(10), operation, condition).await;
452
453        assert!(matches!(retry_result, Err(err) if err.kind() == ErrorKind::AddrNotAvailable));
454        assert_eq!(executed_retries, 2);
455    }
456}