1use rand::Rng;
6use serde::{Deserialize, Serialize};
7use std::future::Future;
8use std::time::Duration;
9use tokio::time::sleep;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct RetryConfig {
14 #[serde(default = "default_max_retries")]
16 pub max_retries: u32,
17 #[serde(default = "default_base_delay")]
19 pub base_delay: u64,
20 #[serde(default = "default_max_delay")]
22 pub max_delay: u64,
23 #[serde(default = "default_exponential_backoff")]
25 pub exponential_backoff: bool,
26 #[serde(default = "default_jitter")]
28 pub jitter: f64,
29 #[serde(default = "default_retryable_errors")]
31 pub retryable_errors: Vec<String>,
32 #[serde(default = "default_retryable_status_codes")]
34 pub retryable_status_codes: Vec<u16>,
35}
36
37fn default_max_retries() -> u32 {
38 4
39}
40fn default_base_delay() -> u64 {
41 1000
42}
43fn default_max_delay() -> u64 {
44 30000
45}
46fn default_exponential_backoff() -> bool {
47 true
48}
49fn default_jitter() -> f64 {
50 0.1
51}
52
53fn default_retryable_errors() -> Vec<String> {
54 vec![
55 "ECONNRESET".to_string(),
56 "ETIMEDOUT".to_string(),
57 "ENOTFOUND".to_string(),
58 "ECONNREFUSED".to_string(),
59 "ENETUNREACH".to_string(),
60 "overloaded_error".to_string(),
61 "rate_limit_error".to_string(),
62 "api_error".to_string(),
63 "timeout".to_string(),
64 ]
65}
66
67fn default_retryable_status_codes() -> Vec<u16> {
68 vec![408, 429, 500, 502, 503, 504]
69}
70
71impl Default for RetryConfig {
72 fn default() -> Self {
73 DEFAULT_RETRY_CONFIG.clone()
74 }
75}
76
77pub const DEFAULT_RETRY_CONFIG: RetryConfig = RetryConfig {
79 max_retries: 4,
80 base_delay: 1000,
81 max_delay: 30000,
82 exponential_backoff: true,
83 jitter: 0.1,
84 retryable_errors: Vec::new(), retryable_status_codes: Vec::new(), };
87
88pub fn calculate_retry_delay(attempt: u32, config: &RetryConfig) -> u64 {
90 let mut delay = config.base_delay;
91
92 if config.exponential_backoff {
93 delay = config.base_delay * 2u64.pow(attempt);
94 }
95
96 if config.jitter > 0.0 {
98 let jitter_amount = (delay as f64 * config.jitter) as i64;
99 let random_jitter = rand::thread_rng().gen_range(-jitter_amount..=jitter_amount);
100 delay = (delay as i64 + random_jitter).max(0) as u64;
101 }
102
103 delay.min(config.max_delay)
105}
106
107pub fn is_retryable_error(error: &str, status_code: Option<u16>, config: &RetryConfig) -> bool {
109 let retryable_errors = if config.retryable_errors.is_empty() {
110 default_retryable_errors()
111 } else {
112 config.retryable_errors.clone()
113 };
114
115 let retryable_status_codes = if config.retryable_status_codes.is_empty() {
116 default_retryable_status_codes()
117 } else {
118 config.retryable_status_codes.clone()
119 };
120
121 for code in &retryable_errors {
123 if error.contains(code) {
124 return true;
125 }
126 }
127
128 if let Some(status) = status_code {
130 if retryable_status_codes.contains(&status) {
131 return true;
132 }
133 }
134
135 false
136}
137
138#[derive(Debug, Clone)]
140pub struct RetryError<E> {
141 pub last_error: E,
143 pub attempts: u32,
145}
146
147impl<E: std::fmt::Display> std::fmt::Display for RetryError<E> {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 write!(
150 f,
151 "Failed after {} attempts: {}",
152 self.attempts, self.last_error
153 )
154 }
155}
156
157impl<E: std::error::Error + 'static> std::error::Error for RetryError<E> {
158 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
159 Some(&self.last_error)
160 }
161}
162
163pub async fn with_retry<T, E, F, Fut>(
165 operation: F,
166 config: &RetryConfig,
167 is_retryable: impl Fn(&E) -> bool,
168 on_retry: Option<impl Fn(u32, &E, u64)>,
169) -> Result<T, RetryError<E>>
170where
171 F: Fn() -> Fut,
172 Fut: Future<Output = Result<T, E>>,
173{
174 let mut last_error: Option<E> = None;
175
176 for attempt in 0..=config.max_retries {
177 match operation().await {
178 Ok(result) => return Ok(result),
179 Err(error) => {
180 if attempt == config.max_retries {
182 return Err(RetryError {
183 last_error: error,
184 attempts: attempt + 1,
185 });
186 }
187
188 if !is_retryable(&error) {
190 return Err(RetryError {
191 last_error: error,
192 attempts: attempt + 1,
193 });
194 }
195
196 let delay = calculate_retry_delay(attempt, config);
198
199 if let Some(ref callback) = on_retry {
201 callback(attempt + 1, &error, delay);
202 }
203
204 last_error = Some(error);
205
206 sleep(Duration::from_millis(delay)).await;
208 }
209 }
210 }
211
212 Err(RetryError {
213 last_error: last_error.unwrap(),
214 attempts: config.max_retries + 1,
215 })
216}
217
218pub async fn retry<T, E, F, Fut>(operation: F, config: &RetryConfig) -> Result<T, RetryError<E>>
220where
221 F: Fn() -> Fut,
222 Fut: Future<Output = Result<T, E>>,
223 E: std::fmt::Display,
224{
225 with_retry(
226 operation,
227 config,
228 |e| is_retryable_error(&e.to_string(), None, config),
229 None::<fn(u32, &E, u64)>,
230 )
231 .await
232}