1use std::sync::Arc;
2use std::time::Duration;
3
4use http::{HeaderMap, StatusCode};
5
6use crate::response::Response;
7
8pub type ShouldRetryFn = Arc<dyn Fn(&Response) -> bool + Send + Sync>;
10
11#[derive(Clone)]
16pub enum RetryPolicy {
17 Count {
19 attempts: u32,
20 should_retry: Option<ShouldRetryFn>,
21 },
22 Linear {
23 attempts: u32,
24 delay: Duration,
25 should_retry: Option<ShouldRetryFn>,
26 jitter: bool,
27 },
28 Exponential {
29 attempts: u32,
30 base_delay: Duration,
31 max_delay: Duration,
32 should_retry: Option<ShouldRetryFn>,
33 jitter: bool,
34 },
35}
36
37impl RetryPolicy {
38 pub fn count(attempts: u32) -> Self {
39 Self::Count {
40 attempts,
41 should_retry: None,
42 }
43 }
44
45 pub fn linear(attempts: u32, delay: Duration) -> Self {
46 Self::Linear {
47 attempts,
48 delay,
49 should_retry: None,
50 jitter: false,
51 }
52 }
53
54 pub fn exponential(attempts: u32, base_delay: Duration, max_delay: Duration) -> Self {
55 Self::Exponential {
56 attempts,
57 base_delay,
58 max_delay,
59 should_retry: None,
60 jitter: true,
61 }
62 }
63
64 pub fn with_jitter(mut self, jitter: bool) -> Self {
66 match &mut self {
67 Self::Linear { jitter: j, .. } | Self::Exponential { jitter: j, .. } => *j = jitter,
68 Self::Count { .. } => {}
69 }
70 self
71 }
72
73 pub fn with_should_retry(self, f: ShouldRetryFn) -> Self {
74 match self {
75 Self::Count { attempts, .. } => Self::Count {
76 attempts,
77 should_retry: Some(f),
78 },
79 Self::Linear {
80 attempts,
81 delay,
82 jitter,
83 ..
84 } => Self::Linear {
85 attempts,
86 delay,
87 should_retry: Some(f),
88 jitter,
89 },
90 Self::Exponential {
91 attempts,
92 base_delay,
93 max_delay,
94 jitter,
95 ..
96 } => Self::Exponential {
97 attempts,
98 base_delay,
99 max_delay,
100 should_retry: Some(f),
101 jitter,
102 },
103 }
104 }
105
106 pub(crate) fn max_attempts(&self) -> u32 {
107 match self {
108 Self::Count { attempts, .. }
109 | Self::Linear { attempts, .. }
110 | Self::Exponential { attempts, .. } => *attempts,
111 }
112 }
113
114 pub(crate) fn delay_before_attempt(&self, attempt: u32) -> Duration {
115 match self {
116 Self::Count { .. } => Duration::from_secs(1),
117 Self::Linear { delay, .. } => *delay,
118 Self::Exponential {
119 base_delay,
120 max_delay,
121 ..
122 } => {
123 let exp = base_delay.saturating_mul(2u32.saturating_pow(attempt));
124 exp.min(*max_delay)
125 }
126 }
127 }
128
129 pub(crate) fn delay_after_response(&self, attempt: u32, headers: &HeaderMap) -> Duration {
131 let base = self.delay_before_attempt(attempt);
132 let delay = parse_retry_after(headers).unwrap_or(base);
133 if self.uses_jitter() {
134 apply_jitter(delay)
135 } else {
136 delay
137 }
138 }
139
140 pub(crate) fn uses_jitter(&self) -> bool {
141 match self {
142 Self::Count { .. } => true,
143 Self::Linear { jitter, .. } | Self::Exponential { jitter, .. } => *jitter,
144 }
145 }
146
147 pub(crate) fn should_retry_response(
148 &self,
149 response: &Response,
150 transport_failed: bool,
151 ) -> bool {
152 if transport_failed {
153 return true;
154 }
155
156 let custom = match self {
157 Self::Count { should_retry, .. }
158 | Self::Linear { should_retry, .. }
159 | Self::Exponential { should_retry, .. } => should_retry.as_ref(),
160 };
161
162 if let Some(f) = custom {
163 return f(response);
164 }
165
166 default_should_retry(response.status())
167 }
168}
169
170pub fn default_should_retry(status: StatusCode) -> bool {
171 matches!(status.as_u16(), 408 | 429 | 502 | 503 | 504)
172}
173
174pub fn parse_retry_after(headers: &HeaderMap) -> Option<Duration> {
176 let value = headers.get(http::header::RETRY_AFTER)?.to_str().ok()?;
177 let secs = value.trim().parse::<u64>().ok()?;
178 Some(Duration::from_secs(secs))
179}
180
181fn apply_jitter(delay: Duration) -> Duration {
182 let nanos = delay.as_nanos().min(u128::from(u64::MAX)) as u64;
183 if nanos == 0 {
184 return delay;
185 }
186 let half = nanos / 2;
187 let span = nanos.saturating_sub(half).max(1);
188 Duration::from_nanos(half + fastrand::u64(..span))
189}
190
191pub(crate) use crate::cancel::sleep_or_cancel;
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::response::Response;
197 use http::StatusCode;
198
199 fn response_with_status(status: u16) -> Response {
200 Response::new(
201 StatusCode::from_u16(status).unwrap(),
202 http::HeaderMap::new(),
203 bytes::Bytes::new(),
204 None,
205 #[cfg(feature = "json")]
206 None,
207 )
208 }
209
210 #[test]
211 fn default_should_retry_codes() {
212 assert!(default_should_retry(StatusCode::REQUEST_TIMEOUT));
213 assert!(default_should_retry(StatusCode::TOO_MANY_REQUESTS));
214 assert!(default_should_retry(StatusCode::SERVICE_UNAVAILABLE));
215 assert!(!default_should_retry(StatusCode::NOT_FOUND));
216 }
217
218 #[test]
219 fn count_policy_max_attempts() {
220 assert_eq!(RetryPolicy::count(3).max_attempts(), 3);
221 }
222
223 #[test]
224 fn count_with_should_retry_stays_count() {
225 let policy = RetryPolicy::count(2)
226 .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
227 assert!(matches!(policy, RetryPolicy::Count { .. }));
228 assert!(policy.should_retry_response(&response_with_status(404), false));
229 assert!(!policy.should_retry_response(&response_with_status(503), false));
230 }
231
232 #[test]
233 fn linear_delay_is_constant() {
234 let policy = RetryPolicy::linear(3, Duration::from_millis(500));
235 assert_eq!(policy.delay_before_attempt(0), Duration::from_millis(500));
236 assert_eq!(policy.delay_before_attempt(2), Duration::from_millis(500));
237 }
238
239 #[test]
240 fn exponential_delay_caps_at_max() {
241 let policy = RetryPolicy::exponential(5, Duration::from_secs(1), Duration::from_secs(5));
242 assert_eq!(policy.delay_before_attempt(0), Duration::from_secs(1));
243 assert_eq!(policy.delay_before_attempt(10), Duration::from_secs(5));
244 }
245
246 #[test]
247 fn custom_should_retry_overrides_default() {
248 let policy = RetryPolicy::linear(2, Duration::from_millis(1))
249 .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
250 assert!(policy.should_retry_response(&response_with_status(404), false));
251 assert!(!policy.should_retry_response(&response_with_status(503), false));
252 }
253
254 #[test]
255 fn parse_retry_after_seconds() {
256 let mut headers = HeaderMap::new();
257 headers.insert(http::header::RETRY_AFTER, "3".parse().unwrap());
258 assert_eq!(parse_retry_after(&headers), Some(Duration::from_secs(3)));
259 }
260
261 #[test]
262 fn delay_after_response_uses_retry_after() {
263 let mut headers = HeaderMap::new();
264 headers.insert(http::header::RETRY_AFTER, "2".parse().unwrap());
265 let policy = RetryPolicy::linear(1, Duration::from_millis(100)).with_jitter(false);
266 assert_eq!(
267 policy.delay_after_response(0, &headers),
268 Duration::from_secs(2)
269 );
270 }
271
272 #[test]
273 fn jitter_stays_within_bounds() {
274 let base = Duration::from_secs(4);
275 for _ in 0..20 {
276 let jittered = apply_jitter(base);
277 assert!(jittered >= Duration::from_secs(2));
278 assert!(jittered <= base);
279 }
280 }
281
282 #[test]
283 fn parse_retry_after_invalid_is_none() {
284 let mut headers = HeaderMap::new();
285 headers.insert(http::header::RETRY_AFTER, "not-a-number".parse().unwrap());
286 assert!(parse_retry_after(&headers).is_none());
287 }
288
289 #[test]
290 fn exponential_uses_jitter_by_default() {
291 let policy = RetryPolicy::exponential(3, Duration::from_secs(1), Duration::from_secs(8));
292 assert!(policy.uses_jitter());
293 }
294
295 #[test]
296 fn linear_jitter_disabled_by_default() {
297 assert!(!RetryPolicy::linear(1, Duration::from_secs(1)).uses_jitter());
298 }
299}