1use std::sync::Arc;
14use std::time::Duration;
15
16use http::{HeaderMap, StatusCode};
17
18use crate::response::Response;
19
20pub type ShouldRetryFn = Arc<dyn Fn(&Response) -> bool + Send + Sync>;
22
23#[derive(Clone)]
38#[must_use = "apply with `ClientBuilder::retry` or `RequestBuilder::retry`"]
39pub enum RetryPolicy {
40 Count {
42 attempts: u32,
44 should_retry: Option<ShouldRetryFn>,
46 },
47 Linear {
49 attempts: u32,
51 delay: Duration,
53 should_retry: Option<ShouldRetryFn>,
54 jitter: bool,
56 },
57 Exponential {
59 attempts: u32,
61 base_delay: Duration,
63 max_delay: Duration,
65 should_retry: Option<ShouldRetryFn>,
66 jitter: bool,
67 },
68}
69
70impl RetryPolicy {
71 pub fn count(attempts: u32) -> Self {
76 Self::Count {
77 attempts,
78 should_retry: None,
79 }
80 }
81
82 pub fn linear(attempts: u32, delay: Duration) -> Self {
84 Self::Linear {
85 attempts,
86 delay,
87 should_retry: None,
88 jitter: false,
89 }
90 }
91
92 pub fn exponential(attempts: u32, base_delay: Duration, max_delay: Duration) -> Self {
94 Self::Exponential {
95 attempts,
96 base_delay,
97 max_delay,
98 should_retry: None,
99 jitter: true,
100 }
101 }
102
103 #[must_use = "chain with `ClientBuilder::retry` or `RequestBuilder::retry`"]
107 pub fn with_jitter(mut self, jitter: bool) -> Self {
108 match &mut self {
109 Self::Linear { jitter: j, .. } | Self::Exponential { jitter: j, .. } => *j = jitter,
110 Self::Count { .. } => {}
111 }
112 self
113 }
114
115 #[must_use = "chain with `ClientBuilder::retry` or `RequestBuilder::retry`"]
117 pub fn with_should_retry(self, f: ShouldRetryFn) -> Self {
118 match self {
119 Self::Count { attempts, .. } => Self::Count {
120 attempts,
121 should_retry: Some(f),
122 },
123 Self::Linear {
124 attempts,
125 delay,
126 jitter,
127 ..
128 } => Self::Linear {
129 attempts,
130 delay,
131 should_retry: Some(f),
132 jitter,
133 },
134 Self::Exponential {
135 attempts,
136 base_delay,
137 max_delay,
138 jitter,
139 ..
140 } => Self::Exponential {
141 attempts,
142 base_delay,
143 max_delay,
144 should_retry: Some(f),
145 jitter,
146 },
147 }
148 }
149
150 pub fn max_attempts(&self) -> u32 {
152 match self {
153 Self::Count { attempts, .. }
154 | Self::Linear { attempts, .. }
155 | Self::Exponential { attempts, .. } => *attempts,
156 }
157 }
158
159 pub(crate) fn delay_before_attempt(&self, attempt: u32) -> Duration {
160 match self {
161 Self::Count { .. } => Duration::from_secs(1),
162 Self::Linear { delay, .. } => *delay,
163 Self::Exponential {
164 base_delay,
165 max_delay,
166 ..
167 } => {
168 let exp = base_delay.saturating_mul(2u32.saturating_pow(attempt));
169 exp.min(*max_delay)
170 }
171 }
172 }
173
174 pub(crate) fn delay_after_response(&self, attempt: u32, headers: &HeaderMap) -> Duration {
179 if let Some(retry_after) = parse_retry_after(headers) {
180 return retry_after;
181 }
182 let base = self.delay_before_attempt(attempt);
183 if self.uses_jitter() {
184 apply_jitter(base)
185 } else {
186 base
187 }
188 }
189
190 pub(crate) fn uses_jitter(&self) -> bool {
191 match self {
192 Self::Count { .. } => true,
193 Self::Linear { jitter, .. } | Self::Exponential { jitter, .. } => *jitter,
194 }
195 }
196
197 pub(crate) fn has_custom_should_retry(&self) -> bool {
199 matches!(
200 self,
201 Self::Count {
202 should_retry: Some(_),
203 ..
204 } | Self::Linear {
205 should_retry: Some(_),
206 ..
207 } | Self::Exponential {
208 should_retry: Some(_),
209 ..
210 }
211 )
212 }
213
214 pub(crate) fn should_retry_response(
215 &self,
216 response: &Response,
217 transport_failed: bool,
218 ) -> bool {
219 if transport_failed {
220 return true;
221 }
222
223 let custom = match self {
224 Self::Count { should_retry, .. }
225 | Self::Linear { should_retry, .. }
226 | Self::Exponential { should_retry, .. } => should_retry.as_ref(),
227 };
228
229 if let Some(f) = custom {
230 return f(response);
231 }
232
233 default_should_retry(response.status())
234 }
235}
236
237pub fn default_should_retry(status: StatusCode) -> bool {
239 matches!(status.as_u16(), 408 | 429 | 502 | 503 | 504)
240}
241
242pub fn parse_retry_after(headers: &HeaderMap) -> Option<Duration> {
247 let value = headers.get(http::header::RETRY_AFTER)?.to_str().ok()?;
248 let value = value.trim();
249 if let Ok(secs) = value.parse::<u64>() {
250 return Some(Duration::from_secs(secs));
251 }
252 let when = httpdate::parse_http_date(value).ok()?;
253 Some(
254 when.duration_since(std::time::SystemTime::now())
255 .unwrap_or(Duration::ZERO),
256 )
257}
258
259fn apply_jitter(delay: Duration) -> Duration {
260 let nanos = delay.as_nanos().min(u128::from(u64::MAX)) as u64;
261 if nanos == 0 {
262 return delay;
263 }
264 let half = nanos / 2;
265 let span = nanos.saturating_sub(half).max(1);
266 Duration::from_nanos(half + fastrand::u64(..span))
267}
268
269pub(crate) use crate::cancel::sleep_or_cancel;
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use crate::response::Response;
275 use http::StatusCode;
276
277 fn response_with_status(status: u16) -> Response {
278 Response::new(
279 StatusCode::from_u16(status).unwrap(),
280 http::HeaderMap::new(),
281 bytes::Bytes::new(),
282 None,
283 #[cfg(feature = "json")]
284 None,
285 )
286 }
287
288 #[test]
289 fn default_should_retry_codes() {
290 assert!(default_should_retry(StatusCode::REQUEST_TIMEOUT));
291 assert!(default_should_retry(StatusCode::TOO_MANY_REQUESTS));
292 assert!(default_should_retry(StatusCode::SERVICE_UNAVAILABLE));
293 assert!(!default_should_retry(StatusCode::NOT_FOUND));
294 }
295
296 #[test]
297 fn count_policy_max_attempts() {
298 assert_eq!(RetryPolicy::count(3).max_attempts(), 3);
299 }
300
301 #[test]
302 fn count_with_should_retry_stays_count() {
303 let policy = RetryPolicy::count(2)
304 .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
305 assert!(matches!(policy, RetryPolicy::Count { .. }));
306 assert!(policy.should_retry_response(&response_with_status(404), false));
307 assert!(!policy.should_retry_response(&response_with_status(503), false));
308 }
309
310 #[test]
311 fn linear_delay_is_constant() {
312 let policy = RetryPolicy::linear(3, Duration::from_millis(500));
313 assert_eq!(policy.delay_before_attempt(0), Duration::from_millis(500));
314 assert_eq!(policy.delay_before_attempt(2), Duration::from_millis(500));
315 }
316
317 #[test]
318 fn exponential_delay_caps_at_max() {
319 let policy = RetryPolicy::exponential(5, Duration::from_secs(1), Duration::from_secs(5));
320 assert_eq!(policy.delay_before_attempt(0), Duration::from_secs(1));
321 assert_eq!(policy.delay_before_attempt(10), Duration::from_secs(5));
322 }
323
324 #[test]
325 fn custom_should_retry_overrides_default() {
326 let policy = RetryPolicy::linear(2, Duration::from_millis(1))
327 .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
328 assert!(policy.should_retry_response(&response_with_status(404), false));
329 assert!(!policy.should_retry_response(&response_with_status(503), false));
330 }
331
332 #[test]
333 fn parse_retry_after_seconds() {
334 let mut headers = HeaderMap::new();
335 headers.insert(http::header::RETRY_AFTER, "3".parse().unwrap());
336 assert_eq!(parse_retry_after(&headers), Some(Duration::from_secs(3)));
337 }
338
339 #[test]
340 fn delay_after_response_uses_retry_after() {
341 let mut headers = HeaderMap::new();
342 headers.insert(http::header::RETRY_AFTER, "2".parse().unwrap());
343 let policy = RetryPolicy::linear(1, Duration::from_millis(100)).with_jitter(false);
344 assert_eq!(
345 policy.delay_after_response(0, &headers),
346 Duration::from_secs(2)
347 );
348 }
349
350 #[test]
351 fn retry_after_is_not_reduced_by_jitter() {
352 let mut headers = HeaderMap::new();
353 headers.insert(http::header::RETRY_AFTER, "5".parse().unwrap());
354 let policy = RetryPolicy::exponential(3, Duration::from_secs(1), Duration::from_secs(30));
355 assert!(policy.uses_jitter());
356 for _ in 0..20 {
357 assert_eq!(
358 policy.delay_after_response(0, &headers),
359 Duration::from_secs(5)
360 );
361 }
362 }
363
364 #[test]
365 fn parse_retry_after_future_http_date() {
366 let future = std::time::SystemTime::now() + Duration::from_secs(3600);
367 let mut headers = HeaderMap::new();
368 headers.insert(
369 http::header::RETRY_AFTER,
370 httpdate::fmt_http_date(future).parse().unwrap(),
371 );
372 let delay = parse_retry_after(&headers).expect("date delay");
373 assert!(delay > Duration::from_secs(3000) && delay <= Duration::from_secs(3600));
374 }
375
376 #[test]
377 fn jitter_stays_within_bounds() {
378 let base = Duration::from_secs(4);
379 for _ in 0..20 {
380 let jittered = apply_jitter(base);
381 assert!(jittered >= Duration::from_secs(2));
382 assert!(jittered <= base);
383 }
384 }
385
386 #[test]
387 fn parse_retry_after_invalid_is_none() {
388 let mut headers = HeaderMap::new();
389 headers.insert(http::header::RETRY_AFTER, "not-a-number".parse().unwrap());
390 assert!(parse_retry_after(&headers).is_none());
391 }
392
393 #[test]
394 fn exponential_uses_jitter_by_default() {
395 let policy = RetryPolicy::exponential(3, Duration::from_secs(1), Duration::from_secs(8));
396 assert!(policy.uses_jitter());
397 }
398
399 #[test]
400 fn linear_jitter_disabled_by_default() {
401 assert!(!RetryPolicy::linear(1, Duration::from_secs(1)).uses_jitter());
402 }
403}