use crate::neo_error::{Neo3Error, Neo3Result};
use std::{
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
time::{Duration, Instant},
};
use tokio::sync::RwLock;
#[derive(Debug, Clone, PartialEq, Default)]
pub enum CircuitState {
#[default]
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub timeout: Duration,
pub success_threshold: u32,
pub failure_window: Duration,
pub half_open_max_requests: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
timeout: Duration::from_secs(60),
success_threshold: 3,
failure_window: Duration::from_secs(60),
half_open_max_requests: 3,
}
}
}
#[derive(Debug, Default)]
pub struct CircuitBreakerStats {
pub total_requests: u64,
pub successful_requests: u64,
pub failed_requests: u64,
pub rejected_requests: u64,
pub state_transitions: u64,
pub current_state: CircuitState,
pub last_failure_time: Option<Instant>,
pub last_success_time: Option<Instant>,
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: Arc<RwLock<CircuitState>>,
failure_count: AtomicU32,
success_count: AtomicU32,
half_open_requests: AtomicU32,
last_failure_time: Arc<RwLock<Option<Instant>>>,
last_success_time: Arc<RwLock<Option<Instant>>>,
stats: Arc<RwLock<CircuitBreakerStats>>,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: Arc::new(RwLock::new(CircuitState::Closed)),
failure_count: AtomicU32::new(0),
success_count: AtomicU32::new(0),
half_open_requests: AtomicU32::new(0),
last_failure_time: Arc::new(RwLock::new(None)),
last_success_time: Arc::new(RwLock::new(None)),
stats: Arc::new(RwLock::new(CircuitBreakerStats::default())),
}
}
pub async fn call<F, T>(&self, operation: F) -> Neo3Result<T>
where
F: std::future::Future<Output = Neo3Result<T>>,
{
{
let mut stats = self.stats.write().await;
stats.total_requests += 1;
}
if !self.should_allow_request().await {
let mut stats = self.stats.write().await;
stats.rejected_requests += 1;
return Err(Neo3Error::Network(crate::neo_error::NetworkError::RateLimitExceeded));
}
match operation.await {
Ok(result) => {
self.on_success().await;
Ok(result)
},
Err(error) => {
self.on_failure().await;
Err(error)
},
}
}
async fn should_allow_request(&self) -> bool {
let state = self.state.read().await;
match *state {
CircuitState::Closed => true,
CircuitState::Open => {
if let Some(last_failure) = *self.last_failure_time.read().await {
if last_failure.elapsed() >= self.config.timeout {
drop(state);
self.transition_to_half_open().await;
true
} else {
false
}
} else {
false
}
},
CircuitState::HalfOpen => {
let current_requests = self.half_open_requests.load(Ordering::Relaxed);
current_requests < self.config.half_open_max_requests
},
}
}
async fn on_success(&self) {
let mut stats = self.stats.write().await;
stats.successful_requests += 1;
stats.last_success_time = Some(Instant::now());
drop(stats);
*self.last_success_time.write().await = Some(Instant::now());
let state = self.state.read().await;
match *state {
CircuitState::Closed => {
self.failure_count.store(0, Ordering::Relaxed);
},
CircuitState::HalfOpen => {
let success_count = self.success_count.fetch_add(1, Ordering::Relaxed) + 1;
if success_count >= self.config.success_threshold {
drop(state);
self.transition_to_closed().await;
}
},
CircuitState::Open => {
drop(state);
self.transition_to_closed().await;
},
}
}
async fn on_failure(&self) {
let mut stats = self.stats.write().await;
stats.failed_requests += 1;
stats.last_failure_time = Some(Instant::now());
drop(stats);
*self.last_failure_time.write().await = Some(Instant::now());
let state = self.state.read().await;
match *state {
CircuitState::Closed => {
let failure_count = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
if failure_count >= self.config.failure_threshold {
drop(state);
self.transition_to_open().await;
}
},
CircuitState::HalfOpen => {
drop(state);
self.transition_to_open().await;
},
CircuitState::Open => {
},
}
}
async fn transition_to_closed(&self) {
let mut state = self.state.write().await;
if *state != CircuitState::Closed {
*state = CircuitState::Closed;
self.failure_count.store(0, Ordering::Relaxed);
self.success_count.store(0, Ordering::Relaxed);
self.half_open_requests.store(0, Ordering::Relaxed);
let mut stats = self.stats.write().await;
stats.state_transitions += 1;
stats.current_state = CircuitState::Closed;
}
}
async fn transition_to_open(&self) {
let mut state = self.state.write().await;
if *state != CircuitState::Open {
*state = CircuitState::Open;
self.success_count.store(0, Ordering::Relaxed);
self.half_open_requests.store(0, Ordering::Relaxed);
let mut stats = self.stats.write().await;
stats.state_transitions += 1;
stats.current_state = CircuitState::Open;
}
}
async fn transition_to_half_open(&self) {
let mut state = self.state.write().await;
if *state != CircuitState::HalfOpen {
*state = CircuitState::HalfOpen;
self.success_count.store(0, Ordering::Relaxed);
self.half_open_requests.store(0, Ordering::Relaxed);
let mut stats = self.stats.write().await;
stats.state_transitions += 1;
stats.current_state = CircuitState::HalfOpen;
}
}
pub async fn get_state(&self) -> CircuitState {
let state = self.state.read().await;
state.clone()
}
pub async fn get_stats(&self) -> CircuitBreakerStats {
let stats = self.stats.read().await;
CircuitBreakerStats {
total_requests: stats.total_requests,
successful_requests: stats.successful_requests,
failed_requests: stats.failed_requests,
rejected_requests: stats.rejected_requests,
state_transitions: stats.state_transitions,
current_state: stats.current_state.clone(),
last_failure_time: stats.last_failure_time,
last_success_time: stats.last_success_time,
}
}
pub async fn reset(&self) {
self.transition_to_closed().await;
*self.last_failure_time.write().await = None;
*self.last_success_time.write().await = None;
let mut stats = self.stats.write().await;
*stats = CircuitBreakerStats::default();
}
pub async fn force_open(&self) {
self.transition_to_open().await;
}
pub async fn get_failure_rate(&self) -> f64 {
let stats = self.stats.read().await;
if stats.total_requests == 0 {
0.0
} else {
stats.failed_requests as f64 / stats.total_requests as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{sleep, Duration};
#[tokio::test]
async fn test_circuit_breaker_closed_state() {
let config = CircuitBreakerConfig { failure_threshold: 3, ..Default::default() };
let cb = CircuitBreaker::new(config);
for _ in 0..5 {
let result = cb.call(async { Ok::<(), Neo3Error>(()) }).await;
assert!(result.is_ok());
}
assert_eq!(cb.get_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_opens_on_failures() {
let config = CircuitBreakerConfig { failure_threshold: 3, ..Default::default() };
let cb = CircuitBreaker::new(config);
for _ in 0..3 {
let result = cb
.call(async {
Err::<(), Neo3Error>(Neo3Error::Network(
crate::neo_error::NetworkError::ConnectionFailed("test".to_string()),
))
})
.await;
assert!(result.is_err());
}
assert_eq!(cb.get_state().await, CircuitState::Open);
}
#[tokio::test]
async fn test_circuit_breaker_half_open_transition() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout: Duration::from_millis(100),
..Default::default()
};
let cb = CircuitBreaker::new(config);
for _ in 0..2 {
let _ = cb
.call(async {
Err::<(), Neo3Error>(Neo3Error::Network(
crate::neo_error::NetworkError::ConnectionFailed("test".to_string()),
))
})
.await;
}
assert_eq!(cb.get_state().await, CircuitState::Open);
sleep(Duration::from_millis(150)).await;
let result = cb.call(async { Ok::<(), Neo3Error>(()) }).await;
assert!(result.is_ok());
assert_eq!(cb.get_state().await, CircuitState::HalfOpen);
}
#[tokio::test]
async fn test_circuit_breaker_stats() {
let cb = CircuitBreaker::new(CircuitBreakerConfig::default());
let _ = cb.call(async { Ok::<(), Neo3Error>(()) }).await;
let _ = cb
.call(async {
Err::<(), Neo3Error>(Neo3Error::Network(
crate::neo_error::NetworkError::ConnectionFailed("test".to_string()),
))
})
.await;
let stats = cb.get_stats().await;
assert_eq!(stats.total_requests, 2);
assert_eq!(stats.successful_requests, 1);
assert_eq!(stats.failed_requests, 1);
}
}