use crate::{ZoeyError, Result};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tracing::{debug, error, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
pub struct CircuitBreaker {
state: Arc<RwLock<CircuitState>>,
failure_threshold: usize,
success_threshold: usize,
timeout: Duration,
failure_count: Arc<RwLock<usize>>,
success_count: Arc<RwLock<usize>>,
last_failure_time: Arc<RwLock<Option<Instant>>>,
}
impl CircuitBreaker {
pub fn new(failure_threshold: usize, success_threshold: usize, timeout: Duration) -> Self {
Self {
state: Arc::new(RwLock::new(CircuitState::Closed)),
failure_threshold,
success_threshold,
timeout,
failure_count: Arc::new(RwLock::new(0)),
success_count: Arc::new(RwLock::new(0)),
last_failure_time: Arc::new(RwLock::new(None)),
}
}
pub async fn call<F, T, E>(&self, f: F) -> Result<T>
where
F: std::future::Future<Output = std::result::Result<T, E>>,
E: std::fmt::Display,
{
{
let state = *self.state.read().unwrap();
if state == CircuitState::Open {
if let Some(last_failure) = *self.last_failure_time.read().unwrap() {
if last_failure.elapsed() >= self.timeout {
*self.state.write().unwrap() = CircuitState::HalfOpen;
*self.success_count.write().unwrap() = 0;
debug!("Circuit breaker transitioning to half-open");
} else {
return Err(ZoeyError::other("Circuit breaker is open"));
}
}
}
}
match f.await {
Ok(result) => {
self.on_success();
Ok(result)
}
Err(e) => {
self.on_failure();
Err(ZoeyError::other(e.to_string()))
}
}
}
fn on_success(&self) {
let state = *self.state.read().unwrap();
match state {
CircuitState::HalfOpen => {
let mut success_count = self.success_count.write().unwrap();
*success_count += 1;
if *success_count >= self.success_threshold {
*self.state.write().unwrap() = CircuitState::Closed;
*self.failure_count.write().unwrap() = 0;
debug!("Circuit breaker closed");
}
}
CircuitState::Closed => {
*self.failure_count.write().unwrap() = 0;
}
CircuitState::Open => {}
}
}
fn on_failure(&self) {
let state = *self.state.read().unwrap();
match state {
CircuitState::Closed | CircuitState::HalfOpen => {
let mut failure_count = self.failure_count.write().unwrap();
*failure_count += 1;
if *failure_count >= self.failure_threshold {
*self.state.write().unwrap() = CircuitState::Open;
*self.last_failure_time.write().unwrap() = Some(Instant::now());
warn!("Circuit breaker opened after {} failures", failure_count);
}
}
CircuitState::Open => {}
}
}
pub fn state(&self) -> CircuitState {
*self.state.read().unwrap()
}
pub fn reset(&self) {
*self.state.write().unwrap() = CircuitState::Closed;
*self.failure_count.write().unwrap() = 0;
*self.success_count.write().unwrap() = 0;
*self.last_failure_time.write().unwrap() = None;
debug!("Circuit breaker reset");
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum HealthStatus {
Healthy,
Degraded,
Unhealthy,
}
#[derive(Debug, Clone)]
pub struct HealthCheck {
pub name: String,
pub status: HealthStatus,
pub message: Option<String>,
pub last_check: Instant,
pub response_time_ms: u64,
}
pub struct HealthChecker {
checks: Arc<RwLock<HashMap<String, HealthCheck>>>,
}
impl HealthChecker {
pub fn new() -> Self {
Self {
checks: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn check<F, T, E>(&self, name: &str, f: F) -> HealthStatus
where
F: std::future::Future<Output = std::result::Result<T, E>>,
E: std::fmt::Display,
{
let start = Instant::now();
let status = match f.await {
Ok(_) => HealthStatus::Healthy,
Err(e) => {
error!("Health check failed for {}: {}", name, e);
HealthStatus::Unhealthy
}
};
let response_time_ms = start.elapsed().as_millis() as u64;
let final_status = if status == HealthStatus::Healthy && response_time_ms > 1000 {
HealthStatus::Degraded
} else {
status
};
let check = HealthCheck {
name: name.to_string(),
status: final_status,
message: None,
last_check: Instant::now(),
response_time_ms,
};
self.checks.write().unwrap().insert(name.to_string(), check);
final_status
}
pub fn overall_health(&self) -> HealthStatus {
let checks = self.checks.read().unwrap();
if checks.is_empty() {
return HealthStatus::Healthy;
}
let mut has_unhealthy = false;
let mut has_degraded = false;
for check in checks.values() {
match check.status {
HealthStatus::Unhealthy => has_unhealthy = true,
HealthStatus::Degraded => has_degraded = true,
HealthStatus::Healthy => {}
}
}
if has_unhealthy {
HealthStatus::Unhealthy
} else if has_degraded {
HealthStatus::Degraded
} else {
HealthStatus::Healthy
}
}
pub fn get_all_checks(&self) -> Vec<HealthCheck> {
self.checks.read().unwrap().values().cloned().collect()
}
pub fn get_check(&self, name: &str) -> Option<HealthCheck> {
self.checks.read().unwrap().get(name).cloned()
}
}
impl Default for HealthChecker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: usize,
pub initial_delay: Duration,
pub max_delay: Duration,
pub multiplier: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(10),
multiplier: 2.0,
}
}
}
pub async fn retry_with_backoff<F, T, E>(config: RetryConfig, mut f: F) -> Result<T>
where
F: FnMut() -> std::pin::Pin<
Box<dyn std::future::Future<Output = std::result::Result<T, E>> + Send>,
>,
E: std::fmt::Display,
{
let mut attempt = 0;
let mut delay = config.initial_delay;
loop {
match f().await {
Ok(result) => return Ok(result),
Err(e) => {
attempt += 1;
if attempt > config.max_retries {
error!("All {} retry attempts failed", config.max_retries);
return Err(ZoeyError::other(format!(
"Retry failed after {} attempts: {}",
config.max_retries, e
)));
}
warn!("Attempt {} failed: {}. Retrying in {:?}", attempt, e, delay);
tokio::time::sleep(delay).await;
delay = Duration::from_millis(
((delay.as_millis() as f64) * config.multiplier)
.min(config.max_delay.as_millis() as f64) as u64,
);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_circuit_breaker_closed() {
let cb = CircuitBreaker::new(3, 2, Duration::from_secs(5));
assert_eq!(cb.state(), CircuitState::Closed);
let result = cb.call(async { Ok::<_, String>(42) }).await;
assert!(result.is_ok());
assert_eq!(cb.state(), CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_opens() {
let cb = CircuitBreaker::new(3, 2, Duration::from_secs(5));
for _ in 0..3 {
let _ = cb.call(async { Err::<(), _>("error") }).await;
}
assert_eq!(cb.state(), CircuitState::Open);
let result = cb.call(async { Ok::<_, String>(42) }).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_health_checker() {
let checker = HealthChecker::new();
let status = checker
.check("test_service", async { Ok::<_, String>(()) })
.await;
assert_eq!(status, HealthStatus::Healthy);
let status = checker
.check("failing_service", async { Err::<(), _>("error") })
.await;
assert_eq!(status, HealthStatus::Unhealthy);
assert_eq!(checker.overall_health(), HealthStatus::Unhealthy);
}
#[tokio::test]
async fn test_retry_success() {
let config = RetryConfig {
max_retries: 3,
initial_delay: Duration::from_millis(10),
max_delay: Duration::from_millis(100),
multiplier: 2.0,
};
let mut attempts = 0;
let result = retry_with_backoff(config, || {
attempts += 1;
Box::pin(async move {
if attempts < 2 {
Err("not yet")
} else {
Ok(42)
}
})
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_retry_failure() {
let config = RetryConfig {
max_retries: 2,
initial_delay: Duration::from_millis(10),
max_delay: Duration::from_millis(100),
multiplier: 2.0,
};
let result =
retry_with_backoff(config, || Box::pin(async { Err::<(), _>("always fails") })).await;
assert!(result.is_err());
}
}