1use std::time::Duration;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum Jitter {
21 None,
23 #[default]
25 Full,
26 Equal,
29}
30
31impl Jitter {
32 pub fn apply(self, value: Duration) -> Duration {
34 match self {
35 Self::None => value,
36 Self::Full => {
37 let max_ms = u64::try_from(value.as_millis()).unwrap_or(u64::MAX);
38 if max_ms == 0 {
39 return Duration::ZERO;
40 }
41 Duration::from_millis(pseudo_random_u64() % (max_ms + 1))
42 }
43 Self::Equal => {
44 let total_ms = u64::try_from(value.as_millis()).unwrap_or(u64::MAX);
45 let half = total_ms / 2;
46 if half == 0 {
47 return value;
48 }
49 Duration::from_millis(half + (pseudo_random_u64() % (half + 1)))
50 }
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57#[non_exhaustive]
58pub struct RetryPolicy {
59 pub max_attempts: u32,
62 pub initial_backoff: Duration,
65 pub max_backoff: Duration,
67 pub jitter: Jitter,
69 pub respect_retry_after: bool,
72}
73
74impl Default for RetryPolicy {
75 fn default() -> Self {
76 Self {
77 max_attempts: 3,
78 initial_backoff: Duration::from_millis(500),
79 max_backoff: Duration::from_secs(30),
80 jitter: Jitter::Full,
81 respect_retry_after: true,
82 }
83 }
84}
85
86impl RetryPolicy {
87 #[must_use]
89 pub fn none() -> Self {
90 Self {
91 max_attempts: 1,
92 initial_backoff: Duration::ZERO,
93 max_backoff: Duration::ZERO,
94 jitter: Jitter::None,
95 respect_retry_after: false,
96 }
97 }
98
99 #[must_use]
105 pub fn compute_backoff(&self, attempt: u32, server_retry_after: Option<Duration>) -> Duration {
106 let factor = 2u32.saturating_pow(attempt.saturating_sub(1).min(30));
108 let exponential = self
109 .initial_backoff
110 .saturating_mul(factor)
111 .min(self.max_backoff);
112 let jittered = self.jitter.apply(exponential);
113
114 if self.respect_retry_after
115 && let Some(server) = server_retry_after
116 {
117 return jittered.max(server);
118 }
119 jittered
120 }
121}
122
123fn pseudo_random_u64() -> u64 {
127 use std::time::{SystemTime, UNIX_EPOCH};
128 SystemTime::now().duration_since(UNIX_EPOCH).map_or(0, |d| {
129 let nanos = d.as_nanos();
130 #[allow(clippy::cast_possible_truncation)]
133 let mixed = (nanos as u64) ^ ((nanos >> 64) as u64);
134 mixed
135 })
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use pretty_assertions::assert_eq;
142
143 fn deterministic_policy() -> RetryPolicy {
144 RetryPolicy {
145 max_attempts: 5,
146 initial_backoff: Duration::from_millis(10),
147 max_backoff: Duration::from_secs(1),
148 jitter: Jitter::None,
149 respect_retry_after: false,
150 }
151 }
152
153 #[test]
154 fn compute_backoff_grows_exponentially() {
155 let p = deterministic_policy();
156 assert_eq!(p.compute_backoff(1, None), Duration::from_millis(10));
157 assert_eq!(p.compute_backoff(2, None), Duration::from_millis(20));
158 assert_eq!(p.compute_backoff(3, None), Duration::from_millis(40));
159 assert_eq!(p.compute_backoff(4, None), Duration::from_millis(80));
160 }
161
162 #[test]
163 fn compute_backoff_caps_at_max() {
164 let p = RetryPolicy {
165 max_backoff: Duration::from_millis(50),
166 ..deterministic_policy()
167 };
168 assert_eq!(p.compute_backoff(20, None), Duration::from_millis(50));
169 assert_eq!(p.compute_backoff(100, None), Duration::from_millis(50));
170 }
171
172 #[test]
173 fn respect_retry_after_uses_max_of_server_and_jittered() {
174 let p = RetryPolicy {
175 respect_retry_after: true,
176 ..deterministic_policy()
177 };
178 assert_eq!(
180 p.compute_backoff(1, Some(Duration::from_secs(5))),
181 Duration::from_secs(5)
182 );
183 }
184
185 #[test]
186 fn respect_retry_after_false_ignores_server_header() {
187 let p = deterministic_policy(); assert_eq!(
189 p.compute_backoff(1, Some(Duration::from_secs(60))),
190 Duration::from_millis(10)
191 );
192 }
193
194 #[test]
195 fn jitter_none_is_identity() {
196 assert_eq!(
197 Jitter::None.apply(Duration::from_millis(42)),
198 Duration::from_millis(42)
199 );
200 }
201
202 #[test]
203 fn jitter_full_stays_within_range() {
204 let max = Duration::from_millis(100);
205 for _ in 0..50 {
206 let v = Jitter::Full.apply(max);
207 assert!(v <= max, "{v:?} should be <= {max:?}");
208 }
209 }
210
211 #[test]
212 fn jitter_equal_stays_in_upper_half() {
213 let max = Duration::from_millis(100);
214 for _ in 0..50 {
215 let v = Jitter::Equal.apply(max);
216 assert!(v >= Duration::from_millis(50), "{v:?} below half");
217 assert!(v <= max, "{v:?} above max");
218 }
219 }
220
221 #[test]
222 fn none_policy_skips_retries() {
223 let p = RetryPolicy::none();
224 assert_eq!(p.max_attempts, 1);
225 assert_eq!(p.initial_backoff, Duration::ZERO);
226 assert!(!p.respect_retry_after);
227 }
228
229 #[test]
230 fn default_policy_matches_spec() {
231 let p = RetryPolicy::default();
232 assert_eq!(p.max_attempts, 3);
233 assert_eq!(p.initial_backoff, Duration::from_millis(500));
234 assert_eq!(p.max_backoff, Duration::from_secs(30));
235 assert_eq!(p.jitter, Jitter::Full);
236 assert!(p.respect_retry_after);
237 }
238}