1use std::time::Duration;
2
3use reqwest::{Method, StatusCode};
4
5#[derive(Clone, Debug, PartialEq, Eq)]
6pub struct RetryConfig {
7 retryable_methods: Vec<Method>,
8 max_retries: u32,
9 retry_on_429: bool,
10 respect_retry_after: bool,
11 base_backoff: Duration,
12 max_backoff: Duration,
13 total_retry_budget: Option<Duration>,
14}
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum RetryDecision {
18 DoNotRetry,
19 RetryAfter(Duration),
20}
21
22impl Default for RetryConfig {
23 fn default() -> Self {
24 Self {
25 retryable_methods: vec![Method::GET],
26 max_retries: 3,
27 retry_on_429: false,
28 respect_retry_after: false,
29 base_backoff: Duration::from_millis(50),
30 max_backoff: Duration::from_secs(5),
31 total_retry_budget: None,
32 }
33 }
34}
35
36impl RetryConfig {
37 #[must_use]
38 pub fn with_retryable_methods<I>(mut self, methods: I) -> Self
39 where
40 I: IntoIterator<Item = Method>,
41 {
42 self.retryable_methods = methods.into_iter().collect();
43 self
44 }
45
46 #[must_use]
47 pub fn with_retry_on_429(mut self, retry_on_429: bool) -> Self {
48 self.retry_on_429 = retry_on_429;
49 self
50 }
51
52 #[must_use]
53 pub fn with_respect_retry_after(mut self, respect_retry_after: bool) -> Self {
54 self.respect_retry_after = respect_retry_after;
55 self
56 }
57
58 #[must_use]
59 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
60 self.max_retries = max_retries;
61 self
62 }
63
64 #[must_use]
65 pub fn classify_response(
66 &self,
67 method: &Method,
68 status: StatusCode,
69 attempt: u32,
70 retry_after: Option<Duration>,
71 elapsed: Duration,
72 ) -> RetryDecision {
73 if attempt >= self.max_retries || !self.retryable_methods.iter().any(|item| item == method)
74 {
75 return RetryDecision::DoNotRetry;
76 }
77
78 let wait = if status == StatusCode::TOO_MANY_REQUESTS {
79 if !self.retry_on_429 {
80 return RetryDecision::DoNotRetry;
81 }
82
83 if self.respect_retry_after {
84 retry_after.unwrap_or_else(|| self.backoff(attempt + 1))
85 } else {
86 self.backoff(attempt + 1)
87 }
88 } else if status.is_server_error() {
89 self.backoff(attempt + 1)
90 } else {
91 return RetryDecision::DoNotRetry;
92 };
93
94 let wait = wait.min(self.max_backoff);
95
96 if let Some(total_retry_budget) = self.total_retry_budget {
97 let Some(remaining_budget) = total_retry_budget.checked_sub(elapsed) else {
98 return RetryDecision::DoNotRetry;
99 };
100 if remaining_budget.is_zero() {
101 return RetryDecision::DoNotRetry;
102 }
103 return RetryDecision::RetryAfter(wait.min(remaining_budget));
104 }
105
106 RetryDecision::RetryAfter(wait)
107 }
108
109 fn backoff(&self, attempt: u32) -> Duration {
110 let factor = 1u32
111 .checked_shl(attempt.saturating_sub(1))
112 .unwrap_or(u32::MAX);
113 let millis = self.base_backoff.as_millis();
114 let scaled = millis.saturating_mul(u128::from(factor));
115 let bounded = scaled.min(self.max_backoff.as_millis());
116 Duration::from_millis(u64::try_from(bounded).unwrap_or(u64::MAX))
117 }
118}