1use std::time::Duration;
7
8use crate::error::{CloudError, Result, RetryError};
9
10pub const DEFAULT_MAX_RETRIES: usize = 3;
12
13pub const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(100);
15
16pub const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(30);
18
19pub const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
21
22#[derive(Debug, Clone)]
24pub struct RetryConfig {
25 pub max_retries: usize,
27 pub initial_backoff: Duration,
29 pub max_backoff: Duration,
31 pub backoff_multiplier: f64,
33 pub jitter: bool,
35 pub circuit_breaker: bool,
37 pub circuit_breaker_threshold: usize,
39 pub circuit_breaker_timeout: Duration,
41}
42
43impl Default for RetryConfig {
44 fn default() -> Self {
45 Self {
46 max_retries: DEFAULT_MAX_RETRIES,
47 initial_backoff: DEFAULT_INITIAL_BACKOFF,
48 max_backoff: DEFAULT_MAX_BACKOFF,
49 backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
50 jitter: true,
51 circuit_breaker: true,
52 circuit_breaker_threshold: 5,
53 circuit_breaker_timeout: Duration::from_secs(60),
54 }
55 }
56}
57
58impl RetryConfig {
59 #[must_use]
61 pub fn new() -> Self {
62 Self::default()
63 }
64
65 #[must_use]
67 pub fn with_max_retries(mut self, max_retries: usize) -> Self {
68 self.max_retries = max_retries;
69 self
70 }
71
72 #[must_use]
74 pub fn with_initial_backoff(mut self, duration: Duration) -> Self {
75 self.initial_backoff = duration;
76 self
77 }
78
79 #[must_use]
81 pub fn with_max_backoff(mut self, duration: Duration) -> Self {
82 self.max_backoff = duration;
83 self
84 }
85
86 #[must_use]
88 pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
89 self.backoff_multiplier = multiplier;
90 self
91 }
92
93 #[must_use]
95 pub fn with_jitter(mut self, jitter: bool) -> Self {
96 self.jitter = jitter;
97 self
98 }
99
100 #[must_use]
102 pub fn with_circuit_breaker(mut self, enabled: bool) -> Self {
103 self.circuit_breaker = enabled;
104 self
105 }
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum CircuitState {
111 Closed,
113 Open,
115 HalfOpen,
117}
118
119#[derive(Debug)]
121pub struct CircuitBreaker {
122 state: CircuitState,
124 failure_count: usize,
126 threshold: usize,
128 timeout: Duration,
130 last_failure: Option<std::time::Instant>,
132}
133
134impl CircuitBreaker {
135 #[must_use]
137 pub fn new(threshold: usize, timeout: Duration) -> Self {
138 Self {
139 state: CircuitState::Closed,
140 failure_count: 0,
141 threshold,
142 timeout,
143 last_failure: None,
144 }
145 }
146
147 pub fn allow_request(&mut self) -> Result<()> {
149 match self.state {
150 CircuitState::Closed => Ok(()),
151 CircuitState::Open => {
152 if let Some(last_failure) = self.last_failure {
154 if last_failure.elapsed() >= self.timeout {
155 tracing::info!("Circuit breaker transitioning to half-open state");
156 self.state = CircuitState::HalfOpen;
157 Ok(())
158 } else {
159 Err(CloudError::Retry(RetryError::CircuitBreakerOpen {
160 message: "Circuit breaker is open".to_string(),
161 }))
162 }
163 } else {
164 Ok(())
165 }
166 }
167 CircuitState::HalfOpen => Ok(()),
168 }
169 }
170
171 pub fn record_success(&mut self) {
173 match self.state {
174 CircuitState::Closed => {
175 self.failure_count = 0;
176 }
177 CircuitState::HalfOpen => {
178 tracing::info!("Circuit breaker transitioning to closed state");
179 self.state = CircuitState::Closed;
180 self.failure_count = 0;
181 }
182 CircuitState::Open => {}
183 }
184 }
185
186 pub fn record_failure(&mut self) {
188 self.failure_count += 1;
189 self.last_failure = Some(std::time::Instant::now());
190
191 if self.failure_count >= self.threshold && self.state != CircuitState::Open {
192 tracing::warn!(
193 "Circuit breaker opening after {} failures",
194 self.failure_count
195 );
196 self.state = CircuitState::Open;
197 }
198 }
199
200 #[must_use]
202 pub fn state(&self) -> CircuitState {
203 self.state
204 }
205}
206
207#[derive(Debug)]
209pub struct RetryBudget {
210 tokens: usize,
212 max_tokens: usize,
214 refill_rate: f64,
216 last_refill: std::time::Instant,
218}
219
220impl RetryBudget {
221 #[must_use]
223 pub fn new(max_tokens: usize, refill_rate: f64) -> Self {
224 Self {
225 tokens: max_tokens,
226 max_tokens,
227 refill_rate,
228 last_refill: std::time::Instant::now(),
229 }
230 }
231
232 pub fn try_consume(&mut self) -> Result<()> {
234 self.refill();
235
236 if self.tokens > 0 {
237 self.tokens -= 1;
238 Ok(())
239 } else {
240 Err(CloudError::Retry(RetryError::BudgetExhausted {
241 message: "Retry budget exhausted".to_string(),
242 }))
243 }
244 }
245
246 fn refill(&mut self) {
248 let elapsed = self.last_refill.elapsed();
249 let tokens_to_add = (elapsed.as_secs_f64() * self.refill_rate) as usize;
250
251 if tokens_to_add > 0 {
252 self.tokens = (self.tokens + tokens_to_add).min(self.max_tokens);
253 self.last_refill = std::time::Instant::now();
254 }
255 }
256}
257
258#[derive(Debug)]
260pub struct Backoff {
261 config: RetryConfig,
263 attempt: usize,
265}
266
267impl Backoff {
268 #[must_use]
270 pub fn new(config: RetryConfig) -> Self {
271 Self { config, attempt: 0 }
272 }
273
274 #[must_use]
276 pub fn next(&mut self) -> Duration {
277 let base = self.config.initial_backoff.as_secs_f64().mul_add(
278 self.config.backoff_multiplier.powi(self.attempt as i32),
279 0.0,
280 );
281
282 let backoff = if self.config.jitter {
283 let jitter_factor = 1.0 + (rand() * 0.5);
285 base * jitter_factor
286 } else {
287 base
288 };
289
290 self.attempt += 1;
291
292 Duration::from_secs_f64(backoff.min(self.config.max_backoff.as_secs_f64()))
293 }
294
295 pub fn reset(&mut self) {
297 self.attempt = 0;
298 }
299}
300
301fn rand() -> f64 {
304 use std::sync::atomic::{AtomicU64, Ordering};
305 static SEED: AtomicU64 = AtomicU64::new(0);
306
307 let seed = SEED.load(Ordering::Relaxed);
308 let next = seed.wrapping_mul(1664525).wrapping_add(1013904223);
309 SEED.store(next, Ordering::Relaxed);
310
311 (next >> 32) as f64 / u32::MAX as f64
312}
313
314#[must_use]
316pub fn is_retryable(error: &CloudError) -> bool {
317 match error {
318 CloudError::Timeout { .. } => true,
319 CloudError::RateLimitExceeded { .. } => true,
320 CloudError::Http(http_error) => match http_error {
321 crate::error::HttpError::Network { .. } => true,
322 crate::error::HttpError::Status { status, .. } => {
323 matches!(
325 *status,
326 500 | 502 | 503 | 504 | 408 | 429 )
328 }
329 _ => false,
330 },
331 CloudError::S3(s3_error) => match s3_error {
332 crate::error::S3Error::Sdk { .. } => true,
333 _ => false,
334 },
335 CloudError::Azure(azure_error) => match azure_error {
336 crate::error::AzureError::Sdk { .. } => true,
337 _ => false,
338 },
339 CloudError::Gcs(gcs_error) => match gcs_error {
340 crate::error::GcsError::Sdk { .. } => true,
341 _ => false,
342 },
343 CloudError::Io(_) => true,
344 _ => false,
345 }
346}
347
348#[cfg(feature = "async")]
350pub struct RetryExecutor {
351 config: RetryConfig,
353 circuit_breaker: Option<CircuitBreaker>,
355 retry_budget: Option<RetryBudget>,
357}
358
359#[cfg(feature = "async")]
360impl RetryExecutor {
361 #[must_use]
363 pub fn new(config: RetryConfig) -> Self {
364 let circuit_breaker = if config.circuit_breaker {
365 Some(CircuitBreaker::new(
366 config.circuit_breaker_threshold,
367 config.circuit_breaker_timeout,
368 ))
369 } else {
370 None
371 };
372
373 let retry_budget = Some(RetryBudget::new(100, 10.0)); Self {
376 config,
377 circuit_breaker,
378 retry_budget,
379 }
380 }
381
382 pub async fn execute<F, Fut, T>(&mut self, mut operation: F) -> Result<T>
384 where
385 F: FnMut() -> Fut,
386 Fut: std::future::Future<Output = Result<T>>,
387 {
388 if let Some(ref mut cb) = self.circuit_breaker {
390 cb.allow_request()?;
391 }
392
393 let mut backoff = Backoff::new(self.config.clone());
394 let mut attempts = 0;
395
396 loop {
397 match operation().await {
398 Ok(result) => {
399 if let Some(ref mut cb) = self.circuit_breaker {
401 cb.record_success();
402 }
403 return Ok(result);
404 }
405 Err(error) => {
406 attempts += 1;
407
408 if !is_retryable(&error) {
410 tracing::warn!("Non-retryable error: {}", error);
411 return Err(error);
412 }
413
414 if attempts > self.config.max_retries {
416 tracing::error!("Max retries ({}) exceeded", self.config.max_retries);
417 if let Some(ref mut cb) = self.circuit_breaker {
418 cb.record_failure();
419 }
420 return Err(CloudError::Retry(RetryError::MaxRetriesExceeded {
421 attempts,
422 }));
423 }
424
425 if let Some(ref mut budget) = self.retry_budget {
427 budget.try_consume()?;
428 }
429
430 let delay = backoff.next();
432 tracing::warn!(
433 "Retry attempt {}/{} after {:?}: {}",
434 attempts,
435 self.config.max_retries,
436 delay,
437 error
438 );
439
440 tokio::time::sleep(delay).await;
441 }
442 }
443 }
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_retry_config_builder() {
453 let config = RetryConfig::new()
454 .with_max_retries(5)
455 .with_initial_backoff(Duration::from_millis(50))
456 .with_backoff_multiplier(3.0)
457 .with_jitter(false);
458
459 assert_eq!(config.max_retries, 5);
460 assert_eq!(config.initial_backoff, Duration::from_millis(50));
461 assert_eq!(config.backoff_multiplier, 3.0);
462 assert!(!config.jitter);
463 }
464
465 #[test]
466 fn test_circuit_breaker_closed() {
467 let mut cb = CircuitBreaker::new(3, Duration::from_secs(60));
468 assert_eq!(cb.state, CircuitState::Closed);
469 assert!(cb.allow_request().is_ok());
470 }
471
472 #[test]
473 fn test_circuit_breaker_opens() {
474 let mut cb = CircuitBreaker::new(3, Duration::from_secs(60));
475
476 cb.record_failure();
478 cb.record_failure();
479 cb.record_failure();
480
481 assert_eq!(cb.state, CircuitState::Open);
482 assert!(cb.allow_request().is_err());
483 }
484
485 #[test]
486 fn test_circuit_breaker_half_open() {
487 let mut cb = CircuitBreaker::new(3, Duration::from_millis(10));
488
489 cb.record_failure();
491 cb.record_failure();
492 cb.record_failure();
493 assert_eq!(cb.state, CircuitState::Open);
494
495 std::thread::sleep(Duration::from_millis(20));
497
498 assert!(cb.allow_request().is_ok());
500 assert_eq!(cb.state, CircuitState::HalfOpen);
501
502 cb.record_success();
504 assert_eq!(cb.state, CircuitState::Closed);
505 }
506
507 #[test]
508 fn test_retry_budget() {
509 let mut budget = RetryBudget::new(10, 100.0);
512
513 for _ in 0..10 {
515 assert!(budget.try_consume().is_ok());
516 }
517
518 assert!(budget.try_consume().is_err());
520
521 std::thread::sleep(Duration::from_millis(50));
523
524 assert!(budget.try_consume().is_ok());
526 }
527
528 #[test]
529 fn test_backoff_exponential() {
530 let config = RetryConfig::new()
531 .with_initial_backoff(Duration::from_millis(100))
532 .with_backoff_multiplier(2.0)
533 .with_jitter(false);
534
535 let mut backoff = Backoff::new(config);
536
537 let d1 = backoff.next();
538 let d2 = backoff.next();
539 let d3 = backoff.next();
540
541 assert!(d1 < d2);
542 assert!(d2 < d3);
543 }
544
545 #[test]
546 fn test_is_retryable() {
547 let timeout_error = CloudError::Timeout {
548 message: "timeout".to_string(),
549 };
550 assert!(is_retryable(&timeout_error));
551
552 let rate_limit_error = CloudError::RateLimitExceeded {
553 message: "rate limit".to_string(),
554 };
555 assert!(is_retryable(&rate_limit_error));
556
557 let not_found_error = CloudError::NotFound {
558 key: "test".to_string(),
559 };
560 assert!(!is_retryable(¬_found_error));
561 }
562
563 #[cfg(feature = "async")]
564 #[tokio::test]
565 async fn test_retry_executor_success() {
566 use std::sync::atomic::{AtomicUsize, Ordering};
567
568 let config = RetryConfig::new().with_max_retries(3);
569 let mut executor = RetryExecutor::new(config);
570
571 let attempt = std::sync::Arc::new(AtomicUsize::new(0));
572 let attempt_clone = attempt.clone();
573 let result = executor
574 .execute(|| {
575 let attempt = attempt_clone.clone();
576 async move {
577 let current = attempt.fetch_add(1, Ordering::SeqCst) + 1;
578 if current < 2 {
579 Err(CloudError::Timeout {
580 message: "timeout".to_string(),
581 })
582 } else {
583 Ok(42)
584 }
585 }
586 })
587 .await;
588
589 assert!(result.is_ok());
590 assert_eq!(result.ok(), Some(42));
591 assert_eq!(attempt.load(Ordering::SeqCst), 2);
592 }
593
594 #[cfg(feature = "async")]
595 #[tokio::test]
596 async fn test_retry_executor_max_retries() {
597 use std::sync::atomic::{AtomicUsize, Ordering};
598
599 let config = RetryConfig::new().with_max_retries(2);
600 let mut executor = RetryExecutor::new(config);
601
602 let attempt = std::sync::Arc::new(AtomicUsize::new(0));
603 let attempt_clone = attempt.clone();
604 let result: Result<i32> = executor
605 .execute(|| {
606 let attempt = attempt_clone.clone();
607 async move {
608 attempt.fetch_add(1, Ordering::SeqCst);
609 Err(CloudError::Timeout {
610 message: "timeout".to_string(),
611 })
612 }
613 })
614 .await;
615
616 assert!(result.is_err());
617 assert_eq!(attempt.load(Ordering::SeqCst), 3); }
619}