use std::sync::atomic::{AtomicU32, AtomicU64, AtomicU8, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum CircuitState {
Closed = 0,
Open = 1,
HalfOpen = 2,
}
impl CircuitState {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::Closed => "closed",
Self::Open => "open",
Self::HalfOpen => "half_open",
}
}
}
impl From<u8> for CircuitState {
fn from(value: u8) -> Self {
match value {
0 => Self::Closed,
1 => Self::Open,
2 => Self::HalfOpen,
_ => Self::Closed,
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub reset_timeout: Duration,
pub half_open_requests: u32,
pub name: String,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
reset_timeout: Duration::from_secs(30),
half_open_requests: 3,
name: "inference".to_string(),
}
}
}
impl CircuitBreakerConfig {
#[must_use]
pub fn new(failure_threshold: u32, reset_timeout: Duration) -> Self {
Self {
failure_threshold,
reset_timeout,
..Default::default()
}
}
#[must_use]
pub fn strict() -> Self {
Self {
failure_threshold: 3,
reset_timeout: Duration::from_secs(60),
half_open_requests: 5,
name: "inference".to_string(),
}
}
#[must_use]
pub fn lenient() -> Self {
Self {
failure_threshold: 10,
reset_timeout: Duration::from_secs(15),
half_open_requests: 2,
name: "inference".to_string(),
}
}
#[must_use]
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
}
#[derive(Debug)]
pub struct CircuitBreaker {
state: AtomicU8,
failure_count: AtomicU32,
half_open_successes: AtomicU32,
opened_at: AtomicU64,
start_instant: Instant,
config: CircuitBreakerConfig,
total_requests: AtomicU64,
rejected_requests: AtomicU64,
successful_requests: AtomicU64,
failed_requests: AtomicU64,
state_transitions: AtomicU64,
}
impl CircuitBreaker {
#[must_use]
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
state: AtomicU8::new(CircuitState::Closed as u8),
failure_count: AtomicU32::new(0),
half_open_successes: AtomicU32::new(0),
opened_at: AtomicU64::new(0),
start_instant: Instant::now(),
config,
total_requests: AtomicU64::new(0),
rejected_requests: AtomicU64::new(0),
successful_requests: AtomicU64::new(0),
failed_requests: AtomicU64::new(0),
state_transitions: AtomicU64::new(0),
}
}
#[must_use]
pub fn with_defaults() -> Self {
Self::new(CircuitBreakerConfig::default())
}
#[must_use]
pub fn state(&self) -> CircuitState {
CircuitState::from(self.state.load(Ordering::Acquire))
}
#[must_use]
pub fn name(&self) -> &str {
&self.config.name
}
#[must_use]
pub fn failure_count(&self) -> u32 {
self.failure_count.load(Ordering::Relaxed)
}
#[must_use]
pub fn allow_request(&self) -> bool {
self.total_requests.fetch_add(1, Ordering::Relaxed);
let current_state = self.state();
match current_state {
CircuitState::Closed => true,
CircuitState::Open => {
if self.should_attempt_reset() {
self.transition_to(CircuitState::HalfOpen);
true
} else {
self.rejected_requests.fetch_add(1, Ordering::Relaxed);
false
}
},
CircuitState::HalfOpen => {
true
},
}
}
pub fn record_success(&self) {
self.successful_requests.fetch_add(1, Ordering::Relaxed);
let current_state = self.state();
match current_state {
CircuitState::Closed => {
self.failure_count.store(0, Ordering::Relaxed);
},
CircuitState::HalfOpen => {
let successes = self.half_open_successes.fetch_add(1, Ordering::Relaxed) + 1;
if successes >= self.config.half_open_requests {
self.transition_to(CircuitState::Closed);
}
},
CircuitState::Open => {
},
}
}
pub fn record_failure(&self) {
self.failed_requests.fetch_add(1, Ordering::Relaxed);
let current_state = self.state();
match current_state {
CircuitState::Closed => {
let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
if failures >= self.config.failure_threshold {
self.transition_to(CircuitState::Open);
}
},
CircuitState::HalfOpen => {
self.transition_to(CircuitState::Open);
},
CircuitState::Open => {
},
}
}
pub fn reset(&self) {
self.transition_to(CircuitState::Closed);
self.failure_count.store(0, Ordering::Relaxed);
self.half_open_successes.store(0, Ordering::Relaxed);
}
#[must_use]
pub fn metrics(&self) -> CircuitBreakerMetrics {
CircuitBreakerMetrics {
name: self.config.name.clone(),
state: self.state(),
failure_count: self.failure_count.load(Ordering::Relaxed),
total_requests: self.total_requests.load(Ordering::Relaxed),
rejected_requests: self.rejected_requests.load(Ordering::Relaxed),
successful_requests: self.successful_requests.load(Ordering::Relaxed),
failed_requests: self.failed_requests.load(Ordering::Relaxed),
state_transitions: self.state_transitions.load(Ordering::Relaxed),
}
}
#[must_use]
pub fn render_prometheus_metrics(&self) -> String {
let metrics = self.metrics();
let name = &metrics.name;
format!(
r#"# HELP infernum_circuit_breaker_state Current circuit breaker state (0=closed, 1=open, 2=half_open)
# TYPE infernum_circuit_breaker_state gauge
infernum_circuit_breaker_state{{name="{name}"}} {}
# HELP infernum_circuit_breaker_failures Current consecutive failure count
# TYPE infernum_circuit_breaker_failures gauge
infernum_circuit_breaker_failures{{name="{name}"}} {}
# HELP infernum_circuit_breaker_requests_total Total requests through circuit breaker
# TYPE infernum_circuit_breaker_requests_total counter
infernum_circuit_breaker_requests_total{{name="{name}"}} {}
# HELP infernum_circuit_breaker_rejected_total Requests rejected due to open circuit
# TYPE infernum_circuit_breaker_rejected_total counter
infernum_circuit_breaker_rejected_total{{name="{name}"}} {}
# HELP infernum_circuit_breaker_transitions_total State transitions
# TYPE infernum_circuit_breaker_transitions_total counter
infernum_circuit_breaker_transitions_total{{name="{name}"}} {}
"#,
metrics.state as u8,
metrics.failure_count,
metrics.total_requests,
metrics.rejected_requests,
metrics.state_transitions,
)
}
fn should_attempt_reset(&self) -> bool {
let opened_at = self.opened_at.load(Ordering::Acquire);
if opened_at == 0 {
return false;
}
let now = self.start_instant.elapsed().as_millis() as u64;
let elapsed = now.saturating_sub(opened_at);
elapsed >= self.config.reset_timeout.as_millis() as u64
}
fn transition_to(&self, new_state: CircuitState) {
let old_state = self.state.swap(new_state as u8, Ordering::AcqRel);
if old_state != new_state as u8 {
self.state_transitions.fetch_add(1, Ordering::Relaxed);
match new_state {
CircuitState::Open => {
let now = self.start_instant.elapsed().as_millis() as u64;
self.opened_at.store(now.max(1), Ordering::Release);
},
CircuitState::Closed => {
self.failure_count.store(0, Ordering::Relaxed);
self.half_open_successes.store(0, Ordering::Relaxed);
self.opened_at.store(0, Ordering::Relaxed);
},
CircuitState::HalfOpen => {
self.half_open_successes.store(0, Ordering::Relaxed);
},
}
tracing::info!(
circuit_breaker = %self.config.name,
old_state = CircuitState::from(old_state).as_str(),
new_state = new_state.as_str(),
"Circuit breaker state transition"
);
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerMetrics {
pub name: String,
pub state: CircuitState,
pub failure_count: u32,
pub total_requests: u64,
pub rejected_requests: u64,
pub successful_requests: u64,
pub failed_requests: u64,
pub state_transitions: u64,
}
#[derive(Debug, Clone)]
pub struct CircuitOpenError {
pub circuit_name: String,
pub retry_after: Option<Duration>,
}
impl std::fmt::Display for CircuitOpenError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "circuit breaker '{}' is open", self.circuit_name)?;
if let Some(retry_after) = self.retry_after {
write!(f, ", retry after {:?}", retry_after)?;
}
Ok(())
}
}
impl std::error::Error for CircuitOpenError {}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_circuit_state_default() {
let breaker = CircuitBreaker::with_defaults();
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn test_circuit_allows_requests_when_closed() {
let breaker = CircuitBreaker::with_defaults();
assert!(breaker.allow_request());
assert!(breaker.allow_request());
assert!(breaker.allow_request());
}
#[test]
fn test_circuit_opens_after_failures() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
reset_timeout: Duration::from_secs(30),
half_open_requests: 2,
name: "test".to_string(),
};
let breaker = CircuitBreaker::new(config);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Closed);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Closed);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
}
#[test]
fn test_circuit_rejects_when_open() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
reset_timeout: Duration::from_secs(60), half_open_requests: 1,
name: "test".to_string(),
};
let breaker = CircuitBreaker::new(config);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
assert!(!breaker.allow_request());
assert!(!breaker.allow_request());
}
#[test]
fn test_success_resets_failure_count() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
reset_timeout: Duration::from_secs(30),
half_open_requests: 2,
name: "test".to_string(),
};
let breaker = CircuitBreaker::new(config);
breaker.record_failure();
breaker.record_failure();
assert_eq!(breaker.failure_count(), 2);
breaker.record_success();
assert_eq!(breaker.failure_count(), 0);
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn test_half_open_transitions_to_closed_on_success() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
reset_timeout: Duration::from_millis(50), half_open_requests: 2,
name: "test".to_string(),
};
let breaker = CircuitBreaker::new(config);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
thread::sleep(Duration::from_millis(100));
assert!(breaker.allow_request());
assert_eq!(breaker.state(), CircuitState::HalfOpen);
breaker.record_success();
assert_eq!(breaker.state(), CircuitState::HalfOpen);
breaker.record_success();
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn test_half_open_returns_to_open_on_failure() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
reset_timeout: Duration::from_millis(50),
half_open_requests: 3,
name: "test".to_string(),
};
let breaker = CircuitBreaker::new(config);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
thread::sleep(Duration::from_millis(100));
assert!(breaker.allow_request());
assert_eq!(breaker.state(), CircuitState::HalfOpen);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
}
#[test]
fn test_manual_reset() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
reset_timeout: Duration::from_secs(60),
half_open_requests: 1,
name: "test".to_string(),
};
let breaker = CircuitBreaker::new(config);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
breaker.reset();
assert_eq!(breaker.state(), CircuitState::Closed);
assert_eq!(breaker.failure_count(), 0);
}
#[test]
fn test_metrics() {
let breaker = CircuitBreaker::with_defaults();
let _ = breaker.allow_request();
breaker.record_success();
let _ = breaker.allow_request();
breaker.record_failure();
let metrics = breaker.metrics();
assert_eq!(metrics.total_requests, 2);
assert_eq!(metrics.successful_requests, 1);
assert_eq!(metrics.failed_requests, 1);
}
#[test]
fn test_prometheus_metrics_output() {
let breaker = CircuitBreaker::with_defaults();
let output = breaker.render_prometheus_metrics();
assert!(output.contains("infernum_circuit_breaker_state"));
assert!(output.contains("infernum_circuit_breaker_failures"));
assert!(output.contains("infernum_circuit_breaker_requests_total"));
}
#[test]
fn test_config_presets() {
let strict = CircuitBreakerConfig::strict();
assert_eq!(strict.failure_threshold, 3);
assert_eq!(strict.reset_timeout, Duration::from_secs(60));
let lenient = CircuitBreakerConfig::lenient();
assert_eq!(lenient.failure_threshold, 10);
assert_eq!(lenient.reset_timeout, Duration::from_secs(15));
}
#[test]
fn test_config_with_name() {
let config = CircuitBreakerConfig::default().with_name("gpu-inference");
assert_eq!(config.name, "gpu-inference");
}
#[test]
fn test_circuit_state_as_str() {
assert_eq!(CircuitState::Closed.as_str(), "closed");
assert_eq!(CircuitState::Open.as_str(), "open");
assert_eq!(CircuitState::HalfOpen.as_str(), "half_open");
}
#[test]
fn test_state_transitions_counted() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
reset_timeout: Duration::from_millis(50),
half_open_requests: 1,
name: "test".to_string(),
};
let breaker = CircuitBreaker::new(config);
breaker.record_failure();
assert_eq!(breaker.metrics().state_transitions, 1);
thread::sleep(Duration::from_millis(100));
let _ = breaker.allow_request();
assert_eq!(breaker.metrics().state_transitions, 2);
breaker.record_success();
assert_eq!(breaker.metrics().state_transitions, 3);
}
#[test]
fn test_concurrent_access() {
use std::sync::Arc;
let breaker = Arc::new(CircuitBreaker::with_defaults());
let mut handles = vec![];
for _ in 0..10 {
let b = Arc::clone(&breaker);
handles.push(thread::spawn(move || {
for _ in 0..100 {
if b.allow_request() {
if rand_bool() {
b.record_success();
} else {
b.record_failure();
}
}
}
}));
}
for handle in handles {
handle.join().expect("Thread panicked");
}
let metrics = breaker.metrics();
assert!(metrics.total_requests > 0);
}
fn rand_bool() -> bool {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::time::SystemTime;
let mut hasher = DefaultHasher::new();
SystemTime::now().hash(&mut hasher);
thread::current().id().hash(&mut hasher);
hasher.finish() % 2 == 0
}
}