1use std::sync::Arc;
14use std::time::Duration;
15
16use http::{HeaderMap, StatusCode};
17
18use crate::response::Response;
19
20pub type ShouldRetryFn = Arc<dyn Fn(&Response) -> bool + Send + Sync>;
22
23#[derive(Clone)]
38#[must_use = "apply with `ClientBuilder::retry` or `RequestBuilder::retry`"]
39pub enum RetryPolicy {
40 Count {
42 attempts: u32,
44 should_retry: Option<ShouldRetryFn>,
46 },
47 Linear {
49 attempts: u32,
51 delay: Duration,
53 should_retry: Option<ShouldRetryFn>,
54 jitter: bool,
56 },
57 Exponential {
59 attempts: u32,
61 base_delay: Duration,
63 max_delay: Duration,
65 should_retry: Option<ShouldRetryFn>,
66 jitter: bool,
67 },
68}
69
70impl RetryPolicy {
71 pub fn count(attempts: u32) -> Self {
76 Self::Count {
77 attempts,
78 should_retry: None,
79 }
80 }
81
82 pub fn linear(attempts: u32, delay: Duration) -> Self {
84 Self::Linear {
85 attempts,
86 delay,
87 should_retry: None,
88 jitter: false,
89 }
90 }
91
92 pub fn exponential(attempts: u32, base_delay: Duration, max_delay: Duration) -> Self {
94 Self::Exponential {
95 attempts,
96 base_delay,
97 max_delay,
98 should_retry: None,
99 jitter: true,
100 }
101 }
102
103 #[must_use = "chain with `ClientBuilder::retry` or `RequestBuilder::retry`"]
107 pub fn with_jitter(mut self, jitter: bool) -> Self {
108 match &mut self {
109 Self::Linear { jitter: j, .. } | Self::Exponential { jitter: j, .. } => *j = jitter,
110 Self::Count { .. } => {}
111 }
112 self
113 }
114
115 #[must_use = "chain with `ClientBuilder::retry` or `RequestBuilder::retry`"]
117 pub fn with_should_retry(self, f: ShouldRetryFn) -> Self {
118 match self {
119 Self::Count { attempts, .. } => Self::Count {
120 attempts,
121 should_retry: Some(f),
122 },
123 Self::Linear {
124 attempts,
125 delay,
126 jitter,
127 ..
128 } => Self::Linear {
129 attempts,
130 delay,
131 should_retry: Some(f),
132 jitter,
133 },
134 Self::Exponential {
135 attempts,
136 base_delay,
137 max_delay,
138 jitter,
139 ..
140 } => Self::Exponential {
141 attempts,
142 base_delay,
143 max_delay,
144 should_retry: Some(f),
145 jitter,
146 },
147 }
148 }
149
150 pub fn max_attempts(&self) -> u32 {
152 match self {
153 Self::Count { attempts, .. }
154 | Self::Linear { attempts, .. }
155 | Self::Exponential { attempts, .. } => *attempts,
156 }
157 }
158
159 pub(crate) fn delay_before_attempt(&self, attempt: u32) -> Duration {
160 match self {
161 Self::Count { .. } => Duration::from_secs(1),
162 Self::Linear { delay, .. } => *delay,
163 Self::Exponential {
164 base_delay,
165 max_delay,
166 ..
167 } => {
168 let exp = base_delay.saturating_mul(2u32.saturating_pow(attempt));
169 exp.min(*max_delay)
170 }
171 }
172 }
173
174 pub(crate) fn delay_after_response(&self, attempt: u32, headers: &HeaderMap) -> Duration {
176 let base = self.delay_before_attempt(attempt);
177 let delay = parse_retry_after(headers).unwrap_or(base);
178 if self.uses_jitter() {
179 apply_jitter(delay)
180 } else {
181 delay
182 }
183 }
184
185 pub(crate) fn uses_jitter(&self) -> bool {
186 match self {
187 Self::Count { .. } => true,
188 Self::Linear { jitter, .. } | Self::Exponential { jitter, .. } => *jitter,
189 }
190 }
191
192 pub(crate) fn has_custom_should_retry(&self) -> bool {
194 matches!(
195 self,
196 Self::Count {
197 should_retry: Some(_),
198 ..
199 } | Self::Linear {
200 should_retry: Some(_),
201 ..
202 } | Self::Exponential {
203 should_retry: Some(_),
204 ..
205 }
206 )
207 }
208
209 pub(crate) fn should_retry_response(
210 &self,
211 response: &Response,
212 transport_failed: bool,
213 ) -> bool {
214 if transport_failed {
215 return true;
216 }
217
218 let custom = match self {
219 Self::Count { should_retry, .. }
220 | Self::Linear { should_retry, .. }
221 | Self::Exponential { should_retry, .. } => should_retry.as_ref(),
222 };
223
224 if let Some(f) = custom {
225 return f(response);
226 }
227
228 default_should_retry(response.status())
229 }
230}
231
232pub fn default_should_retry(status: StatusCode) -> bool {
234 matches!(status.as_u16(), 408 | 429 | 502 | 503 | 504)
235}
236
237pub fn parse_retry_after(headers: &HeaderMap) -> Option<Duration> {
239 let value = headers.get(http::header::RETRY_AFTER)?.to_str().ok()?;
240 let secs = value.trim().parse::<u64>().ok()?;
241 Some(Duration::from_secs(secs))
242}
243
244fn apply_jitter(delay: Duration) -> Duration {
245 let nanos = delay.as_nanos().min(u128::from(u64::MAX)) as u64;
246 if nanos == 0 {
247 return delay;
248 }
249 let half = nanos / 2;
250 let span = nanos.saturating_sub(half).max(1);
251 Duration::from_nanos(half + fastrand::u64(..span))
252}
253
254pub(crate) use crate::cancel::sleep_or_cancel;
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::response::Response;
260 use http::StatusCode;
261
262 fn response_with_status(status: u16) -> Response {
263 Response::new(
264 StatusCode::from_u16(status).unwrap(),
265 http::HeaderMap::new(),
266 bytes::Bytes::new(),
267 None,
268 #[cfg(feature = "json")]
269 None,
270 )
271 }
272
273 #[test]
274 fn default_should_retry_codes() {
275 assert!(default_should_retry(StatusCode::REQUEST_TIMEOUT));
276 assert!(default_should_retry(StatusCode::TOO_MANY_REQUESTS));
277 assert!(default_should_retry(StatusCode::SERVICE_UNAVAILABLE));
278 assert!(!default_should_retry(StatusCode::NOT_FOUND));
279 }
280
281 #[test]
282 fn count_policy_max_attempts() {
283 assert_eq!(RetryPolicy::count(3).max_attempts(), 3);
284 }
285
286 #[test]
287 fn count_with_should_retry_stays_count() {
288 let policy = RetryPolicy::count(2)
289 .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
290 assert!(matches!(policy, RetryPolicy::Count { .. }));
291 assert!(policy.should_retry_response(&response_with_status(404), false));
292 assert!(!policy.should_retry_response(&response_with_status(503), false));
293 }
294
295 #[test]
296 fn linear_delay_is_constant() {
297 let policy = RetryPolicy::linear(3, Duration::from_millis(500));
298 assert_eq!(policy.delay_before_attempt(0), Duration::from_millis(500));
299 assert_eq!(policy.delay_before_attempt(2), Duration::from_millis(500));
300 }
301
302 #[test]
303 fn exponential_delay_caps_at_max() {
304 let policy = RetryPolicy::exponential(5, Duration::from_secs(1), Duration::from_secs(5));
305 assert_eq!(policy.delay_before_attempt(0), Duration::from_secs(1));
306 assert_eq!(policy.delay_before_attempt(10), Duration::from_secs(5));
307 }
308
309 #[test]
310 fn custom_should_retry_overrides_default() {
311 let policy = RetryPolicy::linear(2, Duration::from_millis(1))
312 .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
313 assert!(policy.should_retry_response(&response_with_status(404), false));
314 assert!(!policy.should_retry_response(&response_with_status(503), false));
315 }
316
317 #[test]
318 fn parse_retry_after_seconds() {
319 let mut headers = HeaderMap::new();
320 headers.insert(http::header::RETRY_AFTER, "3".parse().unwrap());
321 assert_eq!(parse_retry_after(&headers), Some(Duration::from_secs(3)));
322 }
323
324 #[test]
325 fn delay_after_response_uses_retry_after() {
326 let mut headers = HeaderMap::new();
327 headers.insert(http::header::RETRY_AFTER, "2".parse().unwrap());
328 let policy = RetryPolicy::linear(1, Duration::from_millis(100)).with_jitter(false);
329 assert_eq!(
330 policy.delay_after_response(0, &headers),
331 Duration::from_secs(2)
332 );
333 }
334
335 #[test]
336 fn jitter_stays_within_bounds() {
337 let base = Duration::from_secs(4);
338 for _ in 0..20 {
339 let jittered = apply_jitter(base);
340 assert!(jittered >= Duration::from_secs(2));
341 assert!(jittered <= base);
342 }
343 }
344
345 #[test]
346 fn parse_retry_after_invalid_is_none() {
347 let mut headers = HeaderMap::new();
348 headers.insert(http::header::RETRY_AFTER, "not-a-number".parse().unwrap());
349 assert!(parse_retry_after(&headers).is_none());
350 }
351
352 #[test]
353 fn exponential_uses_jitter_by_default() {
354 let policy = RetryPolicy::exponential(3, Duration::from_secs(1), Duration::from_secs(8));
355 assert!(policy.uses_jitter());
356 }
357
358 #[test]
359 fn linear_jitter_disabled_by_default() {
360 assert!(!RetryPolicy::linear(1, Duration::from_secs(1)).uses_jitter());
361 }
362}