use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct CircuitBreaker {
state: Arc<RwLock<CircuitBreakerState>>,
config: BreakerConfig,
initialized: bool,
}
impl CircuitBreaker {
pub fn new() -> Self {
Self {
state: Arc::new(RwLock::new(CircuitBreakerState::new())),
config: BreakerConfig::default(),
initialized: false,
}
}
pub fn with_config(config: BreakerConfig) -> Self {
Self {
state: Arc::new(RwLock::new(CircuitBreakerState::new())),
config,
initialized: false,
}
}
pub async fn initialize(&mut self) -> Result<(), BreakerError> {
let mut state = self.state.write().await;
state.reset();
self.initialized = true;
Ok(())
}
pub async fn shutdown(&mut self) -> Result<(), BreakerError> {
self.initialized = false;
Ok(())
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
pub async fn can_execute(&self) -> Result<bool, BreakerError> {
if !self.initialized {
return Err(BreakerError::NotInitialized);
}
let state = self.state.read().await;
Ok(state.can_execute())
}
pub async fn execute<F, T, E>(&self, operation: F) -> Result<T, BreakerError>
where
F: Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T, E>> + Send>>,
E: std::error::Error + Send + Sync + 'static,
{
if !self.initialized {
return Err(BreakerError::NotInitialized);
}
if !self.can_execute().await? {
return Err(BreakerError::CircuitOpen);
}
let result = operation().await;
let mut state = self.state.write().await;
match result {
Ok(_) => {
state.record_success();
Ok(result.unwrap())
}
Err(_) => {
state.record_failure(self.config.failure_threshold);
Err(BreakerError::OperationFailed)
}
}
}
pub async fn record_success(&self) -> Result<(), BreakerError> {
if !self.initialized {
return Err(BreakerError::NotInitialized);
}
let mut state = self.state.write().await;
state.record_success();
Ok(())
}
pub async fn record_failure(&self) -> Result<(), BreakerError> {
if !self.initialized {
return Err(BreakerError::NotInitialized);
}
let mut state = self.state.write().await;
state.record_failure(self.config.failure_threshold);
Ok(())
}
pub async fn get_state(&self) -> CircuitState {
let state = self.state.read().await;
state.get_state()
}
pub async fn get_status(&self) -> Result<CircuitBreakerStatus, BreakerError> {
if !self.initialized {
return Err(BreakerError::NotInitialized);
}
let state = self.state.read().await;
Ok(state.get_status())
}
pub async fn reset(&self) -> Result<(), BreakerError> {
if !self.initialized {
return Err(BreakerError::NotInitialized);
}
let mut state = self.state.write().await;
state.reset();
Ok(())
}
}
#[derive(Debug, Clone)]
struct CircuitBreakerState {
state: CircuitState,
failure_count: usize,
success_count: usize,
last_failure_time: Option<Instant>,
last_success_time: Option<Instant>,
}
impl CircuitBreakerState {
fn new() -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
last_failure_time: None,
last_success_time: None,
}
}
fn can_execute(&self) -> bool {
match self.state {
CircuitState::Closed => true,
CircuitState::Open => {
if let Some(last_failure) = self.last_failure_time {
last_failure.elapsed() >= Duration::from_secs(60) } else {
false
}
}
CircuitState::HalfOpen => true,
}
}
fn record_success(&mut self) {
self.success_count += 1;
self.failure_count = 0;
self.last_success_time = Some(Instant::now());
if self.state == CircuitState::HalfOpen && self.success_count >= 3 {
self.state = CircuitState::Closed;
self.success_count = 0;
}
}
fn record_failure(&mut self, failure_threshold: usize) {
self.failure_count += 1;
self.success_count = 0;
self.last_failure_time = Some(Instant::now());
if self.failure_count >= failure_threshold {
self.state = CircuitState::Open;
} else if self.state == CircuitState::HalfOpen {
self.state = CircuitState::Open;
}
}
fn get_state(&self) -> CircuitState {
self.state.clone()
}
fn get_status(&self) -> CircuitBreakerStatus {
CircuitBreakerStatus {
state: self.state.clone(),
failure_count: self.failure_count,
success_count: self.success_count,
last_failure_time: self.last_failure_time,
last_success_time: self.last_success_time,
}
}
fn reset(&mut self) {
self.state = CircuitState::Closed;
self.failure_count = 0;
self.success_count = 0;
self.last_failure_time = None;
self.last_success_time = None;
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CircuitBreakerStatus {
pub state: CircuitState,
pub failure_count: usize,
pub success_count: usize,
#[serde(skip_serializing, skip_deserializing)]
pub last_failure_time: Option<Instant>,
#[serde(skip_serializing, skip_deserializing)]
pub last_success_time: Option<Instant>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct BreakerConfig {
pub failure_threshold: usize,
pub success_threshold: usize,
pub timeout: Duration,
pub enabled: bool,
}
impl Default for BreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 3,
timeout: Duration::from_secs(60),
enabled: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum BreakerError {
NotInitialized,
CircuitOpen,
OperationFailed,
ConfigurationError(String),
}
impl std::fmt::Display for BreakerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BreakerError::NotInitialized => write!(f, "Circuit breaker not initialized"),
BreakerError::CircuitOpen => write!(f, "Circuit breaker is open"),
BreakerError::OperationFailed => write!(f, "Operation failed"),
BreakerError::ConfigurationError(msg) => write!(f, "Configuration error: {}", msg),
}
}
}
impl std::error::Error for BreakerError {}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_circuit_breaker_creation() {
let breaker = CircuitBreaker::new();
assert!(!breaker.is_initialized());
}
#[tokio::test]
async fn test_circuit_breaker_initialization() {
let mut breaker = CircuitBreaker::new();
let result = breaker.initialize().await;
assert!(result.is_ok());
assert!(breaker.is_initialized());
}
#[tokio::test]
async fn test_circuit_breaker_shutdown() {
let mut breaker = CircuitBreaker::new();
breaker.initialize().await.unwrap();
let result = breaker.shutdown().await;
assert!(result.is_ok());
assert!(!breaker.is_initialized());
}
#[tokio::test]
async fn test_circuit_breaker_can_execute() {
let mut breaker = CircuitBreaker::new();
breaker.initialize().await.unwrap();
let can_execute = breaker.can_execute().await.unwrap();
assert!(can_execute);
}
#[tokio::test]
async fn test_circuit_breaker_success() {
let mut breaker = CircuitBreaker::new();
breaker.initialize().await.unwrap();
let result = breaker.execute(|| {
Box::pin(async { Ok::<i32, std::io::Error>(42) })
}).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
let status = breaker.get_status().await.unwrap();
assert_eq!(status.state, CircuitState::Closed);
assert_eq!(status.failure_count, 0);
assert_eq!(status.success_count, 1);
}
#[tokio::test]
async fn test_circuit_breaker_failure() {
let mut breaker = CircuitBreaker::new();
breaker.initialize().await.unwrap();
let result = breaker.execute(|| {
Box::pin(async { Err::<i32, std::io::Error>(std::io::Error::new(std::io::ErrorKind::Other, "Operation failed")) })
}).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), BreakerError::OperationFailed));
let status = breaker.get_status().await.unwrap();
assert_eq!(status.state, CircuitState::Closed);
assert_eq!(status.failure_count, 1);
assert_eq!(status.success_count, 0);
}
#[tokio::test]
async fn test_circuit_breaker_opens_after_failures() {
let mut breaker = CircuitBreaker::new();
breaker.initialize().await.unwrap();
for _ in 0..5 {
let _ = breaker.execute(|| {
Box::pin(async { Err::<i32, std::io::Error>(std::io::Error::new(std::io::ErrorKind::Other, "Operation failed")) })
}).await;
}
let status = breaker.get_status().await.unwrap();
assert_eq!(status.state, CircuitState::Open);
assert_eq!(status.failure_count, 5);
let result = breaker.execute(|| {
Box::pin(async { Ok::<i32, std::io::Error>(42) })
}).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), BreakerError::CircuitOpen));
}
#[tokio::test]
async fn test_circuit_breaker_reset() {
let mut breaker = CircuitBreaker::new();
breaker.initialize().await.unwrap();
for _ in 0..5 {
let _ = breaker.execute(|| {
Box::pin(async { Err::<i32, std::io::Error>(std::io::Error::new(std::io::ErrorKind::Other, "Operation failed")) })
}).await;
}
let status = breaker.get_status().await.unwrap();
assert_eq!(status.state, CircuitState::Open);
breaker.reset().await.unwrap();
let status = breaker.get_status().await.unwrap();
assert_eq!(status.state, CircuitState::Closed);
assert_eq!(status.failure_count, 0);
assert_eq!(status.success_count, 0);
let can_execute = breaker.can_execute().await.unwrap();
assert!(can_execute);
}
#[tokio::test]
async fn test_circuit_breaker_record_success() {
let mut breaker = CircuitBreaker::new();
breaker.initialize().await.unwrap();
breaker.record_success().await.unwrap();
let status = breaker.get_status().await.unwrap();
assert_eq!(status.success_count, 1);
assert_eq!(status.failure_count, 0);
}
#[tokio::test]
async fn test_circuit_breaker_record_failure() {
let mut breaker = CircuitBreaker::new();
breaker.initialize().await.unwrap();
breaker.record_failure().await.unwrap();
let status = breaker.get_status().await.unwrap();
assert_eq!(status.failure_count, 1);
assert_eq!(status.success_count, 0);
}
#[tokio::test]
async fn test_circuit_breaker_with_config() {
let config = BreakerConfig {
failure_threshold: 3,
success_threshold: 2,
timeout: Duration::from_secs(30),
enabled: true,
};
let mut breaker = CircuitBreaker::with_config(config);
breaker.initialize().await.unwrap();
for _ in 0..3 {
breaker.record_failure().await.unwrap();
}
let status = breaker.get_status().await.unwrap();
assert_eq!(status.state, CircuitState::Open);
assert_eq!(status.failure_count, 3);
}
#[tokio::test]
async fn test_circuit_breaker_not_initialized() {
let breaker = CircuitBreaker::new();
let result = breaker.can_execute().await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), BreakerError::NotInitialized));
let result = breaker.execute(|| {
Box::pin(async { Ok::<i32, std::io::Error>(42) })
}).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), BreakerError::NotInitialized));
}
#[test]
fn test_breaker_config_default() {
let config = BreakerConfig::default();
assert_eq!(config.failure_threshold, 5);
assert_eq!(config.success_threshold, 3);
assert_eq!(config.timeout, Duration::from_secs(60));
assert!(config.enabled);
}
#[test]
fn test_circuit_state() {
assert_eq!(CircuitState::Closed, CircuitState::Closed);
assert_eq!(CircuitState::Open, CircuitState::Open);
assert_eq!(CircuitState::HalfOpen, CircuitState::HalfOpen);
}
#[test]
fn test_breaker_error_display() {
let error = BreakerError::CircuitOpen;
let error_string = format!("{}", error);
assert!(error_string.contains("Circuit breaker is open"));
}
}