1use std::sync::Arc;
2use std::time::Duration;
3
4use http::StatusCode;
5
6use crate::response::Response;
7
8pub type ShouldRetryFn = Arc<dyn Fn(&Response) -> bool + Send + Sync>;
10
11#[derive(Clone)]
16pub enum RetryPolicy {
17 Count(u32),
19 Linear {
20 attempts: u32,
21 delay: Duration,
22 should_retry: Option<ShouldRetryFn>,
23 },
24 Exponential {
25 attempts: u32,
26 base_delay: Duration,
27 max_delay: Duration,
28 should_retry: Option<ShouldRetryFn>,
29 },
30}
31
32impl RetryPolicy {
33 pub fn count(attempts: u32) -> Self {
34 Self::Count(attempts)
35 }
36
37 pub fn linear(attempts: u32, delay: Duration) -> Self {
38 Self::Linear {
39 attempts,
40 delay,
41 should_retry: None,
42 }
43 }
44
45 pub fn exponential(attempts: u32, base_delay: Duration, max_delay: Duration) -> Self {
46 Self::Exponential {
47 attempts,
48 base_delay,
49 max_delay,
50 should_retry: None,
51 }
52 }
53
54 pub fn with_should_retry(self, f: ShouldRetryFn) -> Self {
55 match self {
56 Self::Count(attempts) => Self::Linear {
57 attempts,
58 delay: Duration::from_secs(1),
59 should_retry: Some(f),
60 },
61 Self::Linear {
62 attempts,
63 delay,
64 should_retry: _,
65 } => Self::Linear {
66 attempts,
67 delay,
68 should_retry: Some(f),
69 },
70 Self::Exponential {
71 attempts,
72 base_delay,
73 max_delay,
74 should_retry: _,
75 } => Self::Exponential {
76 attempts,
77 base_delay,
78 max_delay,
79 should_retry: Some(f),
80 },
81 }
82 }
83
84 pub(crate) fn max_attempts(&self) -> u32 {
85 match self {
86 Self::Count(n)
87 | Self::Linear { attempts: n, .. }
88 | Self::Exponential { attempts: n, .. } => *n,
89 }
90 }
91
92 pub(crate) fn delay_before_attempt(&self, attempt: u32) -> Duration {
93 match self {
94 Self::Count(_) => Duration::from_secs(1),
95 Self::Linear { delay, .. } => *delay,
96 Self::Exponential {
97 base_delay,
98 max_delay,
99 ..
100 } => {
101 let exp = base_delay.saturating_mul(2u32.saturating_pow(attempt));
102 exp.min(*max_delay)
103 }
104 }
105 }
106
107 pub(crate) fn should_retry_response(
108 &self,
109 response: &Response,
110 transport_failed: bool,
111 ) -> bool {
112 if transport_failed {
113 return true;
114 }
115
116 let custom = match self {
117 Self::Linear { should_retry, .. } | Self::Exponential { should_retry, .. } => {
118 should_retry.as_ref()
119 }
120 Self::Count(_) => None,
121 };
122
123 if let Some(f) = custom {
124 return f(response);
125 }
126
127 default_should_retry(response.status())
128 }
129}
130
131pub fn default_should_retry(status: StatusCode) -> bool {
132 matches!(status.as_u16(), 429 | 502 | 503 | 504)
133}
134
135pub(crate) async fn sleep_before_retry(delay: Duration) {
136 tokio::time::sleep(delay).await;
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use crate::response::Response;
143 use http::StatusCode;
144
145 fn response_with_status(status: u16) -> Response {
146 Response::new(
147 StatusCode::from_u16(status).unwrap(),
148 http::HeaderMap::new(),
149 bytes::Bytes::new(),
150 None,
151 #[cfg(feature = "json")]
152 None,
153 )
154 }
155
156 #[test]
157 fn default_should_retry_codes() {
158 assert!(default_should_retry(StatusCode::TOO_MANY_REQUESTS));
159 assert!(default_should_retry(StatusCode::SERVICE_UNAVAILABLE));
160 assert!(!default_should_retry(StatusCode::NOT_FOUND));
161 }
162
163 #[test]
164 fn count_policy_max_attempts() {
165 assert_eq!(RetryPolicy::count(3).max_attempts(), 3);
166 }
167
168 #[test]
169 fn linear_delay_is_constant() {
170 let policy = RetryPolicy::linear(3, Duration::from_millis(500));
171 assert_eq!(policy.delay_before_attempt(0), Duration::from_millis(500));
172 assert_eq!(policy.delay_before_attempt(2), Duration::from_millis(500));
173 }
174
175 #[test]
176 fn exponential_delay_caps_at_max() {
177 let policy = RetryPolicy::exponential(5, Duration::from_secs(1), Duration::from_secs(5));
178 assert_eq!(policy.delay_before_attempt(0), Duration::from_secs(1));
179 assert_eq!(policy.delay_before_attempt(10), Duration::from_secs(5));
180 }
181
182 #[test]
183 fn custom_should_retry_overrides_default() {
184 let policy = RetryPolicy::linear(2, Duration::from_millis(1))
185 .with_should_retry(Arc::new(|r| r.status() == StatusCode::NOT_FOUND));
186 assert!(policy.should_retry_response(&response_with_status(404), false));
187 assert!(!policy.should_retry_response(&response_with_status(503), false));
188 }
189}