use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone, Copy)]
pub struct CircuitBreakerConfig {
pub failure_threshold_z_score: f64,
pub half_open_requests: usize,
pub success_rate_threshold: f64,
pub history_size: usize,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold_z_score: 2.5,
half_open_requests: 10,
success_rate_threshold: 0.8,
history_size: 1000,
}
}
}
#[derive(Debug, Clone)]
pub struct HealthMetrics {
pub latencies: VecDeque<f64>,
pub errors: VecDeque<bool>,
pub mean_latency: f64,
pub std_dev_latency: f64,
pub error_rate: f64,
}
impl Default for HealthMetrics {
fn default() -> Self {
Self {
latencies: VecDeque::new(),
errors: VecDeque::new(),
mean_latency: 0.0,
std_dev_latency: 0.0,
error_rate: 0.0,
}
}
}
impl HealthMetrics {
fn update(&mut self, latency_ms: f64, error: bool, max_size: usize) {
self.latencies.push_back(latency_ms);
self.errors.push_back(error);
while self.latencies.len() > max_size {
self.latencies.pop_front();
self.errors.pop_front();
}
self.recalculate_stats();
}
fn recalculate_stats(&mut self) {
if self.latencies.is_empty() {
return;
}
let sum: f64 = self.latencies.iter().sum();
self.mean_latency = sum / self.latencies.len() as f64;
let variance: f64 = self
.latencies
.iter()
.map(|l| (l - self.mean_latency).powi(2))
.sum::<f64>()
/ self.latencies.len() as f64;
self.std_dev_latency = variance.sqrt();
let error_count = self.errors.iter().filter(|&&e| e).count();
self.error_rate = error_count as f64 / self.errors.len() as f64;
}
fn is_anomaly(&self, latency_ms: f64, z_score_threshold: f64) -> bool {
if self.std_dev_latency == 0.0 {
return false; }
let z_score = (latency_ms - self.mean_latency) / self.std_dev_latency;
z_score > z_score_threshold
}
}
pub struct CircuitBreaker {
pub provider: String,
state: Arc<RwLock<CircuitState>>,
metrics: Arc<RwLock<HealthMetrics>>,
config: CircuitBreakerConfig,
half_open_count: Arc<AtomicU64>,
half_open_success: Arc<AtomicU64>,
}
impl CircuitBreaker {
pub fn new(provider: impl Into<String>, config: CircuitBreakerConfig) -> Self {
Self {
provider: provider.into(),
state: Arc::new(RwLock::new(CircuitState::Closed)),
metrics: Arc::new(RwLock::new(HealthMetrics::default())),
config,
half_open_count: Arc::new(AtomicU64::new(0)),
half_open_success: Arc::new(AtomicU64::new(0)),
}
}
pub async fn record_result(&self, latency_ms: f64, success: bool) -> Result<()> {
let mut metrics = self.metrics.write().await;
metrics.update(latency_ms, !success, self.config.history_size);
let mut state = self.state.write().await;
match *state {
CircuitState::Closed => {
if !success || metrics.is_anomaly(latency_ms, self.config.failure_threshold_z_score)
{
*state = CircuitState::Open;
}
}
CircuitState::Open => {
}
CircuitState::HalfOpen => {
self.half_open_count.fetch_add(1, Ordering::Relaxed);
if success {
self.half_open_success.fetch_add(1, Ordering::Relaxed);
}
let total = self.half_open_count.load(Ordering::Acquire) as usize;
if total >= self.config.half_open_requests {
let success_count = self.half_open_success.load(Ordering::Acquire) as usize;
let success_rate = success_count as f64 / total as f64;
if success_rate >= self.config.success_rate_threshold {
*state = CircuitState::Closed;
self.half_open_count.store(0, Ordering::Release);
self.half_open_success.store(0, Ordering::Release);
} else {
*state = CircuitState::Open;
self.half_open_count.store(0, Ordering::Release);
self.half_open_success.store(0, Ordering::Release);
}
}
}
}
Ok(())
}
pub async fn check_request(&self) -> Result<()> {
let state = self.state.read().await;
match *state {
CircuitState::Closed => Ok(()),
CircuitState::Open => Err(Error::InvalidRequest(format!(
"Circuit breaker open for provider: {}",
self.provider
))),
CircuitState::HalfOpen => {
let count = self.half_open_count.load(Ordering::Acquire) as usize;
if count < self.config.half_open_requests {
Ok(())
} else {
Err(Error::InvalidRequest(
"Circuit breaker half-open, max test requests reached".to_string(),
))
}
}
}
}
pub async fn state(&self) -> CircuitState {
*self.state.read().await
}
pub async fn metrics(&self) -> HealthMetrics {
self.metrics.read().await.clone()
}
pub async fn open(&self) {
let mut state = self.state.write().await;
*state = CircuitState::Open;
}
pub async fn half_open(&self) {
let mut state = self.state.write().await;
*state = CircuitState::HalfOpen;
self.half_open_count.store(0, Ordering::Release);
self.half_open_success.store(0, Ordering::Release);
}
pub async fn reset(&self) {
let mut state = self.state.write().await;
*state = CircuitState::Closed;
self.half_open_count.store(0, Ordering::Release);
self.half_open_success.store(0, Ordering::Release);
}
}
impl Clone for CircuitBreaker {
fn clone(&self) -> Self {
Self {
provider: self.provider.clone(),
state: Arc::clone(&self.state),
metrics: Arc::clone(&self.metrics),
config: self.config,
half_open_count: Arc::clone(&self.half_open_count),
half_open_success: Arc::clone(&self.half_open_success),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_circuit_breaker_creation() {
let config = CircuitBreakerConfig::default();
let breaker = CircuitBreaker::new("test_provider", config);
assert_eq!(breaker.provider, "test_provider");
}
#[tokio::test]
async fn test_circuit_breaker_closed_success() {
let config = CircuitBreakerConfig::default();
let breaker = CircuitBreaker::new("test_provider", config);
assert!(breaker.check_request().await.is_ok());
breaker.record_result(50.0, true).await.unwrap();
breaker.record_result(55.0, true).await.unwrap();
assert_eq!(breaker.state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_opens_on_failure() {
let config = CircuitBreakerConfig {
failure_threshold_z_score: 2.5,
..Default::default()
};
let breaker = CircuitBreaker::new("test_provider", config);
for i in 0..10 {
breaker.record_result(50.0 + i as f64, false).await.unwrap();
}
breaker.record_result(500.0, true).await.unwrap();
assert_eq!(breaker.state().await, CircuitState::Open);
assert!(breaker.check_request().await.is_err());
}
#[tokio::test]
async fn test_circuit_breaker_half_open() {
let config = CircuitBreakerConfig::default();
let breaker = CircuitBreaker::new("test_provider", config);
breaker.open().await;
assert_eq!(breaker.state().await, CircuitState::Open);
breaker.half_open().await;
assert_eq!(breaker.state().await, CircuitState::HalfOpen);
for _ in 0..5 {
assert!(breaker.check_request().await.is_ok());
}
}
#[tokio::test]
async fn test_circuit_breaker_recovery() {
let config = CircuitBreakerConfig {
failure_threshold_z_score: 2.5,
half_open_requests: 5,
success_rate_threshold: 0.8,
..Default::default()
};
let breaker = CircuitBreaker::new("test_provider", config);
breaker.open().await;
breaker.half_open().await;
for _ in 0..4 {
breaker.record_result(50.0, true).await.unwrap();
}
breaker.record_result(60.0, false).await.unwrap();
assert_eq!(breaker.state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_reset() {
let config = CircuitBreakerConfig::default();
let breaker = CircuitBreaker::new("test_provider", config);
breaker.open().await;
assert_eq!(breaker.state().await, CircuitState::Open);
breaker.reset().await;
assert_eq!(breaker.state().await, CircuitState::Closed);
assert!(breaker.check_request().await.is_ok());
}
#[test]
fn test_health_metrics_update() {
let mut metrics = HealthMetrics::default();
metrics.update(50.0, false, 1000);
metrics.update(60.0, false, 1000);
metrics.update(55.0, false, 1000);
assert_eq!(metrics.latencies.len(), 3);
assert!((metrics.mean_latency - 55.0).abs() < 0.1);
assert_eq!(metrics.error_rate, 0.0);
}
#[test]
fn test_health_metrics_error_tracking() {
let mut metrics = HealthMetrics::default();
metrics.update(50.0, false, 1000);
metrics.update(60.0, true, 1000);
metrics.update(55.0, false, 1000);
assert_eq!(metrics.errors.len(), 3);
assert!((metrics.error_rate - 1.0 / 3.0).abs() < 0.01);
}
#[test]
fn test_anomaly_detection() {
let mut metrics = HealthMetrics::default();
for i in 0..20 {
let latency = 100.0 + (i as f64 * 0.5); metrics.update(latency, false, 1000);
}
let is_anomaly = metrics.is_anomaly(500.0, 2.5);
assert!(is_anomaly, "Should detect large deviation as anomaly");
let is_normal = metrics.is_anomaly(105.0, 2.5);
assert!(!is_normal, "Should not detect small variation as anomaly");
}
#[tokio::test]
async fn test_circuit_breaker_clone() {
let config = CircuitBreakerConfig::default();
let breaker1 = CircuitBreaker::new("test_provider", config);
let breaker2 = breaker1.clone();
breaker1.open().await;
assert_eq!(breaker2.state().await, CircuitState::Open);
}
#[tokio::test]
async fn test_circuit_breaker_prevents_cascading_failure() {
let config = CircuitBreakerConfig {
failure_threshold_z_score: 2.0,
half_open_requests: 5,
success_rate_threshold: 0.8,
..Default::default()
};
let breaker = CircuitBreaker::new("flaky_provider", config);
for i in 0..15 {
let latency = if i < 10 { 50.0 } else { 1000.0 };
breaker.record_result(latency, i >= 10).await.unwrap();
}
assert_eq!(breaker.state().await, CircuitState::Open);
assert!(breaker.check_request().await.is_err());
}
}