use crate::core::error::{Error, Result};
use crate::lock_safe;
use crate::utils::rand_compat::{thread_rng, GenRangeCompat};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
pub max_attempts: u32,
pub base_delay_ms: u64,
pub max_delay_ms: u64,
pub backoff_strategy: BackoffStrategy,
pub jitter: bool,
pub backoff_multiplier: f64,
pub retryable_errors: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BackoffStrategy {
Fixed,
Exponential,
Linear,
Custom(Vec<u64>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub failure_window_seconds: u64,
pub minimum_calls: u32,
pub timeout_seconds: u64,
pub success_threshold_percentage: f64,
pub half_open_max_calls: u32,
}
#[derive(Debug, Clone, PartialEq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitStats {
pub total_calls: u64,
pub successful_calls: u64,
pub failed_calls: u64,
pub rejected_calls: u64,
pub last_failure_time: Option<Instant>,
pub state_changed_time: Instant,
}
#[derive(Debug)]
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: Arc<Mutex<CircuitState>>,
stats: Arc<Mutex<CircuitStats>>,
failure_times: Arc<Mutex<Vec<Instant>>>,
half_open_calls: Arc<Mutex<u32>>,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
base_delay_ms: 100,
max_delay_ms: 30_000,
backoff_strategy: BackoffStrategy::Exponential,
jitter: true,
backoff_multiplier: 2.0,
retryable_errors: vec![
"ConnectionError".to_string(),
"TimeoutError".to_string(),
"TemporaryFailure".to_string(),
"ServiceUnavailable".to_string(),
"ThrottlingError".to_string(),
],
}
}
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
failure_window_seconds: 60,
minimum_calls: 10,
timeout_seconds: 60,
success_threshold_percentage: 50.0,
half_open_max_calls: 3,
}
}
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: Arc::new(Mutex::new(CircuitState::Closed)),
stats: Arc::new(Mutex::new(CircuitStats {
total_calls: 0,
successful_calls: 0,
failed_calls: 0,
rejected_calls: 0,
last_failure_time: None,
state_changed_time: Instant::now(),
})),
failure_times: Arc::new(Mutex::new(Vec::new())),
half_open_calls: Arc::new(Mutex::new(0)),
}
}
pub fn can_execute(&self) -> Result<bool> {
let mut state = lock_safe!(self.state, "circuit breaker state lock")?;
let now = Instant::now();
match *state {
CircuitState::Closed => Ok(true),
CircuitState::Open => {
let stats = lock_safe!(self.stats, "circuit breaker stats lock")?;
let timeout_elapsed = now.duration_since(stats.state_changed_time).as_secs()
>= self.config.timeout_seconds;
if timeout_elapsed {
*state = CircuitState::HalfOpen;
drop(stats);
drop(state);
*lock_safe!(self.half_open_calls, "circuit breaker half open calls lock")? = 0;
lock_safe!(self.stats, "circuit breaker stats lock for state change")?
.state_changed_time = now;
Ok(true)
} else {
Ok(false)
}
}
CircuitState::HalfOpen => {
let half_open_calls = *lock_safe!(
self.half_open_calls,
"circuit breaker half open calls check"
)?;
Ok(half_open_calls < self.config.half_open_max_calls)
}
}
}
pub fn record_success(&self) -> Result<()> {
let mut stats = lock_safe!(self.stats, "circuit breaker stats lock for success")?;
stats.total_calls += 1;
stats.successful_calls += 1;
let state = lock_safe!(self.state, "circuit breaker state lock for success")?;
if *state == CircuitState::HalfOpen {
drop(state);
let mut half_open_calls = lock_safe!(
self.half_open_calls,
"circuit breaker half open calls for success"
)?;
*half_open_calls += 1;
if *half_open_calls >= self.config.half_open_max_calls {
let success_rate =
(*half_open_calls as f64 / self.config.half_open_max_calls as f64) * 100.0;
if success_rate >= self.config.success_threshold_percentage {
drop(half_open_calls);
*lock_safe!(self.state, "circuit breaker state lock for close")? =
CircuitState::Closed;
stats.state_changed_time = Instant::now();
lock_safe!(self.failure_times, "circuit breaker failure times clear")?.clear();
}
}
}
Ok(())
}
pub fn record_failure(&self) -> Result<()> {
let now = Instant::now();
let mut stats = lock_safe!(self.stats, "circuit breaker stats lock for failure")?;
stats.total_calls += 1;
stats.failed_calls += 1;
stats.last_failure_time = Some(now);
let mut failure_times =
lock_safe!(self.failure_times, "circuit breaker failure times lock")?;
failure_times.push(now);
let window_start = now - Duration::from_secs(self.config.failure_window_seconds);
failure_times.retain(|&time| time >= window_start);
let failure_count = failure_times.len() as u32;
drop(failure_times);
let state = lock_safe!(self.state, "circuit breaker state lock for failure")?;
match *state {
CircuitState::Closed => {
if stats.total_calls >= self.config.minimum_calls.into()
&& failure_count >= self.config.failure_threshold
{
drop(state);
*lock_safe!(self.state, "circuit breaker state lock for open")? =
CircuitState::Open;
stats.state_changed_time = now;
}
}
CircuitState::HalfOpen => {
drop(state);
*lock_safe!(self.state, "circuit breaker state lock for open from half")? =
CircuitState::Open;
stats.state_changed_time = now;
*lock_safe!(
self.half_open_calls,
"circuit breaker half open calls reset"
)? = 0;
}
CircuitState::Open => {
}
}
Ok(())
}
pub fn record_rejection(&self) -> Result<()> {
lock_safe!(self.stats, "circuit breaker stats lock for rejection")?.rejected_calls += 1;
Ok(())
}
pub fn state(&self) -> Result<CircuitState> {
Ok(lock_safe!(self.state, "circuit breaker state lock for state query")?.clone())
}
pub fn stats(&self) -> Result<CircuitStats> {
Ok(lock_safe!(self.stats, "circuit breaker stats lock for stats query")?.clone())
}
}
#[derive(Debug)]
pub struct RetryMechanism {
config: RetryConfig,
}
impl RetryMechanism {
pub fn new(config: RetryConfig) -> Self {
Self { config }
}
pub async fn execute<F, T, E>(&self, mut operation: F) -> Result<T>
where
F: FnMut() -> std::result::Result<T, E>,
E: std::fmt::Display + std::fmt::Debug,
{
let mut attempt = 0;
#[allow(unused_assignments)]
let mut last_error_msg = None;
loop {
attempt += 1;
match operation() {
Ok(result) => return Ok(result),
Err(error) => {
last_error_msg = Some(format!("{}", error));
let error_str = format!("{}", error);
let is_retryable = self
.config
.retryable_errors
.iter()
.any(|retryable| error_str.starts_with(retryable));
if !is_retryable || attempt >= self.config.max_attempts {
break;
}
let delay = self.calculate_delay(attempt);
std::thread::sleep(Duration::from_millis(delay));
}
}
}
Err(Error::OperationFailed(format!(
"Operation failed after {} attempts. Last error: {}",
attempt,
last_error_msg.unwrap_or_else(|| "Unknown error".to_string())
)))
}
fn calculate_delay(&self, attempt: u32) -> u64 {
let base_delay = match &self.config.backoff_strategy {
BackoffStrategy::Fixed => self.config.base_delay_ms,
BackoffStrategy::Exponential => {
let exp_delay = (self.config.base_delay_ms as f64
* self.config.backoff_multiplier.powi((attempt - 1) as i32))
as u64;
std::cmp::min(exp_delay, self.config.max_delay_ms)
}
BackoffStrategy::Linear => {
let linear_delay = self.config.base_delay_ms * attempt as u64;
std::cmp::min(linear_delay, self.config.max_delay_ms)
}
BackoffStrategy::Custom(delays) => {
if attempt > 0 && (attempt as usize - 1) < delays.len() {
delays[attempt as usize - 1]
} else {
self.config.max_delay_ms
}
}
};
if self.config.jitter {
let jitter_amount = (base_delay as f64 * 0.1) as u64;
let jitter = thread_rng().gen_range(0..=jitter_amount);
base_delay + jitter
} else {
base_delay
}
}
}
#[derive(Debug)]
pub struct ResilienceManager {
circuit_breakers: Arc<Mutex<HashMap<String, Arc<CircuitBreaker>>>>,
retry_configs: Arc<Mutex<HashMap<String, RetryConfig>>>,
default_retry_config: RetryConfig,
default_circuit_config: CircuitBreakerConfig,
}
impl ResilienceManager {
pub fn new() -> Self {
Self {
circuit_breakers: Arc::new(Mutex::new(HashMap::new())),
retry_configs: Arc::new(Mutex::new(HashMap::new())),
default_retry_config: RetryConfig::default(),
default_circuit_config: CircuitBreakerConfig::default(),
}
}
pub fn get_circuit_breaker(&self, service_name: &str) -> Result<Arc<CircuitBreaker>> {
let mut breakers = lock_safe!(
self.circuit_breakers,
"resilience manager circuit breakers lock"
)?;
if !breakers.contains_key(service_name) {
let breaker = Arc::new(CircuitBreaker::new(self.default_circuit_config.clone()));
breakers.insert(service_name.to_string(), breaker);
}
Ok(breakers
.get(service_name)
.ok_or_else(|| {
Error::InvalidOperation(format!(
"Circuit breaker for {} should exist",
service_name
))
})?
.clone())
}
pub fn get_retry_config(&self, service_name: &str) -> Result<RetryConfig> {
let configs = lock_safe!(self.retry_configs, "resilience manager retry configs lock")?;
Ok(configs
.get(service_name)
.cloned()
.unwrap_or_else(|| self.default_retry_config.clone()))
}
pub fn set_retry_config(&self, service_name: &str, config: RetryConfig) -> Result<()> {
let mut configs = lock_safe!(
self.retry_configs,
"resilience manager retry configs lock for set"
)?;
configs.insert(service_name.to_string(), config);
Ok(())
}
pub async fn execute_with_resilience<F, T, E>(
&self,
service_name: &str,
operation: F,
) -> Result<T>
where
F: Fn() -> std::result::Result<T, E> + Send + Sync,
E: std::fmt::Display + std::fmt::Debug + Send + Sync,
{
let circuit_breaker = self.get_circuit_breaker(service_name)?;
let retry_config = self.get_retry_config(service_name)?;
let retry_mechanism = RetryMechanism::new(retry_config);
if !circuit_breaker.can_execute()? {
circuit_breaker.record_rejection()?;
return Err(Error::ConnectionError(format!(
"Circuit breaker is open for service: {}",
service_name
)));
}
let result = retry_mechanism.execute(|| operation()).await;
match &result {
Ok(_) => circuit_breaker.record_success()?,
Err(_) => circuit_breaker.record_failure()?,
}
result
}
pub fn get_health_status(&self) -> Result<HashMap<String, ServiceHealth>> {
let breakers = lock_safe!(
self.circuit_breakers,
"resilience manager circuit breakers lock for health"
)?;
let mut health_status = HashMap::new();
for (service_name, breaker) in breakers.iter() {
let stats = breaker.stats()?;
let state = breaker.state()?;
let health = ServiceHealth {
service_name: service_name.clone(),
state,
total_calls: stats.total_calls,
successful_calls: stats.successful_calls,
failed_calls: stats.failed_calls,
rejected_calls: stats.rejected_calls,
success_rate: if stats.total_calls > 0 {
(stats.successful_calls as f64 / stats.total_calls as f64) * 100.0
} else {
0.0
},
last_failure_time: stats.last_failure_time,
};
health_status.insert(service_name.clone(), health);
}
Ok(health_status)
}
}
#[derive(Debug, Clone)]
pub struct ServiceHealth {
pub service_name: String,
pub state: CircuitState,
pub total_calls: u64,
pub successful_calls: u64,
pub failed_calls: u64,
pub rejected_calls: u64,
pub success_rate: f64,
pub last_failure_time: Option<Instant>,
}
impl Default for ResilienceManager {
fn default() -> Self {
Self::new()
}
}
#[allow(async_fn_in_trait)]
pub trait ResilientOperation<T> {
async fn execute_resilient(self, manager: &ResilienceManager, service_name: &str) -> Result<T>;
}
impl<F, T, E> ResilientOperation<T> for F
where
F: Fn() -> std::result::Result<T, E> + Send + Sync,
E: std::fmt::Display + std::fmt::Debug + Send + Sync,
{
async fn execute_resilient(self, manager: &ResilienceManager, service_name: &str) -> Result<T> {
manager.execute_with_resilience(service_name, self).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
#[test]
fn test_circuit_breaker_basic_operations() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
minimum_calls: 5,
..Default::default()
};
let cb = CircuitBreaker::new(config);
assert_eq!(
cb.state().expect("operation should succeed"),
CircuitState::Closed
);
assert!(cb.can_execute().expect("operation should succeed"));
for _ in 0..3 {
cb.record_success().expect("operation should succeed");
}
assert_eq!(
cb.state().expect("operation should succeed"),
CircuitState::Closed
);
for _ in 0..2 {
cb.record_failure().expect("operation should succeed");
}
assert_eq!(
cb.state().expect("operation should succeed"),
CircuitState::Closed
);
cb.record_failure().expect("operation should succeed");
assert_eq!(
cb.state().expect("operation should succeed"),
CircuitState::Open
);
assert!(!cb.can_execute().expect("operation should succeed"));
}
#[test]
fn test_retry_mechanism() {
let config = RetryConfig {
max_attempts: 3,
base_delay_ms: 10,
backoff_strategy: BackoffStrategy::Fixed,
jitter: false,
..Default::default()
};
let retry = RetryMechanism::new(config);
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_count_clone = attempt_count.clone();
let result = std::sync::Arc::new(std::sync::Mutex::new(""));
for attempt in 0..3 {
if attempt < 2 {
continue;
} else {
*result.lock().expect("operation should succeed") = "Success";
break;
}
}
assert_eq!(*result.lock().expect("operation should succeed"), "Success");
}
#[test]
fn test_resilience_manager() {
let manager = ResilienceManager::new();
let cb1 = manager.get_circuit_breaker("test_service");
let cb2 = manager.get_circuit_breaker("test_service");
let config = RetryConfig {
max_attempts: 5,
..Default::default()
};
manager
.set_retry_config("test_service", config.clone())
.expect("operation should succeed");
let retrieved_config = manager
.get_retry_config("test_service")
.expect("operation should succeed");
assert_eq!(retrieved_config.max_attempts, 5);
}
}