use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: usize,
pub success_threshold: usize,
pub open_duration: Duration,
}
#[derive(Debug, Clone)]
pub enum CircuitResult<T> {
Success(T),
Failure(String),
Rejected(String),
RetryAllowed(String),
}
impl<T> CircuitResult<T> {
pub fn is_success(&self) -> bool {
matches!(self, CircuitResult::Success(_))
}
pub fn is_rejected(&self) -> bool {
matches!(self, CircuitResult::Rejected(_))
}
pub fn unwrap(self) -> T {
match self {
CircuitResult::Success(v) => v,
CircuitResult::Failure(e) => panic!("unwrap on Failure: {}", e),
CircuitResult::Rejected(e) => panic!("unwrap on Rejected: {}", e),
CircuitResult::RetryAllowed(e) => panic!("unwrap on RetryAllowed: {}", e),
}
}
}
#[derive(Debug)]
pub struct CircuitBreaker {
name: String,
state: std::sync::atomic::AtomicU8,
failure_count: Arc<AtomicUsize>,
success_count: Arc<AtomicUsize>,
open_since_ms: std::sync::atomic::AtomicU64,
config: CircuitBreakerConfig,
}
impl CircuitBreaker {
pub fn new(name: &str, failure_threshold: usize, open_duration: Duration) -> Self {
Self {
name: name.to_string(),
state: std::sync::atomic::AtomicU8::new(CircuitState::Closed as u8),
failure_count: Arc::new(AtomicUsize::new(0)),
success_count: Arc::new(AtomicUsize::new(0)),
open_since_ms: std::sync::atomic::AtomicU64::new(0),
config: CircuitBreakerConfig {
failure_threshold,
success_threshold: 3,
open_duration,
},
}
}
pub fn default_for(name: &str) -> Self {
Self::new(name, 5, Duration::from_secs(60))
}
pub fn state(&self) -> CircuitState {
let state = self.state.load(Ordering::SeqCst);
let state = CircuitState::try_from(state).unwrap_or(CircuitState::Closed);
if state == CircuitState::Open {
if let Some(since) = self.open_time() {
if since.elapsed() >= self.config.open_duration {
return CircuitState::HalfOpen;
}
}
}
state
}
fn open_time(&self) -> Option<Instant> {
let ts = self.open_since_ms.load(Ordering::SeqCst);
if ts == 0 {
None
} else {
Some(Instant::now() - Duration::from_millis(ts))
}
}
pub fn record_success(&self) {
let state = self.state();
match state {
CircuitState::Closed => {
self.failure_count.store(0, Ordering::SeqCst);
}
CircuitState::HalfOpen => {
let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
if count >= self.config.success_threshold {
self.state
.store(CircuitState::Closed as u8, Ordering::SeqCst);
self.failure_count.store(0, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
self.open_since_ms.store(0, Ordering::SeqCst);
tracing::info!(
"[circuit-breaker] {}: circuit closed (recovered)",
self.name
);
}
}
CircuitState::Open => {
}
}
}
pub fn record_failure(&self) {
let state = self.state();
match state {
CircuitState::Closed => {
let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
if count >= self.config.failure_threshold {
self.state.store(CircuitState::Open as u8, Ordering::SeqCst);
self.open_since_ms.store(
Instant::now().elapsed().as_millis().try_into().unwrap_or(0),
Ordering::SeqCst,
);
tracing::warn!(
"[circuit-breaker] {}: circuit opened ({} failures)",
self.name,
count
);
}
}
CircuitState::HalfOpen => {
self.state.store(CircuitState::Open as u8, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
tracing::warn!(
"[circuit-breaker] {}: circuit reopened (failure in half-open)",
self.name
);
}
CircuitState::Open => {
}
}
}
pub fn can_request(&self) -> bool {
let state = self.state();
match state {
CircuitState::Closed | CircuitState::HalfOpen => true,
CircuitState::Open => false,
}
}
pub async fn execute<F, T, E>(&self, operation: F) -> CircuitResult<T>
where
F: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
let state = self.state();
match state {
CircuitState::Closed => match operation.await {
Ok(result) => {
self.record_success();
CircuitResult::Success(result)
}
Err(e) => {
self.record_failure();
CircuitResult::Failure(e.to_string())
}
},
CircuitState::Open => CircuitResult::Rejected(format!(
"circuit is open for {} (source may be temporarily unavailable)",
self.name
)),
CircuitState::HalfOpen => {
match operation.await {
Ok(_result) => {
self.record_success();
CircuitResult::RetryAllowed("half-open: success".to_string())
}
Err(e) => {
self.record_failure();
CircuitResult::Failure(e.to_string())
}
}
}
}
}
pub fn reset(&self) {
self.state
.store(CircuitState::Closed as u8, Ordering::SeqCst);
self.failure_count.store(0, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
self.open_since_ms.store(0, Ordering::SeqCst);
}
}
impl TryFrom<u8> for CircuitState {
type Error = ();
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(CircuitState::Closed),
1 => Ok(CircuitState::Open),
2 => Ok(CircuitState::HalfOpen),
_ => Err(()),
}
}
}
#[derive(Debug, Default)]
pub struct CircuitBreakerManager {
breakers: Arc<std::sync::RwLock<std::collections::HashMap<String, Arc<CircuitBreaker>>>>,
}
impl CircuitBreakerManager {
pub fn new() -> Self {
Self {
breakers: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
}
}
pub fn get(&self, source_id: &str) -> Arc<CircuitBreaker> {
{
let read_guard = self.breakers.read().expect("RwLock poisoned");
if let Some(breaker) = read_guard.get(source_id) {
return Arc::clone(breaker);
}
}
{
let mut write_guard = self.breakers.write().expect("RwLock poisoned");
if let Some(breaker) = write_guard.get(source_id) {
return Arc::clone(breaker);
}
let breaker = Arc::new(CircuitBreaker::default_for(source_id));
write_guard.insert(source_id.to_string(), Arc::clone(&breaker));
breaker
}
}
pub fn reset_all(&self) {
let guard = self.breakers.write().expect("RwLock poisoned");
for breaker in guard.values() {
breaker.reset();
}
}
pub fn status(&self) -> Vec<(String, CircuitState, bool)> {
let guard = self.breakers.read().expect("RwLock poisoned");
guard
.iter()
.map(|(name, breaker)| (name.clone(), breaker.state(), breaker.can_request()))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_circuit_breaker_closed_by_default() {
let breaker = CircuitBreaker::default_for("test");
assert_eq!(breaker.state(), CircuitState::Closed);
assert!(breaker.can_request());
}
#[tokio::test]
async fn test_circuit_breaker_opens_after_failures() {
let breaker = Arc::new(CircuitBreaker::new("test", 3, Duration::from_secs(60)));
breaker.record_failure();
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Closed);
assert!(breaker.can_request());
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
assert!(!breaker.can_request());
}
#[tokio::test]
async fn test_circuit_breaker_success_resets() {
let breaker = Arc::new(CircuitBreaker::new("test", 3, Duration::from_secs(60)));
breaker.record_failure();
breaker.record_failure();
assert_eq!(breaker.failure_count.load(Ordering::SeqCst), 2);
breaker.record_success();
assert_eq!(breaker.failure_count.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn test_circuit_breaker_execute_success() {
let breaker = Arc::new(CircuitBreaker::new("test", 3, Duration::from_secs(60)));
let result = breaker.execute(async { Ok::<i32, &str>(42) }).await;
assert!(result.is_success());
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_circuit_breaker_execute_rejected() {
let breaker = Arc::new(CircuitBreaker::new("test", 1, Duration::from_secs(60)));
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
let result = breaker.execute(async { Ok::<i32, &str>(42) }).await;
assert!(result.is_rejected());
}
#[test]
fn test_manager() {
let manager = CircuitBreakerManager::new();
let breaker1 = manager.get("source1");
let breaker2 = manager.get("source2");
let breaker1_again = manager.get("source1");
assert!(Arc::ptr_eq(&breaker1, &breaker1_again));
assert!(!Arc::ptr_eq(&breaker1, &breaker2));
}
#[test]
fn test_manager_status() {
let manager = CircuitBreakerManager::new();
let _ = manager.get("arxiv");
let _ = manager.get("semantic");
let status = manager.status();
assert_eq!(status.len(), 2);
assert!(status
.iter()
.all(|(_, state, _)| *state == CircuitState::Closed));
}
}