use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use crate::error::{ClusterError, Result};
use crate::raft::OxirsNodeId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub timeout_ms: u64,
pub window_size_secs: u64,
pub half_open_requests: u32,
pub adaptive_thresholds: bool,
pub min_failure_rate: f64,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
timeout_ms: 5000,
window_size_secs: 60,
half_open_requests: 3,
adaptive_thresholds: true,
min_failure_rate: 0.5,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CircuitBreakerStats {
pub total_requests: u64,
pub total_failures: u64,
pub total_successes: u64,
pub rejected_requests: u64,
pub times_opened: u64,
pub times_closed: u64,
pub current_failure_rate: f64,
pub avg_response_time_ms: f64,
}
#[derive(Debug, Clone)]
struct CircuitBreakerState {
state: CircuitState,
failure_count: u32,
success_count: u32,
opened_at: Option<Instant>,
recent_results: Vec<(Instant, bool)>,
stats: CircuitBreakerStats,
half_open_request_count: u32,
}
impl Default for CircuitBreakerState {
fn default() -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
opened_at: None,
recent_results: Vec::new(),
stats: CircuitBreakerStats::default(),
half_open_request_count: 0,
}
}
}
#[derive(Clone)]
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: Arc<RwLock<CircuitBreakerState>>,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: Arc::new(RwLock::new(CircuitBreakerState::default())),
}
}
pub async fn can_execute(&self) -> Result<()> {
let mut state = self.state.write().await;
match state.state {
CircuitState::Closed => Ok(()),
CircuitState::Open => {
if let Some(opened_at) = state.opened_at {
let elapsed = opened_at.elapsed();
if elapsed.as_millis() >= self.config.timeout_ms as u128 {
state.state = CircuitState::HalfOpen;
state.success_count = 0;
state.failure_count = 0;
state.half_open_request_count = 0;
info!("Circuit breaker transitioned to HalfOpen state");
Ok(())
} else {
state.stats.rejected_requests += 1;
Err(ClusterError::CircuitOpen)
}
} else {
state.stats.rejected_requests += 1;
Err(ClusterError::CircuitOpen)
}
}
CircuitState::HalfOpen => {
if state.half_open_request_count < self.config.half_open_requests {
state.half_open_request_count += 1;
Ok(())
} else {
state.stats.rejected_requests += 1;
Err(ClusterError::CircuitOpen)
}
}
}
}
pub async fn record_success(&self, response_time_ms: f64) {
let mut state = self.state.write().await;
state.stats.total_requests += 1;
state.stats.total_successes += 1;
let total_responses = state.stats.total_successes + state.stats.total_failures;
if total_responses > 1 {
state.stats.avg_response_time_ms = (state.stats.avg_response_time_ms
* (total_responses - 1) as f64
+ response_time_ms)
/ total_responses as f64;
} else {
state.stats.avg_response_time_ms = response_time_ms;
}
state.recent_results.push((Instant::now(), true));
self.cleanup_old_results(&mut state);
match state.state {
CircuitState::Closed => {
state.failure_count = 0;
}
CircuitState::HalfOpen => {
state.success_count += 1;
if state.success_count >= self.config.success_threshold {
state.state = CircuitState::Closed;
state.failure_count = 0;
state.success_count = 0;
state.stats.times_closed += 1;
info!("Circuit breaker closed after successful recovery");
}
}
CircuitState::Open => {
state.state = CircuitState::HalfOpen;
state.success_count = 1;
state.failure_count = 0;
}
}
self.update_failure_rate(&mut state);
}
pub async fn record_failure(&self) {
let mut state = self.state.write().await;
state.stats.total_requests += 1;
state.stats.total_failures += 1;
state.recent_results.push((Instant::now(), false));
self.cleanup_old_results(&mut state);
match state.state {
CircuitState::Closed => {
state.failure_count += 1;
if self.should_open_circuit(&state) {
state.state = CircuitState::Open;
state.opened_at = Some(Instant::now());
state.stats.times_opened += 1;
warn!(
"Circuit breaker opened after {} failures (rate: {:.2}%)",
state.failure_count,
state.stats.current_failure_rate * 100.0
);
}
}
CircuitState::HalfOpen => {
state.state = CircuitState::Open;
state.opened_at = Some(Instant::now());
state.failure_count = 1;
state.success_count = 0;
state.stats.times_opened += 1;
warn!("Circuit breaker re-opened after failure in HalfOpen state");
}
CircuitState::Open => {
state.failure_count += 1;
}
}
self.update_failure_rate(&mut state);
}
pub async fn get_state(&self) -> CircuitState {
self.state.read().await.state
}
pub async fn get_stats(&self) -> CircuitBreakerStats {
self.state.read().await.stats.clone()
}
pub async fn reset(&self) {
let mut state = self.state.write().await;
state.state = CircuitState::Closed;
state.failure_count = 0;
state.success_count = 0;
state.opened_at = None;
state.recent_results.clear();
info!("Circuit breaker reset to Closed state");
}
fn should_open_circuit(&self, state: &CircuitBreakerState) -> bool {
if state.failure_count >= self.config.failure_threshold {
return true;
}
if self.config.adaptive_thresholds {
let recent_failures = state
.recent_results
.iter()
.filter(|(_, success)| !success)
.count();
let total_recent = state.recent_results.len();
if total_recent > 0 {
let failure_rate = recent_failures as f64 / total_recent as f64;
if failure_rate >= self.config.min_failure_rate {
return true;
}
}
}
false
}
fn update_failure_rate(&self, state: &mut CircuitBreakerState) {
if state.stats.total_requests > 0 {
state.stats.current_failure_rate =
state.stats.total_failures as f64 / state.stats.total_requests as f64;
}
}
fn cleanup_old_results(&self, state: &mut CircuitBreakerState) {
let window = Duration::from_secs(self.config.window_size_secs);
let now = Instant::now();
state
.recent_results
.retain(|(timestamp, _)| now.duration_since(*timestamp) < window);
}
}
pub struct CircuitBreakerManager {
node_breakers: Arc<RwLock<HashMap<OxirsNodeId, CircuitBreaker>>>,
operation_breakers: Arc<RwLock<HashMap<String, CircuitBreaker>>>,
default_config: CircuitBreakerConfig,
}
impl CircuitBreakerManager {
pub fn new(default_config: CircuitBreakerConfig) -> Self {
Self {
node_breakers: Arc::new(RwLock::new(HashMap::new())),
operation_breakers: Arc::new(RwLock::new(HashMap::new())),
default_config,
}
}
pub async fn get_node_breaker(&self, node_id: OxirsNodeId) -> CircuitBreaker {
let mut breakers = self.node_breakers.write().await;
breakers
.entry(node_id)
.or_insert_with(|| CircuitBreaker::new(self.default_config.clone()))
.clone()
}
pub async fn get_operation_breaker(&self, operation: &str) -> CircuitBreaker {
let mut breakers = self.operation_breakers.write().await;
breakers
.entry(operation.to_string())
.or_insert_with(|| CircuitBreaker::new(self.default_config.clone()))
.clone()
}
pub async fn execute_with_node_breaker<F, T>(&self, node_id: OxirsNodeId, f: F) -> Result<T>
where
F: std::future::Future<Output = Result<T>>,
{
let breaker = self.get_node_breaker(node_id).await;
breaker.can_execute().await?;
let start = Instant::now();
match f.await {
Ok(result) => {
let elapsed = start.elapsed().as_millis() as f64;
breaker.record_success(elapsed).await;
Ok(result)
}
Err(e) => {
breaker.record_failure().await;
Err(e)
}
}
}
pub async fn execute_with_operation_breaker<F, T>(&self, operation: &str, f: F) -> Result<T>
where
F: std::future::Future<Output = Result<T>>,
{
let breaker = self.get_operation_breaker(operation).await;
breaker.can_execute().await?;
let start = Instant::now();
match f.await {
Ok(result) => {
let elapsed = start.elapsed().as_millis() as f64;
breaker.record_success(elapsed).await;
Ok(result)
}
Err(e) => {
breaker.record_failure().await;
Err(e)
}
}
}
pub async fn get_all_node_states(&self) -> HashMap<OxirsNodeId, CircuitState> {
let breakers = self.node_breakers.read().await;
let mut states = HashMap::new();
for (node_id, breaker) in breakers.iter() {
states.insert(*node_id, breaker.get_state().await);
}
states
}
pub async fn get_all_operation_states(&self) -> HashMap<String, CircuitState> {
let breakers = self.operation_breakers.read().await;
let mut states = HashMap::new();
for (operation, breaker) in breakers.iter() {
states.insert(operation.clone(), breaker.get_state().await);
}
states
}
pub async fn get_node_stats(&self, node_id: OxirsNodeId) -> Option<CircuitBreakerStats> {
let breakers = self.node_breakers.read().await;
if let Some(breaker) = breakers.get(&node_id) {
Some(breaker.get_stats().await)
} else {
None
}
}
pub async fn get_all_node_stats(&self) -> HashMap<OxirsNodeId, CircuitBreakerStats> {
let breakers = self.node_breakers.read().await;
let mut stats = HashMap::new();
for (node_id, breaker) in breakers.iter() {
stats.insert(*node_id, breaker.get_stats().await);
}
stats
}
pub async fn reset_all(&self) {
debug!("Resetting all circuit breakers");
let node_breakers = self.node_breakers.read().await;
for breaker in node_breakers.values() {
breaker.reset().await;
}
let operation_breakers = self.operation_breakers.read().await;
for breaker in operation_breakers.values() {
breaker.reset().await;
}
info!("All circuit breakers reset");
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_circuit_breaker_initial_state() {
let config = CircuitBreakerConfig::default();
let breaker = CircuitBreaker::new(config);
assert_eq!(breaker.get_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_opens_on_failures() {
let mut config = CircuitBreakerConfig::default();
config.failure_threshold = 3;
let breaker = CircuitBreaker::new(config);
for _ in 0..3 {
breaker.record_failure().await;
}
assert_eq!(breaker.get_state().await, CircuitState::Open);
}
#[tokio::test]
async fn test_circuit_breaker_half_open_transition() {
let mut config = CircuitBreakerConfig::default();
config.failure_threshold = 2;
config.timeout_ms = 100;
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
breaker.record_failure().await;
assert_eq!(breaker.get_state().await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(150)).await;
let result = breaker.can_execute().await;
assert!(result.is_ok());
assert_eq!(breaker.get_state().await, CircuitState::HalfOpen);
}
#[tokio::test]
async fn test_circuit_breaker_closes_on_success() {
let mut config = CircuitBreakerConfig::default();
config.failure_threshold = 2;
config.success_threshold = 2;
config.timeout_ms = 100;
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
breaker.record_failure().await;
assert_eq!(breaker.get_state().await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(150)).await;
breaker.can_execute().await.unwrap();
breaker.record_success(10.0).await;
breaker.record_success(10.0).await;
assert_eq!(breaker.get_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_stats() {
let config = CircuitBreakerConfig::default();
let breaker = CircuitBreaker::new(config);
breaker.record_success(10.0).await;
breaker.record_success(20.0).await;
breaker.record_failure().await;
let stats = breaker.get_stats().await;
assert_eq!(stats.total_requests, 3);
assert_eq!(stats.total_successes, 2);
assert_eq!(stats.total_failures, 1);
assert_eq!(stats.current_failure_rate, 1.0 / 3.0);
}
#[tokio::test]
async fn test_circuit_breaker_manager() {
let config = CircuitBreakerConfig::default();
let manager = CircuitBreakerManager::new(config);
let breaker = manager.get_node_breaker(1).await;
assert_eq!(breaker.get_state().await, CircuitState::Closed);
let breaker = manager.get_operation_breaker("query").await;
assert_eq!(breaker.get_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_execute_with_protection() {
let config = CircuitBreakerConfig::default();
let manager = CircuitBreakerManager::new(config);
let result = manager
.execute_with_node_breaker(1, async { Ok::<_, ClusterError>(42) })
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
let result = manager
.execute_with_node_breaker(2, async {
Err::<i32, _>(ClusterError::Network("test error".to_string()))
})
.await;
assert!(result.is_err());
}
}