use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
impl std::fmt::Display for CircuitState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitState::Closed => write!(f, "CLOSED"),
CircuitState::Open => write!(f, "OPEN"),
CircuitState::HalfOpen => write!(f, "HALF-OPEN"),
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub recovery_timeout: Duration,
pub success_threshold: u32,
pub failure_window: Duration,
pub name: String,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
recovery_timeout: Duration::from_secs(30),
success_threshold: 2,
failure_window: Duration::from_secs(60),
name: "default".to_string(),
}
}
}
impl CircuitBreakerConfig {
pub fn with_name(name: impl Into<String>) -> Self {
Self {
name: name.into(),
..Default::default()
}
}
pub fn failure_threshold(mut self, threshold: u32) -> Self {
self.failure_threshold = threshold;
self
}
pub fn recovery_timeout(mut self, timeout: Duration) -> Self {
self.recovery_timeout = timeout;
self
}
pub fn success_threshold(mut self, threshold: u32) -> Self {
self.success_threshold = threshold;
self
}
pub fn failure_window(mut self, window: Duration) -> Self {
self.failure_window = window;
self
}
}
#[derive(Debug, Clone)]
pub struct CircuitOpenError {
pub service_name: String,
pub time_until_retry: Duration,
}
impl std::fmt::Display for CircuitOpenError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Circuit breaker open for '{}': service unavailable, retry in {:?}",
self.service_name, self.time_until_retry
)
}
}
impl std::error::Error for CircuitOpenError {}
struct CircuitBreakerState {
state: CircuitState,
failure_count: u32,
success_count: u32,
last_failure_time: Option<Instant>,
last_state_change: Instant,
}
#[derive(Clone)]
pub struct CircuitBreaker {
config: Arc<CircuitBreakerConfig>,
state: Arc<RwLock<CircuitBreakerState>>,
total_calls: Arc<AtomicU64>,
total_failures: Arc<AtomicU64>,
total_rejections: Arc<AtomicU64>,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config: Arc::new(config),
state: Arc::new(RwLock::new(CircuitBreakerState {
state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
last_failure_time: None,
last_state_change: Instant::now(),
})),
total_calls: Arc::new(AtomicU64::new(0)),
total_failures: Arc::new(AtomicU64::new(0)),
total_rejections: Arc::new(AtomicU64::new(0)),
}
}
pub async fn state(&self) -> CircuitState {
let state = self.state.read().await;
self.effective_state(&state)
}
pub fn metrics(&self) -> CircuitBreakerMetrics {
CircuitBreakerMetrics {
name: self.config.name.clone(),
total_calls: self.total_calls.load(Ordering::Relaxed),
total_failures: self.total_failures.load(Ordering::Relaxed),
total_rejections: self.total_rejections.load(Ordering::Relaxed),
}
}
pub async fn check(&self) -> Result<(), CircuitOpenError> {
let state = self.state.read().await;
let effective_state = self.effective_state(&state);
match effective_state {
CircuitState::Closed => Ok(()),
CircuitState::HalfOpen => Ok(()), CircuitState::Open => {
self.total_rejections.fetch_add(1, Ordering::Relaxed);
let elapsed = state.last_state_change.elapsed();
let time_until_retry = self.config.recovery_timeout.saturating_sub(elapsed);
Err(CircuitOpenError {
service_name: self.config.name.clone(),
time_until_retry,
})
}
}
}
pub async fn call<F, T, E>(&self, operation: F) -> Result<T, CircuitBreakerError<E>>
where
F: std::future::Future<Output = Result<T, E>>,
E: std::error::Error,
{
self.total_calls.fetch_add(1, Ordering::Relaxed);
self.check().await.map_err(CircuitBreakerError::Open)?;
match operation.await {
Ok(result) => {
self.record_success().await;
Ok(result)
}
Err(e) => {
self.record_failure().await;
Err(CircuitBreakerError::ServiceError(e))
}
}
}
pub async fn record_success(&self) {
let mut state = self.state.write().await;
let effective_state = self.effective_state(&state);
match effective_state {
CircuitState::Closed => {
if state.failure_count > 0 {
debug!(
"Circuit breaker '{}': success in closed state, resetting failure count",
self.config.name
);
}
state.failure_count = 0;
}
CircuitState::HalfOpen => {
state.success_count += 1;
debug!(
"Circuit breaker '{}': success in half-open state ({}/{})",
self.config.name, state.success_count, self.config.success_threshold
);
if state.success_count >= self.config.success_threshold {
debug!(
"Circuit breaker '{}': closing circuit after {} successful requests",
self.config.name, state.success_count
);
state.state = CircuitState::Closed;
state.failure_count = 0;
state.success_count = 0;
state.last_state_change = Instant::now();
}
}
CircuitState::Open => {
}
}
}
pub async fn record_failure(&self) {
self.total_failures.fetch_add(1, Ordering::Relaxed);
let mut state = self.state.write().await;
let effective_state = self.effective_state(&state);
match effective_state {
CircuitState::Closed => {
if let Some(last_failure) = state.last_failure_time {
if last_failure.elapsed() > self.config.failure_window {
state.failure_count = 0;
}
}
state.failure_count += 1;
state.last_failure_time = Some(Instant::now());
debug!(
"Circuit breaker '{}': failure in closed state ({}/{})",
self.config.name, state.failure_count, self.config.failure_threshold
);
if state.failure_count >= self.config.failure_threshold {
warn!(
"Circuit breaker '{}': opening circuit after {} failures",
self.config.name, state.failure_count
);
state.state = CircuitState::Open;
state.last_state_change = Instant::now();
}
}
CircuitState::HalfOpen => {
warn!(
"Circuit breaker '{}': failure in half-open state, re-opening circuit",
self.config.name
);
state.state = CircuitState::Open;
state.success_count = 0;
state.last_state_change = Instant::now();
}
CircuitState::Open => {
}
}
}
fn effective_state(&self, state: &CircuitBreakerState) -> CircuitState {
match state.state {
CircuitState::Open => {
if state.last_state_change.elapsed() >= self.config.recovery_timeout {
CircuitState::HalfOpen
} else {
CircuitState::Open
}
}
other => other,
}
}
pub async fn force_close(&self) {
let mut state = self.state.write().await;
warn!("Circuit breaker '{}': manually closing circuit", self.config.name);
state.state = CircuitState::Closed;
state.failure_count = 0;
state.success_count = 0;
state.last_state_change = Instant::now();
}
pub async fn force_open(&self) {
let mut state = self.state.write().await;
warn!("Circuit breaker '{}': manually opening circuit", self.config.name);
state.state = CircuitState::Open;
state.last_state_change = Instant::now();
}
}
#[derive(Debug)]
pub enum CircuitBreakerError<E> {
Open(CircuitOpenError),
ServiceError(E),
}
impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitBreakerError::Open(e) => write!(f, "{}", e),
CircuitBreakerError::ServiceError(e) => write!(f, "{}", e),
}
}
}
impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
CircuitBreakerError::Open(e) => Some(e),
CircuitBreakerError::ServiceError(e) => Some(e),
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerMetrics {
pub name: String,
pub total_calls: u64,
pub total_failures: u64,
pub total_rejections: u64,
}
#[derive(Clone, Default)]
pub struct CircuitBreakerRegistry {
breakers: Arc<RwLock<std::collections::HashMap<String, CircuitBreaker>>>,
}
impl CircuitBreakerRegistry {
pub fn new() -> Self {
Self::default()
}
pub async fn register(&self, name: impl Into<String>, breaker: CircuitBreaker) {
let name = name.into();
let mut breakers = self.breakers.write().await;
breakers.insert(name, breaker);
}
pub async fn get(&self, name: &str) -> Option<CircuitBreaker> {
let breakers = self.breakers.read().await;
breakers.get(name).cloned()
}
pub async fn get_or_create(&self, name: impl Into<String>) -> CircuitBreaker {
let name = name.into();
let breakers = self.breakers.read().await;
if let Some(breaker) = breakers.get(&name) {
return breaker.clone();
}
drop(breakers);
let breaker = CircuitBreaker::new(CircuitBreakerConfig::with_name(&name));
let mut breakers = self.breakers.write().await;
breakers.entry(name).or_insert(breaker).clone()
}
pub async fn all_metrics(&self) -> Vec<CircuitBreakerMetrics> {
let breakers = self.breakers.read().await;
breakers.values().map(|b| b.metrics()).collect()
}
pub async fn all_states(&self) -> Vec<(String, CircuitState)> {
let breakers = self.breakers.read().await;
let mut states = Vec::new();
for (name, breaker) in breakers.iter() {
let state = breaker.state().await;
states.push((name.clone(), state));
}
states
}
}
pub mod presets {
use super::*;
pub fn redis() -> CircuitBreakerConfig {
CircuitBreakerConfig::with_name("redis")
.failure_threshold(3)
.recovery_timeout(Duration::from_secs(10))
.success_threshold(2)
.failure_window(Duration::from_secs(30))
}
pub fn s3() -> CircuitBreakerConfig {
CircuitBreakerConfig::with_name("s3")
.failure_threshold(5)
.recovery_timeout(Duration::from_secs(30))
.success_threshold(2)
.failure_window(Duration::from_secs(60))
}
pub fn email() -> CircuitBreakerConfig {
CircuitBreakerConfig::with_name("email")
.failure_threshold(3)
.recovery_timeout(Duration::from_secs(60))
.success_threshold(1)
.failure_window(Duration::from_secs(120))
}
pub fn database() -> CircuitBreakerConfig {
CircuitBreakerConfig::with_name("database")
.failure_threshold(10)
.recovery_timeout(Duration::from_secs(15))
.success_threshold(3)
.failure_window(Duration::from_secs(60))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_circuit_starts_closed() {
let breaker = CircuitBreaker::new(CircuitBreakerConfig::with_name("test"));
assert_eq!(breaker.state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_opens_after_failures() {
let config = CircuitBreakerConfig::with_name("test").failure_threshold(3);
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
breaker.record_failure().await;
assert_eq!(breaker.state().await, CircuitState::Closed);
breaker.record_failure().await;
assert_eq!(breaker.state().await, CircuitState::Open);
}
#[tokio::test]
async fn test_circuit_rejects_when_open() {
let config = CircuitBreakerConfig::with_name("test")
.failure_threshold(1)
.recovery_timeout(Duration::from_secs(60));
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
assert!(breaker.check().await.is_err());
}
#[tokio::test]
async fn test_circuit_transitions_to_half_open() {
let config = CircuitBreakerConfig::with_name("test")
.failure_threshold(1)
.recovery_timeout(Duration::from_millis(10));
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
assert_eq!(breaker.state().await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(breaker.state().await, CircuitState::HalfOpen);
}
#[tokio::test]
async fn test_circuit_closes_after_successes_in_half_open() {
let config = CircuitBreakerConfig::with_name("test")
.failure_threshold(1)
.recovery_timeout(Duration::from_millis(10))
.success_threshold(2);
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(breaker.state().await, CircuitState::HalfOpen);
breaker.record_success().await;
assert_eq!(breaker.state().await, CircuitState::HalfOpen);
breaker.record_success().await;
assert_eq!(breaker.state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_failure_in_half_open_reopens_circuit() {
let config = CircuitBreakerConfig::with_name("test")
.failure_threshold(1)
.recovery_timeout(Duration::from_millis(10));
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(breaker.state().await, CircuitState::HalfOpen);
breaker.record_failure().await;
assert_eq!(breaker.state().await, CircuitState::Open);
}
#[tokio::test]
async fn test_success_resets_failure_count() {
let config = CircuitBreakerConfig::with_name("test").failure_threshold(3);
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
breaker.record_failure().await;
breaker.record_success().await;
breaker.record_failure().await;
breaker.record_failure().await;
assert_eq!(breaker.state().await, CircuitState::Closed);
breaker.record_failure().await;
assert_eq!(breaker.state().await, CircuitState::Open);
}
#[tokio::test]
async fn test_force_close() {
let config = CircuitBreakerConfig::with_name("test").failure_threshold(1);
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
assert_eq!(breaker.state().await, CircuitState::Open);
breaker.force_close().await;
assert_eq!(breaker.state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_force_open() {
let breaker = CircuitBreaker::new(CircuitBreakerConfig::with_name("test"));
assert_eq!(breaker.state().await, CircuitState::Closed);
breaker.force_open().await;
assert_eq!(breaker.state().await, CircuitState::Open);
}
#[tokio::test]
async fn test_metrics() {
let config = CircuitBreakerConfig::with_name("test").failure_threshold(5);
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
breaker.record_success().await;
let metrics = breaker.metrics();
assert_eq!(metrics.name, "test");
assert_eq!(metrics.total_failures, 1);
}
#[tokio::test]
async fn test_registry() {
let registry = CircuitBreakerRegistry::new();
let redis_breaker = CircuitBreaker::new(presets::redis());
registry.register("redis", redis_breaker).await;
let s3_breaker = CircuitBreaker::new(presets::s3());
registry.register("s3", s3_breaker).await;
assert!(registry.get("redis").await.is_some());
assert!(registry.get("s3").await.is_some());
assert!(registry.get("nonexistent").await.is_none());
}
#[tokio::test]
async fn test_registry_get_or_create() {
let registry = CircuitBreakerRegistry::new();
let breaker1 = registry.get_or_create("test").await;
let breaker2 = registry.get_or_create("test").await;
breaker1.record_failure().await;
assert_eq!(breaker1.metrics().total_failures, 1);
assert_eq!(breaker2.metrics().total_failures, 1);
}
}