use std::{
future::Future,
sync::{
atomic::{AtomicU32, AtomicU64, Ordering},
Arc,
},
time::{Duration, Instant},
};
use dashmap::DashMap;
use parking_lot::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
impl std::fmt::Display for CircuitState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitState::Closed => write!(f, "closed"),
CircuitState::Open => write!(f, "open"),
CircuitState::HalfOpen => write!(f, "half-open"),
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitOpenError {
pub circuit_name: String,
pub retry_after: Duration,
}
impl std::fmt::Display for CircuitOpenError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"circuit '{}' is open, retry after {:?}",
self.circuit_name, self.retry_after
)
}
}
impl std::error::Error for CircuitOpenError {}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub timeout: Duration,
pub failure_window: Duration,
pub half_open_requests: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 3,
timeout: Duration::from_secs(30),
failure_window: Duration::from_secs(60),
half_open_requests: 1,
}
}
}
impl CircuitBreakerConfig {
pub fn new(failure_threshold: u32) -> Self {
Self {
failure_threshold,
..Default::default()
}
}
pub fn with_success_threshold(mut self, threshold: u32) -> Self {
self.success_threshold = threshold;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_failure_window(mut self, window: Duration) -> Self {
self.failure_window = window;
self
}
pub fn with_half_open_requests(mut self, requests: u32) -> Self {
self.half_open_requests = requests.max(1);
self
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerStats {
pub state: CircuitState,
pub success_count: u64,
pub failure_count: u64,
pub rejected_count: u64,
pub failures_in_window: u32,
pub half_open_successes: u32,
pub time_in_state: Duration,
}
pub struct CircuitBreaker {
name: String,
config: CircuitBreakerConfig,
state: RwLock<CircuitState>,
state_changed_at: RwLock<Instant>,
failures: RwLock<Vec<Instant>>,
half_open_successes: AtomicU32,
half_open_in_flight: AtomicU32,
success_count: AtomicU64,
failure_count: AtomicU64,
rejected_count: AtomicU64,
}
impl CircuitBreaker {
pub fn new(name: impl Into<String>, config: CircuitBreakerConfig) -> Self {
Self {
name: name.into(),
config,
state: RwLock::new(CircuitState::Closed),
state_changed_at: RwLock::new(Instant::now()),
failures: RwLock::new(Vec::new()),
half_open_successes: AtomicU32::new(0),
half_open_in_flight: AtomicU32::new(0),
success_count: AtomicU64::new(0),
failure_count: AtomicU64::new(0),
rejected_count: AtomicU64::new(0),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn check(&self) -> Result<(), CircuitOpenError> {
self.maybe_transition_to_half_open();
let state = *self.state.read();
match state {
CircuitState::Closed => Ok(()),
CircuitState::Open => {
self.rejected_count.fetch_add(1, Ordering::Relaxed);
let elapsed = self.state_changed_at.read().elapsed();
let retry_after = self.config.timeout.saturating_sub(elapsed);
Err(CircuitOpenError {
circuit_name: self.name.clone(),
retry_after,
})
}
CircuitState::HalfOpen => {
let in_flight = self.half_open_in_flight.load(Ordering::Acquire);
if in_flight < self.config.half_open_requests {
self.half_open_in_flight.fetch_add(1, Ordering::AcqRel);
Ok(())
} else {
self.rejected_count.fetch_add(1, Ordering::Relaxed);
Err(CircuitOpenError {
circuit_name: self.name.clone(),
retry_after: Duration::from_millis(100),
})
}
}
}
}
pub fn record_success(&self) {
self.success_count.fetch_add(1, Ordering::Relaxed);
let state = *self.state.read();
if state == CircuitState::HalfOpen {
self.half_open_in_flight.fetch_sub(1, Ordering::AcqRel);
let successes = self.half_open_successes.fetch_add(1, Ordering::AcqRel) + 1;
if successes >= self.config.success_threshold {
self.transition_to(CircuitState::Closed);
}
}
}
pub fn record_failure(&self) {
self.failure_count.fetch_add(1, Ordering::Relaxed);
let state = *self.state.read();
match state {
CircuitState::Closed => {
let now = Instant::now();
let mut failures = self.failures.write();
failures.push(now);
let cutoff = now - self.config.failure_window;
failures.retain(|&t| t > cutoff);
if failures.len() as u32 >= self.config.failure_threshold {
drop(failures); self.transition_to(CircuitState::Open);
}
}
CircuitState::HalfOpen => {
self.half_open_in_flight.fetch_sub(1, Ordering::AcqRel);
self.transition_to(CircuitState::Open);
}
CircuitState::Open => {
}
}
}
pub async fn call<F, Fut, T, E>(&self, f: F) -> Result<T, CircuitBreakerError<E>>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, E>>,
{
self.check().map_err(CircuitBreakerError::CircuitOpen)?;
match f().await {
Ok(result) => {
self.record_success();
Ok(result)
}
Err(e) => {
self.record_failure();
Err(CircuitBreakerError::Inner(e))
}
}
}
pub fn get_state(&self) -> CircuitState {
self.maybe_transition_to_half_open();
*self.state.read()
}
pub fn get_stats(&self) -> CircuitBreakerStats {
self.maybe_transition_to_half_open();
let state = *self.state.read();
let failures = self.failures.read();
let now = Instant::now();
let cutoff = now - self.config.failure_window;
let failures_in_window = failures.iter().filter(|&&t| t > cutoff).count() as u32;
CircuitBreakerStats {
state,
success_count: self.success_count.load(Ordering::Relaxed),
failure_count: self.failure_count.load(Ordering::Relaxed),
rejected_count: self.rejected_count.load(Ordering::Relaxed),
failures_in_window,
half_open_successes: self.half_open_successes.load(Ordering::Relaxed),
time_in_state: self.state_changed_at.read().elapsed(),
}
}
pub fn reset(&self) {
self.transition_to(CircuitState::Closed);
self.failures.write().clear();
}
fn transition_to(&self, new_state: CircuitState) {
let mut state = self.state.write();
let old_state = *state;
if old_state != new_state {
*state = new_state;
*self.state_changed_at.write() = Instant::now();
if new_state == CircuitState::HalfOpen || new_state == CircuitState::Closed {
self.half_open_successes.store(0, Ordering::Relaxed);
self.half_open_in_flight.store(0, Ordering::Relaxed);
}
if new_state == CircuitState::Closed {
self.failures.write().clear();
}
#[cfg(feature = "otel")]
tracing::info!(
circuit = %self.name,
old_state = %old_state,
new_state = %new_state,
"circuit breaker state changed"
);
}
}
fn maybe_transition_to_half_open(&self) {
let state = *self.state.read();
if state == CircuitState::Open {
let elapsed = self.state_changed_at.read().elapsed();
if elapsed >= self.config.timeout {
self.transition_to(CircuitState::HalfOpen);
}
}
}
}
#[derive(Debug)]
pub enum CircuitBreakerError<E> {
CircuitOpen(CircuitOpenError),
Inner(E),
}
impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
CircuitBreakerError::Inner(e) => write!(f, "{}", e),
}
}
}
impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
CircuitBreakerError::CircuitOpen(e) => Some(e),
CircuitBreakerError::Inner(e) => Some(e),
}
}
}
pub struct CircuitBreakerManager {
breakers: DashMap<String, Arc<CircuitBreaker>>,
default_config: CircuitBreakerConfig,
}
impl CircuitBreakerManager {
pub fn new(default_config: CircuitBreakerConfig) -> Self {
Self {
breakers: DashMap::new(),
default_config,
}
}
pub fn get_or_create(&self, name: &str) -> Arc<CircuitBreaker> {
self.breakers
.entry(name.to_string())
.or_insert_with(|| Arc::new(CircuitBreaker::new(name, self.default_config.clone())))
.clone()
}
pub fn get(&self, name: &str) -> Option<Arc<CircuitBreaker>> {
self.breakers.get(name).map(|r| r.clone())
}
pub fn create_with_config(
&self,
name: &str,
config: CircuitBreakerConfig,
) -> Arc<CircuitBreaker> {
let breaker = Arc::new(CircuitBreaker::new(name, config));
self.breakers.insert(name.to_string(), breaker.clone());
breaker
}
pub fn get_all_stats(&self) -> Vec<(String, CircuitBreakerStats)> {
self.breakers
.iter()
.map(|entry| (entry.key().clone(), entry.value().get_stats()))
.collect()
}
pub fn reset_all(&self) {
for entry in self.breakers.iter() {
entry.value().reset();
}
}
pub fn remove(&self, name: &str) {
self.breakers.remove(name);
}
pub fn clear(&self) {
self.breakers.clear();
}
pub fn len(&self) -> usize {
self.breakers.len()
}
pub fn is_empty(&self) -> bool {
self.breakers.is_empty()
}
}
impl Default for CircuitBreakerManager {
fn default() -> Self {
Self::new(CircuitBreakerConfig::default())
}
}
pub struct KeyedCircuitBreaker<K: std::hash::Hash + Eq + Clone + Send + Sync + 'static> {
breakers: DashMap<K, Arc<CircuitBreaker>>,
config: CircuitBreakerConfig,
counter: std::sync::atomic::AtomicU64,
}
impl<K: std::hash::Hash + Eq + Clone + Send + Sync + 'static> KeyedCircuitBreaker<K> {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
breakers: DashMap::new(),
config,
counter: std::sync::atomic::AtomicU64::new(0),
}
}
pub fn check(&self, key: &K) -> Result<(), CircuitOpenError> {
self.get_or_create(key).check()
}
pub fn record_success(&self, key: &K) {
self.get_or_create(key).record_success()
}
pub fn record_failure(&self, key: &K) {
self.get_or_create(key).record_failure()
}
pub async fn call<F, Fut, T, E>(&self, key: &K, f: F) -> Result<T, CircuitBreakerError<E>>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, E>>,
{
self.get_or_create(key).call(f).await
}
pub fn get_state(&self, key: &K) -> Option<CircuitState> {
self.breakers.get(key).map(|cb| cb.get_state())
}
pub fn get_stats(&self, key: &K) -> Option<CircuitBreakerStats> {
self.breakers.get(key).map(|cb| cb.get_stats())
}
pub fn get_all_stats(&self) -> Vec<(K, CircuitBreakerStats)> {
self.breakers
.iter()
.map(|entry| (entry.key().clone(), entry.value().get_stats()))
.collect()
}
pub fn reset(&self, key: &K) {
if let Some(cb) = self.breakers.get(key) {
cb.reset();
}
}
pub fn reset_all(&self) {
for entry in self.breakers.iter() {
entry.value().reset();
}
}
pub fn remove(&self, key: &K) {
self.breakers.remove(key);
}
pub fn clear(&self) {
self.breakers.clear();
}
pub fn len(&self) -> usize {
self.breakers.len()
}
pub fn is_empty(&self) -> bool {
self.breakers.is_empty()
}
pub fn get(&self, key: &K) -> Option<Arc<CircuitBreaker>> {
self.breakers.get(key).map(|r| r.clone())
}
fn get_or_create(&self, key: &K) -> Arc<CircuitBreaker> {
self.breakers
.entry(key.clone())
.or_insert_with(|| {
let id = self.counter.fetch_add(1, Ordering::Relaxed);
Arc::new(CircuitBreaker::new(
format!("keyed-{}", id),
self.config.clone(),
))
})
.clone()
}
}
impl<K: std::hash::Hash + Eq + Clone + Send + Sync + 'static> Default for KeyedCircuitBreaker<K> {
fn default() -> Self {
Self::new(CircuitBreakerConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_circuit_breaker_config_default() {
let config = CircuitBreakerConfig::default();
assert_eq!(config.failure_threshold, 5);
assert_eq!(config.success_threshold, 3);
assert_eq!(config.timeout, Duration::from_secs(30));
}
#[test]
fn test_circuit_breaker_config_builder() {
let config = CircuitBreakerConfig::new(10)
.with_success_threshold(5)
.with_timeout(Duration::from_secs(60))
.with_half_open_requests(3);
assert_eq!(config.failure_threshold, 10);
assert_eq!(config.success_threshold, 5);
assert_eq!(config.timeout, Duration::from_secs(60));
assert_eq!(config.half_open_requests, 3);
}
#[test]
fn test_circuit_breaker_initial_state() {
let cb = CircuitBreaker::new("test", CircuitBreakerConfig::default());
assert_eq!(cb.get_state(), CircuitState::Closed);
}
#[test]
fn test_circuit_breaker_opens_on_failures() {
let config = CircuitBreakerConfig::new(3);
let cb = CircuitBreaker::new("test", config);
cb.record_failure();
assert_eq!(cb.get_state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.get_state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.get_state(), CircuitState::Open);
}
#[test]
fn test_circuit_breaker_check_when_open() {
let config = CircuitBreakerConfig::new(1);
let cb = CircuitBreaker::new("test", config);
cb.record_failure();
assert_eq!(cb.get_state(), CircuitState::Open);
let result = cb.check();
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.circuit_name, "test");
}
#[test]
fn test_circuit_breaker_transitions_to_half_open() {
let config = CircuitBreakerConfig::new(1).with_timeout(Duration::from_millis(10));
let cb = CircuitBreaker::new("test", config);
cb.record_failure();
assert_eq!(cb.get_state(), CircuitState::Open);
std::thread::sleep(Duration::from_millis(20));
assert_eq!(cb.get_state(), CircuitState::HalfOpen);
}
#[test]
fn test_circuit_breaker_closes_on_success() {
let config = CircuitBreakerConfig::new(1)
.with_timeout(Duration::from_millis(10))
.with_success_threshold(2);
let cb = CircuitBreaker::new("test", config);
cb.record_failure();
std::thread::sleep(Duration::from_millis(20));
assert_eq!(cb.get_state(), CircuitState::HalfOpen);
cb.check().unwrap();
cb.record_success();
assert_eq!(cb.get_state(), CircuitState::HalfOpen);
cb.check().unwrap();
cb.record_success();
assert_eq!(cb.get_state(), CircuitState::Closed);
}
#[test]
fn test_circuit_breaker_reopens_on_half_open_failure() {
let config = CircuitBreakerConfig::new(1).with_timeout(Duration::from_millis(10));
let cb = CircuitBreaker::new("test", config);
cb.record_failure();
std::thread::sleep(Duration::from_millis(20));
assert_eq!(cb.get_state(), CircuitState::HalfOpen);
cb.check().unwrap();
cb.record_failure();
assert_eq!(cb.get_state(), CircuitState::Open);
}
#[test]
fn test_circuit_breaker_reset() {
let config = CircuitBreakerConfig::new(1);
let cb = CircuitBreaker::new("test", config);
cb.record_failure();
assert_eq!(cb.get_state(), CircuitState::Open);
cb.reset();
assert_eq!(cb.get_state(), CircuitState::Closed);
}
#[test]
fn test_circuit_breaker_stats() {
let config = CircuitBreakerConfig::new(5);
let cb = CircuitBreaker::new("test", config);
cb.record_success();
cb.record_success();
cb.record_failure();
let stats = cb.get_stats();
assert_eq!(stats.state, CircuitState::Closed);
assert_eq!(stats.success_count, 2);
assert_eq!(stats.failure_count, 1);
assert_eq!(stats.failures_in_window, 1);
}
#[tokio::test]
async fn test_circuit_breaker_call_success() {
let cb = CircuitBreaker::new("test", CircuitBreakerConfig::default());
let result = cb
.call(|| async { Ok::<_, std::io::Error>("success") })
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(cb.get_stats().success_count, 1);
}
#[tokio::test]
async fn test_circuit_breaker_call_failure() {
let cb = CircuitBreaker::new("test", CircuitBreakerConfig::default());
let result: Result<(), CircuitBreakerError<std::io::Error>> = cb
.call(|| async { Err(std::io::Error::new(std::io::ErrorKind::Other, "failed")) })
.await;
assert!(result.is_err());
assert_eq!(cb.get_stats().failure_count, 1);
}
#[test]
fn test_circuit_breaker_manager() {
let manager = CircuitBreakerManager::default();
let cb1 = manager.get_or_create("service1");
let cb2 = manager.get_or_create("service2");
let cb1_again = manager.get_or_create("service1");
assert_eq!(cb1.name(), "service1");
assert_eq!(cb2.name(), "service2");
assert!(Arc::ptr_eq(&cb1, &cb1_again));
}
#[test]
fn test_circuit_breaker_manager_custom_config() {
let manager = CircuitBreakerManager::default();
let config = CircuitBreakerConfig::new(10);
let cb = manager.create_with_config("custom", config);
assert_eq!(cb.name(), "custom");
assert_eq!(cb.get_state(), CircuitState::Closed);
}
#[test]
fn test_circuit_breaker_manager_get_all_stats() {
let manager = CircuitBreakerManager::default();
manager.get_or_create("a").record_success();
manager.get_or_create("b").record_failure();
let stats = manager.get_all_stats();
assert_eq!(stats.len(), 2);
}
#[test]
fn test_circuit_breaker_manager_reset_all() {
let manager = CircuitBreakerManager::new(CircuitBreakerConfig::new(1));
let cb = manager.get_or_create("test");
cb.record_failure();
assert_eq!(cb.get_state(), CircuitState::Open);
manager.reset_all();
assert_eq!(cb.get_state(), CircuitState::Closed);
}
#[test]
fn test_circuit_state_display() {
assert_eq!(format!("{}", CircuitState::Closed), "closed");
assert_eq!(format!("{}", CircuitState::Open), "open");
assert_eq!(format!("{}", CircuitState::HalfOpen), "half-open");
}
#[test]
fn test_keyed_circuit_breaker_basic() {
let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
assert!(cb.check(&"key1".to_string()).is_ok());
assert!(cb.check(&"key2".to_string()).is_ok());
assert_eq!(cb.len(), 2);
}
#[test]
fn test_keyed_circuit_breaker_isolation() {
let config = CircuitBreakerConfig::new(1); let cb = KeyedCircuitBreaker::<String>::new(config);
cb.record_failure(&"key1".to_string());
assert!(cb.check(&"key1".to_string()).is_err());
assert!(cb.check(&"key2".to_string()).is_ok());
}
#[test]
fn test_keyed_circuit_breaker_stats() {
let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
cb.record_success(&"a".to_string());
cb.record_success(&"a".to_string());
cb.record_failure(&"b".to_string());
let stats_a = cb.get_stats(&"a".to_string()).unwrap();
assert_eq!(stats_a.success_count, 2);
let stats_b = cb.get_stats(&"b".to_string()).unwrap();
assert_eq!(stats_b.failure_count, 1);
let all_stats = cb.get_all_stats();
assert_eq!(all_stats.len(), 2);
}
#[test]
fn test_keyed_circuit_breaker_reset() {
let config = CircuitBreakerConfig::new(1);
let cb = KeyedCircuitBreaker::<String>::new(config);
cb.record_failure(&"key".to_string());
assert!(cb.check(&"key".to_string()).is_err());
cb.reset(&"key".to_string());
assert!(cb.check(&"key".to_string()).is_ok());
}
#[test]
fn test_keyed_circuit_breaker_reset_all() {
let config = CircuitBreakerConfig::new(1);
let cb = KeyedCircuitBreaker::<String>::new(config);
cb.record_failure(&"a".to_string());
cb.record_failure(&"b".to_string());
assert!(cb.check(&"a".to_string()).is_err());
assert!(cb.check(&"b".to_string()).is_err());
cb.reset_all();
assert!(cb.check(&"a".to_string()).is_ok());
assert!(cb.check(&"b".to_string()).is_ok());
}
#[tokio::test]
async fn test_keyed_circuit_breaker_call() {
let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
let result = cb
.call(&"test".to_string(), || async {
Ok::<_, std::io::Error>("success")
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
}
#[test]
fn test_keyed_circuit_breaker_remove() {
let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
cb.check(&"key".to_string()).ok();
assert_eq!(cb.len(), 1);
cb.remove(&"key".to_string());
assert_eq!(cb.len(), 0);
}
#[test]
fn test_keyed_circuit_breaker_clear() {
let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
cb.check(&"a".to_string()).ok();
cb.check(&"b".to_string()).ok();
cb.check(&"c".to_string()).ok();
assert_eq!(cb.len(), 3);
cb.clear();
assert!(cb.is_empty());
}
#[test]
fn test_keyed_circuit_breaker_default() {
let cb = KeyedCircuitBreaker::<u64>::default();
assert!(cb.check(&1).is_ok());
assert!(cb.check(&2).is_ok());
}
}