use parking_lot::RwLock; use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
impl std::fmt::Display for CircuitState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitState::Closed => write!(f, "Closed"),
CircuitState::Open => write!(f, "Open"),
CircuitState::HalfOpen => write!(f, "HalfOpen"),
}
}
}
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
#[serde(with = "duration_ms")]
pub recovery_timeout: Duration,
pub half_open_max_requests: u32,
}
mod duration_ms {
use serde::{Deserialize, Deserializer, Serializer};
use std::time::Duration;
pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_u64(duration.as_millis() as u64)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let millis = u64::deserialize(deserializer)?;
Ok(Duration::from_millis(millis))
}
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
recovery_timeout: Duration::from_secs(30),
half_open_max_requests: 3,
}
}
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: RwLock<CircuitState>,
failure_count: AtomicU32,
opened_at: AtomicU64,
half_open_requests: AtomicU32,
half_open_successes: AtomicU32,
service_name: String,
}
impl CircuitBreaker {
pub fn new(service_name: impl Into<String>, config: CircuitBreakerConfig) -> Self {
Self {
config,
state: RwLock::new(CircuitState::Closed),
failure_count: AtomicU32::new(0),
opened_at: AtomicU64::new(0),
half_open_requests: AtomicU32::new(0),
half_open_successes: AtomicU32::new(0),
service_name: service_name.into(),
}
}
pub fn state(&self) -> CircuitState {
self.maybe_transition_to_half_open();
*self.state.read()
}
pub fn allow_request(&self) -> Result<(), CircuitBreakerError> {
self.maybe_transition_to_half_open();
let state = *self.state.read();
match state {
CircuitState::Closed => Ok(()),
CircuitState::Open => {
tracing::debug!(
service = %self.service_name,
"Circuit breaker OPEN - rejecting request"
);
Err(CircuitBreakerError::CircuitOpen {
service: self.service_name.clone(),
retry_after: self.time_until_half_open(),
})
}
CircuitState::HalfOpen => {
let current = self.half_open_requests.fetch_add(1, Ordering::SeqCst);
if current < self.config.half_open_max_requests {
tracing::debug!(
service = %self.service_name,
request = current + 1,
max = self.config.half_open_max_requests,
"Circuit breaker HALF-OPEN - allowing test request"
);
Ok(())
} else {
tracing::debug!(
service = %self.service_name,
"Circuit breaker HALF-OPEN - max test requests reached"
);
Err(CircuitBreakerError::CircuitOpen {
service: self.service_name.clone(),
retry_after: Some(Duration::from_secs(1)),
})
}
}
}
}
pub fn record_success(&self) {
let state = *self.state.read();
match state {
CircuitState::Closed => {
self.failure_count.store(0, Ordering::SeqCst);
}
CircuitState::HalfOpen => {
let successes = self.half_open_successes.fetch_add(1, Ordering::SeqCst) + 1;
if successes >= self.config.half_open_max_requests {
self.close();
tracing::info!(
service = %self.service_name,
"Circuit breaker CLOSED - service recovered"
);
}
}
CircuitState::Open => {
self.failure_count.store(0, Ordering::SeqCst);
}
}
}
pub fn record_failure(&self) {
let state = *self.state.read();
match state {
CircuitState::Closed => {
let failures = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
if failures >= self.config.failure_threshold {
self.open();
tracing::warn!(
service = %self.service_name,
failures = failures,
"Circuit breaker OPENED - too many failures"
);
}
}
CircuitState::HalfOpen => {
self.open();
tracing::warn!(
service = %self.service_name,
"Circuit breaker REOPENED - test request failed"
);
}
CircuitState::Open => {
}
}
}
pub fn service_name(&self) -> &str {
&self.service_name
}
pub fn failure_count(&self) -> u32 {
self.failure_count.load(Ordering::SeqCst)
}
fn open(&self) {
let mut state = self.state.write();
*state = CircuitState::Open;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
self.opened_at.store(now, Ordering::SeqCst);
self.half_open_requests.store(0, Ordering::SeqCst);
self.half_open_successes.store(0, Ordering::SeqCst);
}
fn close(&self) {
let mut state = self.state.write();
*state = CircuitState::Closed;
self.failure_count.store(0, Ordering::SeqCst);
self.half_open_requests.store(0, Ordering::SeqCst);
self.half_open_successes.store(0, Ordering::SeqCst);
}
fn maybe_transition_to_half_open(&self) {
let state = *self.state.read();
if state != CircuitState::Open {
return;
}
let opened_at = self.opened_at.load(Ordering::SeqCst);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let elapsed = Duration::from_millis(now.saturating_sub(opened_at));
if elapsed >= self.config.recovery_timeout {
let mut state = self.state.write();
if *state == CircuitState::Open {
*state = CircuitState::HalfOpen;
self.half_open_requests.store(0, Ordering::SeqCst);
self.half_open_successes.store(0, Ordering::SeqCst);
tracing::info!(
service = %self.service_name,
"Circuit breaker HALF-OPEN - testing recovery"
);
}
}
}
fn time_until_half_open(&self) -> Option<Duration> {
let state = *self.state.read();
if state != CircuitState::Open {
return None;
}
let opened_at = self.opened_at.load(Ordering::SeqCst);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let elapsed = Duration::from_millis(now.saturating_sub(opened_at));
self.config.recovery_timeout.checked_sub(elapsed)
}
pub fn reset(&self) {
self.close();
tracing::info!(
service = %self.service_name,
"Circuit breaker manually RESET"
);
}
}
impl Clone for CircuitBreaker {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
state: RwLock::new(*self.state.read()),
failure_count: AtomicU32::new(self.failure_count.load(Ordering::SeqCst)),
opened_at: AtomicU64::new(self.opened_at.load(Ordering::SeqCst)),
half_open_requests: AtomicU32::new(self.half_open_requests.load(Ordering::SeqCst)),
half_open_successes: AtomicU32::new(self.half_open_successes.load(Ordering::SeqCst)),
service_name: self.service_name.clone(),
}
}
}
impl std::fmt::Debug for CircuitBreaker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CircuitBreaker")
.field("service", &self.service_name)
.field("state", &self.state())
.field("failure_count", &self.failure_count())
.finish()
}
}
pub struct CircuitBreakerRegistry {
config: CircuitBreakerConfig,
breakers: RwLock<HashMap<String, Arc<CircuitBreaker>>>,
}
impl CircuitBreakerRegistry {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
breakers: RwLock::new(HashMap::new()),
}
}
pub fn get_or_create(&self, service_name: &str) -> Arc<CircuitBreaker> {
if let Some(breaker) = self.breakers.read().get(service_name) {
return breaker.clone();
}
let mut breakers = self.breakers.write();
breakers
.entry(service_name.to_string())
.or_insert_with(|| Arc::new(CircuitBreaker::new(service_name, self.config.clone())))
.clone()
}
pub fn get(&self, service_name: &str) -> Option<Arc<CircuitBreaker>> {
self.breakers.read().get(service_name).cloned()
}
pub fn all(&self) -> Vec<Arc<CircuitBreaker>> {
self.breakers.read().values().cloned().collect()
}
pub fn status(&self) -> HashMap<String, CircuitState> {
self.breakers
.read()
.iter()
.map(|(k, v)| (k.clone(), v.state()))
.collect()
}
pub fn reset_all(&self) {
for breaker in self.breakers.read().values() {
breaker.reset();
}
}
}
impl Clone for CircuitBreakerRegistry {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
breakers: RwLock::new(self.breakers.read().clone()),
}
}
}
impl std::fmt::Debug for CircuitBreakerRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CircuitBreakerRegistry")
.field("config", &self.config)
.field("services", &self.breakers.read().keys().collect::<Vec<_>>())
.finish()
}
}
pub type SharedCircuitBreakerRegistry = Arc<CircuitBreakerRegistry>;
pub fn create_circuit_breaker_registry(
config: CircuitBreakerConfig,
) -> SharedCircuitBreakerRegistry {
Arc::new(CircuitBreakerRegistry::new(config))
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum CircuitBreakerError {
#[error("Circuit breaker open for service '{service}'. Retry after {retry_after:?}")]
CircuitOpen {
service: String,
retry_after: Option<Duration>,
},
}
impl CircuitBreakerError {
pub fn to_extensions(&self) -> HashMap<String, serde_json::Value> {
let mut extensions = HashMap::new();
match self {
CircuitBreakerError::CircuitOpen {
service,
retry_after,
} => {
extensions.insert("code".to_string(), serde_json::json!("SERVICE_UNAVAILABLE"));
extensions.insert("service".to_string(), serde_json::json!(service));
if let Some(retry) = retry_after {
extensions.insert("retryAfter".to_string(), serde_json::json!(retry.as_secs()));
}
}
}
extensions
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_circuit_starts_closed() {
let cb = CircuitBreaker::new("test", CircuitBreakerConfig::default());
assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.allow_request().is_ok());
}
#[test]
fn test_circuit_opens_after_failures() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
recovery_timeout: Duration::from_secs(30),
half_open_max_requests: 1,
};
let cb = CircuitBreaker::new("test", config);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure(); assert_eq!(cb.state(), CircuitState::Open);
assert!(cb.allow_request().is_err());
}
#[test]
fn test_success_resets_failure_count() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
};
let cb = CircuitBreaker::new("test", config);
cb.record_failure();
cb.record_failure();
assert_eq!(cb.failure_count(), 2);
cb.record_success();
assert_eq!(cb.failure_count(), 0);
}
#[test]
fn test_manual_reset() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
..Default::default()
};
let cb = CircuitBreaker::new("test", config);
cb.record_failure(); assert_eq!(cb.state(), CircuitState::Open);
cb.reset();
assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.allow_request().is_ok());
}
#[test]
fn test_half_open_transition() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout: Duration::from_millis(10),
half_open_max_requests: 1,
};
let cb = CircuitBreaker::new("test", config);
cb.record_failure(); assert_eq!(cb.state(), CircuitState::Open);
std::thread::sleep(Duration::from_millis(20));
assert_eq!(cb.state(), CircuitState::HalfOpen);
assert!(cb.allow_request().is_ok());
}
#[test]
fn test_half_open_success_closes() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout: Duration::from_millis(10),
half_open_max_requests: 1,
};
let cb = CircuitBreaker::new("test", config);
cb.record_failure(); std::thread::sleep(Duration::from_millis(20));
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success(); assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_half_open_failure_reopens() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout: Duration::from_millis(10),
half_open_max_requests: 1,
};
let cb = CircuitBreaker::new("test", config);
cb.record_failure(); std::thread::sleep(Duration::from_millis(20));
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_failure(); assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn test_registry() {
let registry = CircuitBreakerRegistry::new(CircuitBreakerConfig::default());
let cb1 = registry.get_or_create("service1");
let cb2 = registry.get_or_create("service2");
let cb1_again = registry.get_or_create("service1");
assert!(Arc::ptr_eq(&cb1, &cb1_again));
assert!(!Arc::ptr_eq(&cb1, &cb2));
assert_eq!(registry.all().len(), 2);
}
#[test]
fn test_registry_status() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
..Default::default()
};
let registry = CircuitBreakerRegistry::new(config);
let _cb1 = registry.get_or_create("healthy");
let cb2 = registry.get_or_create("unhealthy");
cb2.record_failure();
let status = registry.status();
assert_eq!(status.get("healthy"), Some(&CircuitState::Closed));
assert_eq!(status.get("unhealthy"), Some(&CircuitState::Open));
}
#[test]
fn test_error_extensions() {
let err = CircuitBreakerError::CircuitOpen {
service: "test".to_string(),
retry_after: Some(Duration::from_secs(30)),
};
let ext = err.to_extensions();
assert_eq!(
ext.get("code"),
Some(&serde_json::json!("SERVICE_UNAVAILABLE"))
);
assert_eq!(ext.get("service"), Some(&serde_json::json!("test")));
assert_eq!(ext.get("retryAfter"), Some(&serde_json::json!(30)));
}
#[test]
fn test_config_default() {
let config = CircuitBreakerConfig::default();
assert_eq!(config.failure_threshold, 5);
assert_eq!(config.recovery_timeout, Duration::from_secs(30));
assert_eq!(config.half_open_max_requests, 3);
}
#[test]
fn test_circuit_state_display() {
assert_eq!(CircuitState::Closed.to_string(), "Closed");
assert_eq!(CircuitState::Open.to_string(), "Open");
assert_eq!(CircuitState::HalfOpen.to_string(), "HalfOpen");
}
#[test]
fn test_circuit_state_equality() {
assert_eq!(CircuitState::Closed, CircuitState::Closed);
assert_ne!(CircuitState::Closed, CircuitState::Open);
assert_ne!(CircuitState::Open, CircuitState::HalfOpen);
}
#[test]
fn test_circuit_breaker_clone() {
let cb1 = CircuitBreaker::new("test", CircuitBreakerConfig::default());
cb1.record_failure();
let cb2 = cb1.clone();
assert_eq!(cb2.failure_count(), cb1.failure_count());
assert_eq!(cb2.state(), cb1.state());
assert_eq!(cb2.service_name(), cb1.service_name());
}
#[test]
fn test_circuit_breaker_debug() {
let cb = CircuitBreaker::new("test-service", CircuitBreakerConfig::default());
let debug_str = format!("{:?}", cb);
assert!(debug_str.contains("CircuitBreaker"));
assert!(debug_str.contains("test-service"));
assert!(debug_str.contains("Closed"));
}
#[test]
fn test_service_name() {
let cb = CircuitBreaker::new("my-service", CircuitBreakerConfig::default());
assert_eq!(cb.service_name(), "my-service");
}
#[test]
fn test_half_open_max_requests_limit() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout: Duration::from_millis(10),
half_open_max_requests: 2,
};
let cb = CircuitBreaker::new("test", config);
cb.record_failure(); std::thread::sleep(Duration::from_millis(20));
assert_eq!(cb.state(), CircuitState::HalfOpen);
assert!(cb.allow_request().is_ok());
assert!(cb.allow_request().is_ok());
assert!(cb.allow_request().is_err());
}
#[test]
fn test_registry_clone() {
let registry1 = CircuitBreakerRegistry::new(CircuitBreakerConfig::default());
registry1.get_or_create("service1");
let registry2 = registry1.clone();
assert!(registry2.get("service1").is_some());
}
#[test]
fn test_registry_debug() {
let registry = CircuitBreakerRegistry::new(CircuitBreakerConfig::default());
registry.get_or_create("service1");
let debug_str = format!("{:?}", registry);
assert!(debug_str.contains("CircuitBreakerRegistry"));
assert!(debug_str.contains("service1"));
}
#[test]
fn test_registry_get_nonexistent() {
let registry = CircuitBreakerRegistry::new(CircuitBreakerConfig::default());
assert!(registry.get("nonexistent").is_none());
}
#[test]
fn test_registry_reset_all() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
..Default::default()
};
let registry = CircuitBreakerRegistry::new(config);
let cb1 = registry.get_or_create("service1");
let cb2 = registry.get_or_create("service2");
cb1.record_failure(); cb2.record_failure();
assert_eq!(cb1.state(), CircuitState::Open);
assert_eq!(cb2.state(), CircuitState::Open);
registry.reset_all();
assert_eq!(cb1.state(), CircuitState::Closed);
assert_eq!(cb2.state(), CircuitState::Closed);
}
#[test]
fn test_create_circuit_breaker_registry() {
let config = CircuitBreakerConfig::default();
let registry = create_circuit_breaker_registry(config);
assert_eq!(registry.all().len(), 0);
}
#[test]
fn test_error_without_retry_after() {
let err = CircuitBreakerError::CircuitOpen {
service: "test".to_string(),
retry_after: None,
};
let ext = err.to_extensions();
assert_eq!(
ext.get("code"),
Some(&serde_json::json!("SERVICE_UNAVAILABLE"))
);
assert!(!ext.contains_key("retryAfter"));
}
#[test]
fn test_failure_count_tracking() {
let cb = CircuitBreaker::new("test", CircuitBreakerConfig::default());
assert_eq!(cb.failure_count(), 0);
cb.record_failure();
assert_eq!(cb.failure_count(), 1);
cb.record_failure();
assert_eq!(cb.failure_count(), 2);
}
#[test]
fn test_config_clone() {
let config1 = CircuitBreakerConfig {
failure_threshold: 10,
recovery_timeout: Duration::from_secs(60),
half_open_max_requests: 5,
};
let config2 = config1.clone();
assert_eq!(config1.failure_threshold, config2.failure_threshold);
assert_eq!(config1.recovery_timeout, config2.recovery_timeout);
assert_eq!(
config1.half_open_max_requests,
config2.half_open_max_requests
);
}
}