use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
pub mod circuit_breaker;
pub mod error;
pub mod fallback;
pub mod retry;
pub mod timeout;
pub use circuit_breaker::{CircuitBreaker, CircuitBreakerConfig, CircuitState};
pub use error::{AIError, AIErrorKind, ErrorContext};
pub use fallback::{FallbackChain, FallbackStrategy};
pub use retry::{ExponentialBackoff, RetryPolicy, RetryStrategy};
pub use timeout::{TimeoutConfig, TimeoutGuard};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResilienceConfig {
pub enable_retry: bool,
pub max_retries: usize,
pub initial_retry_delay_ms: u64,
pub max_retry_delay_ms: u64,
pub enable_circuit_breaker: bool,
pub circuit_breaker_threshold: usize,
pub circuit_breaker_timeout_secs: u64,
pub enable_graceful_degradation: bool,
pub request_timeout_secs: u64,
pub enable_fallback: bool,
}
impl Default for ResilienceConfig {
fn default() -> Self {
Self {
enable_retry: true,
max_retries: 3,
initial_retry_delay_ms: 1000,
max_retry_delay_ms: 30000,
enable_circuit_breaker: true,
circuit_breaker_threshold: 5,
circuit_breaker_timeout_secs: 60,
enable_graceful_degradation: true,
request_timeout_secs: 30,
enable_fallback: true,
}
}
}
pub struct ResilienceManager {
config: ResilienceConfig,
retry_policy: RetryPolicy,
circuit_breaker: Arc<RwLock<CircuitBreaker>>,
fallback_chain: FallbackChain,
timeout_config: TimeoutConfig,
metrics: Arc<RwLock<ResilienceMetrics>>,
}
#[derive(Debug, Clone, Default)]
pub struct ResilienceMetrics {
pub total_requests: u64,
pub successful_requests: u64,
pub failed_requests: u64,
pub retried_requests: u64,
pub circuit_breaker_opens: u64,
pub fallback_invocations: u64,
pub timeout_errors: u64,
pub avg_response_time_ms: f64,
}
impl ResilienceManager {
pub fn new(config: ResilienceConfig) -> Result<Self> {
let retry_policy = RetryPolicy::new(
RetryStrategy::ExponentialBackoff(ExponentialBackoff {
initial_delay: Duration::from_millis(config.initial_retry_delay_ms),
max_delay: Duration::from_millis(config.max_retry_delay_ms),
multiplier: 2.0,
jitter: true,
}),
config.max_retries,
);
let circuit_breaker = Arc::new(RwLock::new(CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: config.circuit_breaker_threshold,
success_threshold: 2,
timeout: Duration::from_secs(config.circuit_breaker_timeout_secs),
half_open_max_calls: 1,
})));
let timeout_config = TimeoutConfig {
request_timeout: Duration::from_secs(config.request_timeout_secs),
read_timeout: Duration::from_secs(config.request_timeout_secs / 2),
connect_timeout: Duration::from_secs(10),
};
Ok(Self {
config,
retry_policy,
circuit_breaker,
fallback_chain: FallbackChain::new(),
timeout_config,
metrics: Arc::new(RwLock::new(ResilienceMetrics::default())),
})
}
pub async fn execute<F, T>(&self, mut operation: F) -> Result<T>
where
F: FnMut() -> Result<T> + Send,
T: Clone,
{
let start = Instant::now();
{
let mut metrics = self
.metrics
.write()
.map_err(|e| anyhow!("Failed to acquire write lock: {}", e))?;
metrics.total_requests += 1;
}
if self.config.enable_circuit_breaker {
let mut breaker = self
.circuit_breaker
.write()
.map_err(|e| anyhow!("Failed to acquire write lock: {}", e))?;
if !breaker.can_execute() {
let mut metrics = self
.metrics
.write()
.map_err(|e| anyhow!("Failed to acquire write lock: {}", e))?;
metrics.failed_requests += 1;
return Err(anyhow!("Circuit breaker is open"));
}
}
let result = if self.config.enable_retry {
self.execute_with_retry(&mut operation).await
} else {
operation()
};
{
let mut breaker = self
.circuit_breaker
.write()
.map_err(|e| anyhow!("Failed to acquire write lock: {}", e))?;
let mut metrics = self
.metrics
.write()
.map_err(|e| anyhow!("Failed to acquire write lock: {}", e))?;
match &result {
Ok(_) => {
breaker.record_success();
metrics.successful_requests += 1;
}
Err(_) => {
breaker.record_failure();
metrics.failed_requests += 1;
}
}
let elapsed = start.elapsed().as_millis() as f64;
metrics.avg_response_time_ms =
(metrics.avg_response_time_ms * (metrics.total_requests - 1) as f64 + elapsed)
/ metrics.total_requests as f64;
}
result
}
async fn execute_with_retry<F, T>(&self, operation: &mut F) -> Result<T>
where
F: FnMut() -> Result<T> + Send,
{
let mut attempt = 0;
loop {
match operation() {
Ok(result) => return Ok(result),
Err(_) if attempt < self.retry_policy.max_attempts => {
attempt += 1;
{
let mut metrics = self
.metrics
.write()
.map_err(|err| anyhow!("Failed to acquire write lock: {}", err))?;
metrics.retried_requests += 1;
}
let delay = self.retry_policy.strategy.calculate_delay(attempt);
tokio::time::sleep(delay).await;
}
Err(e) => return Err(e),
}
}
}
pub async fn execute_with_timeout<F, T>(&self, operation: F) -> Result<T>
where
F: std::future::Future<Output = Result<T>> + Send,
{
match tokio::time::timeout(self.timeout_config.request_timeout, operation).await {
Ok(result) => result,
Err(_) => {
let mut metrics = self
.metrics
.write()
.map_err(|e| anyhow!("Failed to acquire write lock: {}", e))?;
metrics.timeout_errors += 1;
Err(anyhow!("Operation timed out"))
}
}
}
pub fn metrics(&self) -> Result<ResilienceMetrics> {
let metrics = self
.metrics
.read()
.map_err(|e| anyhow!("Failed to acquire read lock: {}", e))?;
Ok(metrics.clone())
}
pub fn circuit_state(&self) -> Result<CircuitState> {
let breaker = self
.circuit_breaker
.read()
.map_err(|e| anyhow!("Failed to acquire read lock: {}", e))?;
Ok(breaker.state())
}
pub fn reset(&self) -> Result<()> {
let mut breaker = self
.circuit_breaker
.write()
.map_err(|e| anyhow!("Failed to acquire write lock: {}", e))?;
breaker.reset();
let mut metrics = self
.metrics
.write()
.map_err(|e| anyhow!("Failed to acquire write lock: {}", e))?;
*metrics = ResilienceMetrics::default();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_resilience_manager_success() {
let manager = ResilienceManager::new(ResilienceConfig::default()).expect("should succeed");
let result = manager.execute(|| Ok::<_, anyhow::Error>(42)).await;
assert!(result.is_ok());
assert_eq!(result.expect("should succeed"), 42);
let metrics = manager.metrics().expect("should succeed");
assert_eq!(metrics.total_requests, 1);
assert_eq!(metrics.successful_requests, 1);
}
#[tokio::test]
async fn test_resilience_manager_retry() {
let manager = ResilienceManager::new(ResilienceConfig {
max_retries: 2,
initial_retry_delay_ms: 10,
..Default::default()
})
.expect("should succeed");
let mut call_count = 0;
let result = manager
.execute(|| {
call_count += 1;
if call_count < 2 {
Err(anyhow!("Temporary failure"))
} else {
Ok(42)
}
})
.await;
assert!(result.is_ok());
assert_eq!(call_count, 2);
}
#[tokio::test]
async fn test_circuit_breaker_opens() {
let manager = ResilienceManager::new(ResilienceConfig {
circuit_breaker_threshold: 2,
enable_retry: false,
..Default::default()
})
.expect("should succeed");
for _ in 0..2 {
let _ = manager.execute(|| Err::<(), _>(anyhow!("Error"))).await;
}
let state = manager.circuit_state().expect("should succeed");
assert_eq!(state, CircuitState::Open);
}
}