Skip to main content

alpaca_http/
retry.rs

1use std::time::Duration;
2
3use reqwest::{Method, StatusCode};
4
5#[derive(Clone, Debug, PartialEq, Eq)]
6pub struct RetryConfig {
7    retryable_methods: Vec<Method>,
8    max_retries: u32,
9    retry_on_429: bool,
10    respect_retry_after: bool,
11    base_backoff: Duration,
12    max_backoff: Duration,
13    total_retry_budget: Option<Duration>,
14}
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum RetryDecision {
18    DoNotRetry,
19    RetryAfter(Duration),
20}
21
22impl Default for RetryConfig {
23    fn default() -> Self {
24        Self {
25            retryable_methods: vec![Method::GET],
26            max_retries: 3,
27            retry_on_429: false,
28            respect_retry_after: false,
29            base_backoff: Duration::from_millis(50),
30            max_backoff: Duration::from_secs(5),
31            total_retry_budget: None,
32        }
33    }
34}
35
36impl RetryConfig {
37    #[must_use]
38    pub fn with_retryable_methods<I>(mut self, methods: I) -> Self
39    where
40        I: IntoIterator<Item = Method>,
41    {
42        self.retryable_methods = methods.into_iter().collect();
43        self
44    }
45
46    #[must_use]
47    pub fn with_retry_on_429(mut self, retry_on_429: bool) -> Self {
48        self.retry_on_429 = retry_on_429;
49        self
50    }
51
52    #[must_use]
53    pub fn with_respect_retry_after(mut self, respect_retry_after: bool) -> Self {
54        self.respect_retry_after = respect_retry_after;
55        self
56    }
57
58    #[must_use]
59    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
60        self.max_retries = max_retries;
61        self
62    }
63
64    #[must_use]
65    pub fn classify_response(
66        &self,
67        method: &Method,
68        status: StatusCode,
69        attempt: u32,
70        retry_after: Option<Duration>,
71        elapsed: Duration,
72    ) -> RetryDecision {
73        if attempt >= self.max_retries || !self.retryable_methods.iter().any(|item| item == method)
74        {
75            return RetryDecision::DoNotRetry;
76        }
77
78        let wait = if status == StatusCode::TOO_MANY_REQUESTS {
79            if !self.retry_on_429 {
80                return RetryDecision::DoNotRetry;
81            }
82
83            if self.respect_retry_after {
84                retry_after.unwrap_or_else(|| self.backoff(attempt + 1))
85            } else {
86                self.backoff(attempt + 1)
87            }
88        } else if status.is_server_error() {
89            self.backoff(attempt + 1)
90        } else {
91            return RetryDecision::DoNotRetry;
92        };
93
94        let wait = wait.min(self.max_backoff);
95
96        if let Some(total_retry_budget) = self.total_retry_budget {
97            let Some(remaining_budget) = total_retry_budget.checked_sub(elapsed) else {
98                return RetryDecision::DoNotRetry;
99            };
100            if remaining_budget.is_zero() {
101                return RetryDecision::DoNotRetry;
102            }
103            return RetryDecision::RetryAfter(wait.min(remaining_budget));
104        }
105
106        RetryDecision::RetryAfter(wait)
107    }
108
109    fn backoff(&self, attempt: u32) -> Duration {
110        let factor = 1u32
111            .checked_shl(attempt.saturating_sub(1))
112            .unwrap_or(u32::MAX);
113        let millis = self.base_backoff.as_millis();
114        let scaled = millis.saturating_mul(u128::from(factor));
115        let bounded = scaled.min(self.max_backoff.as_millis());
116        Duration::from_millis(u64::try_from(bounded).unwrap_or(u64::MAX))
117    }
118}