named_retry/
lib.rs

1//! Utilities for retrying falliable, asynchronous operations.
2
3use std::fmt::Debug;
4use std::future::Future;
5use std::time::Duration;
6
7use tokio::time;
8use tracing::warn;
9
10/// Calls a fallible async function multiple times, with a given timeout.
11///
12/// If a `base_delay` is provided, the function is given an exponentially
13/// increasing delay on each run, up until the maximum number of attempts.
14///
15/// Returns the first successful result if any, or the last error.
16#[derive(Copy, Clone, Debug)]
17pub struct Retry {
18    /// Name of the operation being retried.
19    pub name: &'static str,
20
21    /// The number of attempts to make.
22    pub attempts: u32,
23
24    /// The base delay after the first attempt, if provided.
25    pub base_delay: Duration,
26
27    /// Exponential factor to increase the delay by on each attempt.
28    pub delay_factor: f64,
29
30    /// If true, the delay will be selected randomly from the range [delay/2, delay).
31    pub enable_jitter: bool,
32}
33
34impl Retry {
35    /// Construct a new [`Retry`] object with default parameters.
36    pub const fn new(name: &'static str) -> Self {
37        Self {
38            name,
39            attempts: 3,
40            base_delay: Duration::ZERO,
41            delay_factor: 1.0,
42            enable_jitter: false,
43        }
44    }
45
46    /// Set the number of attempts to make.
47    pub const fn attempts(mut self, attempts: u32) -> Self {
48        self.attempts = attempts;
49        self
50    }
51
52    /// Set the base delay.
53    pub const fn base_delay(mut self, base_delay: Duration) -> Self {
54        self.base_delay = base_delay;
55        self
56    }
57
58    /// Set the exponential factor increasing delay.
59    pub const fn delay_factor(mut self, delay_factor: f64) -> Self {
60        self.delay_factor = delay_factor;
61        self
62    }
63
64    /// Enable jitter.
65    pub const fn jitter(mut self, enabled: bool) -> Self {
66        self.enable_jitter = enabled;
67        self
68    }
69
70    fn apply_jitter(&self, delay: Duration) -> Duration {
71        if self.enable_jitter {
72            // [0.5, 1.0)
73            delay.mul_f64(0.5 + fastrand::f64() / 2.0)
74        } else {
75            delay
76        }
77    }
78
79    /// Run a falliable asynchronous function using this retry configuration.
80    ///
81    /// Panics if the number of attempts is set to `0`, or the base delay is
82    /// incorrectly set to a negative duration.
83    pub async fn run<T, E: Debug, Fut>(self, mut func: impl FnMut() -> Fut) -> Result<T, E>
84    where
85        Fut: Future<Output = Result<T, E>>,
86    {
87        assert!(self.attempts > 0, "attempts must be greater than 0");
88        assert!(
89            self.base_delay >= Duration::ZERO && self.delay_factor >= 0.0,
90            "retry delay cannot be negative"
91        );
92        let mut delay = self.base_delay;
93        for i in 0..self.attempts {
94            match func().await {
95                Ok(value) => return Ok(value),
96                Err(err) if i == self.attempts - 1 => return Err(err),
97                Err(err) => {
98                    warn!(?err, "failed retryable operation {}, retrying", self.name);
99                    time::sleep(self.apply_jitter(delay)).await;
100                    delay = delay.mul_f64(self.delay_factor);
101                }
102            }
103        }
104        unreachable!();
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use std::time::Duration;
111
112    use tokio::time::Instant;
113
114    use super::Retry;
115
116    #[tokio::test]
117    #[should_panic]
118    async fn zero_retry_attempts() {
119        let _ = Retry::new("test")
120            .attempts(0)
121            .run(|| async { Ok::<_, std::io::Error>(()) })
122            .await;
123    }
124
125    #[tokio::test]
126    async fn successful_retry() {
127        let mut count = 0;
128        let task = Retry::new("test").run(|| {
129            count += 1;
130            async { Ok::<_, std::io::Error>(()) }
131        });
132        let result = task.await;
133        assert_eq!(count, 1);
134        assert!(result.is_ok());
135    }
136
137    #[tokio::test]
138    async fn failed_retry() {
139        let mut count = 0;
140        let retry = Retry::new("test");
141        let task = retry.run(|| {
142            count += 1;
143            async { Err::<(), ()>(()) }
144        });
145        let result = task.await;
146        assert_eq!(count, retry.attempts);
147        assert!(result.is_err());
148    }
149
150    #[tokio::test(start_paused = true)]
151    async fn delayed_retry() {
152        let start = Instant::now();
153
154        let mut count = 0;
155        // Will retry at 0s, 1s, 3s, 7s, 15s
156        let task = Retry::new("test")
157            .attempts(5)
158            .base_delay(Duration::from_secs(1))
159            .delay_factor(2.0)
160            .run(|| {
161                count += 1;
162                async {
163                    println!("elapsed = {:?}", start.elapsed());
164                    if start.elapsed() < Duration::from_secs(5) {
165                        Err::<(), ()>(())
166                    } else {
167                        Ok(())
168                    }
169                }
170            });
171        let result = task.await;
172        assert_eq!(count, 4);
173        assert!(result.is_ok());
174    }
175
176    #[tokio::test(start_paused = true)]
177    async fn delayed_retry_with_jitter() {
178        let start = Instant::now();
179
180        let mut count = 0;
181        // Earliest possible retry is 0s, 50ms, 525ms, 5.525s
182        let task = Retry::new("test_jitter")
183            .attempts(4)
184            .base_delay(Duration::from_millis(100))
185            .delay_factor(10.0)
186            .jitter(true)
187            .run(|| {
188                count += 1;
189                async {
190                    println!("elapsed = {:?}", start.elapsed());
191                    if start.elapsed() < Duration::from_millis(500) {
192                        Err::<(), ()>(())
193                    } else {
194                        Ok(())
195                    }
196                }
197            });
198        let result = task.await;
199        assert_eq!(count, 3);
200        assert!(result.is_ok());
201    }
202}