1use std::time::{Duration, Instant};
34
35#[derive(Debug, Clone)]
51pub struct ExponentialBackoff {
52 current: Duration,
54 max: Duration,
56 multiplier: f64,
58 retries: u32,
60 max_retries: Option<u32>,
62 jitter: bool,
64 max_total_duration: Option<Duration>,
66 start_time: Option<Instant>,
68}
69
70impl ExponentialBackoff {
71 pub fn new(initial: Duration, max: Duration, max_retries: Option<u32>) -> Self {
75 Self {
76 current: initial,
77 max,
78 multiplier: 2.0,
79 retries: 0,
80 max_retries,
81 jitter: false,
82 max_total_duration: None,
83 start_time: None,
84 }
85 }
86
87 pub fn with_multiplier(
89 initial: Duration,
90 max: Duration,
91 max_retries: Option<u32>,
92 multiplier: f64,
93 ) -> Self {
94 Self {
95 multiplier,
96 ..Self::new(initial, max, max_retries)
97 }
98 }
99
100 pub fn with_total_duration(
102 initial: Duration,
103 max: Duration,
104 max_retries: Option<u32>,
105 max_total_duration: Duration,
106 ) -> Self {
107 Self {
108 max_total_duration: Some(max_total_duration),
109 start_time: Some(Instant::now()),
110 ..Self::new(initial, max, max_retries)
111 }
112 }
113
114 pub fn builder() -> BackoffBuilder {
130 BackoffBuilder::default()
131 }
132
133 pub fn retry_count(&self) -> u32 {
135 self.retries
136 }
137
138 pub fn reset(&mut self, initial: Duration) {
140 self.retries = 0;
141 self.current = initial;
142 if self.max_total_duration.is_some() {
143 self.start_time = Some(Instant::now());
144 }
145 }
146
147 fn is_duration_exceeded(&self) -> bool {
149 if let (Some(max_duration), Some(start)) = (self.max_total_duration, self.start_time) {
150 start.elapsed() > max_duration
151 } else {
152 false
153 }
154 }
155
156 fn apply_jitter(duration: Duration) -> Duration {
158 use std::collections::hash_map::RandomState;
159 use std::hash::{BuildHasher, Hasher};
160
161 let random = RandomState::new().build_hasher().finish();
163 let factor = ((random % 601) as f64 / 1000.0) - 0.3; let jittered_millis = duration.as_millis() as f64 * (1.0 + factor);
166 Duration::from_millis(jittered_millis.max(1.0) as u64)
167 }
168}
169
170impl Iterator for ExponentialBackoff {
171 type Item = Duration;
172
173 fn next(&mut self) -> Option<Duration> {
174 if self.max_total_duration.is_some() && self.start_time.is_none() {
176 self.start_time = Some(Instant::now());
177 }
178
179 if self.is_duration_exceeded() {
181 return None;
182 }
183
184 if let Some(max_retries) = self.max_retries {
186 if self.retries >= max_retries {
187 return None;
188 }
189 }
190
191 let delay = self.current;
193
194 let next_millis = (self.current.as_millis() as f64 * self.multiplier) as u64;
196 let next_duration = Duration::from_millis(next_millis);
197 self.current = next_duration.min(self.max);
198
199 self.retries += 1;
200
201 if self.jitter {
203 Some(Self::apply_jitter(delay))
204 } else {
205 Some(delay)
206 }
207 }
208}
209
210#[derive(Debug, Clone)]
218pub struct BackoffBuilder {
219 initial_delay: Duration,
220 max_delay: Duration,
221 max_retries: Option<u32>,
222 multiplier: f64,
223 jitter: bool,
224 max_total_duration: Option<Duration>,
225}
226
227impl Default for BackoffBuilder {
228 fn default() -> Self {
229 Self {
230 initial_delay: Duration::from_secs(1),
231 max_delay: Duration::from_secs(60),
232 max_retries: None,
233 multiplier: 2.0,
234 jitter: false,
235 max_total_duration: None,
236 }
237 }
238}
239
240impl BackoffBuilder {
241 pub fn initial_delay(mut self, delay: Duration) -> Self {
243 self.initial_delay = delay;
244 self
245 }
246
247 pub fn max_delay(mut self, delay: Duration) -> Self {
249 self.max_delay = delay;
250 self
251 }
252
253 pub fn max_retries(mut self, retries: u32) -> Self {
255 self.max_retries = Some(retries);
256 self
257 }
258
259 pub fn multiplier(mut self, multiplier: f64) -> Self {
261 self.multiplier = multiplier;
262 self
263 }
264
265 pub fn with_jitter(mut self) -> Self {
267 self.jitter = true;
268 self
269 }
270
271 pub fn max_total_duration(mut self, duration: Duration) -> Self {
273 self.max_total_duration = Some(duration);
274 self
275 }
276
277 pub fn build(self) -> ExponentialBackoff {
279 ExponentialBackoff {
280 current: self.initial_delay,
281 max: self.max_delay,
282 multiplier: self.multiplier,
283 retries: 0,
284 max_retries: self.max_retries,
285 jitter: self.jitter,
286 max_total_duration: self.max_total_duration,
287 start_time: if self.max_total_duration.is_some() {
288 Some(Instant::now())
289 } else {
290 None
291 },
292 }
293 }
294}
295
296#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_basic_backoff() {
306 let mut backoff =
307 ExponentialBackoff::new(Duration::from_millis(100), Duration::from_secs(2), Some(4));
308
309 assert_eq!(backoff.next(), Some(Duration::from_millis(100)));
310 assert_eq!(backoff.next(), Some(Duration::from_millis(200)));
311 assert_eq!(backoff.next(), Some(Duration::from_millis(400)));
312 assert_eq!(backoff.next(), Some(Duration::from_millis(800)));
313 assert_eq!(backoff.next(), None); }
315
316 #[test]
317 fn test_capped_backoff() {
318 let mut backoff = ExponentialBackoff::new(
319 Duration::from_millis(100),
320 Duration::from_millis(500),
321 Some(5),
322 );
323
324 assert_eq!(backoff.next(), Some(Duration::from_millis(100)));
325 assert_eq!(backoff.next(), Some(Duration::from_millis(200)));
326 assert_eq!(backoff.next(), Some(Duration::from_millis(400)));
327 assert_eq!(backoff.next(), Some(Duration::from_millis(500))); assert_eq!(backoff.next(), Some(Duration::from_millis(500))); assert_eq!(backoff.next(), None);
330 }
331
332 #[test]
333 fn test_unlimited_retries() {
334 let mut backoff =
335 ExponentialBackoff::new(Duration::from_millis(50), Duration::from_secs(10), None);
336
337 for i in 0..20 {
338 let delay = backoff.next();
339 assert!(delay.is_some(), "Retry {i} should succeed");
340 }
341 }
342
343 #[test]
344 fn test_custom_multiplier() {
345 let mut backoff = ExponentialBackoff::with_multiplier(
346 Duration::from_millis(100),
347 Duration::from_secs(10),
348 Some(3),
349 1.5,
350 );
351
352 assert_eq!(backoff.next(), Some(Duration::from_millis(100)));
353 assert_eq!(backoff.next(), Some(Duration::from_millis(150)));
354 assert_eq!(backoff.next(), Some(Duration::from_millis(225)));
355 assert_eq!(backoff.next(), None);
356 }
357
358 #[test]
359 fn test_retry_count() {
360 let mut backoff =
361 ExponentialBackoff::new(Duration::from_millis(100), Duration::from_secs(1), None);
362
363 assert_eq!(backoff.retry_count(), 0);
364 backoff.next();
365 assert_eq!(backoff.retry_count(), 1);
366 backoff.next();
367 assert_eq!(backoff.retry_count(), 2);
368 }
369
370 #[test]
371 fn test_total_duration_limit() {
372 let backoff = ExponentialBackoff::with_total_duration(
373 Duration::from_millis(10),
374 Duration::from_millis(100),
375 None,
376 Duration::from_millis(50), );
378
379 let delays: Vec<_> = backoff.collect();
380 assert!(!delays.is_empty());
382 }
383
384 #[test]
385 fn test_builder_basic() {
386 let mut backoff = ExponentialBackoff::builder()
387 .initial_delay(Duration::from_millis(100))
388 .max_delay(Duration::from_secs(2))
389 .max_retries(4)
390 .build();
391
392 assert_eq!(backoff.next(), Some(Duration::from_millis(100)));
393 assert_eq!(backoff.next(), Some(Duration::from_millis(200)));
394 assert_eq!(backoff.next(), Some(Duration::from_millis(400)));
395 assert_eq!(backoff.next(), Some(Duration::from_millis(800)));
396 assert_eq!(backoff.next(), None);
397 }
398
399 #[test]
400 fn test_builder_with_jitter() {
401 let mut backoff = ExponentialBackoff::builder()
402 .initial_delay(Duration::from_millis(1000))
403 .max_delay(Duration::from_secs(60))
404 .max_retries(3)
405 .with_jitter()
406 .build();
407
408 let d1 = backoff.next().unwrap();
409 let d2 = backoff.next().unwrap();
410 let d3 = backoff.next().unwrap();
411
412 assert!(
414 d1.as_millis() >= 700 && d1.as_millis() <= 1300,
415 "delay 1 = {:?}, expected 700-1300ms",
416 d1
417 );
418 assert!(
420 d2.as_millis() >= 1400 && d2.as_millis() <= 2600,
421 "delay 2 = {:?}, expected 1400-2600ms",
422 d2
423 );
424 assert!(
426 d3.as_millis() >= 2800 && d3.as_millis() <= 5200,
427 "delay 3 = {:?}, expected 2800-5200ms",
428 d3
429 );
430
431 assert_eq!(backoff.next(), None);
432 }
433
434 #[test]
435 fn test_builder_with_custom_multiplier() {
436 let mut backoff = ExponentialBackoff::builder()
437 .initial_delay(Duration::from_millis(100))
438 .max_delay(Duration::from_secs(10))
439 .max_retries(3)
440 .multiplier(3.0)
441 .build();
442
443 assert_eq!(backoff.next(), Some(Duration::from_millis(100)));
444 assert_eq!(backoff.next(), Some(Duration::from_millis(300)));
445 assert_eq!(backoff.next(), Some(Duration::from_millis(900)));
446 assert_eq!(backoff.next(), None);
447 }
448
449 #[test]
450 fn test_reset() {
451 let mut backoff =
452 ExponentialBackoff::new(Duration::from_millis(100), Duration::from_secs(2), Some(2));
453
454 assert_eq!(backoff.next(), Some(Duration::from_millis(100)));
455 assert_eq!(backoff.next(), Some(Duration::from_millis(200)));
456 assert_eq!(backoff.next(), None);
457
458 backoff.reset(Duration::from_millis(100));
459 assert_eq!(backoff.retry_count(), 0);
460 assert_eq!(backoff.next(), Some(Duration::from_millis(100)));
461 }
462
463 #[test]
464 fn test_jitter_distribution() {
465 let results: Vec<Duration> = (0..100)
467 .map(|_| ExponentialBackoff::apply_jitter(Duration::from_millis(1000)))
468 .collect();
469
470 let min = results.iter().min().unwrap().as_millis();
471 let max = results.iter().max().unwrap().as_millis();
472
473 assert!(min < 1000, "min={min}ms, should be < 1000ms for jitter");
475 assert!(max > 1000, "max={max}ms, should be > 1000ms for jitter");
476 }
477}