1use anyhow::{anyhow, Result};
13use rand::Rng;
14use serde::{Deserialize, Serialize};
15use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct RetryConfig {
22 #[serde(default = "default_max_retries")]
24 pub max_retries: u32,
25
26 #[serde(default = "default_base_delay_ms")]
28 pub base_delay_ms: u64,
29
30 #[serde(default = "default_max_delay_ms")]
32 pub max_delay_ms: u64,
33
34 #[serde(default = "default_circuit_breaker_enabled")]
36 pub circuit_breaker_enabled: bool,
37
38 #[serde(default = "default_circuit_breaker_threshold")]
40 pub circuit_breaker_threshold: u32,
41
42 #[serde(default = "default_circuit_breaker_cooldown_ms")]
44 pub circuit_breaker_cooldown_ms: u64,
45}
46
47fn default_max_retries() -> u32 {
48 5
49}
50fn default_base_delay_ms() -> u64 {
51 100
52}
53fn default_max_delay_ms() -> u64 {
54 30_000
55}
56fn default_circuit_breaker_enabled() -> bool {
57 true
58}
59fn default_circuit_breaker_threshold() -> u32 {
60 10
61}
62fn default_circuit_breaker_cooldown_ms() -> u64 {
63 60_000
64}
65
66impl Default for RetryConfig {
67 fn default() -> Self {
68 Self {
69 max_retries: default_max_retries(),
70 base_delay_ms: default_base_delay_ms(),
71 max_delay_ms: default_max_delay_ms(),
72 circuit_breaker_enabled: default_circuit_breaker_enabled(),
73 circuit_breaker_threshold: default_circuit_breaker_threshold(),
74 circuit_breaker_cooldown_ms: default_circuit_breaker_cooldown_ms(),
75 }
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum ErrorKind {
82 Transient,
84 ClientError,
86 AuthError,
88 NotFound,
90 Unknown,
92}
93
94pub fn classify_error(error: &anyhow::Error) -> ErrorKind {
96 let error_str = error.to_string().to_lowercase();
97
98 if error_str.contains("500")
100 || error_str.contains("502")
101 || error_str.contains("503")
102 || error_str.contains("504")
103 || error_str.contains("internal server error")
104 || error_str.contains("bad gateway")
105 || error_str.contains("service unavailable")
106 || error_str.contains("gateway timeout")
107 {
108 return ErrorKind::Transient;
109 }
110
111 if error_str.contains("timeout")
113 || error_str.contains("timed out")
114 || error_str.contains("connection")
115 || error_str.contains("network")
116 || error_str.contains("socket")
117 || error_str.contains("reset")
118 || error_str.contains("broken pipe")
119 || error_str.contains("eof")
120 || error_str.contains("temporarily unavailable")
121 {
122 return ErrorKind::Transient;
123 }
124
125 if error_str.contains("throttl")
127 || error_str.contains("slowdown")
128 || error_str.contains("reduce your request rate")
129 || error_str.contains("request rate exceeded")
130 {
131 return ErrorKind::Transient;
132 }
133
134 if error_str.contains("dispatch failure") {
136 return ErrorKind::Transient;
137 }
138
139 if error_str.contains("service unavailable (injected)") {
141 return ErrorKind::Transient;
142 }
143
144 if error_str.contains("400") || error_str.contains("bad request") {
146 return ErrorKind::ClientError;
147 }
148
149 if error_str.contains("401")
151 || error_str.contains("403")
152 || error_str.contains("unauthorized")
153 || error_str.contains("forbidden")
154 || error_str.contains("access denied")
155 || error_str.contains("invalid credentials")
156 || error_str.contains("expired token")
157 {
158 return ErrorKind::AuthError;
159 }
160
161 if error_str.contains("404")
163 || error_str.contains("not found")
164 || error_str.contains("no such key")
165 {
166 return ErrorKind::NotFound;
167 }
168
169 ErrorKind::Unknown
170}
171
172pub fn is_retryable(error: &anyhow::Error) -> bool {
174 matches!(
175 classify_error(error),
176 ErrorKind::Transient | ErrorKind::Unknown
177 )
178}
179
180#[derive(Debug, Clone, Copy, PartialEq, Eq)]
182pub enum CircuitState {
183 Closed,
185 Open,
187 HalfOpen,
189}
190
191pub type OnCircuitOpen = Arc<dyn Fn(u32) + Send + Sync>;
193
194pub struct CircuitBreaker {
196 consecutive_failures: AtomicU32,
197 threshold: u32,
198 opened_at_ms: AtomicU64,
200 cooldown_ms: u64,
201 on_open: Option<OnCircuitOpen>,
203}
204
205impl std::fmt::Debug for CircuitBreaker {
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 f.debug_struct("CircuitBreaker")
208 .field("consecutive_failures", &self.consecutive_failures)
209 .field("threshold", &self.threshold)
210 .field("opened_at_ms", &self.opened_at_ms)
211 .field("cooldown_ms", &self.cooldown_ms)
212 .field("on_open", &self.on_open.as_ref().map(|_| "..."))
213 .finish()
214 }
215}
216
217impl CircuitBreaker {
218 pub fn new(threshold: u32, cooldown_ms: u64) -> Self {
220 Self {
221 consecutive_failures: AtomicU32::new(0),
222 threshold,
223 opened_at_ms: AtomicU64::new(0),
224 cooldown_ms,
225 on_open: None,
226 }
227 }
228
229 pub fn with_on_open(threshold: u32, cooldown_ms: u64, on_open: OnCircuitOpen) -> Self {
231 Self {
232 consecutive_failures: AtomicU32::new(0),
233 threshold,
234 opened_at_ms: AtomicU64::new(0),
235 cooldown_ms,
236 on_open: Some(on_open),
237 }
238 }
239
240 pub fn state(&self) -> CircuitState {
242 let failures = self.consecutive_failures.load(Ordering::Relaxed);
243 let opened_at = self.opened_at_ms.load(Ordering::Relaxed);
244
245 if failures < self.threshold {
246 return CircuitState::Closed;
247 }
248
249 if opened_at == 0 {
250 return CircuitState::Closed;
251 }
252
253 let now_ms = std::time::SystemTime::now()
254 .duration_since(std::time::UNIX_EPOCH)
255 .unwrap_or_default()
256 .as_millis() as u64;
257
258 if now_ms - opened_at >= self.cooldown_ms {
259 CircuitState::HalfOpen
260 } else {
261 CircuitState::Open
262 }
263 }
264
265 pub fn record_success(&self) {
267 self.consecutive_failures.store(0, Ordering::Relaxed);
268 self.opened_at_ms.store(0, Ordering::Relaxed);
269 }
270
271 pub fn record_failure(&self) {
273 let failures = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
274
275 if failures >= self.threshold && self.opened_at_ms.load(Ordering::Relaxed) == 0 {
276 let now_ms = std::time::SystemTime::now()
277 .duration_since(std::time::UNIX_EPOCH)
278 .unwrap_or_default()
279 .as_millis() as u64;
280 self.opened_at_ms.store(now_ms, Ordering::Relaxed);
281 tracing::warn!(
282 "Circuit breaker opened after {} consecutive failures",
283 failures
284 );
285 if let Some(ref callback) = self.on_open {
286 callback(failures);
287 }
288 }
289 }
290
291 pub fn should_allow(&self) -> bool {
293 match self.state() {
294 CircuitState::Closed => true,
295 CircuitState::HalfOpen => true,
296 CircuitState::Open => false,
297 }
298 }
299
300 pub fn consecutive_failures(&self) -> u32 {
302 self.consecutive_failures.load(Ordering::Relaxed)
303 }
304}
305
306#[derive(Debug, Clone)]
308pub struct RetryPolicy {
309 config: RetryConfig,
310 circuit_breaker: Option<Arc<CircuitBreaker>>,
311}
312
313impl RetryPolicy {
314 pub fn new(config: RetryConfig) -> Self {
316 let circuit_breaker = if config.circuit_breaker_enabled {
317 Some(Arc::new(CircuitBreaker::new(
318 config.circuit_breaker_threshold,
319 config.circuit_breaker_cooldown_ms,
320 )))
321 } else {
322 None
323 };
324
325 Self {
326 config,
327 circuit_breaker,
328 }
329 }
330
331 pub fn with_circuit_breaker(config: RetryConfig, cb: Arc<CircuitBreaker>) -> Self {
333 Self {
334 config,
335 circuit_breaker: Some(cb),
336 }
337 }
338
339 pub fn default_policy() -> Self {
341 Self::new(RetryConfig::default())
342 }
343
344 pub fn config(&self) -> &RetryConfig {
346 &self.config
347 }
348
349 pub fn circuit_breaker(&self) -> Option<&Arc<CircuitBreaker>> {
351 self.circuit_breaker.as_ref()
352 }
353
354 pub fn calculate_delay(&self, attempt: u32) -> Duration {
358 let base = self.config.base_delay_ms;
359 let cap = self.config.max_delay_ms;
360
361 let exp_delay = base.saturating_mul(1u64 << attempt.min(20));
362 let capped_delay = exp_delay.min(cap);
363
364 let jittered = if capped_delay > 0 {
365 rand::thread_rng().gen_range(0..=capped_delay)
366 } else {
367 0
368 };
369
370 Duration::from_millis(jittered)
371 }
372
373 pub async fn execute<F, Fut, T>(&self, operation: F) -> Result<T>
375 where
376 F: Fn() -> Fut,
377 Fut: std::future::Future<Output = Result<T>>,
378 {
379 if let Some(cb) = &self.circuit_breaker {
380 if !cb.should_allow() {
381 return Err(anyhow!(
382 "Circuit breaker open — refusing request after {} consecutive failures",
383 self.config.circuit_breaker_threshold
384 ));
385 }
386 }
387
388 let mut last_error: Option<anyhow::Error> = None;
389
390 for attempt in 0..=self.config.max_retries {
391 match operation().await {
392 Ok(result) => {
393 if let Some(cb) = &self.circuit_breaker {
394 cb.record_success();
395 }
396 return Ok(result);
397 }
398 Err(e) => {
399 let error_kind = classify_error(&e);
400 let retryable =
401 matches!(error_kind, ErrorKind::Transient | ErrorKind::Unknown);
402
403 if let Some(cb) = &self.circuit_breaker {
404 cb.record_failure();
405 }
406
407 if !retryable {
408 tracing::warn!(
409 "Non-retryable error (kind={:?}): {}",
410 error_kind,
411 e
412 );
413 return Err(e);
414 }
415
416 if attempt < self.config.max_retries {
417 let delay = self.calculate_delay(attempt);
418 tracing::debug!(
419 "Attempt {}/{} failed (kind={:?}), retrying in {:?}: {}",
420 attempt + 1,
421 self.config.max_retries + 1,
422 error_kind,
423 delay,
424 e
425 );
426 tokio::time::sleep(delay).await;
427 }
428
429 last_error = Some(e);
430 }
431 }
432 }
433
434 Err(last_error.unwrap_or_else(|| anyhow!("Retry failed with no error recorded")))
435 }
436
437 pub async fn execute_with_context<F, Fut, T>(&self, context: &str, operation: F) -> Result<T>
439 where
440 F: Fn() -> Fut,
441 Fut: std::future::Future<Output = Result<T>>,
442 {
443 self.execute(operation)
444 .await
445 .map_err(|e| anyhow!("{}: {}", context, e))
446 }
447}
448
449#[derive(Debug, Clone)]
451pub struct RetryOutcome {
452 pub operation: String,
453 pub success: bool,
454 pub attempts: u32,
455 pub total_duration: Duration,
456 pub error: Option<String>,
457 pub error_kind: Option<ErrorKind>,
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 #[test]
465 fn test_error_classification() {
466 assert_eq!(
467 classify_error(&anyhow!("500 Internal Server Error")),
468 ErrorKind::Transient
469 );
470 assert_eq!(
471 classify_error(&anyhow!("503 Service Unavailable")),
472 ErrorKind::Transient
473 );
474 assert_eq!(
475 classify_error(&anyhow!("Connection timeout")),
476 ErrorKind::Transient
477 );
478 assert_eq!(
479 classify_error(&anyhow!("dispatch failure")),
480 ErrorKind::Transient
481 );
482 assert_eq!(
483 classify_error(&anyhow!("Storage error: Service unavailable (injected)")),
484 ErrorKind::Transient
485 );
486 assert_eq!(
487 classify_error(&anyhow!("401 Unauthorized")),
488 ErrorKind::AuthError
489 );
490 assert_eq!(
491 classify_error(&anyhow!("403 Forbidden")),
492 ErrorKind::AuthError
493 );
494 assert_eq!(
495 classify_error(&anyhow!("Access Denied")),
496 ErrorKind::AuthError
497 );
498 assert_eq!(
499 classify_error(&anyhow!("400 Bad Request")),
500 ErrorKind::ClientError
501 );
502 assert_eq!(
503 classify_error(&anyhow!("404 Not Found")),
504 ErrorKind::NotFound
505 );
506 assert_eq!(
507 classify_error(&anyhow!("No such key")),
508 ErrorKind::NotFound
509 );
510 }
511
512 #[test]
513 fn test_error_classification_throttling() {
514 assert_eq!(
515 classify_error(&anyhow!("Request rate exceeded")),
516 ErrorKind::Transient
517 );
518 assert_eq!(
519 classify_error(&anyhow!("SlowDown: reduce your request rate")),
520 ErrorKind::Transient
521 );
522 assert_eq!(
523 classify_error(&anyhow!("throttling exception")),
524 ErrorKind::Transient
525 );
526 }
527
528 #[test]
529 fn test_error_classification_network() {
530 assert_eq!(
531 classify_error(&anyhow!("connection reset by peer")),
532 ErrorKind::Transient
533 );
534 assert_eq!(
535 classify_error(&anyhow!("broken pipe")),
536 ErrorKind::Transient
537 );
538 assert_eq!(
539 classify_error(&anyhow!("unexpected eof")),
540 ErrorKind::Transient
541 );
542 assert_eq!(
543 classify_error(&anyhow!("network unreachable")),
544 ErrorKind::Transient
545 );
546 }
547
548 #[test]
549 fn test_error_classification_unknown() {
550 assert_eq!(
551 classify_error(&anyhow!("some random error")),
552 ErrorKind::Unknown
553 );
554 }
555
556 #[test]
557 fn test_is_retryable() {
558 assert!(is_retryable(&anyhow!("500 Internal Server Error")));
559 assert!(is_retryable(&anyhow!("Connection timeout")));
560 assert!(is_retryable(&anyhow!("dispatch failure")));
561 assert!(is_retryable(&anyhow!("some unknown error")));
562 assert!(!is_retryable(&anyhow!("401 Unauthorized")));
563 assert!(!is_retryable(&anyhow!("403 Forbidden")));
564 assert!(!is_retryable(&anyhow!("400 Bad Request")));
565 assert!(!is_retryable(&anyhow!("404 Not Found")));
566 }
567
568 #[test]
569 fn test_backoff_calculation() {
570 let policy = RetryPolicy::new(RetryConfig {
571 base_delay_ms: 100,
572 max_delay_ms: 30_000,
573 ..Default::default()
574 });
575
576 for _ in 0..10 {
577 let delay = policy.calculate_delay(0);
578 assert!(delay <= Duration::from_millis(100));
579 }
580
581 for _ in 0..10 {
582 let delay = policy.calculate_delay(1);
583 assert!(delay <= Duration::from_millis(200));
584 }
585
586 for _ in 0..10 {
587 let delay = policy.calculate_delay(20);
588 assert!(delay <= Duration::from_millis(30_000));
589 }
590 }
591
592 #[test]
593 fn test_circuit_breaker_states() {
594 let cb = CircuitBreaker::new(3, 100);
595
596 assert_eq!(cb.state(), CircuitState::Closed);
597 assert!(cb.should_allow());
598 assert_eq!(cb.consecutive_failures(), 0);
599
600 cb.record_failure();
601 cb.record_failure();
602 assert_eq!(cb.state(), CircuitState::Closed);
603 assert!(cb.should_allow());
604 assert_eq!(cb.consecutive_failures(), 2);
605
606 cb.record_failure();
607 assert_eq!(cb.state(), CircuitState::Open);
608 assert!(!cb.should_allow());
609 assert_eq!(cb.consecutive_failures(), 3);
610
611 std::thread::sleep(Duration::from_millis(150));
612 assert_eq!(cb.state(), CircuitState::HalfOpen);
613 assert!(cb.should_allow());
614
615 cb.record_success();
616 assert_eq!(cb.state(), CircuitState::Closed);
617 assert!(cb.should_allow());
618 assert_eq!(cb.consecutive_failures(), 0);
619 }
620
621 #[test]
622 fn test_circuit_breaker_on_open_callback() {
623 let called = Arc::new(AtomicU32::new(0));
624 let called_clone = called.clone();
625 let on_open: OnCircuitOpen = Arc::new(move |failures| {
626 called_clone.store(failures, Ordering::Relaxed);
627 });
628
629 let cb = CircuitBreaker::with_on_open(2, 60_000, on_open);
630
631 cb.record_failure();
632 assert_eq!(called.load(Ordering::Relaxed), 0);
633
634 cb.record_failure();
635 assert_eq!(called.load(Ordering::Relaxed), 2);
636 }
637
638 #[test]
639 fn test_retry_config_defaults() {
640 let config = RetryConfig::default();
641 assert_eq!(config.max_retries, 5);
642 assert_eq!(config.base_delay_ms, 100);
643 assert_eq!(config.max_delay_ms, 30_000);
644 assert!(config.circuit_breaker_enabled);
645 assert_eq!(config.circuit_breaker_threshold, 10);
646 assert_eq!(config.circuit_breaker_cooldown_ms, 60_000);
647 }
648
649 #[test]
650 fn test_retry_config_serde() {
651 let json = r#"{"max_retries": 3, "base_delay_ms": 50}"#;
652 let config: RetryConfig = serde_json::from_str(json).unwrap();
653 assert_eq!(config.max_retries, 3);
654 assert_eq!(config.base_delay_ms, 50);
655 assert_eq!(config.max_delay_ms, 30_000);
657 assert!(config.circuit_breaker_enabled);
658 }
659
660 #[test]
661 fn test_retry_policy_no_circuit_breaker() {
662 let policy = RetryPolicy::new(RetryConfig {
663 circuit_breaker_enabled: false,
664 ..Default::default()
665 });
666 assert!(policy.circuit_breaker().is_none());
667 }
668
669 #[test]
670 fn test_retry_policy_with_custom_circuit_breaker() {
671 let cb = Arc::new(CircuitBreaker::new(5, 30_000));
672 let policy = RetryPolicy::with_circuit_breaker(RetryConfig::default(), cb.clone());
673 assert!(policy.circuit_breaker().is_some());
674 assert_eq!(policy.circuit_breaker().unwrap().consecutive_failures(), 0);
675 }
676
677 #[tokio::test]
678 async fn test_retry_success() {
679 let policy = RetryPolicy::default_policy();
680 let result: Result<i32> = policy.execute(|| async { Ok(42) }).await;
681 assert_eq!(result.unwrap(), 42);
682 }
683
684 #[tokio::test]
685 async fn test_retry_transient_then_success() {
686 let policy = RetryPolicy::new(RetryConfig {
687 max_retries: 3,
688 base_delay_ms: 10,
689 ..Default::default()
690 });
691
692 let attempts = std::sync::atomic::AtomicU32::new(0);
693
694 let result: Result<i32> = policy
695 .execute(|| {
696 let attempt = attempts.fetch_add(1, Ordering::Relaxed);
697 async move {
698 if attempt < 2 {
699 Err(anyhow!("Service unavailable (injected)"))
700 } else {
701 Ok(42)
702 }
703 }
704 })
705 .await;
706
707 assert_eq!(result.unwrap(), 42);
708 assert_eq!(attempts.load(Ordering::Relaxed), 3);
709 }
710
711 #[tokio::test]
712 async fn test_retry_auth_error_no_retry() {
713 let policy = RetryPolicy::new(RetryConfig {
714 max_retries: 5,
715 base_delay_ms: 10,
716 ..Default::default()
717 });
718
719 let attempts = std::sync::atomic::AtomicU32::new(0);
720
721 let result: Result<i32> = policy
722 .execute(|| {
723 attempts.fetch_add(1, Ordering::Relaxed);
724 async { Err(anyhow!("401 Unauthorized")) }
725 })
726 .await;
727
728 assert!(result.is_err());
729 assert_eq!(attempts.load(Ordering::Relaxed), 1);
730 }
731
732 #[tokio::test]
733 async fn test_retry_not_found_no_retry() {
734 let policy = RetryPolicy::new(RetryConfig {
735 max_retries: 5,
736 base_delay_ms: 10,
737 ..Default::default()
738 });
739
740 let attempts = std::sync::atomic::AtomicU32::new(0);
741
742 let result: Result<i32> = policy
743 .execute(|| {
744 attempts.fetch_add(1, Ordering::Relaxed);
745 async { Err(anyhow!("404 Not Found")) }
746 })
747 .await;
748
749 assert!(result.is_err());
750 assert_eq!(attempts.load(Ordering::Relaxed), 1);
751 }
752
753 #[tokio::test]
754 async fn test_retry_exhausted() {
755 let policy = RetryPolicy::new(RetryConfig {
756 max_retries: 2,
757 base_delay_ms: 10,
758 circuit_breaker_enabled: false,
759 ..Default::default()
760 });
761
762 let attempts = std::sync::atomic::AtomicU32::new(0);
763
764 let result: Result<i32> = policy
765 .execute(|| {
766 attempts.fetch_add(1, Ordering::Relaxed);
767 async { Err(anyhow!("Service unavailable (injected)")) }
768 })
769 .await;
770
771 assert!(result.is_err());
772 assert_eq!(attempts.load(Ordering::Relaxed), 3);
774 }
775
776 #[tokio::test]
777 async fn test_retry_circuit_breaker_blocks() {
778 let policy = RetryPolicy::new(RetryConfig {
779 max_retries: 1,
780 base_delay_ms: 10,
781 circuit_breaker_enabled: true,
782 circuit_breaker_threshold: 2,
783 circuit_breaker_cooldown_ms: 60_000,
784 ..Default::default()
785 });
786
787 let cb = policy.circuit_breaker().unwrap();
789 cb.record_failure();
790 cb.record_failure();
791
792 let result: Result<i32> = policy.execute(|| async { Ok(42) }).await;
793 assert!(result.is_err());
794 assert!(result.unwrap_err().to_string().contains("Circuit breaker open"));
795 }
796
797 #[tokio::test]
798 async fn test_execute_with_context() {
799 let policy = RetryPolicy::new(RetryConfig {
800 max_retries: 0,
801 circuit_breaker_enabled: false,
802 ..Default::default()
803 });
804
805 let result: Result<i32> = policy
806 .execute_with_context("upload segment", || async {
807 Err(anyhow!("Service unavailable (injected)"))
808 })
809 .await;
810
811 assert!(result.is_err());
812 let err = result.unwrap_err().to_string();
813 assert!(err.contains("upload segment"));
814 assert!(err.contains("Service unavailable"));
815 }
816}