use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use dashmap::DashMap;
use parking_lot::RwLock;
use crate::plugin::PluginId;
use crate::qname::QName;
#[derive(Clone, Copy, Debug)]
pub struct BreakerConfig {
pub failure_threshold: u32,
pub cooldown: Duration,
}
impl Default for BreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 10,
cooldown: Duration::from_secs(30),
}
}
}
#[derive(Debug)]
struct BreakerState {
consecutive_failures: AtomicU64,
opened_at: RwLock<Option<Instant>>,
}
impl Default for BreakerState {
fn default() -> Self {
Self {
consecutive_failures: AtomicU64::new(0),
opened_at: RwLock::new(None),
}
}
}
#[derive(Debug)]
pub struct CircuitBreaker {
cfg: BreakerConfig,
states: DashMap<(PluginId, QName), Arc<BreakerState>>,
}
impl CircuitBreaker {
#[must_use]
pub fn new(cfg: BreakerConfig) -> Self {
Self {
cfg,
states: DashMap::new(),
}
}
#[must_use]
pub fn allow(&self, plugin: &PluginId, qname: &QName) -> bool {
let key = (plugin.clone(), qname.clone());
let Some(state) = self.states.get(&key) else {
return true;
};
let opened_at = *state.opened_at.read();
match opened_at {
None => true,
Some(t) => {
if t.elapsed() >= self.cfg.cooldown {
*state.opened_at.write() = None;
state.consecutive_failures.store(0, Ordering::SeqCst);
true
} else {
false
}
}
}
}
pub fn record_success(&self, plugin: &PluginId, qname: &QName) {
let key = (plugin.clone(), qname.clone());
if let Some(state) = self.states.get(&key) {
state.consecutive_failures.store(0, Ordering::SeqCst);
*state.opened_at.write() = None;
}
}
pub fn record_failure(&self, plugin: &PluginId, qname: &QName) {
let key = (plugin.clone(), qname.clone());
let state = self
.states
.entry(key)
.or_insert_with(|| Arc::new(BreakerState::default()))
.clone();
let n = state.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
if n >= u64::from(self.cfg.failure_threshold) {
let mut opened = state.opened_at.write();
if opened.is_none() {
*opened = Some(Instant::now());
}
}
}
#[must_use]
pub fn failure_count(&self, plugin: &PluginId, qname: &QName) -> u64 {
let key = (plugin.clone(), qname.clone());
self.states
.get(&key)
.map(|s| s.consecutive_failures.load(Ordering::SeqCst))
.unwrap_or(0)
}
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self::new(BreakerConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fixture() -> (CircuitBreaker, PluginId, QName) {
(
CircuitBreaker::new(BreakerConfig {
failure_threshold: 3,
cooldown: Duration::from_millis(50),
}),
PluginId::new("test"),
QName::builtin("doomed"),
)
}
#[test]
fn fresh_breaker_allows_calls() {
let (b, p, q) = fixture();
assert!(b.allow(&p, &q));
}
#[test]
fn breaker_opens_after_threshold_failures() {
let (b, p, q) = fixture();
for _ in 0..3 {
b.record_failure(&p, &q);
}
assert!(!b.allow(&p, &q));
}
#[test]
fn success_resets_failure_count() {
let (b, p, q) = fixture();
b.record_failure(&p, &q);
b.record_failure(&p, &q);
b.record_success(&p, &q);
assert_eq!(b.failure_count(&p, &q), 0);
}
#[test]
fn breaker_half_opens_after_cooldown() {
let (b, p, q) = fixture();
for _ in 0..3 {
b.record_failure(&p, &q);
}
assert!(!b.allow(&p, &q));
std::thread::sleep(Duration::from_millis(60));
assert!(b.allow(&p, &q));
assert_eq!(b.failure_count(&p, &q), 0);
}
}