1use anyhow::Result;
4use std::time::Duration;
5use tokio::time::sleep;
6
7#[derive(Debug, Clone)]
9pub struct RetryConfig {
10 pub max_retries: usize,
12 pub initial_delay: Duration,
14 pub max_delay: Duration,
16 pub backoff_multiplier: f64,
18}
19
20impl RetryConfig {
21 pub fn new() -> Self {
23 Self {
24 max_retries: 2,
25 initial_delay: Duration::from_secs(1),
26 max_delay: Duration::from_secs(30),
27 backoff_multiplier: 2.0,
28 }
29 }
30
31 pub fn with_settings(
33 max_retries: usize,
34 initial_delay: Duration,
35 max_delay: Duration,
36 backoff_multiplier: f64,
37 ) -> Self {
38 Self {
39 max_retries,
40 initial_delay,
41 max_delay,
42 backoff_multiplier,
43 }
44 }
45
46 pub fn no_retry() -> Self {
48 Self {
49 max_retries: 0,
50 initial_delay: Duration::from_secs(0),
51 max_delay: Duration::from_secs(0),
52 backoff_multiplier: 1.0,
53 }
54 }
55
56 fn calculate_delay(&self, attempt: usize) -> Duration {
58 if attempt == 0 {
59 return self.initial_delay;
60 }
61
62 let delay_secs = self.initial_delay.as_secs_f64()
63 * self.backoff_multiplier.powi(attempt as i32);
64
65 let delay = Duration::from_secs_f64(delay_secs);
66
67 if delay > self.max_delay {
69 self.max_delay
70 } else {
71 delay
72 }
73 }
74}
75
76impl Default for RetryConfig {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82pub async fn retry_async<F, Fut, T>(config: &RetryConfig, mut f: F) -> Result<T>
84where
85 F: FnMut() -> Fut,
86 Fut: std::future::Future<Output = Result<T>>,
87{
88 let mut last_error = None;
89
90 for attempt in 0..=config.max_retries {
91 match f().await {
92 Ok(result) => {
93 if attempt > 0 {
94 tracing::info!("Retry successful after {} attempt(s)", attempt);
95 }
96 return Ok(result);
97 }
98 Err(e) => {
99 last_error = Some(e);
100
101 if attempt < config.max_retries {
102 let delay = config.calculate_delay(attempt);
103 tracing::warn!(
104 "Attempt {} failed, retrying in {:?}...",
105 attempt + 1,
106 delay
107 );
108 sleep(delay).await;
109 } else {
110 tracing::error!("All {} retry attempts failed", config.max_retries + 1);
111 }
112 }
113 }
114 }
115
116 Err(last_error.unwrap())
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum ErrorType {
123 Transient,
125 Permanent,
127}
128
129pub fn classify_error(error: &anyhow::Error) -> ErrorType {
131 let error_msg = error.to_string().to_lowercase();
132
133 if error_msg.contains("timeout")
135 || error_msg.contains("connection")
136 || error_msg.contains("temporarily unavailable")
137 || error_msg.contains("too many open files")
138 || error_msg.contains("resource temporarily unavailable")
139 {
140 return ErrorType::Transient;
141 }
142
143 if error_msg.contains("file not found")
145 || error_msg.contains("permission denied")
146 || error_msg.contains("invalid")
147 || error_msg.contains("unsupported")
148 || error_msg.contains("corrupted")
149 {
150 return ErrorType::Permanent;
151 }
152
153 ErrorType::Transient
155}
156
157pub async fn smart_retry_async<F, Fut, T>(config: &RetryConfig, mut f: F) -> Result<T>
159where
160 F: FnMut() -> Fut,
161 Fut: std::future::Future<Output = Result<T>>,
162{
163 let mut last_error = None;
164
165 for attempt in 0..=config.max_retries {
166 match f().await {
167 Ok(result) => {
168 if attempt > 0 {
169 tracing::info!("Smart retry successful after {} attempt(s)", attempt);
170 }
171 return Ok(result);
172 }
173 Err(e) => {
174 let error_type = classify_error(&e);
175
176 if error_type == ErrorType::Permanent {
177 tracing::error!("Permanent error detected, not retrying: {}", e);
178 return Err(e);
179 }
180
181 last_error = Some(e);
182
183 if attempt < config.max_retries {
184 let delay = config.calculate_delay(attempt);
185 tracing::warn!(
186 "Transient error (attempt {}), retrying in {:?}...",
187 attempt + 1,
188 delay
189 );
190 sleep(delay).await;
191 } else {
192 tracing::error!(
193 "All {} retry attempts failed for transient error",
194 config.max_retries + 1
195 );
196 }
197 }
198 }
199 }
200
201 Err(last_error.unwrap())
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use std::sync::atomic::{AtomicUsize, Ordering};
208 use std::sync::Arc;
209
210 #[test]
211 fn test_retry_config_creation() {
212 let config = RetryConfig::new();
213 assert_eq!(config.max_retries, 2);
214 assert_eq!(config.initial_delay, Duration::from_secs(1));
215 assert_eq!(config.backoff_multiplier, 2.0);
216 }
217
218 #[test]
219 fn test_retry_config_no_retry() {
220 let config = RetryConfig::no_retry();
221 assert_eq!(config.max_retries, 0);
222 }
223
224 #[test]
225 fn test_calculate_delay() {
226 let config = RetryConfig::new();
227
228 assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
229 assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
230 assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
231 assert_eq!(config.calculate_delay(3), Duration::from_secs(8));
232
233 let config = RetryConfig::with_settings(
235 5,
236 Duration::from_secs(1),
237 Duration::from_secs(5),
238 2.0,
239 );
240 assert_eq!(config.calculate_delay(10), Duration::from_secs(5));
241 }
242
243 #[tokio::test]
244 async fn test_retry_async_success_first_try() {
245 let config = RetryConfig::new();
246 let counter = Arc::new(AtomicUsize::new(0));
247
248 let result: Result<i32> = retry_async(&config, || {
249 let counter = Arc::clone(&counter);
250 async move {
251 counter.fetch_add(1, Ordering::Relaxed);
252 Ok::<i32, anyhow::Error>(42)
253 }
254 })
255 .await;
256
257 assert!(result.is_ok());
258 assert_eq!(result.unwrap(), 42);
259 assert_eq!(counter.load(Ordering::Relaxed), 1);
260 }
261
262 #[tokio::test]
263 async fn test_retry_async_success_after_retries() {
264 let config = RetryConfig::with_settings(
265 3,
266 Duration::from_millis(10),
267 Duration::from_millis(100),
268 2.0,
269 );
270 let counter = Arc::new(AtomicUsize::new(0));
271
272 let result = retry_async(&config, || {
273 let counter = Arc::clone(&counter);
274 async move {
275 let count = counter.fetch_add(1, Ordering::Relaxed);
276 if count < 2 {
277 anyhow::bail!("Transient error");
278 }
279 Ok::<i32, anyhow::Error>(42)
280 }
281 })
282 .await;
283
284 assert!(result.is_ok());
285 assert_eq!(result.unwrap(), 42);
286 assert_eq!(counter.load(Ordering::Relaxed), 3);
287 }
288
289 #[tokio::test]
290 async fn test_retry_async_all_fail() {
291 let config = RetryConfig::with_settings(
292 2,
293 Duration::from_millis(10),
294 Duration::from_millis(100),
295 2.0,
296 );
297 let counter = Arc::new(AtomicUsize::new(0));
298
299 let result: Result<i32> = retry_async(&config, || {
300 let counter = Arc::clone(&counter);
301 async move {
302 counter.fetch_add(1, Ordering::Relaxed);
303 anyhow::bail!("Always fails")
304 }
305 })
306 .await;
307
308 assert!(result.is_err());
309 assert_eq!(counter.load(Ordering::Relaxed), 3); }
311
312 #[test]
313 fn test_classify_error() {
314 let transient = anyhow::anyhow!("Connection timeout");
315 assert_eq!(classify_error(&transient), ErrorType::Transient);
316
317 let permanent = anyhow::anyhow!("File not found");
318 assert_eq!(classify_error(&permanent), ErrorType::Permanent);
319
320 let unknown = anyhow::anyhow!("Some random error");
321 assert_eq!(classify_error(&unknown), ErrorType::Transient);
322 }
323
324 #[tokio::test]
325 async fn test_smart_retry_permanent_error() {
326 let config = RetryConfig::new();
327 let counter = Arc::new(AtomicUsize::new(0));
328
329 let result: Result<i32> = smart_retry_async(&config, || {
330 let counter = Arc::clone(&counter);
331 async move {
332 counter.fetch_add(1, Ordering::Relaxed);
333 anyhow::bail!("File not found")
334 }
335 })
336 .await;
337
338 assert!(result.is_err());
339 assert_eq!(counter.load(Ordering::Relaxed), 1); }
341
342 #[tokio::test]
343 async fn test_smart_retry_transient_error() {
344 let config = RetryConfig::with_settings(
345 2,
346 Duration::from_millis(10),
347 Duration::from_millis(100),
348 2.0,
349 );
350 let counter = Arc::new(AtomicUsize::new(0));
351
352 let result = smart_retry_async(&config, || {
353 let counter = Arc::clone(&counter);
354 async move {
355 let count = counter.fetch_add(1, Ordering::Relaxed);
356 if count < 2 {
357 anyhow::bail!("Connection timeout");
358 }
359 Ok::<i32, anyhow::Error>(42)
360 }
361 })
362 .await;
363
364 assert!(result.is_ok());
365 assert_eq!(result.unwrap(), 42);
366 assert_eq!(counter.load(Ordering::Relaxed), 3);
367 }
368}