Skip to main content

better_fetch/
retry.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use http::{HeaderMap, StatusCode};
5
6use crate::response::Response;
7
8/// Predicate for whether a response should be retried.
9pub type ShouldRetryFn = Arc<dyn Fn(&Response) -> bool + Send + Sync>;
10
11/// Retry policy configuration.
12///
13/// The `attempts` value is the maximum number of **retries after the initial request**.
14/// For example, `RetryPolicy::count(2)` performs up to three HTTP calls (one initial + two retries).
15#[derive(Clone)]
16pub enum RetryPolicy {
17    /// Shorthand for linear retry with `attempts` retries and a 1 second delay between attempts.
18    Count {
19        attempts: u32,
20        should_retry: Option<ShouldRetryFn>,
21    },
22    Linear {
23        attempts: u32,
24        delay: Duration,
25        should_retry: Option<ShouldRetryFn>,
26        jitter: bool,
27    },
28    Exponential {
29        attempts: u32,
30        base_delay: Duration,
31        max_delay: Duration,
32        should_retry: Option<ShouldRetryFn>,
33        jitter: bool,
34    },
35}
36
37impl RetryPolicy {
38    pub fn count(attempts: u32) -> Self {
39        Self::Count {
40            attempts,
41            should_retry: None,
42        }
43    }
44
45    pub fn linear(attempts: u32, delay: Duration) -> Self {
46        Self::Linear {
47            attempts,
48            delay,
49            should_retry: None,
50            jitter: false,
51        }
52    }
53
54    pub fn exponential(attempts: u32, base_delay: Duration, max_delay: Duration) -> Self {
55        Self::Exponential {
56            attempts,
57            base_delay,
58            max_delay,
59            should_retry: None,
60            jitter: true,
61        }
62    }
63
64    /// Enables randomized backoff jitter on linear or exponential policies.
65    pub fn with_jitter(mut self, jitter: bool) -> Self {
66        match &mut self {
67            Self::Linear { jitter: j, .. } | Self::Exponential { jitter: j, .. } => *j = jitter,
68            Self::Count { .. } => {}
69        }
70        self
71    }
72
73    pub fn with_should_retry(self, f: ShouldRetryFn) -> Self {
74        match self {
75            Self::Count { attempts, .. } => Self::Count {
76                attempts,
77                should_retry: Some(f),
78            },
79            Self::Linear {
80                attempts,
81                delay,
82                jitter,
83                ..
84            } => Self::Linear {
85                attempts,
86                delay,
87                should_retry: Some(f),
88                jitter,
89            },
90            Self::Exponential {
91                attempts,
92                base_delay,
93                max_delay,
94                jitter,
95                ..
96            } => Self::Exponential {
97                attempts,
98                base_delay,
99                max_delay,
100                should_retry: Some(f),
101                jitter,
102            },
103        }
104    }
105
106    pub(crate) fn max_attempts(&self) -> u32 {
107        match self {
108            Self::Count { attempts, .. }
109            | Self::Linear { attempts, .. }
110            | Self::Exponential { attempts, .. } => *attempts,
111        }
112    }
113
114    pub(crate) fn delay_before_attempt(&self, attempt: u32) -> Duration {
115        match self {
116            Self::Count { .. } => Duration::from_secs(1),
117            Self::Linear { delay, .. } => *delay,
118            Self::Exponential {
119                base_delay,
120                max_delay,
121                ..
122            } => {
123                let exp = base_delay.saturating_mul(2u32.saturating_pow(attempt));
124                exp.min(*max_delay)
125            }
126        }
127    }
128
129    /// Computes sleep duration using policy backoff, optional [`Retry-After`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After), and jitter.
130    pub(crate) fn delay_after_response(&self, attempt: u32, headers: &HeaderMap) -> Duration {
131        let base = self.delay_before_attempt(attempt);
132        let delay = parse_retry_after(headers).unwrap_or(base);
133        if self.uses_jitter() {
134            apply_jitter(delay)
135        } else {
136            delay
137        }
138    }
139
140    pub(crate) fn uses_jitter(&self) -> bool {
141        match self {
142            Self::Count { .. } => true,
143            Self::Linear { jitter, .. } | Self::Exponential { jitter, .. } => *jitter,
144        }
145    }
146
147    pub(crate) fn should_retry_response(
148        &self,
149        response: &Response,
150        transport_failed: bool,
151    ) -> bool {
152        if transport_failed {
153            return true;
154        }
155
156        let custom = match self {
157            Self::Count { should_retry, .. }
158            | Self::Linear { should_retry, .. }
159            | Self::Exponential { should_retry, .. } => should_retry.as_ref(),
160        };
161
162        if let Some(f) = custom {
163            return f(response);
164        }
165
166        default_should_retry(response.status())
167    }
168}
169
170pub fn default_should_retry(status: StatusCode) -> bool {
171    matches!(status.as_u16(), 408 | 429 | 502 | 503 | 504)
172}
173
174/// Parses `Retry-After` as a delay in seconds (integer values only).
175pub fn parse_retry_after(headers: &HeaderMap) -> Option<Duration> {
176    let value = headers.get(http::header::RETRY_AFTER)?.to_str().ok()?;
177    let secs = value.trim().parse::<u64>().ok()?;
178    Some(Duration::from_secs(secs))
179}
180
181fn apply_jitter(delay: Duration) -> Duration {
182    let nanos = delay.as_nanos().min(u128::from(u64::MAX)) as u64;
183    if nanos == 0 {
184        return delay;
185    }
186    let half = nanos / 2;
187    let span = nanos.saturating_sub(half).max(1);
188    Duration::from_nanos(half + fastrand::u64(..span))
189}
190
191pub(crate) use crate::cancel::sleep_or_cancel;
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::response::Response;
197    use http::StatusCode;
198
199    fn response_with_status(status: u16) -> Response {
200        Response::new(
201            StatusCode::from_u16(status).unwrap(),
202            http::HeaderMap::new(),
203            bytes::Bytes::new(),
204            None,
205            #[cfg(feature = "json")]
206            None,
207        )
208    }
209
210    #[test]
211    fn default_should_retry_codes() {
212        assert!(default_should_retry(StatusCode::REQUEST_TIMEOUT));
213        assert!(default_should_retry(StatusCode::TOO_MANY_REQUESTS));
214        assert!(default_should_retry(StatusCode::SERVICE_UNAVAILABLE));
215        assert!(!default_should_retry(StatusCode::NOT_FOUND));
216    }
217
218    #[test]
219    fn count_policy_max_attempts() {
220        assert_eq!(RetryPolicy::count(3).max_attempts(), 3);
221    }
222
223    #[test]
224    fn count_with_should_retry_stays_count() {
225        let policy = RetryPolicy::count(2)
226            .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
227        assert!(matches!(policy, RetryPolicy::Count { .. }));
228        assert!(policy.should_retry_response(&response_with_status(404), false));
229        assert!(!policy.should_retry_response(&response_with_status(503), false));
230    }
231
232    #[test]
233    fn linear_delay_is_constant() {
234        let policy = RetryPolicy::linear(3, Duration::from_millis(500));
235        assert_eq!(policy.delay_before_attempt(0), Duration::from_millis(500));
236        assert_eq!(policy.delay_before_attempt(2), Duration::from_millis(500));
237    }
238
239    #[test]
240    fn exponential_delay_caps_at_max() {
241        let policy = RetryPolicy::exponential(5, Duration::from_secs(1), Duration::from_secs(5));
242        assert_eq!(policy.delay_before_attempt(0), Duration::from_secs(1));
243        assert_eq!(policy.delay_before_attempt(10), Duration::from_secs(5));
244    }
245
246    #[test]
247    fn custom_should_retry_overrides_default() {
248        let policy = RetryPolicy::linear(2, Duration::from_millis(1))
249            .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
250        assert!(policy.should_retry_response(&response_with_status(404), false));
251        assert!(!policy.should_retry_response(&response_with_status(503), false));
252    }
253
254    #[test]
255    fn parse_retry_after_seconds() {
256        let mut headers = HeaderMap::new();
257        headers.insert(http::header::RETRY_AFTER, "3".parse().unwrap());
258        assert_eq!(parse_retry_after(&headers), Some(Duration::from_secs(3)));
259    }
260
261    #[test]
262    fn delay_after_response_uses_retry_after() {
263        let mut headers = HeaderMap::new();
264        headers.insert(http::header::RETRY_AFTER, "2".parse().unwrap());
265        let policy = RetryPolicy::linear(1, Duration::from_millis(100)).with_jitter(false);
266        assert_eq!(
267            policy.delay_after_response(0, &headers),
268            Duration::from_secs(2)
269        );
270    }
271
272    #[test]
273    fn jitter_stays_within_bounds() {
274        let base = Duration::from_secs(4);
275        for _ in 0..20 {
276            let jittered = apply_jitter(base);
277            assert!(jittered >= Duration::from_secs(2));
278            assert!(jittered <= base);
279        }
280    }
281
282    #[test]
283    fn parse_retry_after_invalid_is_none() {
284        let mut headers = HeaderMap::new();
285        headers.insert(http::header::RETRY_AFTER, "not-a-number".parse().unwrap());
286        assert!(parse_retry_after(&headers).is_none());
287    }
288
289    #[test]
290    fn exponential_uses_jitter_by_default() {
291        let policy = RetryPolicy::exponential(3, Duration::from_secs(1), Duration::from_secs(8));
292        assert!(policy.uses_jitter());
293    }
294
295    #[test]
296    fn linear_jitter_disabled_by_default() {
297        assert!(!RetryPolicy::linear(1, Duration::from_secs(1)).uses_jitter());
298    }
299}