use dashmap::DashMap;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
impl std::fmt::Display for CircuitState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitState::Closed => write!(f, "Closed"),
CircuitState::Open => write!(f, "Open"),
CircuitState::HalfOpen => write!(f, "HalfOpen"),
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub open_duration: Duration,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
open_duration: Duration::from_secs(30),
}
}
}
pub struct CircuitBreaker {
key: String,
config: CircuitBreakerConfig,
state: AtomicU32,
failure_count: AtomicU32,
success_count: AtomicU32,
open_since_ms: AtomicU64,
}
impl CircuitBreaker {
pub fn new(key: String, config: CircuitBreakerConfig) -> Self {
Self {
key,
config,
state: AtomicU32::new(0), failure_count: AtomicU32::new(0),
success_count: AtomicU32::new(0),
open_since_ms: AtomicU64::new(0),
}
}
pub fn state(&self) -> CircuitState {
let raw = self.state.load(Ordering::Acquire);
match raw {
1 => {
if self.should_transition_to_half_open() {
self.transition_to(CircuitState::HalfOpen);
CircuitState::HalfOpen
} else {
CircuitState::Open
}
}
2 => CircuitState::HalfOpen,
_ => CircuitState::Closed,
}
}
pub fn key(&self) -> &str {
&self.key
}
pub fn allow_request(&self) -> bool {
match self.state() {
CircuitState::Closed | CircuitState::HalfOpen => true,
CircuitState::Open => false,
}
}
pub fn record_success(&self) {
match self.state() {
CircuitState::Closed => {
self.failure_count.store(0, Ordering::Release);
}
CircuitState::HalfOpen => {
let count = self.success_count.fetch_add(1, Ordering::AcqRel) + 1;
if count >= self.config.success_threshold {
tracing::info!(
key = %self.key,
"熔断器恢复: HalfOpen → Closed (连续成功 {} 次)",
count
);
self.transition_to(CircuitState::Closed);
}
}
CircuitState::Open => {
}
}
}
pub fn record_failure(&self) {
match self.state() {
CircuitState::Closed => {
let count = self.failure_count.fetch_add(1, Ordering::AcqRel) + 1;
if count >= self.config.failure_threshold {
tracing::warn!(
key = %self.key,
"熔断器触发: Closed → Open (连续失败 {} 次,阈值 {})",
count,
self.config.failure_threshold
);
self.transition_to(CircuitState::Open);
}
}
CircuitState::HalfOpen => {
tracing::warn!(
key = %self.key,
"熔断器重新打开: HalfOpen → Open (探测请求失败)"
);
self.transition_to(CircuitState::Open);
}
CircuitState::Open => {
}
}
}
pub fn reset(&self) {
tracing::info!(key = %self.key, "熔断器手动重置 → Closed");
self.transition_to(CircuitState::Closed);
}
pub fn failure_count(&self) -> u32 {
self.failure_count.load(Ordering::Acquire)
}
fn transition_to(&self, new_state: CircuitState) {
let raw = match new_state {
CircuitState::Closed => {
self.failure_count.store(0, Ordering::Release);
self.success_count.store(0, Ordering::Release);
self.open_since_ms.store(0, Ordering::Release);
0
}
CircuitState::Open => {
self.success_count.store(0, Ordering::Release);
let epoch_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
self.open_since_ms.store(epoch_ms, Ordering::Release);
1
}
CircuitState::HalfOpen => {
self.success_count.store(0, Ordering::Release);
2
}
};
self.state.store(raw, Ordering::Release);
}
fn should_transition_to_half_open(&self) -> bool {
let open_since = self.open_since_ms.load(Ordering::Acquire);
if open_since == 0 {
return false;
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let elapsed_ms = now.saturating_sub(open_since);
elapsed_ms >= self.config.open_duration.as_millis() as u64
}
}
static CIRCUIT_BREAKERS: std::sync::LazyLock<DashMap<String, Arc<CircuitBreaker>>> =
std::sync::LazyLock::new(|| DashMap::new());
pub fn get_circuit_breaker(
service_name: &str,
config: CircuitBreakerConfig,
) -> Arc<CircuitBreaker> {
CIRCUIT_BREAKERS
.entry(service_name.to_string())
.or_insert_with(|| Arc::new(CircuitBreaker::new(service_name.to_string(), config)))
.clone()
}
pub fn get_instance_circuit_breaker(
service_name: &str,
instance_id: &str,
config: CircuitBreakerConfig,
) -> Arc<CircuitBreaker> {
let key = format!("{}::{}", service_name, instance_id);
CIRCUIT_BREAKERS
.entry(key.clone())
.or_insert_with(|| Arc::new(CircuitBreaker::new(key, config)))
.clone()
}
pub fn get_existing_circuit_breaker(service_name: &str) -> Option<Arc<CircuitBreaker>> {
CIRCUIT_BREAKERS.get(service_name).map(|v| v.clone())
}
pub fn get_existing_instance_circuit_breaker(
service_name: &str,
instance_id: &str,
) -> Option<Arc<CircuitBreaker>> {
let key = format!("{}::{}", service_name, instance_id);
CIRCUIT_BREAKERS.get(&key).map(|v| v.clone())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_initial_state_is_closed() {
let cb = CircuitBreaker::new("test".into(), CircuitBreakerConfig::default());
assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.allow_request());
}
#[test]
fn test_closed_to_open_on_failures() {
let cb = CircuitBreaker::new(
"test".into(),
CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
},
);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure(); assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.allow_request());
}
#[test]
fn test_success_resets_failure_count() {
let cb = CircuitBreaker::new(
"test".into(),
CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
},
);
cb.record_failure();
cb.record_failure();
cb.record_success(); assert_eq!(cb.failure_count(), 0);
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_half_open_to_closed_on_success() {
let cb = CircuitBreaker::new(
"test".into(),
CircuitBreakerConfig {
failure_threshold: 1,
success_threshold: 2,
open_duration: Duration::from_millis(0), },
);
cb.record_failure(); assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state(), CircuitState::HalfOpen); cb.record_success(); assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_half_open_to_open_on_failure() {
let cb = CircuitBreaker::new(
"test".into(),
CircuitBreakerConfig {
failure_threshold: 1,
success_threshold: 3,
open_duration: Duration::from_millis(0),
},
);
cb.record_failure(); assert_eq!(cb.state(), CircuitState::HalfOpen);
let cb2 = CircuitBreaker::new(
"test2".into(),
CircuitBreakerConfig {
failure_threshold: 1,
success_threshold: 3,
open_duration: Duration::from_secs(9999), },
);
cb2.record_failure(); assert_eq!(cb2.state(), CircuitState::Open);
assert!(!cb2.allow_request());
}
#[test]
fn test_manual_reset() {
let cb = CircuitBreaker::new(
"test".into(),
CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_secs(9999),
..Default::default()
},
);
cb.record_failure(); assert_eq!(cb.state(), CircuitState::Open);
cb.reset(); assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.allow_request());
}
#[test]
fn test_global_circuit_breaker_pool() {
let cb1 = get_circuit_breaker("pool-test-svc", CircuitBreakerConfig::default());
let cb2 = get_circuit_breaker("pool-test-svc", CircuitBreakerConfig::default());
assert!(Arc::ptr_eq(&cb1, &cb2));
}
#[test]
fn test_instance_level_circuit_breaker() {
let cb1 = get_instance_circuit_breaker(
"inst-test-svc",
"10.0.0.1:9090",
CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_secs(9999),
..Default::default()
},
);
let cb2 = get_instance_circuit_breaker(
"inst-test-svc",
"10.0.0.2:9090",
CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_secs(9999),
..Default::default()
},
);
assert!(!Arc::ptr_eq(&cb1, &cb2));
cb1.record_failure();
assert_eq!(cb1.state(), CircuitState::Open);
assert_eq!(cb2.state(), CircuitState::Closed);
let cb1_again = get_instance_circuit_breaker(
"inst-test-svc",
"10.0.0.1:9090",
CircuitBreakerConfig::default(),
);
assert!(Arc::ptr_eq(&cb1, &cb1_again));
}
}