use std::time::Duration;
use crate::error::{CloudError, Result, RetryError};
pub const DEFAULT_MAX_RETRIES: usize = 3;
pub const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(100);
pub const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(30);
pub const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: usize,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: f64,
pub jitter: bool,
pub circuit_breaker: bool,
pub circuit_breaker_threshold: usize,
pub circuit_breaker_timeout: Duration,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: DEFAULT_MAX_RETRIES,
initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF,
backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
jitter: true,
circuit_breaker: true,
circuit_breaker_threshold: 5,
circuit_breaker_timeout: Duration::from_secs(60),
}
}
}
impl RetryConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_max_retries(mut self, max_retries: usize) -> Self {
self.max_retries = max_retries;
self
}
#[must_use]
pub fn with_initial_backoff(mut self, duration: Duration) -> Self {
self.initial_backoff = duration;
self
}
#[must_use]
pub fn with_max_backoff(mut self, duration: Duration) -> Self {
self.max_backoff = duration;
self
}
#[must_use]
pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
self.backoff_multiplier = multiplier;
self
}
#[must_use]
pub fn with_jitter(mut self, jitter: bool) -> Self {
self.jitter = jitter;
self
}
#[must_use]
pub fn with_circuit_breaker(mut self, enabled: bool) -> Self {
self.circuit_breaker = enabled;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug)]
pub struct CircuitBreaker {
state: CircuitState,
failure_count: usize,
threshold: usize,
timeout: Duration,
last_failure: Option<std::time::Instant>,
}
impl CircuitBreaker {
#[must_use]
pub fn new(threshold: usize, timeout: Duration) -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
threshold,
timeout,
last_failure: None,
}
}
pub fn allow_request(&mut self) -> Result<()> {
match self.state {
CircuitState::Closed => Ok(()),
CircuitState::Open => {
if let Some(last_failure) = self.last_failure {
if last_failure.elapsed() >= self.timeout {
tracing::info!("Circuit breaker transitioning to half-open state");
self.state = CircuitState::HalfOpen;
Ok(())
} else {
Err(CloudError::Retry(RetryError::CircuitBreakerOpen {
message: "Circuit breaker is open".to_string(),
}))
}
} else {
Ok(())
}
}
CircuitState::HalfOpen => Ok(()),
}
}
pub fn record_success(&mut self) {
match self.state {
CircuitState::Closed => {
self.failure_count = 0;
}
CircuitState::HalfOpen => {
tracing::info!("Circuit breaker transitioning to closed state");
self.state = CircuitState::Closed;
self.failure_count = 0;
}
CircuitState::Open => {}
}
}
pub fn record_failure(&mut self) {
self.failure_count += 1;
self.last_failure = Some(std::time::Instant::now());
if self.failure_count >= self.threshold && self.state != CircuitState::Open {
tracing::warn!(
"Circuit breaker opening after {} failures",
self.failure_count
);
self.state = CircuitState::Open;
}
}
#[must_use]
pub fn state(&self) -> CircuitState {
self.state
}
}
#[derive(Debug)]
pub struct RetryBudget {
tokens: usize,
max_tokens: usize,
refill_rate: f64,
last_refill: std::time::Instant,
}
impl RetryBudget {
#[must_use]
pub fn new(max_tokens: usize, refill_rate: f64) -> Self {
Self {
tokens: max_tokens,
max_tokens,
refill_rate,
last_refill: std::time::Instant::now(),
}
}
pub fn try_consume(&mut self) -> Result<()> {
self.refill();
if self.tokens > 0 {
self.tokens -= 1;
Ok(())
} else {
Err(CloudError::Retry(RetryError::BudgetExhausted {
message: "Retry budget exhausted".to_string(),
}))
}
}
fn refill(&mut self) {
let elapsed = self.last_refill.elapsed();
let tokens_to_add = (elapsed.as_secs_f64() * self.refill_rate) as usize;
if tokens_to_add > 0 {
self.tokens = (self.tokens + tokens_to_add).min(self.max_tokens);
self.last_refill = std::time::Instant::now();
}
}
}
#[derive(Debug)]
pub struct Backoff {
config: RetryConfig,
attempt: usize,
}
impl Backoff {
#[must_use]
pub fn new(config: RetryConfig) -> Self {
Self { config, attempt: 0 }
}
#[must_use]
pub fn next(&mut self) -> Duration {
let base = self.config.initial_backoff.as_secs_f64().mul_add(
self.config.backoff_multiplier.powi(self.attempt as i32),
0.0,
);
let backoff = if self.config.jitter {
let jitter_factor = 1.0 + (rand() * 0.5);
base * jitter_factor
} else {
base
};
self.attempt += 1;
Duration::from_secs_f64(backoff.min(self.config.max_backoff.as_secs_f64()))
}
pub fn reset(&mut self) {
self.attempt = 0;
}
}
fn rand() -> f64 {
use std::sync::atomic::{AtomicU64, Ordering};
static SEED: AtomicU64 = AtomicU64::new(0);
let seed = SEED.load(Ordering::Relaxed);
let next = seed.wrapping_mul(1664525).wrapping_add(1013904223);
SEED.store(next, Ordering::Relaxed);
(next >> 32) as f64 / u32::MAX as f64
}
#[must_use]
pub fn is_retryable(error: &CloudError) -> bool {
match error {
CloudError::Timeout { .. } => true,
CloudError::RateLimitExceeded { .. } => true,
CloudError::Http(http_error) => match http_error {
crate::error::HttpError::Network { .. } => true,
crate::error::HttpError::Status { status, .. } => {
matches!(
*status,
500 | 502 | 503 | 504 | 408 | 429 )
}
_ => false,
},
CloudError::S3(s3_error) => match s3_error {
crate::error::S3Error::Sdk { .. } => true,
_ => false,
},
CloudError::Azure(azure_error) => match azure_error {
crate::error::AzureError::Sdk { .. } => true,
_ => false,
},
CloudError::Gcs(gcs_error) => match gcs_error {
crate::error::GcsError::Sdk { .. } => true,
_ => false,
},
CloudError::Io(_) => true,
_ => false,
}
}
#[cfg(feature = "async")]
pub struct RetryExecutor {
config: RetryConfig,
circuit_breaker: Option<CircuitBreaker>,
retry_budget: Option<RetryBudget>,
}
#[cfg(feature = "async")]
impl RetryExecutor {
#[must_use]
pub fn new(config: RetryConfig) -> Self {
let circuit_breaker = if config.circuit_breaker {
Some(CircuitBreaker::new(
config.circuit_breaker_threshold,
config.circuit_breaker_timeout,
))
} else {
None
};
let retry_budget = Some(RetryBudget::new(100, 10.0));
Self {
config,
circuit_breaker,
retry_budget,
}
}
pub async fn execute<F, Fut, T>(&mut self, mut operation: F) -> Result<T>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
if let Some(ref mut cb) = self.circuit_breaker {
cb.allow_request()?;
}
let mut backoff = Backoff::new(self.config.clone());
let mut attempts = 0;
loop {
match operation().await {
Ok(result) => {
if let Some(ref mut cb) = self.circuit_breaker {
cb.record_success();
}
return Ok(result);
}
Err(error) => {
attempts += 1;
if !is_retryable(&error) {
tracing::warn!("Non-retryable error: {}", error);
return Err(error);
}
if attempts > self.config.max_retries {
tracing::error!("Max retries ({}) exceeded", self.config.max_retries);
if let Some(ref mut cb) = self.circuit_breaker {
cb.record_failure();
}
return Err(CloudError::Retry(RetryError::MaxRetriesExceeded {
attempts,
}));
}
if let Some(ref mut budget) = self.retry_budget {
budget.try_consume()?;
}
let delay = backoff.next();
tracing::warn!(
"Retry attempt {}/{} after {:?}: {}",
attempts,
self.config.max_retries,
delay,
error
);
tokio::time::sleep(delay).await;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_config_builder() {
let config = RetryConfig::new()
.with_max_retries(5)
.with_initial_backoff(Duration::from_millis(50))
.with_backoff_multiplier(3.0)
.with_jitter(false);
assert_eq!(config.max_retries, 5);
assert_eq!(config.initial_backoff, Duration::from_millis(50));
assert_eq!(config.backoff_multiplier, 3.0);
assert!(!config.jitter);
}
#[test]
fn test_circuit_breaker_closed() {
let mut cb = CircuitBreaker::new(3, Duration::from_secs(60));
assert_eq!(cb.state, CircuitState::Closed);
assert!(cb.allow_request().is_ok());
}
#[test]
fn test_circuit_breaker_opens() {
let mut cb = CircuitBreaker::new(3, Duration::from_secs(60));
cb.record_failure();
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state, CircuitState::Open);
assert!(cb.allow_request().is_err());
}
#[test]
fn test_circuit_breaker_half_open() {
let mut cb = CircuitBreaker::new(3, Duration::from_millis(10));
cb.record_failure();
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state, CircuitState::Open);
std::thread::sleep(Duration::from_millis(20));
assert!(cb.allow_request().is_ok());
assert_eq!(cb.state, CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state, CircuitState::Closed);
}
#[test]
fn test_retry_budget() {
let mut budget = RetryBudget::new(10, 100.0);
for _ in 0..10 {
assert!(budget.try_consume().is_ok());
}
assert!(budget.try_consume().is_err());
std::thread::sleep(Duration::from_millis(50));
assert!(budget.try_consume().is_ok());
}
#[test]
fn test_backoff_exponential() {
let config = RetryConfig::new()
.with_initial_backoff(Duration::from_millis(100))
.with_backoff_multiplier(2.0)
.with_jitter(false);
let mut backoff = Backoff::new(config);
let d1 = backoff.next();
let d2 = backoff.next();
let d3 = backoff.next();
assert!(d1 < d2);
assert!(d2 < d3);
}
#[test]
fn test_is_retryable() {
let timeout_error = CloudError::Timeout {
message: "timeout".to_string(),
};
assert!(is_retryable(&timeout_error));
let rate_limit_error = CloudError::RateLimitExceeded {
message: "rate limit".to_string(),
};
assert!(is_retryable(&rate_limit_error));
let not_found_error = CloudError::NotFound {
key: "test".to_string(),
};
assert!(!is_retryable(¬_found_error));
}
#[cfg(feature = "async")]
#[tokio::test]
async fn test_retry_executor_success() {
use std::sync::atomic::{AtomicUsize, Ordering};
let config = RetryConfig::new().with_max_retries(3);
let mut executor = RetryExecutor::new(config);
let attempt = std::sync::Arc::new(AtomicUsize::new(0));
let attempt_clone = attempt.clone();
let result = executor
.execute(|| {
let attempt = attempt_clone.clone();
async move {
let current = attempt.fetch_add(1, Ordering::SeqCst) + 1;
if current < 2 {
Err(CloudError::Timeout {
message: "timeout".to_string(),
})
} else {
Ok(42)
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.ok(), Some(42));
assert_eq!(attempt.load(Ordering::SeqCst), 2);
}
#[cfg(feature = "async")]
#[tokio::test]
async fn test_retry_executor_max_retries() {
use std::sync::atomic::{AtomicUsize, Ordering};
let config = RetryConfig::new().with_max_retries(2);
let mut executor = RetryExecutor::new(config);
let attempt = std::sync::Arc::new(AtomicUsize::new(0));
let attempt_clone = attempt.clone();
let result: Result<i32> = executor
.execute(|| {
let attempt = attempt_clone.clone();
async move {
attempt.fetch_add(1, Ordering::SeqCst);
Err(CloudError::Timeout {
message: "timeout".to_string(),
})
}
})
.await;
assert!(result.is_err());
assert_eq!(attempt.load(Ordering::SeqCst), 3); }
}