use std::sync::atomic::{AtomicI64, AtomicU32, AtomicU8, Ordering};
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed = 0,
Open = 1,
HalfOpen = 2,
}
impl From<u8> for CircuitState {
fn from(value: u8) -> Self {
match value {
0 => CircuitState::Closed,
1 => CircuitState::Open,
2 => CircuitState::HalfOpen,
_ => CircuitState::Closed, }
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub timeout_ms: i64,
pub success_threshold: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
timeout_ms: 30_000, success_threshold: 2,
}
}
}
pub struct CircuitBreaker {
failure_count: AtomicU32,
success_count: AtomicU32,
last_failure: AtomicI64,
state: AtomicU8,
config: CircuitBreakerConfig,
name: String,
}
impl CircuitBreaker {
pub fn new(name: String, config: CircuitBreakerConfig) -> Self {
info!(
name = %name,
failure_threshold = config.failure_threshold,
timeout_ms = config.timeout_ms,
success_threshold = config.success_threshold,
"Creating circuit breaker"
);
Self {
failure_count: AtomicU32::new(0),
success_count: AtomicU32::new(0),
last_failure: AtomicI64::new(0),
state: AtomicU8::new(CircuitState::Closed as u8),
config,
name,
}
}
pub fn should_allow_request(&self) -> bool {
let current_state = CircuitState::from(self.state.load(Ordering::Relaxed));
let now_ms = chrono::Utc::now().timestamp_millis();
match current_state {
CircuitState::Closed => {
debug!(name = %self.name, state = "closed", "Request allowed");
true
}
CircuitState::Open => {
let last_failure = self.last_failure.load(Ordering::Relaxed);
let time_since_failure = now_ms - last_failure;
if time_since_failure > self.config.timeout_ms {
self.state
.store(CircuitState::HalfOpen as u8, Ordering::Relaxed);
self.success_count.store(0, Ordering::Relaxed);
info!(
name = %self.name,
timeout_ms = self.config.timeout_ms,
time_since_failure_ms = time_since_failure,
"Circuit breaker transitioning to half-open state"
);
true
} else {
debug!(
name = %self.name,
state = "open",
remaining_timeout_ms = self.config.timeout_ms - time_since_failure,
"Request blocked by open circuit"
);
false
}
}
CircuitState::HalfOpen => {
debug!(name = %self.name, state = "half_open", "Test request allowed");
true
}
}
}
pub fn record_success(&self) {
let current_state = CircuitState::from(self.state.load(Ordering::Relaxed));
match current_state {
CircuitState::Closed => {
self.failure_count.store(0, Ordering::Relaxed);
debug!(name = %self.name, "Success recorded in closed state");
}
CircuitState::HalfOpen => {
let successes = self.success_count.fetch_add(1, Ordering::Relaxed) + 1;
if successes >= self.config.success_threshold {
self.state
.store(CircuitState::Closed as u8, Ordering::Relaxed);
self.failure_count.store(0, Ordering::Relaxed);
self.success_count.store(0, Ordering::Relaxed);
info!(
name = %self.name,
success_threshold = self.config.success_threshold,
"Circuit breaker closed after successful recovery"
);
} else {
debug!(
name = %self.name,
successes = successes,
threshold = self.config.success_threshold,
"Recovery in progress"
);
}
}
CircuitState::Open => {
warn!(name = %self.name, "Success recorded while circuit is open - unexpected state");
}
}
}
pub fn record_failure(&self, error: &str) {
let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
let now_ms = chrono::Utc::now().timestamp_millis();
self.last_failure.store(now_ms, Ordering::Relaxed);
let current_state = CircuitState::from(self.state.load(Ordering::Relaxed));
match current_state {
CircuitState::Closed => {
if failures >= self.config.failure_threshold {
self.state
.store(CircuitState::Open as u8, Ordering::Relaxed);
error!(
name = %self.name,
failure_count = failures,
threshold = self.config.failure_threshold,
timeout_ms = self.config.timeout_ms,
error = %error,
"Circuit breaker opened due to failure threshold"
);
} else {
warn!(
name = %self.name,
failure_count = failures,
threshold = self.config.failure_threshold,
error = %error,
"Failure recorded in closed state"
);
}
}
CircuitState::HalfOpen => {
self.state
.store(CircuitState::Open as u8, Ordering::Relaxed);
error!(
name = %self.name,
error = %error,
"Circuit breaker reopened after failed recovery attempt"
);
}
CircuitState::Open => {
debug!(
name = %self.name,
failure_count = failures,
error = %error,
"Failure recorded while circuit is already open"
);
}
}
}
pub fn get_state(&self) -> CircuitState {
CircuitState::from(self.state.load(Ordering::Relaxed))
}
pub fn get_failure_count(&self) -> u32 {
self.failure_count.load(Ordering::Relaxed)
}
pub fn get_success_count(&self) -> u32 {
self.success_count.load(Ordering::Relaxed)
}
#[cfg(test)]
pub fn time_since_last_failure_ms(&self) -> i64 {
let now_ms = chrono::Utc::now().timestamp_millis();
let last_failure = self.last_failure.load(Ordering::Relaxed);
now_ms - last_failure
}
#[cfg(test)]
pub fn force_closed(&self) {
self.state
.store(CircuitState::Closed as u8, Ordering::Relaxed);
self.failure_count.store(0, Ordering::Relaxed);
self.success_count.store(0, Ordering::Relaxed);
}
#[cfg(test)]
pub fn force_open(&self) {
self.state
.store(CircuitState::Open as u8, Ordering::Relaxed);
self.last_failure
.store(chrono::Utc::now().timestamp_millis(), Ordering::Relaxed);
}
}
impl CircuitBreaker {
pub async fn execute<F, Fut, T>(&self, operation: F) -> Result<T, CircuitBreakerError>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>>,
{
if !self.should_allow_request() {
return Err(CircuitBreakerError::CircuitOpen);
}
match operation().await {
Ok(result) => {
self.record_success();
Ok(result)
}
Err(error) => {
self.record_failure(&error.to_string());
Err(CircuitBreakerError::OperationFailed(error))
}
}
}
}
#[cfg(test)]
pub struct CircuitBreakerWrapper<T> {
circuit_breaker: CircuitBreaker,
_phantom: std::marker::PhantomData<T>,
}
#[cfg(test)]
impl<T> CircuitBreakerWrapper<T> {
pub fn new(name: String, config: CircuitBreakerConfig) -> Self {
Self {
circuit_breaker: CircuitBreaker::new(name, config),
_phantom: std::marker::PhantomData,
}
}
pub async fn execute<F, Fut>(&self, operation: F) -> Result<T, CircuitBreakerError>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>>,
{
self.circuit_breaker.execute(operation).await
}
pub fn get_state(&self) -> CircuitState {
self.circuit_breaker.get_state()
}
pub fn get_failure_count(&self) -> u32 {
self.circuit_breaker.get_failure_count()
}
pub fn force_closed(&self) {
self.circuit_breaker.force_closed();
}
pub fn force_open(&self) {
self.circuit_breaker.force_open();
}
}
#[derive(Debug, thiserror::Error)]
pub enum CircuitBreakerError {
#[error("Circuit breaker is open, requests are blocked")]
CircuitOpen,
#[error("Operation failed: {0}")]
OperationFailed(#[from] Box<dyn std::error::Error + Send + Sync>),
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tokio::time::{sleep, Duration};
#[tokio::test]
async fn test_circuit_breaker_closed_state() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
timeout_ms: 1000,
success_threshold: 2,
};
let cb = CircuitBreaker::new("test".to_string(), config);
assert_eq!(cb.get_state(), CircuitState::Closed);
assert!(cb.should_allow_request());
cb.record_success();
assert_eq!(cb.get_state(), CircuitState::Closed);
assert_eq!(cb.get_failure_count(), 0);
}
#[tokio::test]
async fn test_circuit_breaker_open_on_failures() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
timeout_ms: 1000,
success_threshold: 2,
};
let cb = CircuitBreaker::new("test".to_string(), config);
cb.record_failure("error 1");
assert_eq!(cb.get_state(), CircuitState::Closed);
assert!(cb.should_allow_request());
cb.record_failure("error 2");
assert_eq!(cb.get_state(), CircuitState::Closed);
assert!(cb.should_allow_request());
cb.record_failure("error 3");
assert_eq!(cb.get_state(), CircuitState::Open);
assert!(!cb.should_allow_request());
assert_eq!(cb.get_failure_count(), 3);
}
#[tokio::test]
async fn test_circuit_breaker_half_open_transition() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout_ms: 100, success_threshold: 2,
};
let cb = CircuitBreaker::new("test".to_string(), config);
cb.record_failure("error 1");
cb.record_failure("error 2");
assert_eq!(cb.get_state(), CircuitState::Open);
assert!(!cb.should_allow_request());
sleep(Duration::from_millis(150)).await;
assert!(cb.should_allow_request());
assert_eq!(cb.get_state(), CircuitState::HalfOpen);
}
#[tokio::test]
async fn test_circuit_breaker_recovery_to_closed() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout_ms: 50,
success_threshold: 2,
};
let cb = CircuitBreaker::new("test".to_string(), config);
cb.record_failure("error 1");
cb.record_failure("error 2");
assert_eq!(cb.get_state(), CircuitState::Open);
sleep(Duration::from_millis(60)).await;
assert!(cb.should_allow_request());
assert_eq!(cb.get_state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.get_state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.get_state(), CircuitState::Closed);
assert_eq!(cb.get_failure_count(), 0);
}
#[tokio::test]
async fn test_circuit_breaker_half_open_failure() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout_ms: 50,
success_threshold: 2,
};
let cb = CircuitBreaker::new("test".to_string(), config);
cb.record_failure("error 1");
cb.record_failure("error 2");
sleep(Duration::from_millis(60)).await;
assert!(cb.should_allow_request());
assert_eq!(cb.get_state(), CircuitState::HalfOpen);
cb.record_failure("recovery failed");
assert_eq!(cb.get_state(), CircuitState::Open);
assert!(!cb.should_allow_request());
}
#[tokio::test]
async fn test_circuit_breaker_execute_method() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout_ms: 100,
success_threshold: 1,
};
let cb = CircuitBreaker::new("test".to_string(), config);
let result = cb.execute(|| async { Ok("success".to_string()) }).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
let result: Result<String, _> = cb.execute(|| async { Err("error 1".into()) }).await;
assert!(result.is_err());
let result: Result<String, _> = cb.execute(|| async { Err("error 2".into()) }).await;
assert!(result.is_err());
let result: Result<String, _> = cb
.execute(|| async { Ok("should not execute".to_string()) })
.await;
match result {
Err(CircuitBreakerError::CircuitOpen) => {} other => panic!("Expected CircuitOpen error, got: {:?}", other),
}
}
#[tokio::test]
async fn test_circuit_breaker_wrapper() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout_ms: 100,
success_threshold: 1,
};
let wrapper: CircuitBreakerWrapper<String> =
CircuitBreakerWrapper::new("test".to_string(), config);
let result = wrapper
.execute(|| async { Ok("success".to_string()) })
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
let result = wrapper.execute(|| async { Err("error 1".into()) }).await;
assert!(result.is_err());
let result = wrapper.execute(|| async { Err("error 2".into()) }).await;
assert!(result.is_err());
let result = wrapper
.execute(|| async { Ok("should not execute".to_string()) })
.await;
match result {
Err(CircuitBreakerError::CircuitOpen) => {} other => panic!("Expected CircuitOpen error, got: {:?}", other),
}
}
#[tokio::test]
async fn test_time_since_last_failure() {
let config = CircuitBreakerConfig::default();
let cb = CircuitBreaker::new("test".to_string(), config);
cb.record_failure("test error");
let time_since = cb.time_since_last_failure_ms();
assert!(time_since >= 0);
assert!(time_since < 1000); }
#[test]
fn test_circuit_state_conversion() {
assert_eq!(CircuitState::from(0), CircuitState::Closed);
assert_eq!(CircuitState::from(1), CircuitState::Open);
assert_eq!(CircuitState::from(2), CircuitState::HalfOpen);
assert_eq!(CircuitState::from(99), CircuitState::Closed); }
#[tokio::test]
async fn test_concurrent_access() {
let config = CircuitBreakerConfig {
failure_threshold: 5,
timeout_ms: 100,
success_threshold: 2,
};
let cb = Arc::new(CircuitBreaker::new("concurrent_test".to_string(), config));
let mut handles = Vec::new();
for i in 0..10 {
let cb_clone = cb.clone();
let handle = tokio::spawn(async move {
cb_clone.record_failure(&format!("concurrent error {}", i));
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
assert_eq!(cb.get_state(), CircuitState::Open);
assert!(cb.get_failure_count() >= 5);
}
}