Skip to main content

mkt_core/http/
retry.rs

1//! Retry with exponential backoff for transient provider errors.
2//!
3//! Providers wrap their HTTP calls in [`retry`]: reads repeat on any
4//! transient failure, writes only when the request provably did not
5//! execute (rate-limited before processing, or the connection was never
6//! established), so a timed-out create cannot duplicate spend. Server
7//! `Retry-After` hints take precedence over the computed backoff.
8
9use std::time::Duration;
10
11use crate::error::{MktError, Result};
12
13/// Server hints above this are clamped: a CLI should fail with the
14/// documented rate-limit exit code rather than sleep for many minutes.
15const MAX_HINT_SECS: u64 = 120;
16
17/// Parse a `Retry-After` response header into seconds.
18///
19/// Only the delta-seconds form is honored; the HTTP-date form is rare on
20/// ad APIs and falls back to the caller's default.
21#[must_use]
22pub fn retry_after_secs(headers: &reqwest::header::HeaderMap) -> Option<u64> {
23    headers
24        .get(reqwest::header::RETRY_AFTER)?
25        .to_str()
26        .ok()?
27        .trim()
28        .parse()
29        .ok()
30}
31
32/// Whether the operation is safe to repeat after a failed attempt.
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum OpKind {
35    /// Idempotent reads: any transient failure is retryable.
36    Read,
37    /// Writes: retry only failures that happened before the API could
38    /// act (rate limits, connection failures).
39    Write,
40}
41
42/// Exponential backoff policy.
43#[derive(Debug, Clone)]
44pub struct RetryPolicy {
45    /// Total attempts, including the first (1 = no retries).
46    pub max_attempts: u32,
47    /// Delay before the first retry; doubles each attempt.
48    pub min_delay: Duration,
49    /// Ceiling for the computed backoff.
50    pub max_delay: Duration,
51}
52
53impl RetryPolicy {
54    /// Production default: 4 attempts backing off 1s → 2s → 4s (+ jitter).
55    #[must_use]
56    pub const fn standard() -> Self {
57        Self {
58            max_attempts: 4,
59            min_delay: Duration::from_secs(1),
60            max_delay: Duration::from_secs(30),
61        }
62    }
63
64    /// Single attempt, no retries — for tests and latency-sensitive paths.
65    #[must_use]
66    pub const fn none() -> Self {
67        Self {
68            max_attempts: 1,
69            min_delay: Duration::ZERO,
70            max_delay: Duration::ZERO,
71        }
72    }
73}
74
75impl Default for RetryPolicy {
76    fn default() -> Self {
77        Self::standard()
78    }
79}
80
81/// Whether `error` is worth retrying for this kind of operation.
82fn is_retryable(kind: OpKind, error: &MktError) -> bool {
83    match kind {
84        OpKind::Read => error.is_transient(),
85        OpKind::Write => match error {
86            MktError::RateLimited { .. } => true,
87            MktError::Http(e) => e.is_connect(),
88            _ => false,
89        },
90    }
91}
92
93/// The server-suggested wait, when the error carries one.
94fn retry_hint(error: &MktError) -> Option<Duration> {
95    let secs = match error {
96        MktError::RateLimited {
97            retry_after_secs, ..
98        } => Some(*retry_after_secs),
99        MktError::ApiError {
100            retry_after: Some(secs),
101            ..
102        } => Some(*secs),
103        _ => None,
104    }?;
105    Some(Duration::from_secs(secs.min(MAX_HINT_SECS)))
106}
107
108/// Add up to 20% of `delay` as jitter so synchronized clients spread out.
109fn with_jitter(delay: Duration) -> Duration {
110    let nanos = std::time::SystemTime::now()
111        .duration_since(std::time::UNIX_EPOCH)
112        .map_or(0, |d| d.subsec_nanos());
113    delay + delay.mul_f64(f64::from(nanos % 21) / 100.0)
114}
115
116/// Run `op`, retrying per `policy` while failures stay retryable.
117///
118/// # Errors
119///
120/// Returns the last error once attempts are exhausted or the failure is
121/// not retryable for this [`OpKind`].
122pub async fn retry<T, F, Fut>(policy: &RetryPolicy, kind: OpKind, mut op: F) -> Result<T>
123where
124    F: FnMut() -> Fut,
125    Fut: Future<Output = Result<T>>,
126{
127    let mut attempt: u32 = 0;
128    loop {
129        attempt += 1;
130        let error = match op().await {
131            Ok(value) => return Ok(value),
132            Err(error) => error,
133        };
134        if attempt >= policy.max_attempts || !is_retryable(kind, &error) {
135            return Err(error);
136        }
137
138        let backoff = policy
139            .min_delay
140            .saturating_mul(2_u32.saturating_pow(attempt - 1))
141            .min(policy.max_delay);
142        let delay = retry_hint(&error).unwrap_or_else(|| with_jitter(backoff));
143        tracing::warn!(
144            attempt,
145            delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX),
146            error = %error,
147            "transient provider error; retrying"
148        );
149        tokio::time::sleep(delay).await;
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    #![allow(clippy::unwrap_used)]
156
157    use std::sync::atomic::{AtomicU32, Ordering};
158
159    use super::*;
160
161    fn transient_error() -> MktError {
162        MktError::ApiError {
163            provider: "test".into(),
164            status: 503,
165            message: "unavailable".into(),
166            retry_after: None,
167        }
168    }
169
170    fn rate_limited(secs: u64) -> MktError {
171        MktError::RateLimited {
172            provider: "test".into(),
173            retry_after_secs: secs,
174        }
175    }
176
177    fn validation_error() -> MktError {
178        MktError::ValidationError {
179            field: "f".into(),
180            message: "bad".into(),
181        }
182    }
183
184    #[allow(clippy::future_not_send)] // single-threaded test helper
185    async fn run_counting(
186        policy: &RetryPolicy,
187        kind: OpKind,
188        failures: u32,
189        error_fn: impl Fn() -> MktError,
190    ) -> (Result<u32>, u32) {
191        let calls = AtomicU32::new(0);
192        let result = retry(policy, kind, || {
193            let n = calls.fetch_add(1, Ordering::SeqCst) + 1;
194            let error = (n <= failures).then(&error_fn);
195            async move { error.map_or_else(|| Ok(n), Err) }
196        })
197        .await;
198        (result, calls.load(Ordering::SeqCst))
199    }
200
201    #[tokio::test(start_paused = true)]
202    async fn read_retries_transient_until_success() {
203        let (result, calls) =
204            run_counting(&RetryPolicy::standard(), OpKind::Read, 2, transient_error).await;
205        assert_eq!(result.unwrap(), 3);
206        assert_eq!(calls, 3);
207    }
208
209    #[tokio::test(start_paused = true)]
210    async fn exhausted_attempts_return_last_error() {
211        let (result, calls) =
212            run_counting(&RetryPolicy::standard(), OpKind::Read, 99, transient_error).await;
213        assert!(result.unwrap_err().is_transient());
214        assert_eq!(calls, 4, "standard policy makes 4 attempts");
215    }
216
217    #[tokio::test(start_paused = true)]
218    async fn non_transient_errors_never_retry() {
219        let (result, calls) =
220            run_counting(&RetryPolicy::standard(), OpKind::Read, 99, validation_error).await;
221        assert!(matches!(
222            result.unwrap_err(),
223            MktError::ValidationError { .. }
224        ));
225        assert_eq!(calls, 1);
226    }
227
228    #[tokio::test(start_paused = true)]
229    async fn policy_none_makes_a_single_attempt() {
230        let (result, calls) =
231            run_counting(&RetryPolicy::none(), OpKind::Read, 99, transient_error).await;
232        assert!(result.is_err());
233        assert_eq!(calls, 1);
234    }
235
236    #[tokio::test(start_paused = true)]
237    async fn writes_do_not_retry_server_errors() {
238        let (result, calls) =
239            run_counting(&RetryPolicy::standard(), OpKind::Write, 99, transient_error).await;
240        assert!(result.is_err());
241        assert_eq!(calls, 1, "a 503 may have executed the write");
242    }
243
244    #[tokio::test(start_paused = true)]
245    async fn writes_retry_rate_limits() {
246        let (result, calls) = run_counting(&RetryPolicy::standard(), OpKind::Write, 1, || {
247            rate_limited(7)
248        })
249        .await;
250        assert_eq!(result.unwrap(), 2);
251        assert_eq!(calls, 2);
252    }
253
254    #[tokio::test(start_paused = true)]
255    async fn server_hint_overrides_backoff() {
256        let start = tokio::time::Instant::now();
257        let (result, _) = run_counting(&RetryPolicy::standard(), OpKind::Read, 1, || {
258            rate_limited(7)
259        })
260        .await;
261        assert!(result.is_ok());
262        let waited = start.elapsed();
263        assert!(
264            waited >= Duration::from_secs(7) && waited < Duration::from_secs(8),
265            "should sleep the hinted 7s, slept {waited:?}"
266        );
267    }
268
269    #[tokio::test(start_paused = true)]
270    async fn absurd_hints_are_clamped() {
271        let start = tokio::time::Instant::now();
272        let (result, _) = run_counting(&RetryPolicy::standard(), OpKind::Read, 1, || {
273            rate_limited(86_400)
274        })
275        .await;
276        assert!(result.is_ok());
277        assert!(
278            start.elapsed() <= Duration::from_secs(MAX_HINT_SECS + 1),
279            "hints are clamped to {MAX_HINT_SECS}s"
280        );
281    }
282}