use crate::error::{NetError, NetResult};
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: usize,
pub success_threshold: usize,
pub timeout: Duration,
pub window_duration: Duration,
pub half_open_max_requests: usize,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
timeout: Duration::from_secs(60),
window_duration: Duration::from_secs(60),
half_open_max_requests: 3,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CircuitBreakerStats {
pub total_requests: u64,
pub total_failures: u64,
pub total_successes: u64,
pub times_opened: u64,
pub times_closed: u64,
pub consecutive_failures: usize,
pub consecutive_successes: usize,
pub last_state_change: Option<Instant>,
}
#[derive(Debug)]
struct CircuitBreakerState {
state: CircuitState,
stats: CircuitBreakerStats,
opened_at: Option<Instant>,
window_start: Instant,
half_open_requests: usize,
}
#[derive(Debug, Clone)]
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: Arc<RwLock<CircuitBreakerState>>,
}
impl CircuitBreaker {
pub fn new() -> Self {
Self::with_config(CircuitBreakerConfig::default())
}
pub fn with_config(config: CircuitBreakerConfig) -> Self {
let state = CircuitBreakerState {
state: CircuitState::Closed,
stats: CircuitBreakerStats::default(),
opened_at: None,
window_start: Instant::now(),
half_open_requests: 0,
};
Self {
config,
state: Arc::new(RwLock::new(state)),
}
}
pub fn is_request_allowed(&self) -> NetResult<()> {
let mut state = self.state.write();
if state.window_start.elapsed() > self.config.window_duration {
state.stats.consecutive_failures = 0;
state.window_start = Instant::now();
}
match state.state {
CircuitState::Closed => Ok(()),
CircuitState::Open => {
if let Some(opened_at) = state.opened_at {
if opened_at.elapsed() >= self.config.timeout {
self.transition_to_half_open(&mut state);
state.half_open_requests += 1;
Ok(())
} else {
Err(NetError::ServerUnavailable(
"Circuit breaker is open".to_string(),
))
}
} else {
Err(NetError::ServerUnavailable(
"Circuit breaker is open".to_string(),
))
}
}
CircuitState::HalfOpen => {
if state.half_open_requests < self.config.half_open_max_requests {
state.half_open_requests += 1;
Ok(())
} else {
Err(NetError::ServerUnavailable(
"Circuit breaker half-open limit reached".to_string(),
))
}
}
}
}
pub fn record_success(&self) {
let mut state = self.state.write();
state.stats.total_requests += 1;
state.stats.total_successes += 1;
state.stats.consecutive_failures = 0;
state.stats.consecutive_successes += 1;
match state.state {
CircuitState::HalfOpen => {
if state.stats.consecutive_successes >= self.config.success_threshold {
self.transition_to_closed(&mut state);
}
}
CircuitState::Open => {
self.transition_to_closed(&mut state);
}
CircuitState::Closed => {
}
}
}
pub fn record_failure(&self) {
let mut state = self.state.write();
state.stats.total_requests += 1;
state.stats.total_failures += 1;
state.stats.consecutive_failures += 1;
state.stats.consecutive_successes = 0;
match state.state {
CircuitState::Closed => {
if state.stats.consecutive_failures >= self.config.failure_threshold {
self.transition_to_open(&mut state);
}
}
CircuitState::HalfOpen => {
self.transition_to_open(&mut state);
}
CircuitState::Open => {
}
}
}
pub fn state(&self) -> CircuitState {
self.state.read().state
}
pub fn stats(&self) -> CircuitBreakerStats {
self.state.read().stats.clone()
}
pub fn reset(&self) {
let mut state = self.state.write();
state.state = CircuitState::Closed;
state.stats.consecutive_failures = 0;
state.stats.consecutive_successes = 0;
state.opened_at = None;
state.half_open_requests = 0;
state.window_start = Instant::now();
}
fn transition_to_closed(&self, state: &mut CircuitBreakerState) {
state.state = CircuitState::Closed;
state.stats.times_closed += 1;
state.stats.last_state_change = Some(Instant::now());
state.stats.consecutive_failures = 0;
state.stats.consecutive_successes = 0;
state.opened_at = None;
state.half_open_requests = 0;
}
fn transition_to_open(&self, state: &mut CircuitBreakerState) {
state.state = CircuitState::Open;
state.stats.times_opened += 1;
state.stats.last_state_change = Some(Instant::now());
state.opened_at = Some(Instant::now());
state.half_open_requests = 0;
}
fn transition_to_half_open(&self, state: &mut CircuitBreakerState) {
state.state = CircuitState::HalfOpen;
state.stats.last_state_change = Some(Instant::now());
state.stats.consecutive_successes = 0;
state.half_open_requests = 0;
}
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self::new()
}
}
pub async fn with_circuit_breaker<F, T, E>(
circuit_breaker: &CircuitBreaker,
operation: F,
) -> Result<T, E>
where
F: std::future::Future<Output = Result<T, E>>,
E: From<NetError>,
{
circuit_breaker.is_request_allowed()?;
match operation.await {
Ok(result) => {
circuit_breaker.record_success();
Ok(result)
}
Err(err) => {
circuit_breaker.record_failure();
Err(err)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_circuit_breaker_default() {
let cb = CircuitBreaker::new();
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_circuit_breaker_closed_to_open() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
assert_eq!(cb.state(), CircuitState::Closed);
for _ in 0..3 {
assert!(cb.is_request_allowed().is_ok());
cb.record_failure();
}
assert_eq!(cb.state(), CircuitState::Open);
assert!(cb.is_request_allowed().is_err());
}
#[test]
fn test_circuit_breaker_open_to_half_open() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout: Duration::from_millis(100),
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
cb.is_request_allowed().ok();
cb.record_failure();
cb.is_request_allowed().ok();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
thread::sleep(Duration::from_millis(150));
assert!(cb.is_request_allowed().is_ok());
assert_eq!(cb.state(), CircuitState::HalfOpen);
}
#[test]
fn test_circuit_breaker_half_open_to_closed() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
success_threshold: 2,
timeout: Duration::from_millis(100),
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
cb.is_request_allowed().ok();
cb.record_failure();
cb.is_request_allowed().ok();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
thread::sleep(Duration::from_millis(150));
cb.is_request_allowed().ok();
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
cb.record_success();
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_circuit_breaker_half_open_to_open() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout: Duration::from_millis(100),
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
cb.is_request_allowed().ok();
cb.record_failure();
cb.is_request_allowed().ok();
cb.record_failure();
thread::sleep(Duration::from_millis(150));
cb.is_request_allowed().ok();
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn test_circuit_breaker_stats() {
let cb = CircuitBreaker::new();
cb.is_request_allowed().ok();
cb.record_success();
cb.is_request_allowed().ok();
cb.record_failure();
let stats = cb.stats();
assert_eq!(stats.total_requests, 2);
assert_eq!(stats.total_successes, 1);
assert_eq!(stats.total_failures, 1);
}
#[test]
fn test_circuit_breaker_reset() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
cb.is_request_allowed().ok();
cb.record_failure();
cb.is_request_allowed().ok();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
cb.reset();
assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.is_request_allowed().is_ok());
}
#[test]
fn test_half_open_request_limit() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
timeout: Duration::from_millis(100),
half_open_max_requests: 2,
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
cb.is_request_allowed().ok();
cb.record_failure();
thread::sleep(Duration::from_millis(150));
assert!(cb.is_request_allowed().is_ok());
assert!(cb.is_request_allowed().is_ok());
assert!(cb.is_request_allowed().is_err()); }
#[tokio::test]
async fn test_with_circuit_breaker_success() {
let cb = CircuitBreaker::new();
let result = with_circuit_breaker(&cb, async { Ok::<i32, NetError>(42) }).await;
assert!(result.is_ok());
assert_eq!(result.ok(), Some(42));
assert_eq!(cb.stats().total_successes, 1);
}
#[tokio::test]
async fn test_with_circuit_breaker_failure() {
let cb = CircuitBreaker::new();
let result = with_circuit_breaker(&cb, async {
Err::<i32, NetError>(NetError::Timeout("test".to_string()))
})
.await;
assert!(result.is_err());
assert_eq!(cb.stats().total_failures, 1);
}
}