1use std::time::Duration;
10
11use crate::error::{MktError, Result};
12
13const MAX_HINT_SECS: u64 = 120;
16
17#[must_use]
22pub fn retry_after_secs(headers: &reqwest::header::HeaderMap) -> Option<u64> {
23 headers
24 .get(reqwest::header::RETRY_AFTER)?
25 .to_str()
26 .ok()?
27 .trim()
28 .parse()
29 .ok()
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum OpKind {
35 Read,
37 Write,
40}
41
42#[derive(Debug, Clone)]
44pub struct RetryPolicy {
45 pub max_attempts: u32,
47 pub min_delay: Duration,
49 pub max_delay: Duration,
51}
52
53impl RetryPolicy {
54 #[must_use]
56 pub const fn standard() -> Self {
57 Self {
58 max_attempts: 4,
59 min_delay: Duration::from_secs(1),
60 max_delay: Duration::from_secs(30),
61 }
62 }
63
64 #[must_use]
66 pub const fn none() -> Self {
67 Self {
68 max_attempts: 1,
69 min_delay: Duration::ZERO,
70 max_delay: Duration::ZERO,
71 }
72 }
73}
74
75impl Default for RetryPolicy {
76 fn default() -> Self {
77 Self::standard()
78 }
79}
80
81fn is_retryable(kind: OpKind, error: &MktError) -> bool {
83 match kind {
84 OpKind::Read => error.is_transient(),
85 OpKind::Write => match error {
86 MktError::RateLimited { .. } => true,
87 MktError::Http(e) => e.is_connect(),
88 _ => false,
89 },
90 }
91}
92
93fn retry_hint(error: &MktError) -> Option<Duration> {
95 let secs = match error {
96 MktError::RateLimited {
97 retry_after_secs, ..
98 } => Some(*retry_after_secs),
99 MktError::ApiError {
100 retry_after: Some(secs),
101 ..
102 } => Some(*secs),
103 _ => None,
104 }?;
105 Some(Duration::from_secs(secs.min(MAX_HINT_SECS)))
106}
107
108fn with_jitter(delay: Duration) -> Duration {
110 let nanos = std::time::SystemTime::now()
111 .duration_since(std::time::UNIX_EPOCH)
112 .map_or(0, |d| d.subsec_nanos());
113 delay + delay.mul_f64(f64::from(nanos % 21) / 100.0)
114}
115
116pub async fn retry<T, F, Fut>(policy: &RetryPolicy, kind: OpKind, mut op: F) -> Result<T>
123where
124 F: FnMut() -> Fut,
125 Fut: Future<Output = Result<T>>,
126{
127 let mut attempt: u32 = 0;
128 loop {
129 attempt += 1;
130 let error = match op().await {
131 Ok(value) => return Ok(value),
132 Err(error) => error,
133 };
134 if attempt >= policy.max_attempts || !is_retryable(kind, &error) {
135 return Err(error);
136 }
137
138 let backoff = policy
139 .min_delay
140 .saturating_mul(2_u32.saturating_pow(attempt - 1))
141 .min(policy.max_delay);
142 let delay = retry_hint(&error).unwrap_or_else(|| with_jitter(backoff));
143 tracing::warn!(
144 attempt,
145 delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX),
146 error = %error,
147 "transient provider error; retrying"
148 );
149 tokio::time::sleep(delay).await;
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 #![allow(clippy::unwrap_used)]
156
157 use std::sync::atomic::{AtomicU32, Ordering};
158
159 use super::*;
160
161 fn transient_error() -> MktError {
162 MktError::ApiError {
163 provider: "test".into(),
164 status: 503,
165 message: "unavailable".into(),
166 retry_after: None,
167 }
168 }
169
170 fn rate_limited(secs: u64) -> MktError {
171 MktError::RateLimited {
172 provider: "test".into(),
173 retry_after_secs: secs,
174 }
175 }
176
177 fn validation_error() -> MktError {
178 MktError::ValidationError {
179 field: "f".into(),
180 message: "bad".into(),
181 }
182 }
183
184 #[allow(clippy::future_not_send)] async fn run_counting(
186 policy: &RetryPolicy,
187 kind: OpKind,
188 failures: u32,
189 error_fn: impl Fn() -> MktError,
190 ) -> (Result<u32>, u32) {
191 let calls = AtomicU32::new(0);
192 let result = retry(policy, kind, || {
193 let n = calls.fetch_add(1, Ordering::SeqCst) + 1;
194 let error = (n <= failures).then(&error_fn);
195 async move { error.map_or_else(|| Ok(n), Err) }
196 })
197 .await;
198 (result, calls.load(Ordering::SeqCst))
199 }
200
201 #[tokio::test(start_paused = true)]
202 async fn read_retries_transient_until_success() {
203 let (result, calls) =
204 run_counting(&RetryPolicy::standard(), OpKind::Read, 2, transient_error).await;
205 assert_eq!(result.unwrap(), 3);
206 assert_eq!(calls, 3);
207 }
208
209 #[tokio::test(start_paused = true)]
210 async fn exhausted_attempts_return_last_error() {
211 let (result, calls) =
212 run_counting(&RetryPolicy::standard(), OpKind::Read, 99, transient_error).await;
213 assert!(result.unwrap_err().is_transient());
214 assert_eq!(calls, 4, "standard policy makes 4 attempts");
215 }
216
217 #[tokio::test(start_paused = true)]
218 async fn non_transient_errors_never_retry() {
219 let (result, calls) =
220 run_counting(&RetryPolicy::standard(), OpKind::Read, 99, validation_error).await;
221 assert!(matches!(
222 result.unwrap_err(),
223 MktError::ValidationError { .. }
224 ));
225 assert_eq!(calls, 1);
226 }
227
228 #[tokio::test(start_paused = true)]
229 async fn policy_none_makes_a_single_attempt() {
230 let (result, calls) =
231 run_counting(&RetryPolicy::none(), OpKind::Read, 99, transient_error).await;
232 assert!(result.is_err());
233 assert_eq!(calls, 1);
234 }
235
236 #[tokio::test(start_paused = true)]
237 async fn writes_do_not_retry_server_errors() {
238 let (result, calls) =
239 run_counting(&RetryPolicy::standard(), OpKind::Write, 99, transient_error).await;
240 assert!(result.is_err());
241 assert_eq!(calls, 1, "a 503 may have executed the write");
242 }
243
244 #[tokio::test(start_paused = true)]
245 async fn writes_retry_rate_limits() {
246 let (result, calls) = run_counting(&RetryPolicy::standard(), OpKind::Write, 1, || {
247 rate_limited(7)
248 })
249 .await;
250 assert_eq!(result.unwrap(), 2);
251 assert_eq!(calls, 2);
252 }
253
254 #[tokio::test(start_paused = true)]
255 async fn server_hint_overrides_backoff() {
256 let start = tokio::time::Instant::now();
257 let (result, _) = run_counting(&RetryPolicy::standard(), OpKind::Read, 1, || {
258 rate_limited(7)
259 })
260 .await;
261 assert!(result.is_ok());
262 let waited = start.elapsed();
263 assert!(
264 waited >= Duration::from_secs(7) && waited < Duration::from_secs(8),
265 "should sleep the hinted 7s, slept {waited:?}"
266 );
267 }
268
269 #[tokio::test(start_paused = true)]
270 async fn absurd_hints_are_clamped() {
271 let start = tokio::time::Instant::now();
272 let (result, _) = run_counting(&RetryPolicy::standard(), OpKind::Read, 1, || {
273 rate_limited(86_400)
274 })
275 .await;
276 assert!(result.is_ok());
277 assert!(
278 start.elapsed() <= Duration::from_secs(MAX_HINT_SECS + 1),
279 "hints are clamped to {MAX_HINT_SECS}s"
280 );
281 }
282}