use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub timeout: Duration,
pub success_threshold: u32,
}
impl Default for CircuitBreakerConfig {
#[inline]
fn default() -> Self {
Self {
failure_threshold: 5,
timeout: Duration::from_secs(60),
success_threshold: 2,
}
}
}
pub struct CircuitBreaker {
name: String,
config: CircuitBreakerConfig,
state: Arc<RwLock<CircuitBreakerState>>,
}
#[derive(Debug)]
struct CircuitBreakerState {
circuit_state: CircuitState,
failure_count: u32,
success_count: u32,
last_failure_time: Option<Instant>,
}
impl CircuitBreaker {
#[must_use]
pub fn new(name: impl Into<String>, config: CircuitBreakerConfig) -> Self {
Self {
name: name.into(),
config,
state: Arc::new(RwLock::new(CircuitBreakerState {
circuit_state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
last_failure_time: None,
})),
}
}
#[must_use]
pub async fn state(&self) -> CircuitState {
self.state.read().await.circuit_state
}
#[must_use]
#[inline]
pub fn name(&self) -> &str {
&self.name
}
pub async fn call<F, Fut, T, E>(&self, f: F) -> Result<T, CircuitBreakerError<E>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
{
if let Err(CircuitBreakerError::CircuitOpen) = self.check_state::<E>().await {
return Err(CircuitBreakerError::CircuitOpen);
}
match f().await {
Ok(result) => {
self.on_success().await;
Ok(result)
}
Err(e) => {
self.on_failure().await;
Err(CircuitBreakerError::CallFailed(e))
}
}
}
async fn check_state<E>(&self) -> Result<(), CircuitBreakerError<E>> {
let mut state = self.state.write().await;
match state.circuit_state {
CircuitState::Closed => Ok(()),
CircuitState::Open => {
if let Some(last_failure) = state.last_failure_time {
if last_failure.elapsed() >= self.config.timeout {
state.circuit_state = CircuitState::HalfOpen;
state.success_count = 0;
tracing::info!(
circuit_breaker = %self.name,
"Circuit breaker transitioning to half-open"
);
Ok(())
} else {
Err(CircuitBreakerError::CircuitOpen)
}
} else {
Err(CircuitBreakerError::CircuitOpen)
}
}
CircuitState::HalfOpen => Ok(()),
}
}
async fn on_success(&self) {
let mut state = self.state.write().await;
match state.circuit_state {
CircuitState::Closed => {
state.failure_count = 0;
}
CircuitState::HalfOpen => {
state.success_count += 1;
if state.success_count >= self.config.success_threshold {
state.circuit_state = CircuitState::Closed;
state.failure_count = 0;
state.success_count = 0;
tracing::info!(
circuit_breaker = %self.name,
"Circuit breaker closed after recovery"
);
}
}
CircuitState::Open => {}
}
}
async fn on_failure(&self) {
let mut state = self.state.write().await;
match state.circuit_state {
CircuitState::Closed => {
state.failure_count += 1;
if state.failure_count >= self.config.failure_threshold {
state.circuit_state = CircuitState::Open;
state.last_failure_time = Some(Instant::now());
tracing::warn!(
circuit_breaker = %self.name,
failures = state.failure_count,
"Circuit breaker opened due to failures"
);
}
}
CircuitState::HalfOpen => {
state.circuit_state = CircuitState::Open;
state.last_failure_time = Some(Instant::now());
state.success_count = 0;
tracing::warn!(
circuit_breaker = %self.name,
"Circuit breaker reopened after failed recovery attempt"
);
}
CircuitState::Open => {
state.last_failure_time = Some(Instant::now());
}
}
}
pub async fn reset(&self) {
let mut state = self.state.write().await;
state.circuit_state = CircuitState::Closed;
state.failure_count = 0;
state.success_count = 0;
state.last_failure_time = None;
tracing::info!(circuit_breaker = %self.name, "Circuit breaker manually reset");
}
}
#[derive(Debug, thiserror::Error)]
pub enum CircuitBreakerError<E> {
#[error("Circuit breaker is open")]
CircuitOpen,
#[error("Call failed: {0}")]
CallFailed(E),
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_circuit_breaker_closed_state() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
timeout: Duration::from_secs(1),
success_threshold: 2,
};
let breaker = CircuitBreaker::new("test", config);
assert_eq!(breaker.state().await, CircuitState::Closed);
let result = breaker.call(|| async { Ok::<_, String>("success") }).await;
assert!(result.is_ok());
assert_eq!(breaker.state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_opens_after_failures() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
timeout: Duration::from_secs(1),
success_threshold: 2,
};
let breaker = CircuitBreaker::new("test", config);
for _ in 0..3 {
let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
}
assert_eq!(breaker.state().await, CircuitState::Open);
}
#[tokio::test]
async fn test_circuit_breaker_rejects_when_open() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout: Duration::from_secs(10),
success_threshold: 2,
};
let breaker = CircuitBreaker::new("test", config);
for _ in 0..2 {
let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
}
let result = breaker.call(|| async { Ok::<_, String>("success") }).await;
assert!(matches!(result, Err(CircuitBreakerError::CircuitOpen)));
}
#[tokio::test]
async fn test_circuit_breaker_half_open() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout: Duration::from_millis(100),
success_threshold: 2,
};
let breaker = CircuitBreaker::new("test", config);
for _ in 0..2 {
let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
}
assert_eq!(breaker.state().await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(150)).await;
let _ = breaker.call(|| async { Ok::<_, String>("success") }).await;
let state = breaker.state().await;
assert!(state == CircuitState::HalfOpen || state == CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_recovery() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout: Duration::from_millis(100),
success_threshold: 2,
};
let breaker = CircuitBreaker::new("test", config);
for _ in 0..2 {
let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
}
tokio::time::sleep(Duration::from_millis(150)).await;
for _ in 0..2 {
let _ = breaker.call(|| async { Ok::<_, String>("success") }).await;
}
assert_eq!(breaker.state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_reset() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout: Duration::from_secs(10),
success_threshold: 2,
};
let breaker = CircuitBreaker::new("test", config);
for _ in 0..2 {
let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
}
assert_eq!(breaker.state().await, CircuitState::Open);
breaker.reset().await;
assert_eq!(breaker.state().await, CircuitState::Closed);
}
}