rama_utils/backoff/
exponential.rs

1use parking_lot::Mutex;
2use std::fmt::{self, Display};
3use std::time::Duration;
4use tokio::time;
5
6use super::Backoff;
7use crate::rng::{HasherRng, Rng};
8
9/// A jittered [exponential backoff] strategy.
10///
11/// The backoff duration will increase exponentially for every subsequent
12/// backoff, up to a maximum duration. A small amount of [random jitter] is
13/// added to each backoff duration, in order to avoid retry spikes.
14///
15/// [exponential backoff]: https://en.wikipedia.org/wiki/Exponential_backoff
16/// [random jitter]: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
17pub struct ExponentialBackoff<F, R = HasherRng> {
18    min: time::Duration,
19    max: time::Duration,
20    jitter: f64,
21    rng_creator: F,
22    state: Mutex<ExponentialBackoffState<R>>,
23}
24
25impl<F: fmt::Debug, R: fmt::Debug> fmt::Debug for ExponentialBackoff<F, R> {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        f.debug_struct("ExponentialBackoff")
28            .field("min", &self.min)
29            .field("max", &self.max)
30            .field("jitter", &self.jitter)
31            .field("rng_creator", &self.rng_creator)
32            .field("state", &self.state)
33            .finish()
34    }
35}
36
37impl<F, R> Clone for ExponentialBackoff<F, R>
38where
39    R: Rng + Clone,
40    F: Fn() -> R + Clone,
41{
42    fn clone(&self) -> Self {
43        Self {
44            min: self.min,
45            max: self.max,
46            jitter: self.jitter,
47            rng_creator: self.rng_creator.clone(),
48            state: Mutex::new(ExponentialBackoffState {
49                rng: (self.rng_creator)(),
50                iterations: 0,
51            }),
52        }
53    }
54}
55
56impl Clone for ExponentialBackoff<(), HasherRng> {
57    fn clone(&self) -> Self {
58        Self {
59            min: self.min,
60            max: self.max,
61            jitter: self.jitter,
62            rng_creator: (),
63            state: Mutex::new(ExponentialBackoffState {
64                rng: HasherRng::default(),
65                iterations: 0,
66            }),
67        }
68    }
69}
70
71struct ExponentialBackoffState<R = HasherRng> {
72    rng: R,
73    iterations: u32,
74}
75
76impl<R: fmt::Debug> fmt::Debug for ExponentialBackoffState<R> {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        f.debug_struct("ExponentialBackoffState")
79            .field("rng", &self.rng)
80            .field("iterations", &self.iterations)
81            .finish()
82    }
83}
84
85impl<F, R> ExponentialBackoff<F, R>
86where
87    R: Rng + Clone,
88    F: Fn() -> R + Clone,
89{
90    /// Create a new `ExponentialBackoff`.
91    ///
92    /// # Error
93    ///
94    /// Returns a config validation error if:
95    /// - `min` > `max`
96    /// - `max` == 0
97    /// - `jitter` < `0.0`
98    /// - `jitter` > `100.0`
99    /// - `jitter` is not finite
100    pub fn new(
101        min: time::Duration,
102        max: time::Duration,
103        jitter: f64,
104        rng_creator: F,
105    ) -> Result<Self, InvalidBackoff> {
106        let rng = rng_creator();
107        Self::new_inner(min, max, jitter, rng_creator, rng)
108    }
109}
110
111impl<F, R> ExponentialBackoff<F, R> {
112    fn new_inner(
113        min: time::Duration,
114        max: time::Duration,
115        jitter: f64,
116        rng_creator: F,
117        rng: R,
118    ) -> Result<Self, InvalidBackoff> {
119        if min > max {
120            return Err(InvalidBackoff("maximum must not be less than minimum"));
121        }
122        if max == time::Duration::from_millis(0) {
123            return Err(InvalidBackoff("maximum must be non-zero"));
124        }
125        if jitter < 0.0 {
126            return Err(InvalidBackoff("jitter must not be negative"));
127        }
128        if jitter > 100.0 {
129            return Err(InvalidBackoff("jitter must not be greater than 100"));
130        }
131        if !jitter.is_finite() {
132            return Err(InvalidBackoff("jitter must be finite"));
133        }
134
135        Ok(ExponentialBackoff {
136            min,
137            max,
138            jitter,
139            rng_creator,
140            state: Mutex::new(ExponentialBackoffState { rng, iterations: 0 }),
141        })
142    }
143}
144
145impl<F, R: Rng> ExponentialBackoff<F, R> {
146    fn base(&self) -> time::Duration {
147        debug_assert!(
148            self.min <= self.max,
149            "maximum backoff must not be less than minimum backoff"
150        );
151        debug_assert!(
152            self.max > time::Duration::from_millis(0),
153            "Maximum backoff must be non-zero"
154        );
155        self.min
156            .checked_mul(2_u32.saturating_pow(self.state.lock().iterations))
157            .unwrap_or(self.max)
158            .min(self.max)
159    }
160
161    /// Returns a random, uniform duration on `[0, base*self.jitter]` no greater
162    /// than `self.max`.
163    fn jitter(&self, base: time::Duration) -> Option<time::Duration> {
164        if self.jitter <= 0.0 {
165            None
166        } else {
167            let jitter_factor = self.state.lock().rng.next_f64();
168            debug_assert!(
169                jitter_factor > 0.0,
170                "rng returns values between 0.0 and 1.0"
171            );
172            let rand_jitter = jitter_factor * self.jitter;
173            let secs = (base.as_secs() as f64) * rand_jitter;
174            let nanos = (base.subsec_nanos() as f64) * rand_jitter;
175            let remaining = self.max - base;
176            let result = time::Duration::new(secs as u64, nanos as u32);
177            if remaining.is_zero() || result.is_zero() {
178                None
179            } else {
180                Some(result.min(remaining))
181            }
182        }
183    }
184}
185
186impl<F, R> Backoff for ExponentialBackoff<F, R>
187where
188    R: Rng,
189    F: Send + Sync + 'static,
190{
191    async fn next_backoff(&self) -> bool {
192        let base = self.base();
193        let jitter = match self.jitter(base) {
194            Some(jitter) => jitter,
195            None => {
196                self.reset().await;
197                return false;
198            }
199        };
200
201        let next = base + jitter;
202
203        self.state.lock().iterations += 1;
204
205        tokio::time::sleep(next).await;
206        true
207    }
208
209    async fn reset(&self) {
210        self.state.lock().iterations = 0;
211    }
212}
213
214impl Default for ExponentialBackoff<(), HasherRng> {
215    fn default() -> Self {
216        ExponentialBackoff::new_inner(
217            Duration::from_millis(50),
218            Duration::from_secs(3),
219            0.99,
220            (),
221            HasherRng::default(),
222        )
223        .expect("Unable to create ExponentialBackoff")
224    }
225}
226
227/// Backoff validation error.
228#[derive(Debug)]
229pub struct InvalidBackoff(&'static str);
230
231impl Display for InvalidBackoff {
232    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        write!(f, "invalid backoff: {}", self.0)
234    }
235}
236
237impl std::error::Error for InvalidBackoff {}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use quickcheck::*;
243
244    #[tokio::test]
245    async fn backoff_default() {
246        let backoff = ExponentialBackoff::default();
247        assert!(backoff.next_backoff().await);
248    }
249
250    #[tokio::test]
251    async fn backoff_reset() {
252        let backoff = ExponentialBackoff::default();
253        assert!(backoff.next_backoff().await);
254        assert!(backoff.state.lock().iterations == 1);
255        backoff.reset().await;
256        assert!(backoff.state.lock().iterations == 0);
257    }
258
259    #[tokio::test]
260    async fn backoff_clone() {
261        let backoff = ExponentialBackoff::default();
262
263        assert!(backoff.state.lock().iterations == 0);
264        assert!(backoff.next_backoff().await);
265        assert!(backoff.state.lock().iterations == 1);
266
267        let cloned = backoff.clone();
268        assert!(cloned.state.lock().iterations == 0);
269        assert!(backoff.state.lock().iterations == 1);
270
271        assert!(cloned.next_backoff().await);
272        assert!(cloned.state.lock().iterations == 1);
273        assert!(backoff.state.lock().iterations == 1);
274    }
275
276    quickcheck! {
277        fn backoff_base_first(min_ms: u64, max_ms: u64) -> TestResult {
278            let min = time::Duration::from_millis(min_ms);
279            let max = time::Duration::from_millis(max_ms);
280            let backoff = match ExponentialBackoff::new(min, max, 0.0, HasherRng::default) {
281                Err(_) => return TestResult::discard(),
282                Ok(backoff) => backoff,
283            };
284
285            let delay = backoff.base();
286            TestResult::from_bool(min == delay)
287        }
288
289        fn backoff_base(min_ms: u64, max_ms: u64, iterations: u32) -> TestResult {
290            let min = time::Duration::from_millis(min_ms);
291            let max = time::Duration::from_millis(max_ms);
292            let backoff = match ExponentialBackoff::new(min, max, 0.0, HasherRng::default) {
293                Err(_) => return TestResult::discard(),
294                Ok(backoff) => backoff,
295            };
296
297            backoff.state.lock().iterations = iterations;
298            let delay = backoff.base();
299            TestResult::from_bool(min <= delay && delay <= max)
300        }
301
302        fn backoff_jitter(base_ms: u64, max_ms: u64, jitter: f64) -> TestResult {
303            let base = time::Duration::from_millis(base_ms);
304            let max = time::Duration::from_millis(max_ms);
305            let backoff = match ExponentialBackoff::new(base, max, jitter, HasherRng::default) {
306                Err(_) => return TestResult::discard(),
307                Ok(backoff) => backoff,
308            };
309
310            let j = backoff.jitter(base);
311            if jitter == 0.0 || base_ms == 0 || max_ms == base_ms {
312                TestResult::from_bool(j.is_none())
313            } else {
314                TestResult::from_bool(j.is_some())
315            }
316        }
317    }
318}