use std::sync::atomic::{AtomicU32, AtomicU64, AtomicU8, Ordering};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
const DEFAULT_FAILURE_THRESHOLD: u32 = 3;
const DEFAULT_OPEN_DURATION_MS: u64 = 30000;
const DEFAULT_HALF_OPEN_MAX: u32 = 1;
const STATE_CLOSED: u8 = 0;
const STATE_OPEN: u8 = 1;
const STATE_HALF_OPEN: u8 = 2;
#[derive(Debug, Clone)]
pub struct CircuitConfig {
pub failure_threshold: u32,
pub open_duration_ms: u64,
pub half_open_max: u32,
}
impl Default for CircuitConfig {
fn default() -> Self {
Self {
failure_threshold: DEFAULT_FAILURE_THRESHOLD,
open_duration_ms: DEFAULT_OPEN_DURATION_MS,
half_open_max: DEFAULT_HALF_OPEN_MAX,
}
}
}
impl CircuitConfig {
pub fn strict() -> Self {
Self {
failure_threshold: 2,
open_duration_ms: 60000,
half_open_max: 1,
}
}
pub fn lenient() -> Self {
Self {
failure_threshold: 5,
open_duration_ms: 15000,
half_open_max: 2,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open {
until: Instant,
},
HalfOpen,
}
pub struct CircuitBreaker {
state: AtomicU8,
failure_count: AtomicU32,
opened_at_ms: AtomicU64,
half_open_requests: AtomicU32,
config: CircuitConfig,
}
impl CircuitBreaker {
pub fn new(config: CircuitConfig) -> Self {
Self {
state: AtomicU8::new(STATE_CLOSED),
failure_count: AtomicU32::new(0),
opened_at_ms: AtomicU64::new(0),
half_open_requests: AtomicU32::new(0),
config,
}
}
pub fn with_defaults() -> Self {
Self::new(CircuitConfig::default())
}
pub fn can_execute(&self) -> bool {
let current_state = self.state.load(Ordering::SeqCst);
match current_state {
STATE_CLOSED => true,
STATE_OPEN => {
let opened_at = self.opened_at_ms.load(Ordering::SeqCst);
let now_ms = current_time_ms();
let elapsed_ms = now_ms.saturating_sub(opened_at);
if elapsed_ms >= self.config.open_duration_ms {
if self
.state
.compare_exchange(
STATE_OPEN,
STATE_HALF_OPEN,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
self.half_open_requests.store(0, Ordering::SeqCst);
}
self.half_open_requests.fetch_add(1, Ordering::SeqCst)
< self.config.half_open_max
} else {
false
}
}
STATE_HALF_OPEN => {
self.half_open_requests.fetch_add(1, Ordering::SeqCst) < self.config.half_open_max
}
_ => false,
}
}
pub fn record_success(&self) {
let current_state = self.state.load(Ordering::SeqCst);
self.failure_count.store(0, Ordering::SeqCst);
if current_state == STATE_HALF_OPEN {
self.state.store(STATE_CLOSED, Ordering::SeqCst);
}
}
pub fn record_failure(&self) {
let current_state = self.state.load(Ordering::SeqCst);
match current_state {
STATE_CLOSED => {
let failures = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
if failures >= self.config.failure_threshold {
self.open_circuit();
}
}
STATE_HALF_OPEN => {
self.open_circuit();
}
STATE_OPEN => {
}
_ => {}
}
}
pub fn record_rate_limited(&self) {
self.open_circuit();
}
pub fn record_service_unavailable(&self) {
self.open_circuit();
}
pub fn state(&self) -> CircuitState {
let current_state = self.state.load(Ordering::SeqCst);
match current_state {
STATE_CLOSED => CircuitState::Closed,
STATE_OPEN => {
let opened_at = self.opened_at_ms.load(Ordering::SeqCst);
let now_ms = current_time_ms();
let remaining_ms = self
.config
.open_duration_ms
.saturating_sub(now_ms.saturating_sub(opened_at));
CircuitState::Open {
until: Instant::now() + Duration::from_millis(remaining_ms),
}
}
STATE_HALF_OPEN => CircuitState::HalfOpen,
_ => CircuitState::Closed, }
}
pub fn is_open(&self) -> bool {
self.state.load(Ordering::SeqCst) == STATE_OPEN
}
pub fn is_closed(&self) -> bool {
self.state.load(Ordering::SeqCst) == STATE_CLOSED
}
pub fn failure_count(&self) -> u32 {
self.failure_count.load(Ordering::SeqCst)
}
pub fn reset(&self) {
self.state.store(STATE_CLOSED, Ordering::SeqCst);
self.failure_count.store(0, Ordering::SeqCst);
self.half_open_requests.store(0, Ordering::SeqCst);
}
fn open_circuit(&self) {
self.state.store(STATE_OPEN, Ordering::SeqCst);
self.opened_at_ms.store(current_time_ms(), Ordering::SeqCst);
self.half_open_requests.store(0, Ordering::SeqCst);
}
}
fn current_time_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_default_config() {
let config = CircuitConfig::default();
assert_eq!(config.failure_threshold, 3);
assert_eq!(config.open_duration_ms, 30000);
assert_eq!(config.half_open_max, 1);
}
#[test]
fn test_initial_state_is_closed() {
let circuit = CircuitBreaker::with_defaults();
assert!(circuit.is_closed());
assert!(!circuit.is_open());
assert!(circuit.can_execute());
}
#[test]
fn test_opens_after_failure_threshold() {
let circuit = CircuitBreaker::new(CircuitConfig {
failure_threshold: 3,
open_duration_ms: 30000,
half_open_max: 1,
});
circuit.record_failure();
assert!(circuit.is_closed());
assert_eq!(circuit.failure_count(), 1);
circuit.record_failure();
assert!(circuit.is_closed());
assert_eq!(circuit.failure_count(), 2);
circuit.record_failure();
assert!(circuit.is_open());
assert!(!circuit.can_execute());
}
#[test]
fn test_success_resets_failure_count() {
let circuit = CircuitBreaker::new(CircuitConfig {
failure_threshold: 3,
open_duration_ms: 30000,
half_open_max: 1,
});
circuit.record_failure();
circuit.record_failure();
assert_eq!(circuit.failure_count(), 2);
circuit.record_success();
assert_eq!(circuit.failure_count(), 0);
assert!(circuit.is_closed());
}
#[test]
fn test_rate_limited_immediately_opens() {
let circuit = CircuitBreaker::with_defaults();
assert!(circuit.is_closed());
circuit.record_rate_limited();
assert!(circuit.is_open());
}
#[test]
fn test_service_unavailable_immediately_opens() {
let circuit = CircuitBreaker::with_defaults();
assert!(circuit.is_closed());
circuit.record_service_unavailable();
assert!(circuit.is_open());
}
#[test]
fn test_transitions_to_half_open_after_timeout() {
let circuit = CircuitBreaker::new(CircuitConfig {
failure_threshold: 1,
open_duration_ms: 50, half_open_max: 1,
});
circuit.record_failure();
assert!(circuit.is_open());
assert!(!circuit.can_execute());
thread::sleep(Duration::from_millis(60));
assert!(circuit.can_execute());
assert!(matches!(circuit.state(), CircuitState::HalfOpen));
}
#[test]
fn test_half_open_success_closes_circuit() {
let circuit = CircuitBreaker::new(CircuitConfig {
failure_threshold: 1,
open_duration_ms: 10,
half_open_max: 1,
});
circuit.record_failure();
thread::sleep(Duration::from_millis(20));
assert!(circuit.can_execute());
circuit.record_success();
assert!(circuit.is_closed());
assert!(circuit.can_execute());
}
#[test]
fn test_half_open_failure_reopens_circuit() {
let circuit = CircuitBreaker::new(CircuitConfig {
failure_threshold: 1,
open_duration_ms: 10,
half_open_max: 1,
});
circuit.record_failure();
thread::sleep(Duration::from_millis(20));
assert!(circuit.can_execute());
circuit.record_failure();
assert!(circuit.is_open());
}
#[test]
fn test_reset() {
let circuit = CircuitBreaker::with_defaults();
circuit.record_rate_limited();
assert!(circuit.is_open());
circuit.reset();
assert!(circuit.is_closed());
assert_eq!(circuit.failure_count(), 0);
assert!(circuit.can_execute());
}
#[test]
fn test_half_open_limits_requests() {
let circuit = CircuitBreaker::new(CircuitConfig {
failure_threshold: 1,
open_duration_ms: 10,
half_open_max: 2,
});
circuit.record_failure();
thread::sleep(Duration::from_millis(20));
assert!(circuit.can_execute());
assert!(circuit.can_execute());
assert!(!circuit.can_execute());
}
}