1use std::future::Future;
6use std::time::Duration;
7
8#[derive(Debug, Clone)]
10pub struct RetryPolicy {
11 pub max_retries: u32,
13 pub base_delay_ms: u64,
15 pub max_delay_ms: u64,
17 pub exponential_base: f64,
19 pub jitter: bool,
21}
22
23impl Default for RetryPolicy {
24 fn default() -> Self {
25 Self {
26 max_retries: 3,
27 base_delay_ms: 1000,
28 max_delay_ms: 60_000,
29 exponential_base: 2.0,
30 jitter: true,
31 }
32 }
33}
34
35fn calculate_delay(policy: &RetryPolicy, attempt: u32) -> Duration {
37 let mut delay = policy.base_delay_ms as f64 * policy.exponential_base.powi(attempt as i32);
38
39 if policy.jitter {
41 use rand::Rng;
42 let mut rng = rand::thread_rng();
43 delay *= 0.5 + rng.gen::<f64>();
44 }
45
46 let delay_ms = (delay as u64).min(policy.max_delay_ms);
48 Duration::from_millis(delay_ms)
49}
50
51pub async fn retry_with_backoff<T, E, F, Fut>(mut f: F, policy: RetryPolicy) -> Result<T, E>
53where
54 F: FnMut() -> Fut,
55 Fut: Future<Output = Result<T, E>>,
56 E: std::fmt::Debug,
57{
58 let mut last_error: Option<E> = None;
59
60 for attempt in 0..=policy.max_retries {
61 match f().await {
62 Ok(result) => return Ok(result),
63 Err(err) => {
64 last_error = Some(err);
65
66 if attempt < policy.max_retries {
67 let delay = calculate_delay(&policy, attempt);
68 tokio::time::sleep(delay).await;
69 }
70 }
71 }
72 }
73
74 Err(last_error.unwrap())
75}
76
77const DEFAULT_RETRYABLE_STATUS_CODES: &[u16] = &[429, 500, 502, 503, 504];
79
80pub fn is_retryable_error(error: &str, status_codes: Option<&[u16]>) -> bool {
82 let codes = status_codes.unwrap_or(DEFAULT_RETRYABLE_STATUS_CODES);
83
84 let network_errors = [
86 "ECONNREFUSED",
87 "ETIMEDOUT",
88 "ENOTFOUND",
89 "connection refused",
90 "timeout",
91 "network error",
92 ];
93
94 for net_err in network_errors {
95 if error.to_lowercase().contains(&net_err.to_lowercase()) {
96 return true;
97 }
98 }
99
100 if error.contains("rate limit") || error.contains("429") {
102 return true;
103 }
104
105 for code in codes {
107 if error.contains(&code.to_string()) {
108 return true;
109 }
110 }
111
112 false
113}
114
115pub fn parse_retry_after(header: &str) -> Option<u64> {
117 if let Ok(seconds) = header.parse::<u64>() {
119 return Some(seconds);
120 }
121
122 if let Ok(date) = chrono::DateTime::parse_from_rfc2822(header) {
124 let now = chrono::Utc::now();
125 let diff = date.signed_duration_since(now);
126 if diff.num_seconds() > 0 {
127 return Some(diff.num_seconds() as u64);
128 }
129 }
130
131 None
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[test]
139 fn test_is_retryable_error() {
140 assert!(is_retryable_error("rate limit exceeded", None));
141 assert!(is_retryable_error("status code 429", None));
142 assert!(is_retryable_error("connection refused", None));
143 assert!(is_retryable_error("ETIMEDOUT", None));
144 assert!(!is_retryable_error("invalid request", None));
145 }
146
147 #[test]
148 fn test_parse_retry_after_seconds() {
149 assert_eq!(parse_retry_after("60"), Some(60));
150 assert_eq!(parse_retry_after("0"), Some(0));
151 }
152
153 #[test]
154 fn test_calculate_delay() {
155 let policy = RetryPolicy {
156 jitter: false,
157 ..Default::default()
158 };
159
160 let delay0 = calculate_delay(&policy, 0);
161 let delay1 = calculate_delay(&policy, 1);
162 let delay2 = calculate_delay(&policy, 2);
163
164 assert_eq!(delay0.as_millis(), 1000);
165 assert_eq!(delay1.as_millis(), 2000);
166 assert_eq!(delay2.as_millis(), 4000);
167 }
168}