1use std::sync::Arc;
12use std::time::Duration;
13
14use http::{HeaderMap, StatusCode};
15
16use crate::response::Response;
17
18pub type ShouldRetryFn = Arc<dyn Fn(&Response) -> bool + Send + Sync>;
20
21#[derive(Clone)]
36#[must_use = "apply with `ClientBuilder::retry` or `RequestBuilder::retry`"]
37pub enum RetryPolicy {
38 Count {
40 attempts: u32,
42 should_retry: Option<ShouldRetryFn>,
44 },
45 Linear {
47 attempts: u32,
49 delay: Duration,
51 should_retry: Option<ShouldRetryFn>,
52 jitter: bool,
54 },
55 Exponential {
57 attempts: u32,
59 base_delay: Duration,
61 max_delay: Duration,
63 should_retry: Option<ShouldRetryFn>,
64 jitter: bool,
65 },
66}
67
68impl RetryPolicy {
69 pub fn count(attempts: u32) -> Self {
74 Self::Count {
75 attempts,
76 should_retry: None,
77 }
78 }
79
80 pub fn linear(attempts: u32, delay: Duration) -> Self {
82 Self::Linear {
83 attempts,
84 delay,
85 should_retry: None,
86 jitter: false,
87 }
88 }
89
90 pub fn exponential(attempts: u32, base_delay: Duration, max_delay: Duration) -> Self {
92 Self::Exponential {
93 attempts,
94 base_delay,
95 max_delay,
96 should_retry: None,
97 jitter: true,
98 }
99 }
100
101 #[must_use = "chain with `ClientBuilder::retry` or `RequestBuilder::retry`"]
105 pub fn with_jitter(mut self, jitter: bool) -> Self {
106 match &mut self {
107 Self::Linear { jitter: j, .. } | Self::Exponential { jitter: j, .. } => *j = jitter,
108 Self::Count { .. } => {}
109 }
110 self
111 }
112
113 #[must_use = "chain with `ClientBuilder::retry` or `RequestBuilder::retry`"]
115 pub fn with_should_retry(self, f: ShouldRetryFn) -> Self {
116 match self {
117 Self::Count { attempts, .. } => Self::Count {
118 attempts,
119 should_retry: Some(f),
120 },
121 Self::Linear {
122 attempts,
123 delay,
124 jitter,
125 ..
126 } => Self::Linear {
127 attempts,
128 delay,
129 should_retry: Some(f),
130 jitter,
131 },
132 Self::Exponential {
133 attempts,
134 base_delay,
135 max_delay,
136 jitter,
137 ..
138 } => Self::Exponential {
139 attempts,
140 base_delay,
141 max_delay,
142 should_retry: Some(f),
143 jitter,
144 },
145 }
146 }
147
148 pub fn max_attempts(&self) -> u32 {
150 match self {
151 Self::Count { attempts, .. }
152 | Self::Linear { attempts, .. }
153 | Self::Exponential { attempts, .. } => *attempts,
154 }
155 }
156
157 pub(crate) fn delay_before_attempt(&self, attempt: u32) -> Duration {
158 match self {
159 Self::Count { .. } => Duration::from_secs(1),
160 Self::Linear { delay, .. } => *delay,
161 Self::Exponential {
162 base_delay,
163 max_delay,
164 ..
165 } => {
166 let exp = base_delay.saturating_mul(2u32.saturating_pow(attempt));
167 exp.min(*max_delay)
168 }
169 }
170 }
171
172 pub(crate) fn delay_after_response(&self, attempt: u32, headers: &HeaderMap) -> Duration {
174 let base = self.delay_before_attempt(attempt);
175 let delay = parse_retry_after(headers).unwrap_or(base);
176 if self.uses_jitter() {
177 apply_jitter(delay)
178 } else {
179 delay
180 }
181 }
182
183 pub(crate) fn uses_jitter(&self) -> bool {
184 match self {
185 Self::Count { .. } => true,
186 Self::Linear { jitter, .. } | Self::Exponential { jitter, .. } => *jitter,
187 }
188 }
189
190 pub(crate) fn has_custom_should_retry(&self) -> bool {
192 matches!(
193 self,
194 Self::Count {
195 should_retry: Some(_),
196 ..
197 } | Self::Linear {
198 should_retry: Some(_),
199 ..
200 } | Self::Exponential {
201 should_retry: Some(_),
202 ..
203 }
204 )
205 }
206
207 pub(crate) fn should_retry_response(
208 &self,
209 response: &Response,
210 transport_failed: bool,
211 ) -> bool {
212 if transport_failed {
213 return true;
214 }
215
216 let custom = match self {
217 Self::Count { should_retry, .. }
218 | Self::Linear { should_retry, .. }
219 | Self::Exponential { should_retry, .. } => should_retry.as_ref(),
220 };
221
222 if let Some(f) = custom {
223 return f(response);
224 }
225
226 default_should_retry(response.status())
227 }
228}
229
230pub fn default_should_retry(status: StatusCode) -> bool {
232 matches!(status.as_u16(), 408 | 429 | 502 | 503 | 504)
233}
234
235pub fn parse_retry_after(headers: &HeaderMap) -> Option<Duration> {
237 let value = headers.get(http::header::RETRY_AFTER)?.to_str().ok()?;
238 let secs = value.trim().parse::<u64>().ok()?;
239 Some(Duration::from_secs(secs))
240}
241
242fn apply_jitter(delay: Duration) -> Duration {
243 let nanos = delay.as_nanos().min(u128::from(u64::MAX)) as u64;
244 if nanos == 0 {
245 return delay;
246 }
247 let half = nanos / 2;
248 let span = nanos.saturating_sub(half).max(1);
249 Duration::from_nanos(half + fastrand::u64(..span))
250}
251
252pub(crate) use crate::cancel::sleep_or_cancel;
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use crate::response::Response;
258 use http::StatusCode;
259
260 fn response_with_status(status: u16) -> Response {
261 Response::new(
262 StatusCode::from_u16(status).unwrap(),
263 http::HeaderMap::new(),
264 bytes::Bytes::new(),
265 None,
266 #[cfg(feature = "json")]
267 None,
268 )
269 }
270
271 #[test]
272 fn default_should_retry_codes() {
273 assert!(default_should_retry(StatusCode::REQUEST_TIMEOUT));
274 assert!(default_should_retry(StatusCode::TOO_MANY_REQUESTS));
275 assert!(default_should_retry(StatusCode::SERVICE_UNAVAILABLE));
276 assert!(!default_should_retry(StatusCode::NOT_FOUND));
277 }
278
279 #[test]
280 fn count_policy_max_attempts() {
281 assert_eq!(RetryPolicy::count(3).max_attempts(), 3);
282 }
283
284 #[test]
285 fn count_with_should_retry_stays_count() {
286 let policy = RetryPolicy::count(2)
287 .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
288 assert!(matches!(policy, RetryPolicy::Count { .. }));
289 assert!(policy.should_retry_response(&response_with_status(404), false));
290 assert!(!policy.should_retry_response(&response_with_status(503), false));
291 }
292
293 #[test]
294 fn linear_delay_is_constant() {
295 let policy = RetryPolicy::linear(3, Duration::from_millis(500));
296 assert_eq!(policy.delay_before_attempt(0), Duration::from_millis(500));
297 assert_eq!(policy.delay_before_attempt(2), Duration::from_millis(500));
298 }
299
300 #[test]
301 fn exponential_delay_caps_at_max() {
302 let policy = RetryPolicy::exponential(5, Duration::from_secs(1), Duration::from_secs(5));
303 assert_eq!(policy.delay_before_attempt(0), Duration::from_secs(1));
304 assert_eq!(policy.delay_before_attempt(10), Duration::from_secs(5));
305 }
306
307 #[test]
308 fn custom_should_retry_overrides_default() {
309 let policy = RetryPolicy::linear(2, Duration::from_millis(1))
310 .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
311 assert!(policy.should_retry_response(&response_with_status(404), false));
312 assert!(!policy.should_retry_response(&response_with_status(503), false));
313 }
314
315 #[test]
316 fn parse_retry_after_seconds() {
317 let mut headers = HeaderMap::new();
318 headers.insert(http::header::RETRY_AFTER, "3".parse().unwrap());
319 assert_eq!(parse_retry_after(&headers), Some(Duration::from_secs(3)));
320 }
321
322 #[test]
323 fn delay_after_response_uses_retry_after() {
324 let mut headers = HeaderMap::new();
325 headers.insert(http::header::RETRY_AFTER, "2".parse().unwrap());
326 let policy = RetryPolicy::linear(1, Duration::from_millis(100)).with_jitter(false);
327 assert_eq!(
328 policy.delay_after_response(0, &headers),
329 Duration::from_secs(2)
330 );
331 }
332
333 #[test]
334 fn jitter_stays_within_bounds() {
335 let base = Duration::from_secs(4);
336 for _ in 0..20 {
337 let jittered = apply_jitter(base);
338 assert!(jittered >= Duration::from_secs(2));
339 assert!(jittered <= base);
340 }
341 }
342
343 #[test]
344 fn parse_retry_after_invalid_is_none() {
345 let mut headers = HeaderMap::new();
346 headers.insert(http::header::RETRY_AFTER, "not-a-number".parse().unwrap());
347 assert!(parse_retry_after(&headers).is_none());
348 }
349
350 #[test]
351 fn exponential_uses_jitter_by_default() {
352 let policy = RetryPolicy::exponential(3, Duration::from_secs(1), Duration::from_secs(8));
353 assert!(policy.uses_jitter());
354 }
355
356 #[test]
357 fn linear_jitter_disabled_by_default() {
358 assert!(!RetryPolicy::linear(1, Duration::from_secs(1)).uses_jitter());
359 }
360}