1use std::time::Duration;
4use tokio::time::{sleep, timeout};
5
6use crate::sources::SourceError;
7
8#[derive(Debug, Clone, Copy)]
10pub struct RetryConfig {
11 pub max_attempts: u32,
13 pub initial_delay: Duration,
15 pub max_delay: Duration,
17 pub backoff_multiplier: f64,
19 pub max_total_time: Duration,
21}
22
23impl Default for RetryConfig {
24 fn default() -> Self {
25 Self {
26 max_attempts: 3,
27 initial_delay: Duration::from_secs(1),
28 max_delay: Duration::from_secs(60),
29 backoff_multiplier: 2.0,
30 max_total_time: Duration::from_secs(120),
31 }
32 }
33}
34
35#[derive(Debug, Clone, PartialEq)]
37pub enum TransientError {
38 Network,
40 RateLimit(Option<u64>),
42 ServerError,
44 ServiceUnavailable,
46 GatewayTimeout,
48 TooManyRequests,
50 Timeout,
52}
53
54impl TransientError {
55 pub fn from_reqwest_error(err: &reqwest::Error) -> Option<Self> {
57 if err.is_timeout() {
58 return Some(TransientError::Timeout);
59 }
60 if err.is_connect() {
61 return Some(TransientError::Network);
62 }
63
64 if let Some(status) = err.status() {
65 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
66 return Some(TransientError::TooManyRequests);
67 }
68
69 if status == reqwest::StatusCode::SERVICE_UNAVAILABLE {
70 return Some(TransientError::ServiceUnavailable);
71 }
72
73 if status == reqwest::StatusCode::GATEWAY_TIMEOUT {
74 return Some(TransientError::GatewayTimeout);
75 }
76
77 if status.is_server_error() {
78 return Some(TransientError::ServerError);
79 }
80 }
81
82 None
83 }
84
85 pub fn from_source_error(err: &SourceError) -> Option<Self> {
87 match err {
88 SourceError::RateLimit => Some(TransientError::RateLimit(None)),
89 SourceError::Network(_) => Some(TransientError::Network),
90 SourceError::Api(msg) => {
91 let msg_lower = msg.to_lowercase();
93 if msg_lower.contains("timeout") {
94 Some(TransientError::Timeout)
95 } else if msg_lower.contains("service unavailable")
96 || msg_lower.contains("temporarily unavailable")
97 {
98 Some(TransientError::ServiceUnavailable)
99 } else {
100 None
101 }
102 }
103 _ => None,
104 }
105 }
106
107 pub fn recommended_delay(&self) -> Duration {
109 match self {
110 TransientError::RateLimit(Some(seconds)) => Duration::from_secs(*seconds + 1),
111 TransientError::RateLimit(None) => Duration::from_secs(61),
112 TransientError::TooManyRequests => Duration::from_secs(61),
113 TransientError::ServiceUnavailable => Duration::from_secs(10),
114 TransientError::GatewayTimeout => Duration::from_secs(5),
115 TransientError::Timeout => Duration::from_secs(2),
116 TransientError::Network => Duration::from_secs(2),
117 TransientError::ServerError => Duration::from_secs(2),
118 }
119 }
120}
121
122pub enum RetryResult<T> {
124 Success(T),
126 TransientFailure(SourceError, TransientError, u32),
128 PermanentFailure(SourceError),
130}
131
132pub async fn with_retry<T, F, Fut>(config: RetryConfig, operation: F) -> Result<T, SourceError>
143where
144 F: FnMut() -> Fut,
145 Fut: std::future::Future<Output = Result<T, SourceError>>,
146{
147 let mut attempts = 0;
148 let mut total_elapsed = Duration::ZERO;
149 let mut operation = operation;
150
151 loop {
152 attempts += 1;
153
154 match timeout(config.max_total_time, operation()).await {
155 Ok(Ok(result)) => {
156 if attempts > 1 {
158 tracing::info!(
159 "Operation succeeded on attempt {} after {} transient failures",
160 attempts,
161 attempts - 1
162 );
163 }
164 return Ok(result);
165 }
166 Ok(Err(error)) => {
167 if let Some(transient) = TransientError::from_source_error(&error) {
169 let delay = if attempts == 1 {
171 config.initial_delay
172 } else {
173 let exp_delay = config.initial_delay.as_secs_f64()
174 * config.backoff_multiplier.powf(attempts as f64 - 1.0);
175 let delay_secs = exp_delay.min(config.max_delay.as_secs_f64());
176 Duration::from_secs_f64(delay_secs)
177 };
178
179 let delay = std::cmp::max(delay, transient.recommended_delay());
181
182 total_elapsed += delay;
183
184 if attempts >= config.max_attempts || total_elapsed >= config.max_total_time {
185 tracing::warn!(
186 "Operation failed after {} attempts (total elapsed: {:?}): {}",
187 attempts,
188 total_elapsed,
189 error
190 );
191 return Err(error);
192 }
193
194 tracing::debug!(
195 "Transient error on attempt {}: {:?}, retrying in {:?}",
196 attempts,
197 transient,
198 delay
199 );
200
201 sleep(delay).await;
202 continue;
203 } else {
204 return Err(error);
206 }
207 }
208 Err(_) => {
209 let error = SourceError::Network("Operation timed out".to_string());
211 if attempts >= config.max_attempts {
212 return Err(error);
213 }
214
215 let delay = config.initial_delay;
216 total_elapsed += delay;
217
218 tracing::debug!(
219 "Operation timed out, attempt {}/{}",
220 attempts,
221 config.max_attempts
222 );
223 sleep(delay).await;
224 }
225 }
226 }
227}
228
229pub async fn with_retry_detailed<T, F, Fut>(config: RetryConfig, operation: F) -> RetryResult<T>
233where
234 F: FnMut() -> Fut,
235 Fut: std::future::Future<Output = Result<T, SourceError>>,
236{
237 let mut attempts = 0;
238 let mut total_elapsed = Duration::ZERO;
239 let mut operation = operation;
240
241 loop {
242 attempts += 1;
243
244 match timeout(config.max_total_time, operation()).await {
245 Ok(Ok(result)) => {
246 return RetryResult::Success(result);
247 }
248 Ok(Err(error)) => {
249 if let Some(transient) = TransientError::from_source_error(&error) {
250 let delay = if attempts == 1 {
251 config.initial_delay
252 } else {
253 let exp_delay = config.initial_delay.as_secs_f64()
254 * config.backoff_multiplier.powf(attempts as f64 - 1.0);
255 Duration::from_secs_f64(exp_delay.min(config.max_delay.as_secs_f64()))
256 };
257
258 let delay = std::cmp::max(delay, transient.recommended_delay());
259 total_elapsed += delay;
260
261 if attempts >= config.max_attempts || total_elapsed >= config.max_total_time {
262 return RetryResult::TransientFailure(error, transient, attempts);
263 }
264
265 sleep(delay).await;
266 continue;
267 } else {
268 return RetryResult::PermanentFailure(error);
269 }
270 }
271 Err(_) => {
272 let error = SourceError::Network("Operation timed out".to_string());
273 if attempts >= config.max_attempts {
274 return RetryResult::TransientFailure(error, TransientError::Timeout, attempts);
275 }
276
277 let delay = config.initial_delay;
278 total_elapsed += delay;
279 sleep(delay).await;
280 }
281 }
282 }
283}
284
285pub fn api_retry_config() -> RetryConfig {
287 RetryConfig {
288 max_attempts: 5,
289 initial_delay: Duration::from_secs(2),
290 max_delay: Duration::from_secs(120),
291 backoff_multiplier: 2.0,
292 max_total_time: Duration::from_secs(300),
293 }
294}
295
296pub fn strict_rate_limit_retry_config() -> RetryConfig {
298 RetryConfig {
299 max_attempts: 3,
300 initial_delay: Duration::from_secs(2),
301 max_delay: Duration::from_secs(120),
302 backoff_multiplier: 2.0,
303 max_total_time: Duration::from_secs(180),
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use std::cell::RefCell;
311 use std::rc::Rc;
312
313 #[tokio::test]
314 async fn test_retry_success_first_try() {
315 let config = RetryConfig::default();
316 let call_count = Rc::new(RefCell::new(0));
317
318 let result = {
319 let call_count = call_count.clone();
320 with_retry(config, move || {
321 let call_count = call_count.clone();
322 async move {
323 *call_count.borrow_mut() += 1;
324 Ok("success")
325 }
326 })
327 }
328 .await;
329
330 assert_eq!(result.unwrap(), "success");
331 assert_eq!(*call_count.borrow(), 1);
332 }
333
334 #[tokio::test]
335 async fn test_retry_success_after_failures() {
336 let config = RetryConfig {
338 max_attempts: 4, initial_delay: Duration::from_millis(10),
340 max_delay: Duration::from_millis(100),
341 backoff_multiplier: 2.0,
342 max_total_time: Duration::from_secs(10),
343 };
344 let call_count = Rc::new(RefCell::new(0));
345
346 let result = {
347 let call_count = call_count.clone();
348 with_retry(config, move || {
349 let call_count = call_count.clone();
350 async move {
351 *call_count.borrow_mut() += 1;
352 let count = *call_count.borrow();
353 if count < 3 {
354 Err(SourceError::Network("temporary error".to_string()))
356 } else {
357 Ok("success")
359 }
360 }
361 })
362 }
363 .await;
364
365 assert_eq!(result.unwrap(), "success");
366 assert_eq!(*call_count.borrow(), 3);
367 }
368
369 #[tokio::test]
370 async fn test_retry_returns_permanent_error() {
371 let config = RetryConfig {
372 max_attempts: 5,
373 initial_delay: Duration::from_millis(10),
374 max_delay: Duration::from_millis(50),
375 backoff_multiplier: 2.0,
376 max_total_time: Duration::from_secs(5),
377 };
378 let call_count = Rc::new(RefCell::new(0));
379
380 let result: Result<&str, SourceError> = {
381 let call_count = call_count.clone();
382 with_retry(config, move || {
383 let call_count = call_count.clone();
384 async move {
385 *call_count.borrow_mut() += 1;
386 Err(SourceError::NotFound("not found".to_string()))
387 }
388 })
389 }
390 .await;
391
392 assert!(result.is_err());
393 if let Err(e) = result {
394 match e {
395 SourceError::NotFound(_) => {} _ => panic!("Expected NotFound error"),
397 }
398 }
399 assert_eq!(*call_count.borrow(), 1); }
401
402 #[test]
403 fn test_transient_error_detection() {
404 let rate_limit_error = SourceError::RateLimit;
406 assert!(TransientError::from_source_error(&rate_limit_error).is_some());
407
408 let network_error = SourceError::Network("connection refused".to_string());
410 assert!(TransientError::from_source_error(&network_error).is_some());
411
412 let parse_error = SourceError::Parse("invalid json".to_string());
414 assert!(TransientError::from_source_error(&parse_error).is_none());
415 }
416
417 #[test]
418 fn test_recommended_delay() {
419 assert_eq!(
420 TransientError::RateLimit(Some(30)).recommended_delay(),
421 Duration::from_secs(31)
422 );
423
424 assert_eq!(
425 TransientError::RateLimit(None).recommended_delay(),
426 Duration::from_secs(61)
427 );
428
429 assert_eq!(
430 TransientError::ServiceUnavailable.recommended_delay(),
431 Duration::from_secs(10)
432 );
433
434 assert_eq!(
435 TransientError::Network.recommended_delay(),
436 Duration::from_secs(2)
437 );
438 }
439}