use crate::core::{Result, SolanaRecoverError};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
#[derive(Debug, Clone, PartialEq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub timeout: Duration,
pub recovery_timeout: Duration,
pub success_threshold: u32,
pub max_half_open_requests: u32,
pub track_per_request_type: bool,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
timeout: Duration::from_secs(60),
recovery_timeout: Duration::from_secs(30),
success_threshold: 3,
max_half_open_requests: 10,
track_per_request_type: false,
}
}
}
#[derive(Debug)]
pub struct CircuitBreakerMetrics {
pub total_requests: AtomicU64,
pub successful_requests: AtomicU64,
pub failed_requests: AtomicU64,
pub rejected_requests: AtomicU64,
pub circuit_open_count: AtomicU64,
pub last_failure_time: RwLock<Option<Instant>>,
pub last_success_time: RwLock<Option<Instant>>,
}
impl Default for CircuitBreakerMetrics {
fn default() -> Self {
Self {
total_requests: AtomicU64::new(0),
successful_requests: AtomicU64::new(0),
failed_requests: AtomicU64::new(0),
rejected_requests: AtomicU64::new(0),
circuit_open_count: AtomicU64::new(0),
last_failure_time: RwLock::new(None),
last_success_time: RwLock::new(None),
}
}
}
impl CircuitBreakerMetrics {
pub fn success_rate(&self) -> f64 {
let total = self.total_requests.load(Ordering::Relaxed);
if total == 0 {
0.0
} else {
let successful = self.successful_requests.load(Ordering::Relaxed);
(successful as f64 / total as f64) * 100.0
}
}
pub fn failure_rate(&self) -> f64 {
let total = self.total_requests.load(Ordering::Relaxed);
if total == 0 {
0.0
} else {
let failed = self.failed_requests.load(Ordering::Relaxed);
(failed as f64 / total as f64) * 100.0
}
}
pub fn snapshot(&self) -> MetricsSnapshot {
MetricsSnapshot {
total_requests: self.total_requests.load(Ordering::Relaxed),
successful_requests: self.successful_requests.load(Ordering::Relaxed),
failed_requests: self.failed_requests.load(Ordering::Relaxed),
rejected_requests: self.rejected_requests.load(Ordering::Relaxed),
circuit_open_count: self.circuit_open_count.load(Ordering::Relaxed),
success_rate: self.success_rate(),
failure_rate: self.failure_rate(),
last_failure_time: None, last_success_time: None, }
}
pub async fn clone_metrics(&self) -> CircuitBreakerMetrics {
let last_failure = self.last_failure_time.read().await;
let last_success = self.last_success_time.read().await;
CircuitBreakerMetrics {
total_requests: AtomicU64::new(self.total_requests.load(Ordering::Relaxed)),
successful_requests: AtomicU64::new(self.successful_requests.load(Ordering::Relaxed)),
failed_requests: AtomicU64::new(self.failed_requests.load(Ordering::Relaxed)),
rejected_requests: AtomicU64::new(self.rejected_requests.load(Ordering::Relaxed)),
circuit_open_count: AtomicU64::new(self.circuit_open_count.load(Ordering::Relaxed)),
last_failure_time: RwLock::new(*last_failure),
last_success_time: RwLock::new(*last_success),
}
}
}
#[derive(Debug, Clone)]
pub struct MetricsSnapshot {
pub total_requests: u64,
pub successful_requests: u64,
pub failed_requests: u64,
pub rejected_requests: u64,
pub circuit_open_count: u64,
pub success_rate: f64,
pub failure_rate: f64,
pub last_failure_time: Option<Instant>,
pub last_success_time: Option<Instant>,
}
pub struct CircuitBreaker {
state: Arc<RwLock<CircuitState>>,
config: CircuitBreakerConfig,
failure_count: AtomicU32,
half_open_success_count: AtomicU32,
half_open_request_count: AtomicU32,
last_failure_time: Arc<RwLock<Option<Instant>>>,
circuit_open_time: Arc<RwLock<Option<Instant>>>,
metrics: Arc<CircuitBreakerMetrics>,
request_type_breakers: Arc<RwLock<std::collections::HashMap<String, Arc<CircuitBreaker>>>>,
}
impl CircuitBreaker {
pub fn new() -> Self {
Self::with_config(CircuitBreakerConfig::default())
}
pub fn with_config(config: CircuitBreakerConfig) -> Self {
Self {
state: Arc::new(RwLock::new(CircuitState::Closed)),
config,
failure_count: AtomicU32::new(0),
half_open_success_count: AtomicU32::new(0),
half_open_request_count: AtomicU32::new(0),
last_failure_time: Arc::new(RwLock::new(None)),
circuit_open_time: Arc::new(RwLock::new(None)),
metrics: Arc::new(CircuitBreakerMetrics::default()),
request_type_breakers: Arc::new(RwLock::new(std::collections::HashMap::new())),
}
}
pub async fn execute<F, T>(&self, operation: F) -> Result<T>
where
F: std::future::Future<Output = Result<T>>,
{
Box::pin(self.execute_with_type("default", operation)).await
}
pub async fn execute_with_type<F, T>(&self, request_type: &str, operation: F) -> Result<T>
where
F: std::future::Future<Output = Result<T>>,
{
if self.config.track_per_request_type && request_type != "default" {
let mut breakers = self.request_type_breakers.write().await;
let _breaker = breakers.entry(request_type.to_string())
.or_insert_with(|| Arc::new(CircuitBreaker::with_config(self.config.clone())));
drop(breakers);
let breakers = self.request_type_breakers.read().await;
if let Some(breaker) = breakers.get(request_type) {
return breaker.execute(operation).await;
}
}
let state = self.state.read().await;
let should_transition = match *state {
CircuitState::Open => {
let last_failure = self.last_failure_time.read().await;
if let Some(last) = *last_failure {
if last.elapsed() > self.config.timeout {
drop(last_failure);
true
} else {
self.metrics.rejected_requests.fetch_add(1, Ordering::Relaxed);
return Err(SolanaRecoverError::CircuitBreakerOpen(
"Circuit breaker is open - service unavailable".to_string()
));
}
} else {
false
}
}
_ => false,
};
let is_half_open = *state == CircuitState::HalfOpen;
drop(state);
if should_transition {
self.transition_to_half_open().await;
}
if is_half_open {
let half_open_requests = self.half_open_request_count.load(Ordering::Relaxed);
if half_open_requests >= self.config.max_half_open_requests {
self.metrics.rejected_requests.fetch_add(1, Ordering::Relaxed);
return Err(SolanaRecoverError::CircuitBreakerOpen(
"HalfOpen request limit exceeded".to_string()
));
}
}
self.metrics.total_requests.fetch_add(1, Ordering::Relaxed);
if is_half_open {
self.half_open_request_count.fetch_add(1, Ordering::Relaxed);
}
let result = operation.await;
match result {
Ok(value) => {
self.on_success().await;
Ok(value)
}
Err(error) => {
self.on_failure().await;
Err(error)
}
}
}
async fn on_success(&self) {
let mut state = self.state.write().await;
self.metrics.successful_requests.fetch_add(1, Ordering::Relaxed);
{
let mut last_success = self.metrics.last_success_time.write().await;
*last_success = Some(Instant::now());
}
match *state {
CircuitState::HalfOpen => {
let success_count = self.half_open_success_count.fetch_add(1, Ordering::Relaxed) + 1;
if success_count >= self.config.success_threshold {
*state = CircuitState::Closed;
self.failure_count.store(0, Ordering::Relaxed);
self.half_open_success_count.store(0, Ordering::Relaxed);
self.half_open_request_count.store(0, Ordering::Relaxed);
}
}
CircuitState::Closed => {
self.failure_count.store(0, Ordering::Relaxed);
}
CircuitState::Open => {
*state = CircuitState::Closed;
self.failure_count.store(0, Ordering::Relaxed);
}
}
}
async fn on_failure(&self) {
let mut state = self.state.write().await;
self.metrics.failed_requests.fetch_add(1, Ordering::Relaxed);
{
let mut last_failure = self.metrics.last_failure_time.write().await;
*last_failure = Some(Instant::now());
}
match *state {
CircuitState::Closed => {
let count = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
if count >= self.config.failure_threshold {
*state = CircuitState::Open;
let mut circuit_open_time = self.circuit_open_time.write().await;
*circuit_open_time = Some(Instant::now());
drop(circuit_open_time);
self.metrics.circuit_open_count.fetch_add(1, Ordering::Relaxed);
}
}
CircuitState::HalfOpen => {
*state = CircuitState::Open;
let mut circuit_open_time = self.circuit_open_time.write().await;
*circuit_open_time = Some(Instant::now());
drop(circuit_open_time);
self.metrics.circuit_open_count.fetch_add(1, Ordering::Relaxed);
self.half_open_success_count.store(0, Ordering::Relaxed);
self.half_open_request_count.store(0, Ordering::Relaxed);
}
CircuitState::Open => {
}
}
}
async fn transition_to_half_open(&self) {
let mut state = self.state.write().await;
*state = CircuitState::HalfOpen;
self.half_open_success_count.store(0, Ordering::Relaxed);
self.half_open_request_count.store(0, Ordering::Relaxed);
}
pub async fn force_open(&self) {
let mut state = self.state.write().await;
*state = CircuitState::Open;
let mut circuit_open_time = self.circuit_open_time.write().await;
*circuit_open_time = Some(Instant::now());
drop(circuit_open_time);
self.metrics.circuit_open_count.fetch_add(1, Ordering::Relaxed);
}
pub async fn force_close(&self) {
let mut state = self.state.write().await;
*state = CircuitState::Closed;
self.failure_count.store(0, Ordering::Relaxed);
self.half_open_success_count.store(0, Ordering::Relaxed);
self.half_open_request_count.store(0, Ordering::Relaxed);
}
pub async fn get_state(&self) -> CircuitState {
self.state.read().await.clone()
}
pub fn get_metrics(&self) -> Arc<CircuitBreakerMetrics> {
Arc::clone(&self.metrics)
}
pub async fn reset(&self) {
let mut state = self.state.write().await;
*state = CircuitState::Closed;
self.failure_count.store(0, Ordering::Relaxed);
self.half_open_success_count.store(0, Ordering::Relaxed);
self.half_open_request_count.store(0, Ordering::Relaxed);
self.metrics.total_requests.store(0, Ordering::Relaxed);
self.metrics.successful_requests.store(0, Ordering::Relaxed);
self.metrics.failed_requests.store(0, Ordering::Relaxed);
self.metrics.rejected_requests.store(0, Ordering::Relaxed);
self.metrics.circuit_open_count.store(0, Ordering::Relaxed);
let mut last_failure_time = self.metrics.last_failure_time.write().await;
*last_failure_time = None;
drop(last_failure_time);
let mut last_success_time = self.metrics.last_success_time.write().await;
*last_success_time = None;
drop(last_success_time);
let mut last_failure_time2 = self.last_failure_time.write().await;
*last_failure_time2 = None;
drop(last_failure_time2);
let mut circuit_open_time = self.circuit_open_time.write().await;
*circuit_open_time = None;
drop(circuit_open_time);
}
pub async fn is_allowing_requests(&self) -> bool {
let state = self.state.read().await;
match *state {
CircuitState::Closed => true,
CircuitState::HalfOpen => {
let half_open_requests = self.half_open_request_count.load(Ordering::Relaxed);
half_open_requests < self.config.max_half_open_requests
}
CircuitState::Open => false,
}
}
pub async fn time_to_next_state(&self) -> Option<Duration> {
let state = self.state.read().await;
match *state {
CircuitState::Open => {
let last_failure = self.last_failure_time.read().await;
last_failure.map(|last| {
let elapsed = last.elapsed();
if elapsed < self.config.timeout {
self.config.timeout - elapsed
} else {
Duration::ZERO
}
})
}
_ => None,
}
}
}
pub struct CircuitBreakerManager {
breakers: Arc<RwLock<std::collections::HashMap<String, Arc<CircuitBreaker>>>>,
default_config: CircuitBreakerConfig,
}
impl CircuitBreakerManager {
pub fn new(default_config: CircuitBreakerConfig) -> Self {
Self {
breakers: Arc::new(RwLock::new(std::collections::HashMap::new())),
default_config,
}
}
pub async fn get_breaker(&self, service_name: &str) -> Arc<CircuitBreaker> {
let mut breakers = self.breakers.write().await;
breakers.entry(service_name.to_string())
.or_insert_with(|| Arc::new(CircuitBreaker::with_config(self.default_config.clone())))
.clone()
}
pub async fn execute<F, T>(&self, service_name: &str, operation: F) -> Result<T>
where
F: std::future::Future<Output = Result<T>>,
{
let breaker = self.get_breaker(service_name).await;
breaker.execute(operation).await
}
pub async fn get_all_states(&self) -> std::collections::HashMap<String, CircuitState> {
let breakers = self.breakers.read().await;
let mut states = std::collections::HashMap::new();
for (name, breaker) in breakers.iter() {
states.insert(name.clone(), breaker.get_state().await);
}
states
}
pub async fn reset_all(&self) {
let breakers = self.breakers.read().await;
for breaker in breakers.values() {
breaker.reset().await;
}
}
pub async fn close_all(&self) {
let breakers = self.breakers.read().await;
for breaker in breakers.values() {
breaker.force_close().await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::sleep;
#[tokio::test]
async fn test_circuit_breaker_basic_operation() {
let breaker = CircuitBreaker::with_config(CircuitBreakerConfig {
failure_threshold: 3,
timeout: Duration::from_millis(100),
recovery_timeout: Duration::from_millis(50),
success_threshold: 2,
max_half_open_requests: 5,
track_per_request_type: false,
});
let result = breaker.execute(async { Ok(42) }).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(breaker.get_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_failure_threshold() {
let breaker = CircuitBreaker::with_config(CircuitBreakerConfig {
failure_threshold: 2,
timeout: Duration::from_millis(100),
recovery_timeout: Duration::from_millis(50),
success_threshold: 2,
max_half_open_requests: 5,
track_per_request_type: false,
});
for _ in 0..2 {
let result: Result<String> = breaker.execute(async {
Err(SolanaRecoverError::NetworkError("Test error".to_string()))
}).await;
assert!(result.is_err());
}
assert_eq!(breaker.get_state().await, CircuitState::Open);
let _result = breaker.execute(async { Ok(42) }).await;
}
#[tokio::test]
async fn test_circuit_breaker_recovery() {
let breaker = CircuitBreaker::with_config(CircuitBreakerConfig {
failure_threshold: 2,
timeout: Duration::from_millis(50),
recovery_timeout: Duration::from_millis(50),
success_threshold: 2,
max_half_open_requests: 5,
track_per_request_type: false,
});
for _ in 0..2 {
let _: Result<String> = breaker.execute(async {
Err(SolanaRecoverError::NetworkError("Test error".to_string()))
}).await;
}
assert_eq!(breaker.get_state().await, CircuitState::Open);
sleep(Duration::from_millis(60)).await;
let result = breaker.execute(async { Ok(42) }).await;
assert!(result.is_ok());
assert_eq!(breaker.get_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_manager() {
let manager = CircuitBreakerManager::new(CircuitBreakerConfig::default());
let result1 = manager.execute("service1", async { Ok(1) }).await;
let result2 = manager.execute("service2", async { Ok(2) }).await;
assert!(result1.is_ok());
assert!(result2.is_ok());
let states = manager.get_all_states().await;
assert_eq!(states.get("service1"), Some(&CircuitState::Closed));
assert_eq!(states.get("service2"), Some(&CircuitState::Closed));
}
#[tokio::test]
async fn test_circuit_breaker_metrics() {
let breaker = CircuitBreaker::new();
for i in 0..5 {
let result = if i < 3 {
Ok(i)
} else {
Err(SolanaRecoverError::NetworkError("Test error".to_string()))
};
let _ = breaker.execute(async { result }).await;
}
let metrics = breaker.get_metrics();
let snapshot = metrics.snapshot();
assert_eq!(snapshot.total_requests, 5);
assert_eq!(snapshot.successful_requests, 3);
assert_eq!(snapshot.failed_requests, 2);
assert_eq!(snapshot.success_rate, 60.0);
assert_eq!(snapshot.failure_rate, 40.0);
}
}