use parking_lot::RwLock;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, Instant};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub timeout: Duration,
pub half_open_max_requests: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
timeout: Duration::from_secs(60),
half_open_max_requests: 3,
}
}
}
#[derive(Debug, Error)]
pub enum CircuitBreakerError {
#[error("Circuit breaker is open")]
CircuitOpen,
#[error("Operation failed: {0}")]
OperationFailed(String),
#[error("Half-open limit exceeded")]
HalfOpenLimitExceeded,
}
pub struct CircuitBreaker {
state: RwLock<CircuitState>,
failure_count: AtomicU32,
success_count: AtomicU32,
half_open_requests: AtomicU32,
last_failure: RwLock<Option<Instant>>,
last_state_change: RwLock<Instant>,
config: CircuitBreakerConfig,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
state: RwLock::new(CircuitState::Closed),
failure_count: AtomicU32::new(0),
success_count: AtomicU32::new(0),
half_open_requests: AtomicU32::new(0),
last_failure: RwLock::new(None),
last_state_change: RwLock::new(Instant::now()),
config,
}
}
pub async fn call<F, T, E>(&self, operation: F) -> Result<T, CircuitBreakerError>
where
F: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
self.check_state()?;
match operation.await {
Ok(result) => {
self.on_success();
Ok(result)
}
Err(e) => {
self.on_failure();
Err(CircuitBreakerError::OperationFailed(e.to_string()))
}
}
}
pub fn call_sync<F, T, E>(&self, operation: F) -> Result<T, CircuitBreakerError>
where
F: FnOnce() -> Result<T, E>,
E: std::fmt::Display,
{
self.check_state()?;
match operation() {
Ok(result) => {
self.on_success();
Ok(result)
}
Err(e) => {
self.on_failure();
Err(CircuitBreakerError::OperationFailed(e.to_string()))
}
}
}
fn check_state(&self) -> Result<(), CircuitBreakerError> {
let state = *self.state.read();
match state {
CircuitState::Closed => Ok(()),
CircuitState::Open => {
if let Some(last_failure) = *self.last_failure.read() {
if last_failure.elapsed() >= self.config.timeout {
self.transition_to_half_open();
Ok(())
} else {
Err(CircuitBreakerError::CircuitOpen)
}
} else {
Err(CircuitBreakerError::CircuitOpen)
}
}
CircuitState::HalfOpen => {
let current_requests = self.half_open_requests.load(Ordering::SeqCst);
if current_requests < self.config.half_open_max_requests {
self.half_open_requests.fetch_add(1, Ordering::SeqCst);
Ok(())
} else {
Err(CircuitBreakerError::HalfOpenLimitExceeded)
}
}
}
}
fn on_success(&self) {
let state = *self.state.read();
match state {
CircuitState::Closed => {
self.failure_count.store(0, Ordering::SeqCst);
}
CircuitState::HalfOpen => {
let success_count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
if success_count >= self.config.success_threshold {
self.transition_to_closed();
}
}
CircuitState::Open => {
self.transition_to_closed();
}
}
}
fn on_failure(&self) {
let state = *self.state.read();
match state {
CircuitState::Closed => {
let failure_count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
if failure_count >= self.config.failure_threshold {
self.transition_to_open();
}
}
CircuitState::HalfOpen => {
self.transition_to_open();
}
CircuitState::Open => {
*self.last_failure.write() = Some(Instant::now());
}
}
}
fn transition_to_closed(&self) {
*self.state.write() = CircuitState::Closed;
self.failure_count.store(0, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
self.half_open_requests.store(0, Ordering::SeqCst);
*self.last_state_change.write() = Instant::now();
}
fn transition_to_open(&self) {
*self.state.write() = CircuitState::Open;
*self.last_failure.write() = Some(Instant::now());
self.success_count.store(0, Ordering::SeqCst);
self.half_open_requests.store(0, Ordering::SeqCst);
*self.last_state_change.write() = Instant::now();
}
fn transition_to_half_open(&self) {
*self.state.write() = CircuitState::HalfOpen;
self.failure_count.store(0, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
self.half_open_requests.store(0, Ordering::SeqCst);
*self.last_state_change.write() = Instant::now();
}
pub fn reset(&self) {
self.transition_to_closed();
}
pub fn state(&self) -> CircuitState {
*self.state.read()
}
pub fn failure_count(&self) -> u32 {
self.failure_count.load(Ordering::SeqCst)
}
pub fn success_count(&self) -> u32 {
self.success_count.load(Ordering::SeqCst)
}
pub fn time_since_state_change(&self) -> Duration {
self.last_state_change.read().elapsed()
}
pub fn time_since_last_failure(&self) -> Option<Duration> {
self.last_failure.read().map(|instant| instant.elapsed())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tokio::time::sleep;
#[tokio::test]
async fn test_circuit_breaker_starts_closed() {
let cb = CircuitBreaker::new(CircuitBreakerConfig::default());
assert_eq!(cb.state(), CircuitState::Closed);
}
#[tokio::test]
async fn test_successful_operation_keeps_circuit_closed() {
let cb = CircuitBreaker::new(CircuitBreakerConfig::default());
for _ in 0..10 {
let result = cb.call(async { Ok::<_, String>("success") }).await;
assert!(result.is_ok());
assert_eq!(cb.state(), CircuitState::Closed);
}
}
#[tokio::test]
async fn test_circuit_opens_after_threshold_failures() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
};
let cb = CircuitBreaker::new(config);
for _ in 0..2 {
let _ = cb.call(async { Err::<String, _>("error") }).await;
assert_eq!(cb.state(), CircuitState::Closed);
}
let _ = cb.call(async { Err::<String, _>("error") }).await;
assert_eq!(cb.state(), CircuitState::Open);
}
#[tokio::test]
async fn test_circuit_rejects_operations_when_open() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout: Duration::from_secs(10),
..Default::default()
};
let cb = CircuitBreaker::new(config);
for _ in 0..2 {
let _ = cb.call(async { Err::<String, _>("error") }).await;
}
assert_eq!(cb.state(), CircuitState::Open);
let result = cb.call(async { Ok::<_, String>("success") }).await;
assert!(matches!(result, Err(CircuitBreakerError::CircuitOpen)));
}
#[tokio::test]
async fn test_circuit_transitions_to_half_open_after_timeout() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout: Duration::from_millis(100),
..Default::default()
};
let cb = CircuitBreaker::new(config);
for _ in 0..2 {
let _ = cb.call(async { Err::<String, _>("error") }).await;
}
assert_eq!(cb.state(), CircuitState::Open);
sleep(Duration::from_millis(150)).await;
let result = cb.call(async { Ok::<_, String>("success") }).await;
assert!(result.is_ok());
assert_eq!(cb.state(), CircuitState::HalfOpen);
}
#[tokio::test]
async fn test_half_open_closes_after_success_threshold() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
success_threshold: 2,
timeout: Duration::from_millis(100),
..Default::default()
};
let cb = CircuitBreaker::new(config);
for _ in 0..2 {
let _ = cb.call(async { Err::<String, _>("error") }).await;
}
sleep(Duration::from_millis(150)).await;
let _ = cb.call(async { Ok::<_, String>("success") }).await;
assert_eq!(cb.state(), CircuitState::HalfOpen);
let _ = cb.call(async { Ok::<_, String>("success") }).await;
assert_eq!(cb.state(), CircuitState::Closed);
}
#[tokio::test]
async fn test_half_open_reopens_on_failure() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout: Duration::from_millis(100),
..Default::default()
};
let cb = CircuitBreaker::new(config);
for _ in 0..2 {
let _ = cb.call(async { Err::<String, _>("error") }).await;
}
sleep(Duration::from_millis(150)).await;
let _ = cb.call(async { Ok::<_, String>("success") }).await;
assert_eq!(cb.state(), CircuitState::HalfOpen);
let _ = cb.call(async { Err::<String, _>("error") }).await;
assert_eq!(cb.state(), CircuitState::Open);
}
#[tokio::test]
async fn test_reset_closes_circuit() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
..Default::default()
};
let cb = CircuitBreaker::new(config);
for _ in 0..2 {
let _ = cb.call(async { Err::<String, _>("error") }).await;
}
assert_eq!(cb.state(), CircuitState::Open);
cb.reset();
assert_eq!(cb.state(), CircuitState::Closed);
assert_eq!(cb.failure_count(), 0);
}
#[tokio::test]
async fn test_synchronous_operations() {
let cb = CircuitBreaker::new(CircuitBreakerConfig::default());
let result = cb.call_sync(|| Ok::<_, String>("success"));
assert!(result.is_ok());
let result = cb.call_sync(|| Err::<String, _>("error"));
assert!(result.is_err());
}
#[tokio::test]
async fn test_concurrent_operations() {
let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig::default()));
let mut handles = vec![];
for i in 0..10 {
let cb_clone = Arc::clone(&cb);
let handle = tokio::spawn(async move {
cb_clone
.call(async move {
sleep(Duration::from_millis(10)).await;
Ok::<_, String>(format!("success {}", i))
})
.await
});
handles.push(handle);
}
for handle in handles {
let result = handle.await.unwrap();
assert!(result.is_ok());
}
assert_eq!(cb.state(), CircuitState::Closed);
}
#[tokio::test]
async fn test_half_open_max_requests_limit() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout: Duration::from_millis(100),
half_open_max_requests: 1,
success_threshold: 10, ..Default::default()
};
let cb = Arc::new(CircuitBreaker::new(config));
for _ in 0..2 {
let _ = cb.call(async { Err::<String, _>("error") }).await;
}
assert_eq!(cb.state(), CircuitState::Open);
sleep(Duration::from_millis(150)).await;
let result = cb.call(async { Ok::<_, String>("success") }).await;
assert!(result.is_ok());
assert_eq!(cb.state(), CircuitState::HalfOpen);
let result = cb.call(async { Ok::<_, String>("success") }).await;
assert!(result.is_ok());
let result = cb.call(async { Ok::<_, String>("success") }).await;
assert!(matches!(
result,
Err(CircuitBreakerError::HalfOpenLimitExceeded)
));
}
#[tokio::test]
async fn test_failure_count_increments() {
let config = CircuitBreakerConfig {
failure_threshold: 5,
..Default::default()
};
let cb = CircuitBreaker::new(config);
assert_eq!(cb.failure_count(), 0);
for i in 1..=3 {
let _ = cb.call(async { Err::<String, _>("error") }).await;
assert_eq!(cb.failure_count(), i);
}
}
#[tokio::test]
async fn test_success_resets_failure_count_when_closed() {
let config = CircuitBreakerConfig {
failure_threshold: 5,
..Default::default()
};
let cb = CircuitBreaker::new(config);
for _ in 0..3 {
let _ = cb.call(async { Err::<String, _>("error") }).await;
}
assert_eq!(cb.failure_count(), 3);
let _ = cb.call(async { Ok::<_, String>("success") }).await;
assert_eq!(cb.failure_count(), 0);
}
#[tokio::test]
async fn test_time_tracking() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
..Default::default()
};
let cb = CircuitBreaker::new(config);
for _ in 0..2 {
let _ = cb.call(async { Err::<String, _>("error") }).await;
}
sleep(Duration::from_millis(50)).await;
let time_since_change = cb.time_since_state_change();
assert!(time_since_change >= Duration::from_millis(50));
let time_since_failure = cb.time_since_last_failure();
assert!(time_since_failure.is_some());
assert!(time_since_failure.unwrap() >= Duration::from_millis(50));
}
}