use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone, PartialEq)]
pub enum RetryStrategy {
Immediate,
Fixed(Duration),
Linear {
base_interval: Duration,
max_interval: Duration,
max_retries: u32,
},
Exponential {
base_interval: Duration,
max_interval: Duration,
max_retries: u32,
with_jitter: bool,
},
}
impl Default for RetryStrategy {
fn default() -> Self {
RetryStrategy::Exponential {
base_interval: Duration::from_millis(100),
max_interval: Duration::from_secs(30),
max_retries: 5,
with_jitter: true,
}
}
}
#[derive(Debug, Clone)]
pub enum RetryResult<T> {
Success(T),
Exhausted(T),
Error(String),
}
impl<T> RetryResult<T> {
pub fn is_success(&self) -> bool {
matches!(self, RetryResult::Success(_))
}
pub fn success(&self) -> Option<&T> {
match self {
RetryResult::Success(v) => Some(v),
_ => None,
}
}
}
pub trait RetryableOp<T, E> {
fn execute(&self) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T, E>> + Send>>;
}
#[derive(Clone)]
pub struct Retryer {
strategy: RetryStrategy,
total_retries: Arc<AtomicU64>,
success_count: Arc<AtomicU64>,
exhausted_count: Arc<AtomicU64>,
}
impl Retryer {
pub fn new(strategy: RetryStrategy) -> Self {
Self {
strategy,
total_retries: Arc::new(AtomicU64::new(0)),
success_count: Arc::new(AtomicU64::new(0)),
exhausted_count: Arc::new(AtomicU64::new(0)),
}
}
pub fn default() -> Self {
Self::new(RetryStrategy::default())
}
pub async fn retry<F, T, E>(&self, mut op: F) -> RetryResult<T>
where
F: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
let (max_retries, base_interval, max_interval, with_jitter) = match &self.strategy {
RetryStrategy::Immediate => {
return match op.await {
Ok(v) => {
self.success_count.fetch_add(1, Ordering::Relaxed);
RetryResult::Success(v)
}
Err(e) => {
self.exhausted_count.fetch_add(1, Ordering::Relaxed);
RetryResult::Error(e.to_string())
}
};
}
RetryStrategy::Fixed(interval) => (u32::MAX, *interval, *interval, false),
RetryStrategy::Linear {
base_interval,
max_interval,
max_retries,
} => (*max_retries, *base_interval, *max_interval, false),
RetryStrategy::Exponential {
base_interval,
max_interval,
max_retries,
with_jitter,
} => (*max_retries, *base_interval, *max_interval, *with_jitter),
};
let mut attempt = 0u32;
let mut interval = base_interval;
loop {
match op.await {
Ok(v) => {
self.success_count.fetch_add(1, Ordering::Relaxed);
return RetryResult::Success(v);
}
Err(e) => {
attempt += 1;
self.total_retries.fetch_add(1, Ordering::Relaxed);
if attempt > max_retries {
self.exhausted_count.fetch_add(1, Ordering::Relaxed);
error!("Retry exhausted after {} attempts: {}", attempt, e);
return RetryResult::Exhausted(e);
}
let delay = if with_jitter {
let jitter: f64 = rand::random();
let factor = 0.5 + jitter;
let mut d = interval.mul_f64(factor);
if d > max_interval {
d = max_interval;
}
d
} else {
interval
};
debug!(
"Retry attempt {}/{} after error: {}, waiting {:?}",
attempt, max_retries, e, delay
);
tokio::time::sleep(delay).await;
interval = (interval * 2).min(max_interval);
}
}
}
pub fn stats(&self) -> RetryStats {
let total = self.total_retries.load(Ordering::Relaxed);
let success = self.success_count.load(Ordering::Relaxed);
let exhausted = self.exhausted_count.load(Ordering::Relaxed);
RetryStats {
total_retries: total,
success_count: success,
exhausted_count: exhausted,
success_rate: if success + exhausted > 0 {
success as f64 / (success + exhausted) as f64
} else {
0.0
},
}
}
}
#[derive(Debug, Clone)]
pub struct RetryStats {
pub total_retries: u64,
pub success_count: u64,
pub exhausted_count: u64,
pub success_rate: f64,
}
#[derive(Debug, Clone, PartialEq)]
pub enum CircuitBreakerState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub enum CircuitBreakerEvent {
StateTransition(CircuitBreakerState, CircuitBreakerState),
Success,
Failure,
Rejected,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u64,
pub success_threshold: u64,
pub open_duration: Duration,
pub half_open_max_calls: u64,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 3,
open_duration: Duration::from_secs(60),
half_open_max_calls: 10,
}
}
}
#[derive(Clone)]
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: Arc<Mutex<CircuitBreakerState>>,
state_since: Arc<Mutex<Instant>>,
failure_count: Arc<AtomicU64>,
success_count: Arc<AtomicU64>,
half_open_calls: Arc<AtomicU64>,
total_calls: Arc<AtomicU64>,
rejected_calls: Arc<AtomicU64>,
success_calls: Arc<AtomicU64>,
failed_calls: Arc<AtomicU64>,
events: Arc<Mutex<Vec<CircuitBreakerEvent>>>,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: Arc::new(Mutex::new(CircuitBreakerState::Closed)),
state_since: Arc::new(Mutex::new(Instant::now())),
failure_count: Arc::new(AtomicU64::new(0)),
success_count: Arc::new(AtomicU64::new(0)),
half_open_calls: Arc::new(AtomicU64::new(0)),
total_calls: Arc::new(AtomicU64::new(0)),
rejected_calls: Arc::new(AtomicU64::new(0)),
success_calls: Arc::new(AtomicU64::new(0)),
failed_calls: Arc::new(AtomicU64::new(0)),
events: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn default() -> Self {
Self::new(CircuitBreakerConfig::default())
}
async fn check_permission(&self) -> bool {
let mut state = self.state.lock().await;
let now = Instant::now();
match *state {
CircuitBreakerState::Closed => true,
CircuitBreakerState::Open => {
let state_since = *self.state_since.lock().await;
if now.duration_since(state_since) >= self.config.open_duration {
*state = CircuitBreakerState::HalfOpen;
*self.state_since.lock().await = now;
self.half_open_calls.store(0, Ordering::Relaxed);
self.failure_count.store(0, Ordering::Relaxed);
self.success_count.store(0, Ordering::Relaxed);
self.log_event(CircuitBreakerEvent::StateTransition(
CircuitBreakerState::Open,
CircuitBreakerState::HalfOpen,
));
debug!("Circuit breaker: Open -> HalfOpen");
true
} else {
false
}
}
CircuitBreakerState::HalfOpen => {
let calls = self.half_open_calls.load(Ordering::Relaxed);
calls < self.config.half_open_max_calls
}
}
}
fn log_event(&self, event: CircuitBreakerEvent) {
let mut events = self.events.blocking_lock();
if events.len() >= 100 {
events.remove(0);
}
events.push(event);
}
pub async fn call<F, T, E>(&self, op: F) -> Result<T, CircuitBreakerError<E>>
where
F: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
self.total_calls.fetch_add(1, Ordering::Relaxed);
if !self.check_permission().await {
self.rejected_calls.fetch_add(1, Ordering::Relaxed);
self.log_event(CircuitBreakerEvent::Rejected);
return Err(CircuitBreakerError::Rejected);
}
let state = self.state.lock().await;
match *state {
CircuitBreakerState::HalfOpen => {
self.half_open_calls.fetch_add(1, Ordering::Relaxed);
}
_ => {}
}
drop(state);
match op.await {
Ok(v) => {
self.on_success().await;
Ok(v)
}
Err(e) => {
self.on_failure().await;
Err(CircuitBreakerError::CircuitOpen(e.to_string()))
}
}
}
async fn on_success(&self) {
self.success_calls.fetch_add(1, Ordering::Relaxed);
self.log_event(CircuitBreakerEvent::Success);
let mut state = self.state.lock().await;
let mut success_count = self.success_count.fetch_add(1, Ordering::Relaxed);
success_count += 1;
match *state {
CircuitBreakerState::Closed => {
self.failure_count.store(0, Ordering::Relaxed);
}
CircuitBreakerState::HalfOpen => {
if success_count >= self.config.success_threshold {
*state = CircuitBreakerState::Closed;
*self.state_since.lock().await = Instant::now();
self.log_event(CircuitBreakerEvent::StateTransition(
CircuitBreakerState::HalfOpen,
CircuitBreakerState::Closed,
));
debug!("Circuit breaker: HalfOpen -> Closed");
}
}
_ => {}
}
}
async fn on_failure(&self) {
self.failed_calls.fetch_add(1, Ordering::Relaxed);
self.log_event(CircuitBreakerEvent::Failure);
let mut state = self.state.lock().await;
let mut failure_count = self.failure_count.fetch_add(1, Ordering::Relaxed);
failure_count += 1;
match *state {
CircuitBreakerState::Closed => {
if failure_count >= self.config.failure_threshold {
*state = CircuitBreakerState::Open;
*self.state_since.lock().await = Instant::now();
self.log_event(CircuitBreakerEvent::StateTransition(
CircuitBreakerState::Closed,
CircuitBreakerState::Open,
));
warn!(
"Circuit breaker: Closed -> Open ({} failures)",
failure_count
);
}
}
CircuitBreakerState::HalfOpen => {
*state = CircuitBreakerState::Open;
*self.state_since.lock().await = Instant::now();
self.log_event(CircuitBreakerEvent::StateTransition(
CircuitBreakerState::HalfOpen,
CircuitBreakerState::Open,
));
warn!("Circuit breaker: HalfOpen -> Open (failure in half-open state)");
}
_ => {}
}
}
pub async fn state(&self) -> CircuitBreakerState {
*self.state.lock().await
}
pub async fn state_duration(&self) -> Duration {
let state_since = *self.state_since.lock().await;
Instant::now().duration_since(state_since)
}
pub fn stats(&self) -> CircuitBreakerStats {
CircuitBreakerStats {
total_calls: self.total_calls.load(Ordering::Relaxed),
rejected_calls: self.rejected_calls.load(Ordering::Relaxed),
success_calls: self.success_calls.load(Ordering::Relaxed),
failed_calls: self.failed_calls.load(Ordering::Relaxed),
failure_count: self.failure_count.load(Ordering::Relaxed),
success_count: self.success_count.load(Ordering::Relaxed),
}
}
pub async fn reset(&self) {
let mut state = self.state.lock().await;
*state = CircuitBreakerState::Closed;
*self.state_since.lock().await = Instant::now();
self.failure_count.store(0, Ordering::Relaxed);
self.success_count.store(0, Ordering::Relaxed);
self.half_open_calls.store(0, Ordering::Relaxed);
self.events.lock().clear();
debug!("Circuit breaker manually reset");
}
pub async fn force_open(&self) {
let mut state = self.state.lock().await;
*state = CircuitBreakerState::Open;
*self.state_since.lock().await = Instant::now();
self.log_event(CircuitBreakerEvent::StateTransition(
CircuitBreakerState::Closed,
CircuitBreakerState::Open,
));
debug!("Circuit breaker force opened");
}
}
#[derive(Debug, Clone)]
pub enum CircuitBreakerError<E> {
Rejected,
CircuitOpen(String),
Inner(E),
}
impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitBreakerError::Rejected => write!(f, "Circuit breaker is open, call rejected"),
CircuitBreakerError::CircuitOpen(msg) => {
write!(f, "Circuit breaker is open: {}", msg)
}
CircuitBreakerError::Inner(e) => write!(f, "Inner error: {}", e),
}
}
}
impl<E: std::fmt::Display> std::error::Error for CircuitBreakerError<E> {}
#[derive(Debug, Clone, Default)]
pub struct CircuitBreakerStats {
pub total_calls: u64,
pub rejected_calls: u64,
pub success_calls: u64,
pub failed_calls: u64,
pub failure_count: u64,
pub success_count: u64,
}
#[derive(Clone)]
pub struct ResilientClient {
retryer: Retryer,
circuit_breaker: CircuitBreaker,
}
impl ResilientClient {
pub fn new(retry_strategy: RetryStrategy, circuit_config: CircuitBreakerConfig) -> Self {
Self {
retryer: Retryer::new(retry_strategy),
circuit_breaker: CircuitBreaker::new(circuit_config),
}
}
pub fn default() -> Self {
Self::new(RetryStrategy::default(), CircuitBreakerConfig::default())
}
pub async fn execute<F, T, E>(&self, op: F) -> Result<T, CircuitBreakerError<E>>
where
F: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display + Send + 'static,
{
self.circuit_breaker
.call(async {
self.retryer.retry(op).await.into()
})
.await
}
pub fn stats(&self) -> ResilientClientStats {
let retry_stats = self.retryer.stats();
let cb_stats = self.circuit_breaker.stats();
ResilientClientStats {
retry_total_retries: retry_stats.total_retries,
retry_success_count: retry_stats.success_count,
retry_exhausted_count: retry_stats.exhausted_count,
retry_success_rate: retry_stats.success_rate,
cb_total_calls: cb_stats.total_calls,
cb_rejected_calls: cb_stats.rejected_calls,
cb_success_calls: cb_stats.success_calls,
cb_failed_calls: cb_stats.failed_calls,
}
}
pub async fn circuit_state(&self) -> CircuitBreakerState {
self.circuit_breaker.state().await
}
}
impl<T> From<RetryResult<T>> for Result<T, String> {
fn from(result: RetryResult<T>) -> Self {
match result {
RetryResult::Success(v) => Ok(v),
RetryResult::Exhausted(_) => Err("Retry exhausted".to_string()),
RetryResult::Error(e) => Err(e),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ResilientClientStats {
pub retry_total_retries: u64,
pub retry_success_count: u64,
pub retry_exhausted_count: u64,
pub retry_success_rate: f64,
pub cb_total_calls: u64,
pub cb_rejected_calls: u64,
pub cb_success_calls: u64,
pub cb_failed_calls: u64,
}