1use parking_lot::Mutex;
2use std::fmt::{self, Display};
3use std::time::Duration;
4use tokio::time;
5
6use super::Backoff;
7use crate::rng::{HasherRng, Rng};
8
9pub struct ExponentialBackoff<F, R = HasherRng> {
18 min: time::Duration,
19 max: time::Duration,
20 jitter: f64,
21 rng_creator: F,
22 state: Mutex<ExponentialBackoffState<R>>,
23}
24
25impl<F: fmt::Debug, R: fmt::Debug> fmt::Debug for ExponentialBackoff<F, R> {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 f.debug_struct("ExponentialBackoff")
28 .field("min", &self.min)
29 .field("max", &self.max)
30 .field("jitter", &self.jitter)
31 .field("rng_creator", &self.rng_creator)
32 .field("state", &self.state)
33 .finish()
34 }
35}
36
37impl<F, R> Clone for ExponentialBackoff<F, R>
38where
39 R: Rng + Clone,
40 F: Fn() -> R + Clone,
41{
42 fn clone(&self) -> Self {
43 Self {
44 min: self.min,
45 max: self.max,
46 jitter: self.jitter,
47 rng_creator: self.rng_creator.clone(),
48 state: Mutex::new(ExponentialBackoffState {
49 rng: (self.rng_creator)(),
50 iterations: 0,
51 }),
52 }
53 }
54}
55
56impl Clone for ExponentialBackoff<(), HasherRng> {
57 fn clone(&self) -> Self {
58 Self {
59 min: self.min,
60 max: self.max,
61 jitter: self.jitter,
62 rng_creator: (),
63 state: Mutex::new(ExponentialBackoffState {
64 rng: HasherRng::default(),
65 iterations: 0,
66 }),
67 }
68 }
69}
70
71struct ExponentialBackoffState<R = HasherRng> {
72 rng: R,
73 iterations: u32,
74}
75
76impl<R: fmt::Debug> fmt::Debug for ExponentialBackoffState<R> {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 f.debug_struct("ExponentialBackoffState")
79 .field("rng", &self.rng)
80 .field("iterations", &self.iterations)
81 .finish()
82 }
83}
84
85impl<F, R> ExponentialBackoff<F, R>
86where
87 R: Rng + Clone,
88 F: Fn() -> R + Clone,
89{
90 pub fn new(
101 min: time::Duration,
102 max: time::Duration,
103 jitter: f64,
104 rng_creator: F,
105 ) -> Result<Self, InvalidBackoff> {
106 let rng = rng_creator();
107 Self::new_inner(min, max, jitter, rng_creator, rng)
108 }
109}
110
111impl<F, R> ExponentialBackoff<F, R> {
112 fn new_inner(
113 min: time::Duration,
114 max: time::Duration,
115 jitter: f64,
116 rng_creator: F,
117 rng: R,
118 ) -> Result<Self, InvalidBackoff> {
119 if min > max {
120 return Err(InvalidBackoff("maximum must not be less than minimum"));
121 }
122 if max == time::Duration::from_millis(0) {
123 return Err(InvalidBackoff("maximum must be non-zero"));
124 }
125 if jitter < 0.0 {
126 return Err(InvalidBackoff("jitter must not be negative"));
127 }
128 if jitter > 100.0 {
129 return Err(InvalidBackoff("jitter must not be greater than 100"));
130 }
131 if !jitter.is_finite() {
132 return Err(InvalidBackoff("jitter must be finite"));
133 }
134
135 Ok(ExponentialBackoff {
136 min,
137 max,
138 jitter,
139 rng_creator,
140 state: Mutex::new(ExponentialBackoffState { rng, iterations: 0 }),
141 })
142 }
143}
144
145impl<F, R: Rng> ExponentialBackoff<F, R> {
146 fn base(&self) -> time::Duration {
147 debug_assert!(
148 self.min <= self.max,
149 "maximum backoff must not be less than minimum backoff"
150 );
151 debug_assert!(
152 self.max > time::Duration::from_millis(0),
153 "Maximum backoff must be non-zero"
154 );
155 self.min
156 .checked_mul(2_u32.saturating_pow(self.state.lock().iterations))
157 .unwrap_or(self.max)
158 .min(self.max)
159 }
160
161 fn jitter(&self, base: time::Duration) -> Option<time::Duration> {
164 if self.jitter <= 0.0 {
165 None
166 } else {
167 let jitter_factor = self.state.lock().rng.next_f64();
168 debug_assert!(
169 jitter_factor > 0.0,
170 "rng returns values between 0.0 and 1.0"
171 );
172 let rand_jitter = jitter_factor * self.jitter;
173 let secs = (base.as_secs() as f64) * rand_jitter;
174 let nanos = (base.subsec_nanos() as f64) * rand_jitter;
175 let remaining = self.max - base;
176 let result = time::Duration::new(secs as u64, nanos as u32);
177 if remaining.is_zero() || result.is_zero() {
178 None
179 } else {
180 Some(result.min(remaining))
181 }
182 }
183 }
184}
185
186impl<F, R> Backoff for ExponentialBackoff<F, R>
187where
188 R: Rng,
189 F: Send + Sync + 'static,
190{
191 async fn next_backoff(&self) -> bool {
192 let base = self.base();
193 let jitter = match self.jitter(base) {
194 Some(jitter) => jitter,
195 None => {
196 self.reset().await;
197 return false;
198 }
199 };
200
201 let next = base + jitter;
202
203 self.state.lock().iterations += 1;
204
205 tokio::time::sleep(next).await;
206 true
207 }
208
209 async fn reset(&self) {
210 self.state.lock().iterations = 0;
211 }
212}
213
214impl Default for ExponentialBackoff<(), HasherRng> {
215 fn default() -> Self {
216 ExponentialBackoff::new_inner(
217 Duration::from_millis(50),
218 Duration::from_secs(3),
219 0.99,
220 (),
221 HasherRng::default(),
222 )
223 .expect("Unable to create ExponentialBackoff")
224 }
225}
226
227#[derive(Debug)]
229pub struct InvalidBackoff(&'static str);
230
231impl Display for InvalidBackoff {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 write!(f, "invalid backoff: {}", self.0)
234 }
235}
236
237impl std::error::Error for InvalidBackoff {}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use quickcheck::*;
243
244 #[tokio::test]
245 async fn backoff_default() {
246 let backoff = ExponentialBackoff::default();
247 assert!(backoff.next_backoff().await);
248 }
249
250 #[tokio::test]
251 async fn backoff_reset() {
252 let backoff = ExponentialBackoff::default();
253 assert!(backoff.next_backoff().await);
254 assert!(backoff.state.lock().iterations == 1);
255 backoff.reset().await;
256 assert!(backoff.state.lock().iterations == 0);
257 }
258
259 #[tokio::test]
260 async fn backoff_clone() {
261 let backoff = ExponentialBackoff::default();
262
263 assert!(backoff.state.lock().iterations == 0);
264 assert!(backoff.next_backoff().await);
265 assert!(backoff.state.lock().iterations == 1);
266
267 let cloned = backoff.clone();
268 assert!(cloned.state.lock().iterations == 0);
269 assert!(backoff.state.lock().iterations == 1);
270
271 assert!(cloned.next_backoff().await);
272 assert!(cloned.state.lock().iterations == 1);
273 assert!(backoff.state.lock().iterations == 1);
274 }
275
276 quickcheck! {
277 fn backoff_base_first(min_ms: u64, max_ms: u64) -> TestResult {
278 let min = time::Duration::from_millis(min_ms);
279 let max = time::Duration::from_millis(max_ms);
280 let backoff = match ExponentialBackoff::new(min, max, 0.0, HasherRng::default) {
281 Err(_) => return TestResult::discard(),
282 Ok(backoff) => backoff,
283 };
284
285 let delay = backoff.base();
286 TestResult::from_bool(min == delay)
287 }
288
289 fn backoff_base(min_ms: u64, max_ms: u64, iterations: u32) -> TestResult {
290 let min = time::Duration::from_millis(min_ms);
291 let max = time::Duration::from_millis(max_ms);
292 let backoff = match ExponentialBackoff::new(min, max, 0.0, HasherRng::default) {
293 Err(_) => return TestResult::discard(),
294 Ok(backoff) => backoff,
295 };
296
297 backoff.state.lock().iterations = iterations;
298 let delay = backoff.base();
299 TestResult::from_bool(min <= delay && delay <= max)
300 }
301
302 fn backoff_jitter(base_ms: u64, max_ms: u64, jitter: f64) -> TestResult {
303 let base = time::Duration::from_millis(base_ms);
304 let max = time::Duration::from_millis(max_ms);
305 let backoff = match ExponentialBackoff::new(base, max, jitter, HasherRng::default) {
306 Err(_) => return TestResult::discard(),
307 Ok(backoff) => backoff,
308 };
309
310 let j = backoff.jitter(base);
311 if jitter == 0.0 || base_ms == 0 || max_ms == base_ms {
312 TestResult::from_bool(j.is_none())
313 } else {
314 TestResult::from_bool(j.is_some())
315 }
316 }
317 }
318}