use dashmap::DashMap;
use std::sync::Arc;
use tracing::{debug, trace, warn};
use zentinel_common::ids::Scope;
use zentinel_common::types::{CircuitBreakerConfig, CircuitBreakerState};
use zentinel_common::CircuitBreaker;
pub struct ScopedCircuitBreakerManager {
breakers: DashMap<String, Arc<CircuitBreaker>>,
scope_configs: DashMap<Scope, CircuitBreakerConfig>,
default_config: CircuitBreakerConfig,
}
impl ScopedCircuitBreakerManager {
pub fn new() -> Self {
Self {
breakers: DashMap::new(),
scope_configs: DashMap::new(),
default_config: CircuitBreakerConfig::default(),
}
}
pub fn with_default_config(config: CircuitBreakerConfig) -> Self {
Self {
breakers: DashMap::new(),
scope_configs: DashMap::new(),
default_config: config,
}
}
pub fn set_scope_config(&self, scope: Scope, config: CircuitBreakerConfig) {
debug!(
scope = ?scope,
failure_threshold = config.failure_threshold,
success_threshold = config.success_threshold,
timeout_seconds = config.timeout_seconds,
"Configured circuit breaker for scope"
);
self.scope_configs.insert(scope, config);
}
fn get_effective_config(&self, scope: &Scope) -> CircuitBreakerConfig {
for s in scope.chain() {
if let Some(config) = self.scope_configs.get(&s) {
return config.clone();
}
}
self.default_config.clone()
}
pub fn get_breaker(&self, scope: &Scope, upstream_id: &str) -> Arc<CircuitBreaker> {
let key = Self::make_key(scope, upstream_id);
self.breakers
.entry(key.clone())
.or_insert_with(|| {
let config = self.get_effective_config(scope);
let name = format!("{}:{}", scope_to_label(scope), upstream_id);
trace!(
scope = ?scope,
upstream_id = upstream_id,
"Creating circuit breaker"
);
Arc::new(CircuitBreaker::with_name(config, name))
})
.clone()
}
pub async fn is_allowed(&self, scope: &Scope, upstream_id: &str) -> bool {
let breaker = self.get_breaker(scope, upstream_id);
breaker.is_closed() }
pub async fn record_success(&self, scope: &Scope, upstream_id: &str) {
let breaker = self.get_breaker(scope, upstream_id);
breaker.record_success(); }
pub async fn record_failure(&self, scope: &Scope, upstream_id: &str) {
let breaker = self.get_breaker(scope, upstream_id);
breaker.record_failure(); }
pub async fn state(&self, scope: &Scope, upstream_id: &str) -> CircuitBreakerState {
let breaker = self.get_breaker(scope, upstream_id);
breaker.state() }
pub async fn reset(&self, scope: &Scope, upstream_id: &str) {
let breaker = self.get_breaker(scope, upstream_id);
breaker.reset(); }
pub async fn reset_scope(&self, scope: &Scope) {
let prefix = format!("{}:", scope_to_label(scope));
let keys_to_reset: Vec<String> = self
.breakers
.iter()
.filter(|entry| entry.key().starts_with(&prefix))
.map(|entry| entry.key().clone())
.collect();
for key in keys_to_reset {
if let Some(breaker) = self.breakers.get(&key) {
breaker.reset(); }
}
}
pub fn clear(&self) {
self.breakers.clear();
self.scope_configs.clear();
}
pub fn breaker_count(&self) -> usize {
self.breakers.len()
}
pub fn scope_count(&self) -> usize {
self.scope_configs.len()
}
pub async fn get_all_states(&self) -> Vec<ScopedBreakerStatus> {
let mut statuses = Vec::with_capacity(self.breakers.len());
for entry in self.breakers.iter() {
let key = entry.key().clone();
let breaker = entry.value().clone();
let state = breaker.state(); let failures = breaker.consecutive_failures();
let (scope_label, upstream) = match key.split_once(':') {
Some((s, u)) => (s.to_string(), u.to_string()),
None => ("global".to_string(), key.clone()),
};
statuses.push(ScopedBreakerStatus {
key,
scope_label,
upstream,
state,
consecutive_failures: failures,
});
}
statuses
}
fn make_key(scope: &Scope, upstream_id: &str) -> String {
format!("{}:{}", scope_to_label(scope), upstream_id)
}
}
impl Default for ScopedCircuitBreakerManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ScopedBreakerStatus {
pub key: String,
pub scope_label: String,
pub upstream: String,
pub state: CircuitBreakerState,
pub consecutive_failures: u64,
}
impl ScopedBreakerStatus {
pub fn is_open(&self) -> bool {
self.state == CircuitBreakerState::Open
}
pub fn is_half_open(&self) -> bool {
self.state == CircuitBreakerState::HalfOpen
}
pub fn is_closed(&self) -> bool {
self.state == CircuitBreakerState::Closed
}
}
fn scope_to_label(scope: &Scope) -> String {
match scope {
Scope::Global => "global".to_string(),
Scope::Namespace(ns) => ns.clone(),
Scope::Service { namespace, service } => format!("{}/{}", namespace, service),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config(failure_threshold: u32) -> CircuitBreakerConfig {
CircuitBreakerConfig {
failure_threshold,
success_threshold: 2,
timeout_seconds: 1,
half_open_max_requests: 2,
}
}
#[tokio::test]
async fn test_scope_isolation() {
let manager = ScopedCircuitBreakerManager::new();
manager.set_scope_config(Scope::Global, test_config(5));
manager.set_scope_config(Scope::Namespace("api".to_string()), test_config(3));
let global_scope = Scope::Global;
let api_scope = Scope::Namespace("api".to_string());
for _ in 0..3 {
manager.record_failure(&api_scope, "backend").await;
}
assert!(!manager.is_allowed(&api_scope, "backend").await);
assert!(manager.is_allowed(&global_scope, "backend").await);
}
#[tokio::test]
async fn test_scope_chain_config_fallback() {
let manager = ScopedCircuitBreakerManager::new();
manager.set_scope_config(Scope::Namespace("api".to_string()), test_config(2));
let svc_scope = Scope::Service {
namespace: "api".to_string(),
service: "payments".to_string(),
};
manager.record_failure(&svc_scope, "backend").await;
manager.record_failure(&svc_scope, "backend").await;
assert!(!manager.is_allowed(&svc_scope, "backend").await);
}
#[tokio::test]
async fn test_service_specific_config() {
let manager = ScopedCircuitBreakerManager::new();
let svc_scope = Scope::Service {
namespace: "api".to_string(),
service: "payments".to_string(),
};
manager.set_scope_config(svc_scope.clone(), test_config(1));
manager.record_failure(&svc_scope, "backend").await;
assert!(!manager.is_allowed(&svc_scope, "backend").await);
}
#[tokio::test]
async fn test_reset_single_breaker() {
let manager = ScopedCircuitBreakerManager::new();
manager.set_scope_config(Scope::Global, test_config(1));
let scope = Scope::Global;
manager.record_failure(&scope, "backend").await;
assert!(!manager.is_allowed(&scope, "backend").await);
manager.reset(&scope, "backend").await;
assert!(manager.is_allowed(&scope, "backend").await);
}
#[tokio::test]
async fn test_reset_scope() {
let manager = ScopedCircuitBreakerManager::new();
manager.set_scope_config(Scope::Namespace("api".to_string()), test_config(1));
let scope = Scope::Namespace("api".to_string());
manager.record_failure(&scope, "backend1").await;
manager.record_failure(&scope, "backend2").await;
assert!(!manager.is_allowed(&scope, "backend1").await);
assert!(!manager.is_allowed(&scope, "backend2").await);
manager.reset_scope(&scope).await;
assert!(manager.is_allowed(&scope, "backend1").await);
assert!(manager.is_allowed(&scope, "backend2").await);
}
#[tokio::test]
async fn test_get_all_states() {
let manager = ScopedCircuitBreakerManager::new();
manager.set_scope_config(Scope::Global, test_config(5));
manager.get_breaker(&Scope::Global, "backend1");
manager.get_breaker(&Scope::Global, "backend2");
let statuses = manager.get_all_states().await;
assert_eq!(statuses.len(), 2);
assert!(statuses.iter().all(|s| s.is_closed()));
}
#[tokio::test]
async fn test_success_recovery() {
let manager = ScopedCircuitBreakerManager::with_default_config(CircuitBreakerConfig {
failure_threshold: 2,
success_threshold: 2,
timeout_seconds: 0, half_open_max_requests: 5,
});
let scope = Scope::Global;
manager.record_failure(&scope, "backend").await;
manager.record_failure(&scope, "backend").await;
assert_eq!(
manager.state(&scope, "backend").await,
CircuitBreakerState::Open
);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
manager.is_allowed(&scope, "backend").await;
assert_eq!(
manager.state(&scope, "backend").await,
CircuitBreakerState::HalfOpen
);
manager.record_success(&scope, "backend").await;
manager.record_success(&scope, "backend").await;
assert_eq!(
manager.state(&scope, "backend").await,
CircuitBreakerState::Closed
);
}
#[test]
fn test_scope_to_label() {
assert_eq!(scope_to_label(&Scope::Global), "global");
assert_eq!(scope_to_label(&Scope::Namespace("api".to_string())), "api");
assert_eq!(
scope_to_label(&Scope::Service {
namespace: "api".to_string(),
service: "payments".to_string(),
}),
"api/payments"
);
}
#[test]
fn test_clear() {
let manager = ScopedCircuitBreakerManager::new();
manager.set_scope_config(Scope::Global, test_config(5));
manager.get_breaker(&Scope::Global, "backend");
assert_eq!(manager.breaker_count(), 1);
assert_eq!(manager.scope_count(), 1);
manager.clear();
assert_eq!(manager.breaker_count(), 0);
assert_eq!(manager.scope_count(), 0);
}
}