Skip to main content

better_fetch/
retry.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use http::StatusCode;
5
6use crate::response::Response;
7
8/// Predicate for whether a response should be retried.
9pub type ShouldRetryFn = Arc<dyn Fn(&Response) -> bool + Send + Sync>;
10
11/// Retry policy configuration.
12///
13/// The `attempts` value is the maximum number of **retries after the initial request**.
14/// For example, `RetryPolicy::count(2)` performs up to three HTTP calls (one initial + two retries).
15#[derive(Clone)]
16pub enum RetryPolicy {
17    /// Shorthand for linear retry with `attempts` retries and a 1 second delay between attempts.
18    Count(u32),
19    Linear {
20        attempts: u32,
21        delay: Duration,
22        should_retry: Option<ShouldRetryFn>,
23    },
24    Exponential {
25        attempts: u32,
26        base_delay: Duration,
27        max_delay: Duration,
28        should_retry: Option<ShouldRetryFn>,
29    },
30}
31
32impl RetryPolicy {
33    pub fn count(attempts: u32) -> Self {
34        Self::Count(attempts)
35    }
36
37    pub fn linear(attempts: u32, delay: Duration) -> Self {
38        Self::Linear {
39            attempts,
40            delay,
41            should_retry: None,
42        }
43    }
44
45    pub fn exponential(attempts: u32, base_delay: Duration, max_delay: Duration) -> Self {
46        Self::Exponential {
47            attempts,
48            base_delay,
49            max_delay,
50            should_retry: None,
51        }
52    }
53
54    pub fn with_should_retry(self, f: ShouldRetryFn) -> Self {
55        match self {
56            Self::Count(attempts) => Self::Linear {
57                attempts,
58                delay: Duration::from_secs(1),
59                should_retry: Some(f),
60            },
61            Self::Linear {
62                attempts,
63                delay,
64                should_retry: _,
65            } => Self::Linear {
66                attempts,
67                delay,
68                should_retry: Some(f),
69            },
70            Self::Exponential {
71                attempts,
72                base_delay,
73                max_delay,
74                should_retry: _,
75            } => Self::Exponential {
76                attempts,
77                base_delay,
78                max_delay,
79                should_retry: Some(f),
80            },
81        }
82    }
83
84    pub(crate) fn max_attempts(&self) -> u32 {
85        match self {
86            Self::Count(n)
87            | Self::Linear { attempts: n, .. }
88            | Self::Exponential { attempts: n, .. } => *n,
89        }
90    }
91
92    pub(crate) fn delay_before_attempt(&self, attempt: u32) -> Duration {
93        match self {
94            Self::Count(_) => Duration::from_secs(1),
95            Self::Linear { delay, .. } => *delay,
96            Self::Exponential {
97                base_delay,
98                max_delay,
99                ..
100            } => {
101                let exp = base_delay.saturating_mul(2u32.saturating_pow(attempt));
102                exp.min(*max_delay)
103            }
104        }
105    }
106
107    pub(crate) fn should_retry_response(
108        &self,
109        response: &Response,
110        transport_failed: bool,
111    ) -> bool {
112        if transport_failed {
113            return true;
114        }
115
116        let custom = match self {
117            Self::Linear { should_retry, .. } | Self::Exponential { should_retry, .. } => {
118                should_retry.as_ref()
119            }
120            Self::Count(_) => None,
121        };
122
123        if let Some(f) = custom {
124            return f(response);
125        }
126
127        default_should_retry(response.status())
128    }
129}
130
131pub fn default_should_retry(status: StatusCode) -> bool {
132    matches!(status.as_u16(), 429 | 502 | 503 | 504)
133}
134
135pub(crate) async fn sleep_before_retry(delay: Duration) {
136    tokio::time::sleep(delay).await;
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::response::Response;
143    use http::StatusCode;
144
145    fn response_with_status(status: u16) -> Response {
146        Response::new(
147            StatusCode::from_u16(status).unwrap(),
148            http::HeaderMap::new(),
149            bytes::Bytes::new(),
150            None,
151            #[cfg(feature = "json")]
152            None,
153        )
154    }
155
156    #[test]
157    fn default_should_retry_codes() {
158        assert!(default_should_retry(StatusCode::TOO_MANY_REQUESTS));
159        assert!(default_should_retry(StatusCode::SERVICE_UNAVAILABLE));
160        assert!(!default_should_retry(StatusCode::NOT_FOUND));
161    }
162
163    #[test]
164    fn count_policy_max_attempts() {
165        assert_eq!(RetryPolicy::count(3).max_attempts(), 3);
166    }
167
168    #[test]
169    fn linear_delay_is_constant() {
170        let policy = RetryPolicy::linear(3, Duration::from_millis(500));
171        assert_eq!(policy.delay_before_attempt(0), Duration::from_millis(500));
172        assert_eq!(policy.delay_before_attempt(2), Duration::from_millis(500));
173    }
174
175    #[test]
176    fn exponential_delay_caps_at_max() {
177        let policy = RetryPolicy::exponential(5, Duration::from_secs(1), Duration::from_secs(5));
178        assert_eq!(policy.delay_before_attempt(0), Duration::from_secs(1));
179        assert_eq!(policy.delay_before_attempt(10), Duration::from_secs(5));
180    }
181
182    #[test]
183    fn custom_should_retry_overrides_default() {
184        let policy = RetryPolicy::linear(2, Duration::from_millis(1))
185            .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
186        assert!(policy.should_retry_response(&response_with_status(404), false));
187        assert!(!policy.should_retry_response(&response_with_status(503), false));
188    }
189}