named_retry/
lib.rs

1//! Utilities for retrying falliable, asynchronous operations.
2
3use std::fmt::Debug;
4use std::time::Duration;
5
6use tokio::time;
7use tracing::warn;
8
9/// Calls a fallible async function multiple times, with a given timeout.
10///
11/// If a `base_delay` is provided, the function is given an exponentially
12/// increasing delay on each run, up until the maximum number of attempts.
13///
14/// Returns the first successful result if any, or the last error.
15#[derive(Copy, Clone, Debug)]
16pub struct Retry {
17    /// Name of the operation being retried.
18    pub name: &'static str,
19
20    /// The number of attempts to make.
21    pub attempts: u32,
22
23    /// The base delay after the first attempt, if provided.
24    pub base_delay: Duration,
25
26    /// Exponential factor to increase the delay by on each attempt.
27    pub delay_factor: f64,
28
29    /// If true, the delay will be selected randomly from the range [delay/2, delay).
30    pub enable_jitter: bool,
31}
32
33impl Retry {
34    /// Construct a new [`Retry`] object with default parameters.
35    pub const fn new(name: &'static str) -> Self {
36        Self {
37            name,
38            attempts: 3,
39            base_delay: Duration::ZERO,
40            delay_factor: 1.0,
41            enable_jitter: false,
42        }
43    }
44
45    /// Set the number of attempts to make.
46    pub const fn attempts(mut self, attempts: u32) -> Self {
47        self.attempts = attempts;
48        self
49    }
50
51    /// Set the base delay.
52    pub const fn base_delay(mut self, base_delay: Duration) -> Self {
53        self.base_delay = base_delay;
54        self
55    }
56
57    /// Set the exponential factor increasing delay.
58    pub const fn delay_factor(mut self, delay_factor: f64) -> Self {
59        self.delay_factor = delay_factor;
60        self
61    }
62
63    /// Enable jitter.
64    pub const fn jitter(mut self, enabled: bool) -> Self {
65        self.enable_jitter = enabled;
66        self
67    }
68
69    fn apply_jitter(&self, delay: Duration) -> Duration {
70        if self.enable_jitter {
71            // [0.5, 1.0)
72            delay.mul_f64(0.5 + fastrand::f64() / 2.0)
73        } else {
74            delay
75        }
76    }
77
78    /// Run a falliable asynchronous function using this retry configuration.
79    ///
80    /// Panics if the number of attempts is set to `0`, or the base delay is
81    /// incorrectly set to a negative duration.
82    pub async fn run<T, E: Debug>(
83        self,
84        mut func: impl AsyncFnMut() -> Result<T, E>,
85    ) -> Result<T, E> {
86        assert!(self.attempts > 0, "attempts must be greater than 0");
87        assert!(
88            self.base_delay >= Duration::ZERO && self.delay_factor >= 0.0,
89            "retry delay cannot be negative"
90        );
91        let mut delay = self.base_delay;
92        for i in 0..self.attempts {
93            match func().await {
94                Ok(value) => return Ok(value),
95                Err(err) if i == self.attempts - 1 => return Err(err),
96                Err(err) => {
97                    warn!(?err, "failed retryable operation {}, retrying", self.name);
98                    time::sleep(self.apply_jitter(delay)).await;
99                    delay = delay.mul_f64(self.delay_factor);
100                }
101            }
102        }
103        unreachable!();
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use std::time::Duration;
110
111    use tokio::time::Instant;
112
113    use super::Retry;
114
115    #[tokio::test]
116    #[should_panic]
117    async fn zero_retry_attempts() {
118        let _ = Retry::new("test")
119            .attempts(0)
120            .run(async || Ok::<_, std::io::Error>(()))
121            .await;
122    }
123
124    #[tokio::test]
125    async fn successful_retry() {
126        let mut count = 0;
127        let task = Retry::new("test").run(async || {
128            count += 1;
129            Ok::<_, std::io::Error>(())
130        });
131        let result = task.await;
132        assert_eq!(count, 1);
133        assert!(result.is_ok());
134    }
135
136    #[tokio::test]
137    async fn failed_retry() {
138        let mut count = 0;
139        let retry = Retry::new("test");
140        let task = retry.run(async || {
141            count += 1;
142            Err::<(), ()>(())
143        });
144        let result = task.await;
145        assert_eq!(count, retry.attempts);
146        assert!(result.is_err());
147    }
148
149    #[tokio::test(start_paused = true)]
150    async fn delayed_retry() {
151        let start = Instant::now();
152
153        let mut count = 0;
154        // Will retry at 0s, 1s, 3s, 7s, 15s
155        let task = Retry::new("test")
156            .attempts(5)
157            .base_delay(Duration::from_secs(1))
158            .delay_factor(2.0)
159            .run(async || {
160                count += 1;
161                println!("elapsed = {:?}", start.elapsed());
162                if start.elapsed() < Duration::from_secs(5) {
163                    Err::<(), ()>(())
164                } else {
165                    Ok(())
166                }
167            });
168        let result = task.await;
169        assert_eq!(count, 4);
170        assert!(result.is_ok());
171    }
172
173    #[tokio::test(start_paused = true)]
174    async fn delayed_retry_with_jitter() {
175        let start = Instant::now();
176
177        let mut count = 0;
178        // Earliest possible retry is 0s, 50ms, 525ms, 5.525s
179        let task = Retry::new("test_jitter")
180            .attempts(4)
181            .base_delay(Duration::from_millis(100))
182            .delay_factor(10.0)
183            .jitter(true)
184            .run(async || {
185                count += 1;
186                println!("elapsed = {:?}", start.elapsed());
187                if start.elapsed() < Duration::from_millis(500) {
188                    Err::<(), ()>(())
189                } else {
190                    Ok(())
191                }
192            });
193        let result = task.await;
194        assert_eq!(count, 3);
195        assert!(result.is_ok());
196    }
197}