use crate::{Error, Result};
use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub timeout: Duration,
pub half_open_test_period: Duration,
pub base_delay: Duration,
pub max_delay: Duration,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
timeout: Duration::from_secs(30),
half_open_test_period: Duration::from_secs(10),
base_delay: Duration::from_millis(100),
max_delay: Duration::from_millis(1600),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CircuitBreakerStats {
pub total_calls: u64,
pub successful_calls: u64,
pub failed_calls: u64,
pub rejected_calls: u64,
pub consecutive_failures: u32,
pub circuit_opened_count: u32,
}
struct CircuitBreakerState {
state: CircuitState,
stats: CircuitBreakerStats,
last_failure_time: Option<Instant>,
half_open_started: Option<Instant>,
}
impl Default for CircuitBreakerState {
fn default() -> Self {
Self {
state: CircuitState::Closed,
stats: CircuitBreakerStats::default(),
last_failure_time: None,
half_open_started: None,
}
}
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: Arc<RwLock<CircuitBreakerState>>,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
info!(
"Initializing circuit breaker: threshold={}, timeout={:?}",
config.failure_threshold, config.timeout
);
Self {
config,
state: Arc::new(RwLock::new(CircuitBreakerState::default())),
}
}
pub async fn call<F, Fut, T>(&self, operation: F) -> Result<T>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T>>,
{
let should_proceed = self.should_allow_request().await?;
if !should_proceed {
let mut state = self.state.write().await;
state.stats.rejected_calls += 1;
debug!("Circuit breaker rejecting request - circuit is open");
return Err(Error::CircuitBreakerOpen);
}
{
let mut state = self.state.write().await;
state.stats.total_calls += 1;
}
let result = operation().await;
self.on_result(&result).await;
result
}
async fn should_allow_request(&self) -> Result<bool> {
let now = Instant::now();
let mut state = self.state.write().await;
match state.state {
CircuitState::Closed => Ok(true),
CircuitState::Open => {
if let Some(last_failure) = state.last_failure_time {
if now.duration_since(last_failure) >= self.config.timeout {
info!("Circuit breaker transitioning to half-open state");
state.state = CircuitState::HalfOpen;
state.half_open_started = Some(now);
Ok(true)
} else {
Ok(false)
}
} else {
Ok(false)
}
}
CircuitState::HalfOpen => {
Ok(true)
}
}
}
async fn on_result<T>(&self, result: &Result<T>) {
let mut state = self.state.write().await;
match result {
Ok(_) => {
state.stats.successful_calls += 1;
self.on_success(&mut state).await;
}
Err(e) => {
if e.is_recoverable() {
state.stats.failed_calls += 1;
self.on_failure(&mut state).await;
} else {
debug!("Non-recoverable error, not affecting circuit: {}", e);
}
}
}
}
#[allow(clippy::unused_async)]
async fn on_success(&self, state: &mut CircuitBreakerState) {
match state.state {
CircuitState::HalfOpen => {
info!("Circuit breaker closing after successful recovery test");
state.state = CircuitState::Closed;
state.stats.consecutive_failures = 0;
state.last_failure_time = None;
state.half_open_started = None;
}
CircuitState::Closed => {
if state.stats.consecutive_failures > 0 {
debug!(
"Resetting consecutive failures from {}",
state.stats.consecutive_failures
);
state.stats.consecutive_failures = 0;
}
}
CircuitState::Open => {
warn!("Unexpected success in open state");
}
}
}
#[allow(clippy::unused_async)]
async fn on_failure(&self, state: &mut CircuitBreakerState) {
state.stats.consecutive_failures += 1;
state.last_failure_time = Some(Instant::now());
debug!(
"Circuit breaker recorded failure {}/{}",
state.stats.consecutive_failures, self.config.failure_threshold
);
match state.state {
CircuitState::Closed => {
if state.stats.consecutive_failures >= self.config.failure_threshold {
warn!(
"Circuit breaker opening after {} consecutive failures",
state.stats.consecutive_failures
);
state.state = CircuitState::Open;
state.stats.circuit_opened_count += 1;
}
}
CircuitState::HalfOpen => {
warn!("Circuit breaker reopening after failure in half-open state");
state.state = CircuitState::Open;
state.stats.circuit_opened_count += 1;
state.half_open_started = None;
}
CircuitState::Open => {
}
}
}
pub async fn state(&self) -> CircuitState {
let state = self.state.read().await;
state.state
}
pub async fn stats(&self) -> CircuitBreakerStats {
let state = self.state.read().await;
state.stats.clone()
}
#[must_use]
pub fn calculate_backoff(&self, attempt: u32) -> Duration {
let delay = self.config.base_delay.as_millis() as u64 * 2u64.pow(attempt);
let delay = Duration::from_millis(delay);
std::cmp::min(delay, self.config.max_delay)
}
pub async fn reset(&self) {
let mut state = self.state.write().await;
info!("Circuit breaker reset to closed state");
state.state = CircuitState::Closed;
state.stats.consecutive_failures = 0;
state.last_failure_time = None;
state.half_open_started = None;
}
}