1use std::sync::Arc;
7use std::time::Duration;
8
9use http::{HeaderMap, StatusCode};
10
11use crate::response::Response;
12
13pub type ShouldRetryFn = Arc<dyn Fn(&Response) -> bool + Send + Sync>;
15
16#[derive(Clone)]
31pub enum RetryPolicy {
32 Count {
34 attempts: u32,
36 should_retry: Option<ShouldRetryFn>,
38 },
39 Linear {
41 attempts: u32,
43 delay: Duration,
45 should_retry: Option<ShouldRetryFn>,
46 jitter: bool,
48 },
49 Exponential {
51 attempts: u32,
53 base_delay: Duration,
55 max_delay: Duration,
57 should_retry: Option<ShouldRetryFn>,
58 jitter: bool,
59 },
60}
61
62impl RetryPolicy {
63 pub fn count(attempts: u32) -> Self {
65 Self::Count {
66 attempts,
67 should_retry: None,
68 }
69 }
70
71 pub fn linear(attempts: u32, delay: Duration) -> Self {
73 Self::Linear {
74 attempts,
75 delay,
76 should_retry: None,
77 jitter: false,
78 }
79 }
80
81 pub fn exponential(attempts: u32, base_delay: Duration, max_delay: Duration) -> Self {
83 Self::Exponential {
84 attempts,
85 base_delay,
86 max_delay,
87 should_retry: None,
88 jitter: true,
89 }
90 }
91
92 pub fn with_jitter(mut self, jitter: bool) -> Self {
94 match &mut self {
95 Self::Linear { jitter: j, .. } | Self::Exponential { jitter: j, .. } => *j = jitter,
96 Self::Count { .. } => {}
97 }
98 self
99 }
100
101 pub fn with_should_retry(self, f: ShouldRetryFn) -> Self {
103 match self {
104 Self::Count { attempts, .. } => Self::Count {
105 attempts,
106 should_retry: Some(f),
107 },
108 Self::Linear {
109 attempts,
110 delay,
111 jitter,
112 ..
113 } => Self::Linear {
114 attempts,
115 delay,
116 should_retry: Some(f),
117 jitter,
118 },
119 Self::Exponential {
120 attempts,
121 base_delay,
122 max_delay,
123 jitter,
124 ..
125 } => Self::Exponential {
126 attempts,
127 base_delay,
128 max_delay,
129 should_retry: Some(f),
130 jitter,
131 },
132 }
133 }
134
135 pub fn max_attempts(&self) -> u32 {
137 match self {
138 Self::Count { attempts, .. }
139 | Self::Linear { attempts, .. }
140 | Self::Exponential { attempts, .. } => *attempts,
141 }
142 }
143
144 pub(crate) fn delay_before_attempt(&self, attempt: u32) -> Duration {
145 match self {
146 Self::Count { .. } => Duration::from_secs(1),
147 Self::Linear { delay, .. } => *delay,
148 Self::Exponential {
149 base_delay,
150 max_delay,
151 ..
152 } => {
153 let exp = base_delay.saturating_mul(2u32.saturating_pow(attempt));
154 exp.min(*max_delay)
155 }
156 }
157 }
158
159 pub(crate) fn delay_after_response(&self, attempt: u32, headers: &HeaderMap) -> Duration {
161 let base = self.delay_before_attempt(attempt);
162 let delay = parse_retry_after(headers).unwrap_or(base);
163 if self.uses_jitter() {
164 apply_jitter(delay)
165 } else {
166 delay
167 }
168 }
169
170 pub(crate) fn uses_jitter(&self) -> bool {
171 match self {
172 Self::Count { .. } => true,
173 Self::Linear { jitter, .. } | Self::Exponential { jitter, .. } => *jitter,
174 }
175 }
176
177 pub(crate) fn should_retry_response(
178 &self,
179 response: &Response,
180 transport_failed: bool,
181 ) -> bool {
182 if transport_failed {
183 return true;
184 }
185
186 let custom = match self {
187 Self::Count { should_retry, .. }
188 | Self::Linear { should_retry, .. }
189 | Self::Exponential { should_retry, .. } => should_retry.as_ref(),
190 };
191
192 if let Some(f) = custom {
193 return f(response);
194 }
195
196 default_should_retry(response.status())
197 }
198}
199
200pub fn default_should_retry(status: StatusCode) -> bool {
202 matches!(status.as_u16(), 408 | 429 | 502 | 503 | 504)
203}
204
205pub fn parse_retry_after(headers: &HeaderMap) -> Option<Duration> {
207 let value = headers.get(http::header::RETRY_AFTER)?.to_str().ok()?;
208 let secs = value.trim().parse::<u64>().ok()?;
209 Some(Duration::from_secs(secs))
210}
211
212fn apply_jitter(delay: Duration) -> Duration {
213 let nanos = delay.as_nanos().min(u128::from(u64::MAX)) as u64;
214 if nanos == 0 {
215 return delay;
216 }
217 let half = nanos / 2;
218 let span = nanos.saturating_sub(half).max(1);
219 Duration::from_nanos(half + fastrand::u64(..span))
220}
221
222pub(crate) use crate::cancel::sleep_or_cancel;
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::response::Response;
228 use http::StatusCode;
229
230 fn response_with_status(status: u16) -> Response {
231 Response::new(
232 StatusCode::from_u16(status).unwrap(),
233 http::HeaderMap::new(),
234 bytes::Bytes::new(),
235 None,
236 #[cfg(feature = "json")]
237 None,
238 )
239 }
240
241 #[test]
242 fn default_should_retry_codes() {
243 assert!(default_should_retry(StatusCode::REQUEST_TIMEOUT));
244 assert!(default_should_retry(StatusCode::TOO_MANY_REQUESTS));
245 assert!(default_should_retry(StatusCode::SERVICE_UNAVAILABLE));
246 assert!(!default_should_retry(StatusCode::NOT_FOUND));
247 }
248
249 #[test]
250 fn count_policy_max_attempts() {
251 assert_eq!(RetryPolicy::count(3).max_attempts(), 3);
252 }
253
254 #[test]
255 fn count_with_should_retry_stays_count() {
256 let policy = RetryPolicy::count(2)
257 .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
258 assert!(matches!(policy, RetryPolicy::Count { .. }));
259 assert!(policy.should_retry_response(&response_with_status(404), false));
260 assert!(!policy.should_retry_response(&response_with_status(503), false));
261 }
262
263 #[test]
264 fn linear_delay_is_constant() {
265 let policy = RetryPolicy::linear(3, Duration::from_millis(500));
266 assert_eq!(policy.delay_before_attempt(0), Duration::from_millis(500));
267 assert_eq!(policy.delay_before_attempt(2), Duration::from_millis(500));
268 }
269
270 #[test]
271 fn exponential_delay_caps_at_max() {
272 let policy = RetryPolicy::exponential(5, Duration::from_secs(1), Duration::from_secs(5));
273 assert_eq!(policy.delay_before_attempt(0), Duration::from_secs(1));
274 assert_eq!(policy.delay_before_attempt(10), Duration::from_secs(5));
275 }
276
277 #[test]
278 fn custom_should_retry_overrides_default() {
279 let policy = RetryPolicy::linear(2, Duration::from_millis(1))
280 .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
281 assert!(policy.should_retry_response(&response_with_status(404), false));
282 assert!(!policy.should_retry_response(&response_with_status(503), false));
283 }
284
285 #[test]
286 fn parse_retry_after_seconds() {
287 let mut headers = HeaderMap::new();
288 headers.insert(http::header::RETRY_AFTER, "3".parse().unwrap());
289 assert_eq!(parse_retry_after(&headers), Some(Duration::from_secs(3)));
290 }
291
292 #[test]
293 fn delay_after_response_uses_retry_after() {
294 let mut headers = HeaderMap::new();
295 headers.insert(http::header::RETRY_AFTER, "2".parse().unwrap());
296 let policy = RetryPolicy::linear(1, Duration::from_millis(100)).with_jitter(false);
297 assert_eq!(
298 policy.delay_after_response(0, &headers),
299 Duration::from_secs(2)
300 );
301 }
302
303 #[test]
304 fn jitter_stays_within_bounds() {
305 let base = Duration::from_secs(4);
306 for _ in 0..20 {
307 let jittered = apply_jitter(base);
308 assert!(jittered >= Duration::from_secs(2));
309 assert!(jittered <= base);
310 }
311 }
312
313 #[test]
314 fn parse_retry_after_invalid_is_none() {
315 let mut headers = HeaderMap::new();
316 headers.insert(http::header::RETRY_AFTER, "not-a-number".parse().unwrap());
317 assert!(parse_retry_after(&headers).is_none());
318 }
319
320 #[test]
321 fn exponential_uses_jitter_by_default() {
322 let policy = RetryPolicy::exponential(3, Duration::from_secs(1), Duration::from_secs(8));
323 assert!(policy.uses_jitter());
324 }
325
326 #[test]
327 fn linear_jitter_disabled_by_default() {
328 assert!(!RetryPolicy::linear(1, Duration::from_secs(1)).uses_jitter());
329 }
330}