use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use crate::error::{AiError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub timeout: Duration,
pub success_threshold: u32,
pub failure_window: Duration,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
timeout: Duration::from_secs(60),
success_threshold: 2,
failure_window: Duration::from_secs(60),
}
}
}
impl CircuitBreakerConfig {
#[must_use]
pub fn new(failure_threshold: u32, timeout: Duration) -> Self {
Self {
failure_threshold,
timeout,
..Default::default()
}
}
#[must_use]
pub fn with_success_threshold(mut self, threshold: u32) -> Self {
self.success_threshold = threshold;
self
}
#[must_use]
pub fn with_failure_window(mut self, window: Duration) -> Self {
self.failure_window = window;
self
}
}
#[derive(Debug)]
struct CircuitBreakerState {
state: CircuitState,
failure_count: u32,
success_count: u32,
opened_at: Option<Instant>,
recent_failures: Vec<Instant>,
}
impl CircuitBreakerState {
fn new() -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
opened_at: None,
recent_failures: Vec::new(),
}
}
fn clean_old_failures(&mut self, window: Duration) {
let cutoff = Instant::now().checked_sub(window).unwrap();
self.recent_failures.retain(|&t| t > cutoff);
}
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: Arc<RwLock<CircuitBreakerState>>,
name: String,
}
impl CircuitBreaker {
pub fn new(name: impl Into<String>, config: CircuitBreakerConfig) -> Self {
Self {
config,
state: Arc::new(RwLock::new(CircuitBreakerState::new())),
name: name.into(),
}
}
pub async fn state(&self) -> CircuitState {
let state = self.state.read().await;
state.state
}
pub async fn call<F, Fut, T>(&self, operation: F) -> Result<T>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
self.check_timeout().await;
let current_state = {
let state = self.state.read().await;
state.state
};
match current_state {
CircuitState::Open => {
tracing::warn!(
circuit = %self.name,
"Circuit breaker is open, rejecting request"
);
Err(AiError::ServiceUnavailable)
}
CircuitState::Closed | CircuitState::HalfOpen => {
match operation().await {
Ok(result) => {
self.on_success().await;
Ok(result)
}
Err(err) => {
self.on_failure().await;
Err(err)
}
}
}
}
}
async fn check_timeout(&self) {
let mut state = self.state.write().await;
if state.state == CircuitState::Open {
if let Some(opened_at) = state.opened_at {
if opened_at.elapsed() >= self.config.timeout {
tracing::info!(
circuit = %self.name,
"Circuit breaker timeout elapsed, transitioning to half-open"
);
state.state = CircuitState::HalfOpen;
state.success_count = 0;
}
}
}
}
async fn on_success(&self) {
let mut state = self.state.write().await;
match state.state {
CircuitState::HalfOpen => {
state.success_count += 1;
if state.success_count >= self.config.success_threshold {
tracing::info!(
circuit = %self.name,
"Circuit breaker closing after {} successful requests",
state.success_count
);
state.state = CircuitState::Closed;
state.failure_count = 0;
state.success_count = 0;
state.recent_failures.clear();
state.opened_at = None;
}
}
CircuitState::Closed => {
state.failure_count = 0;
}
CircuitState::Open => {
state.failure_count = 0;
}
}
}
async fn on_failure(&self) {
let mut state = self.state.write().await;
state.recent_failures.push(Instant::now());
state.clean_old_failures(self.config.failure_window);
match state.state {
CircuitState::HalfOpen => {
tracing::warn!(
circuit = %self.name,
"Circuit breaker reopening after failure in half-open state"
);
state.state = CircuitState::Open;
state.opened_at = Some(Instant::now());
state.success_count = 0;
}
CircuitState::Closed => {
state.failure_count += 1;
if state.recent_failures.len() >= self.config.failure_threshold as usize {
tracing::warn!(
circuit = %self.name,
failures = state.recent_failures.len(),
threshold = self.config.failure_threshold,
"Circuit breaker opening due to failure threshold"
);
state.state = CircuitState::Open;
state.opened_at = Some(Instant::now());
}
}
CircuitState::Open => {
tracing::debug!(
circuit = %self.name,
"Additional failure while circuit is open"
);
}
}
}
pub async fn reset(&self) {
let mut state = self.state.write().await;
tracing::info!(circuit = %self.name, "Manually resetting circuit breaker");
state.state = CircuitState::Closed;
state.failure_count = 0;
state.success_count = 0;
state.recent_failures.clear();
state.opened_at = None;
}
pub async fn metrics(&self) -> CircuitBreakerMetrics {
let state = self.state.read().await;
CircuitBreakerMetrics {
state: state.state,
failure_count: state.failure_count,
success_count: state.success_count,
recent_failure_count: state.recent_failures.len() as u32,
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerMetrics {
pub state: CircuitState,
pub failure_count: u32,
pub success_count: u32,
pub recent_failure_count: u32,
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::sleep;
#[tokio::test]
async fn test_circuit_breaker_closed_state() {
let config = CircuitBreakerConfig::new(3, Duration::from_secs(1));
let cb = CircuitBreaker::new("test", config);
assert_eq!(cb.state().await, CircuitState::Closed);
let result = cb.call(|| async { Ok::<_, AiError>(42) }).await;
assert!(result.is_ok());
assert_eq!(cb.state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_opens_on_failures() {
let config = CircuitBreakerConfig::new(3, Duration::from_secs(1))
.with_failure_window(Duration::from_secs(10));
let cb = CircuitBreaker::new("test", config);
for _ in 0..3 {
let _ = cb
.call(|| async { Err::<i32, _>(AiError::ServiceUnavailable) })
.await;
}
assert_eq!(cb.state().await, CircuitState::Open);
let result = cb.call(|| async { Ok::<_, AiError>(42) }).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_circuit_breaker_half_open_transition() {
let config =
CircuitBreakerConfig::new(2, Duration::from_millis(100)).with_success_threshold(2);
let cb = CircuitBreaker::new("test", config);
for _ in 0..2 {
let _ = cb
.call(|| async { Err::<i32, _>(AiError::ServiceUnavailable) })
.await;
}
assert_eq!(cb.state().await, CircuitState::Open);
sleep(Duration::from_millis(150)).await;
let result = cb.call(|| async { Ok::<_, AiError>(42) }).await;
assert!(result.is_ok());
let metrics = cb.metrics().await;
assert_eq!(metrics.success_count, 1);
let result = cb.call(|| async { Ok::<_, AiError>(42) }).await;
assert!(result.is_ok());
assert_eq!(cb.state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_half_open_reopens_on_failure() {
let config = CircuitBreakerConfig::new(2, Duration::from_millis(100));
let cb = CircuitBreaker::new("test", config);
for _ in 0..2 {
let _ = cb
.call(|| async { Err::<i32, _>(AiError::ServiceUnavailable) })
.await;
}
sleep(Duration::from_millis(150)).await;
let _ = cb
.call(|| async { Err::<i32, _>(AiError::ServiceUnavailable) })
.await;
assert_eq!(cb.state().await, CircuitState::Open);
}
#[tokio::test]
async fn test_circuit_breaker_reset() {
let config = CircuitBreakerConfig::new(2, Duration::from_secs(10));
let cb = CircuitBreaker::new("test", config);
for _ in 0..2 {
let _ = cb
.call(|| async { Err::<i32, _>(AiError::ServiceUnavailable) })
.await;
}
assert_eq!(cb.state().await, CircuitState::Open);
cb.reset().await;
assert_eq!(cb.state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_failure_window() {
let config = CircuitBreakerConfig::new(3, Duration::from_secs(1))
.with_failure_window(Duration::from_millis(200));
let cb = CircuitBreaker::new("test", config);
for _ in 0..2 {
let _ = cb
.call(|| async { Err::<i32, _>(AiError::ServiceUnavailable) })
.await;
}
sleep(Duration::from_millis(250)).await;
assert_eq!(cb.state().await, CircuitState::Closed);
let _ = cb
.call(|| async { Err::<i32, _>(AiError::ServiceUnavailable) })
.await;
assert_eq!(cb.state().await, CircuitState::Closed);
}
}