firebase_rs_sdk/storage/request/
backoff.rs1use rand::Rng;
2use std::time::{Duration, Instant};
3
4#[derive(Clone, Debug)]
6pub struct BackoffConfig {
7 pub initial_delay: Duration,
9 pub max_delay: Duration,
11 pub total_timeout: Duration,
13 pub max_attempts: usize,
15}
16
17impl BackoffConfig {
18 pub fn standard_operation() -> Self {
20 Self {
21 initial_delay: Duration::from_secs(1),
22 max_delay: Duration::from_secs(64),
23 total_timeout: Duration::from_secs(2 * 60),
24 max_attempts: 8,
25 }
26 }
27
28 pub fn upload_operation(max_retry_time: Duration) -> Self {
30 Self {
31 total_timeout: max_retry_time,
32 ..Self::standard_operation()
33 }
34 }
35
36 pub fn with_total_timeout(mut self, timeout: Duration) -> Self {
37 self.total_timeout = timeout;
38 self
39 }
40}
41
42#[derive(Debug)]
44pub struct BackoffState {
45 config: BackoffConfig,
46 attempt: usize,
47 deadline: Instant,
48}
49
50impl BackoffState {
51 pub fn new(config: BackoffConfig) -> Self {
52 let deadline = Instant::now() + config.total_timeout;
53 Self {
54 config,
55 attempt: 0,
56 deadline,
57 }
58 }
59
60 pub fn attempts(&self) -> usize {
61 self.attempt
62 }
63
64 pub fn has_time_remaining(&self) -> bool {
65 Instant::now() < self.deadline
66 }
67
68 pub fn can_retry(&self) -> bool {
69 self.attempt < self.config.max_attempts && self.has_time_remaining()
70 }
71
72 pub fn next_delay(&mut self) -> Duration {
73 if self.attempt == 0 {
74 self.attempt += 1;
75 return Duration::from_millis(0);
76 }
77
78 let exp = 2u64.pow((self.attempt - 1) as u32);
79 let base = self.config.initial_delay.mul_f64(exp as f64);
80 self.attempt += 1;
81
82 let capped = if base > self.config.max_delay {
83 self.config.max_delay
84 } else {
85 base
86 };
87
88 let jitter: f64 = rand::thread_rng().gen();
89 let jittered = capped.mul_f64(1.0 + jitter);
90 if jittered > self.config.max_delay {
91 self.config.max_delay
92 } else {
93 jittered
94 }
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 #[test]
103 fn first_delay_is_zero() {
104 let mut backoff = BackoffState::new(BackoffConfig::standard_operation());
105 assert_eq!(backoff.next_delay(), Duration::from_millis(0));
106 }
107
108 #[test]
109 fn delays_increase_with_jitter() {
110 let mut backoff = BackoffState::new(BackoffConfig::standard_operation());
111 backoff.next_delay();
112 let d1 = backoff.next_delay();
113 backoff.next_delay();
114 let d2 = backoff.next_delay();
115 assert!(d2 >= d1);
116 }
117}