1use std::sync::Arc;
12use std::time::Duration;
13
14use http::{HeaderMap, StatusCode};
15
16use crate::response::Response;
17
18pub type ShouldRetryFn = Arc<dyn Fn(&Response) -> bool + Send + Sync>;
20
21#[derive(Clone)]
36pub enum RetryPolicy {
37 Count {
39 attempts: u32,
41 should_retry: Option<ShouldRetryFn>,
43 },
44 Linear {
46 attempts: u32,
48 delay: Duration,
50 should_retry: Option<ShouldRetryFn>,
51 jitter: bool,
53 },
54 Exponential {
56 attempts: u32,
58 base_delay: Duration,
60 max_delay: Duration,
62 should_retry: Option<ShouldRetryFn>,
63 jitter: bool,
64 },
65}
66
67impl RetryPolicy {
68 pub fn count(attempts: u32) -> Self {
70 Self::Count {
71 attempts,
72 should_retry: None,
73 }
74 }
75
76 pub fn linear(attempts: u32, delay: Duration) -> Self {
78 Self::Linear {
79 attempts,
80 delay,
81 should_retry: None,
82 jitter: false,
83 }
84 }
85
86 pub fn exponential(attempts: u32, base_delay: Duration, max_delay: Duration) -> Self {
88 Self::Exponential {
89 attempts,
90 base_delay,
91 max_delay,
92 should_retry: None,
93 jitter: true,
94 }
95 }
96
97 pub fn with_jitter(mut self, jitter: bool) -> Self {
99 match &mut self {
100 Self::Linear { jitter: j, .. } | Self::Exponential { jitter: j, .. } => *j = jitter,
101 Self::Count { .. } => {}
102 }
103 self
104 }
105
106 pub fn with_should_retry(self, f: ShouldRetryFn) -> Self {
108 match self {
109 Self::Count { attempts, .. } => Self::Count {
110 attempts,
111 should_retry: Some(f),
112 },
113 Self::Linear {
114 attempts,
115 delay,
116 jitter,
117 ..
118 } => Self::Linear {
119 attempts,
120 delay,
121 should_retry: Some(f),
122 jitter,
123 },
124 Self::Exponential {
125 attempts,
126 base_delay,
127 max_delay,
128 jitter,
129 ..
130 } => Self::Exponential {
131 attempts,
132 base_delay,
133 max_delay,
134 should_retry: Some(f),
135 jitter,
136 },
137 }
138 }
139
140 pub fn max_attempts(&self) -> u32 {
142 match self {
143 Self::Count { attempts, .. }
144 | Self::Linear { attempts, .. }
145 | Self::Exponential { attempts, .. } => *attempts,
146 }
147 }
148
149 pub(crate) fn delay_before_attempt(&self, attempt: u32) -> Duration {
150 match self {
151 Self::Count { .. } => Duration::from_secs(1),
152 Self::Linear { delay, .. } => *delay,
153 Self::Exponential {
154 base_delay,
155 max_delay,
156 ..
157 } => {
158 let exp = base_delay.saturating_mul(2u32.saturating_pow(attempt));
159 exp.min(*max_delay)
160 }
161 }
162 }
163
164 pub(crate) fn delay_after_response(&self, attempt: u32, headers: &HeaderMap) -> Duration {
166 let base = self.delay_before_attempt(attempt);
167 let delay = parse_retry_after(headers).unwrap_or(base);
168 if self.uses_jitter() {
169 apply_jitter(delay)
170 } else {
171 delay
172 }
173 }
174
175 pub(crate) fn uses_jitter(&self) -> bool {
176 match self {
177 Self::Count { .. } => true,
178 Self::Linear { jitter, .. } | Self::Exponential { jitter, .. } => *jitter,
179 }
180 }
181
182 pub(crate) fn has_custom_should_retry(&self) -> bool {
184 matches!(
185 self,
186 Self::Count {
187 should_retry: Some(_),
188 ..
189 } | Self::Linear {
190 should_retry: Some(_),
191 ..
192 } | Self::Exponential {
193 should_retry: Some(_),
194 ..
195 }
196 )
197 }
198
199 pub(crate) fn should_retry_response(
200 &self,
201 response: &Response,
202 transport_failed: bool,
203 ) -> bool {
204 if transport_failed {
205 return true;
206 }
207
208 let custom = match self {
209 Self::Count { should_retry, .. }
210 | Self::Linear { should_retry, .. }
211 | Self::Exponential { should_retry, .. } => should_retry.as_ref(),
212 };
213
214 if let Some(f) = custom {
215 return f(response);
216 }
217
218 default_should_retry(response.status())
219 }
220}
221
222pub fn default_should_retry(status: StatusCode) -> bool {
224 matches!(status.as_u16(), 408 | 429 | 502 | 503 | 504)
225}
226
227pub fn parse_retry_after(headers: &HeaderMap) -> Option<Duration> {
229 let value = headers.get(http::header::RETRY_AFTER)?.to_str().ok()?;
230 let secs = value.trim().parse::<u64>().ok()?;
231 Some(Duration::from_secs(secs))
232}
233
234fn apply_jitter(delay: Duration) -> Duration {
235 let nanos = delay.as_nanos().min(u128::from(u64::MAX)) as u64;
236 if nanos == 0 {
237 return delay;
238 }
239 let half = nanos / 2;
240 let span = nanos.saturating_sub(half).max(1);
241 Duration::from_nanos(half + fastrand::u64(..span))
242}
243
244pub(crate) use crate::cancel::sleep_or_cancel;
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use crate::response::Response;
250 use http::StatusCode;
251
252 fn response_with_status(status: u16) -> Response {
253 Response::new(
254 StatusCode::from_u16(status).unwrap(),
255 http::HeaderMap::new(),
256 bytes::Bytes::new(),
257 None,
258 #[cfg(feature = "json")]
259 None,
260 )
261 }
262
263 #[test]
264 fn default_should_retry_codes() {
265 assert!(default_should_retry(StatusCode::REQUEST_TIMEOUT));
266 assert!(default_should_retry(StatusCode::TOO_MANY_REQUESTS));
267 assert!(default_should_retry(StatusCode::SERVICE_UNAVAILABLE));
268 assert!(!default_should_retry(StatusCode::NOT_FOUND));
269 }
270
271 #[test]
272 fn count_policy_max_attempts() {
273 assert_eq!(RetryPolicy::count(3).max_attempts(), 3);
274 }
275
276 #[test]
277 fn count_with_should_retry_stays_count() {
278 let policy = RetryPolicy::count(2)
279 .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
280 assert!(matches!(policy, RetryPolicy::Count { .. }));
281 assert!(policy.should_retry_response(&response_with_status(404), false));
282 assert!(!policy.should_retry_response(&response_with_status(503), false));
283 }
284
285 #[test]
286 fn linear_delay_is_constant() {
287 let policy = RetryPolicy::linear(3, Duration::from_millis(500));
288 assert_eq!(policy.delay_before_attempt(0), Duration::from_millis(500));
289 assert_eq!(policy.delay_before_attempt(2), Duration::from_millis(500));
290 }
291
292 #[test]
293 fn exponential_delay_caps_at_max() {
294 let policy = RetryPolicy::exponential(5, Duration::from_secs(1), Duration::from_secs(5));
295 assert_eq!(policy.delay_before_attempt(0), Duration::from_secs(1));
296 assert_eq!(policy.delay_before_attempt(10), Duration::from_secs(5));
297 }
298
299 #[test]
300 fn custom_should_retry_overrides_default() {
301 let policy = RetryPolicy::linear(2, Duration::from_millis(1))
302 .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
303 assert!(policy.should_retry_response(&response_with_status(404), false));
304 assert!(!policy.should_retry_response(&response_with_status(503), false));
305 }
306
307 #[test]
308 fn parse_retry_after_seconds() {
309 let mut headers = HeaderMap::new();
310 headers.insert(http::header::RETRY_AFTER, "3".parse().unwrap());
311 assert_eq!(parse_retry_after(&headers), Some(Duration::from_secs(3)));
312 }
313
314 #[test]
315 fn delay_after_response_uses_retry_after() {
316 let mut headers = HeaderMap::new();
317 headers.insert(http::header::RETRY_AFTER, "2".parse().unwrap());
318 let policy = RetryPolicy::linear(1, Duration::from_millis(100)).with_jitter(false);
319 assert_eq!(
320 policy.delay_after_response(0, &headers),
321 Duration::from_secs(2)
322 );
323 }
324
325 #[test]
326 fn jitter_stays_within_bounds() {
327 let base = Duration::from_secs(4);
328 for _ in 0..20 {
329 let jittered = apply_jitter(base);
330 assert!(jittered >= Duration::from_secs(2));
331 assert!(jittered <= base);
332 }
333 }
334
335 #[test]
336 fn parse_retry_after_invalid_is_none() {
337 let mut headers = HeaderMap::new();
338 headers.insert(http::header::RETRY_AFTER, "not-a-number".parse().unwrap());
339 assert!(parse_retry_after(&headers).is_none());
340 }
341
342 #[test]
343 fn exponential_uses_jitter_by_default() {
344 let policy = RetryPolicy::exponential(3, Duration::from_secs(1), Duration::from_secs(8));
345 assert!(policy.uses_jitter());
346 }
347
348 #[test]
349 fn linear_jitter_disabled_by_default() {
350 assert!(!RetryPolicy::linear(1, Duration::from_secs(1)).uses_jitter());
351 }
352}