use crate::config::CircuitBreakerConfig;
use crate::metrics::{self, CircuitStateValue};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tracing::{info, warn};
#[derive(Debug, Clone)]
pub enum CircuitState {
Closed {
failures: u32,
requests: u32,
},
Open {
open_until: Instant,
},
HalfOpen {
successes: u32,
},
}
pub type CircuitStore = Arc<RwLock<HashMap<String, CircuitState>>>;
pub fn new_circuit_store() -> CircuitStore {
Arc::new(RwLock::new(HashMap::new()))
}
#[derive(Debug)]
pub struct CircuitOpenError {
pub open_until: Instant,
}
impl std::fmt::Display for CircuitOpenError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "circuit open until {:?}", self.open_until)
}
}
impl std::error::Error for CircuitOpenError {}
pub async fn allow_request(
store: &CircuitStore,
source_id: &str,
config: &CircuitBreakerConfig,
) -> Result<(), CircuitOpenError> {
if !config.enabled {
return Ok(());
}
let mut g = store.write().await;
let state = g.get(source_id).cloned();
let now = Instant::now();
let (allowed, new_state) = match state {
None => (
true,
Some(CircuitState::Closed {
failures: 0,
requests: 0,
}),
),
Some(CircuitState::Closed { .. }) => (true, None),
Some(CircuitState::Open { open_until }) => {
if now >= open_until {
info!(source = %source_id, "circuit half-open, allowing probe");
(true, Some(CircuitState::HalfOpen { successes: 0 }))
} else {
(false, None)
}
}
Some(CircuitState::HalfOpen { successes: _ }) => (true, None),
};
if let Some(s) = new_state {
g.insert(source_id.to_string(), s.clone());
metrics::set_circuit_state(source_id, circuit_state_to_value(&s));
}
if !allowed && let Some(CircuitState::Open { open_until }) = state {
warn!(source = %source_id, "request rejected: circuit open");
return Err(CircuitOpenError { open_until });
}
Ok(())
}
pub async fn record_result(
store: &CircuitStore,
source_id: &str,
config: &CircuitBreakerConfig,
success: bool,
) {
if !config.enabled {
return;
}
let mut g = store.write().await;
let state = g.get(source_id).cloned().unwrap_or(CircuitState::Closed {
failures: 0,
requests: 0,
});
let now = Instant::now();
let open_duration_secs = config
.reset_timeout_secs
.map(|r| std::cmp::min(config.half_open_timeout_secs, r))
.unwrap_or(config.half_open_timeout_secs);
let new_state = match state {
CircuitState::Closed { failures, requests } => {
let requests = requests + 1;
let failures = if success { 0 } else { failures + 1 };
let open_on_count = failures >= config.failure_threshold;
let open_on_rate = config.minimum_requests.is_some()
&& config.failure_rate_threshold.is_some()
&& requests >= config.minimum_requests.unwrap_or(0)
&& (failures as f64 / requests as f64)
>= config.failure_rate_threshold.unwrap_or(0.0);
if !success && (open_on_count || open_on_rate) {
let open_until = now + std::time::Duration::from_secs(open_duration_secs);
warn!(
source = %source_id,
failures,
requests,
open_until_secs = open_duration_secs,
"circuit opened"
);
CircuitState::Open { open_until }
} else {
CircuitState::Closed { failures, requests }
}
}
CircuitState::Open { open_until } => CircuitState::Open { open_until },
CircuitState::HalfOpen { successes } => {
if success {
let s = successes + 1;
if s >= config.success_threshold {
info!(source = %source_id, "circuit closed after half-open success");
CircuitState::Closed {
failures: 0,
requests: 0,
}
} else {
CircuitState::HalfOpen { successes: s }
}
} else {
let open_until = now + std::time::Duration::from_secs(open_duration_secs);
warn!(source = %source_id, "circuit re-opened from half-open");
CircuitState::Open { open_until }
}
}
};
g.insert(source_id.to_string(), new_state.clone());
metrics::set_circuit_state(source_id, circuit_state_to_value(&new_state));
}
fn circuit_state_to_value(s: &CircuitState) -> CircuitStateValue {
match s {
CircuitState::Closed { .. } => CircuitStateValue::Closed,
CircuitState::Open { .. } => CircuitStateValue::Open,
CircuitState::HalfOpen { .. } => CircuitStateValue::HalfOpen,
}
}
#[allow(dead_code)]
pub fn is_circuit_open_error(err: &anyhow::Error) -> bool {
err.downcast_ref::<CircuitOpenError>().is_some()
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_circuit_closed_opens_after_threshold() {
let store = new_circuit_store();
let config = CircuitBreakerConfig {
enabled: true,
failure_threshold: 3,
success_threshold: 2,
half_open_timeout_secs: 60,
reset_timeout_secs: None,
failure_rate_threshold: None,
minimum_requests: None,
};
for _ in 0..3 {
allow_request(&store, "s1", &config).await.unwrap();
record_result(&store, "s1", &config, false).await;
}
let err = allow_request(&store, "s1", &config).await.unwrap_err();
assert!(err.open_until > Instant::now());
}
#[tokio::test]
async fn test_circuit_success_resets_failures() {
let store = new_circuit_store();
let config = CircuitBreakerConfig {
enabled: true,
failure_threshold: 3,
success_threshold: 2,
half_open_timeout_secs: 60,
reset_timeout_secs: None,
failure_rate_threshold: None,
minimum_requests: None,
};
allow_request(&store, "s1", &config).await.unwrap();
record_result(&store, "s1", &config, false).await;
allow_request(&store, "s1", &config).await.unwrap();
record_result(&store, "s1", &config, true).await; allow_request(&store, "s1", &config).await.unwrap();
record_result(&store, "s1", &config, false).await;
allow_request(&store, "s1", &config).await.unwrap();
}
#[tokio::test]
async fn test_circuit_reset_timeout_caps_open_duration() {
let store = new_circuit_store();
let config = CircuitBreakerConfig {
enabled: true,
failure_threshold: 2,
success_threshold: 1,
half_open_timeout_secs: 600,
reset_timeout_secs: Some(1),
failure_rate_threshold: None,
minimum_requests: None,
};
for _ in 0..2 {
allow_request(&store, "s1", &config).await.unwrap();
record_result(&store, "s1", &config, false).await;
}
let err = allow_request(&store, "s1", &config).await.unwrap_err();
let open_secs = err
.open_until
.saturating_duration_since(Instant::now())
.as_secs();
assert!(
open_secs <= 2,
"open duration capped by reset_timeout_secs=1, got {}s",
open_secs
);
}
#[tokio::test]
async fn test_circuit_opens_on_failure_rate_threshold() {
let store = new_circuit_store();
let config = CircuitBreakerConfig {
enabled: true,
failure_threshold: 100,
success_threshold: 1,
half_open_timeout_secs: 60,
reset_timeout_secs: None,
failure_rate_threshold: Some(0.5),
minimum_requests: Some(10),
};
for i in 0..10 {
allow_request(&store, "s1", &config).await.unwrap();
record_result(&store, "s1", &config, i < 5).await;
}
let err = allow_request(&store, "s1", &config).await.unwrap_err();
assert!(err.open_until > Instant::now());
}
#[tokio::test]
async fn test_circuit_does_not_open_on_rate_before_minimum_requests() {
let store = new_circuit_store();
let config = CircuitBreakerConfig {
enabled: true,
failure_threshold: 100,
success_threshold: 1,
half_open_timeout_secs: 60,
reset_timeout_secs: None,
failure_rate_threshold: Some(0.5),
minimum_requests: Some(20),
};
for _ in 0..10 {
allow_request(&store, "s1", &config).await.unwrap();
record_result(&store, "s1", &config, false).await;
}
allow_request(&store, "s1", &config).await.unwrap();
}
}