use super::{ResilienceError, ResilienceResult};
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecoveryPolicy {
pub retry: Option<RetryConfig>,
pub strategy: RecoveryStrategy,
pub log_recoveries: bool,
}
impl Default for RecoveryPolicy {
fn default() -> Self {
Self {
retry: Some(RetryConfig::default()),
strategy: RecoveryStrategy::Retry,
log_recoveries: true,
}
}
}
impl RecoveryPolicy {
pub fn production() -> Self {
Self {
retry: Some(RetryConfig::exponential(3, Duration::from_millis(100))),
strategy: RecoveryStrategy::Retry,
log_recoveries: true,
}
}
pub fn development() -> Self {
Self {
retry: Some(RetryConfig::fixed(2, Duration::from_millis(50))),
strategy: RecoveryStrategy::Retry,
log_recoveries: true,
}
}
pub fn none() -> Self {
Self {
retry: None,
strategy: RecoveryStrategy::FailFast,
log_recoveries: false,
}
}
pub fn with_retry(mut self, config: RetryConfig) -> Self {
self.retry = Some(config);
self
}
pub fn with_strategy(mut self, strategy: RecoveryStrategy) -> Self {
self.strategy = strategy;
self
}
pub async fn execute<F, Fut, T, E>(&self, f: F) -> ResilienceResult<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: Into<crate::error::KernelError> + std::fmt::Debug,
{
match self.strategy {
RecoveryStrategy::FailFast => f()
.await
.map_err(|e| ResilienceError::KernelError(e.into())),
RecoveryStrategy::Retry => {
if let Some(ref retry) = self.retry {
retry.execute(f).await
} else {
f().await
.map_err(|e| ResilienceError::KernelError(e.into()))
}
}
RecoveryStrategy::Skip => {
f().await
.map_err(|e| ResilienceError::KernelError(e.into()))
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RecoveryStrategy {
FailFast,
#[default]
Retry,
Skip,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff: BackoffStrategy,
pub jitter: f64,
pub retry_all_errors: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(10),
backoff: BackoffStrategy::Exponential { factor: 2.0 },
jitter: 0.1,
retry_all_errors: true,
}
}
}
impl RetryConfig {
pub fn exponential(max_retries: u32, initial_delay: Duration) -> Self {
Self {
max_retries,
initial_delay,
backoff: BackoffStrategy::Exponential { factor: 2.0 },
..Default::default()
}
}
pub fn fixed(max_retries: u32, delay: Duration) -> Self {
Self {
max_retries,
initial_delay: delay,
backoff: BackoffStrategy::Fixed,
..Default::default()
}
}
pub fn linear(max_retries: u32, initial_delay: Duration) -> Self {
Self {
max_retries,
initial_delay,
backoff: BackoffStrategy::Linear {
increment: initial_delay,
},
..Default::default()
}
}
pub fn max_retries(mut self, max: u32) -> Self {
self.max_retries = max;
self
}
pub fn initial_delay(mut self, delay: Duration) -> Self {
self.initial_delay = delay;
self
}
pub fn max_delay(mut self, delay: Duration) -> Self {
self.max_delay = delay;
self
}
pub fn jitter(mut self, jitter: f64) -> Self {
self.jitter = jitter.clamp(0.0, 1.0);
self
}
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
let base_delay = match self.backoff {
BackoffStrategy::Fixed => self.initial_delay,
BackoffStrategy::Linear { increment } => self.initial_delay + increment * attempt,
BackoffStrategy::Exponential { factor } => {
let multiplier = factor.powi(attempt as i32);
Duration::from_secs_f64(self.initial_delay.as_secs_f64() * multiplier)
}
};
let capped = base_delay.min(self.max_delay);
if self.jitter > 0.0 {
let jitter_range = capped.as_secs_f64() * self.jitter;
let jitter_amount = rand::random::<f64>() * jitter_range * 2.0 - jitter_range;
Duration::from_secs_f64((capped.as_secs_f64() + jitter_amount).max(0.0))
} else {
capped
}
}
pub async fn execute<F, Fut, T, E>(&self, f: F) -> ResilienceResult<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: Into<crate::error::KernelError> + std::fmt::Debug,
{
let mut last_error = None;
for attempt in 0..=self.max_retries {
match f().await {
Ok(result) => {
if attempt > 0 {
tracing::info!(attempt = attempt, "Operation succeeded after retry");
}
return Ok(result);
}
Err(e) => {
let kernel_error: crate::error::KernelError = e.into();
if !self.retry_all_errors || attempt >= self.max_retries {
tracing::warn!(
attempt = attempt,
error = ?kernel_error,
"Operation failed, no more retries"
);
return Err(ResilienceError::MaxRetriesExceeded {
retries: self.max_retries,
});
}
let delay = self.delay_for_attempt(attempt);
tracing::debug!(
attempt = attempt,
delay = ?delay,
error = ?kernel_error,
"Operation failed, retrying"
);
tokio::time::sleep(delay).await;
last_error = Some(kernel_error);
}
}
}
Err(last_error.map(ResilienceError::KernelError).unwrap_or(
ResilienceError::MaxRetriesExceeded {
retries: self.max_retries,
},
))
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum BackoffStrategy {
Fixed,
Linear {
increment: Duration,
},
Exponential {
factor: f64,
},
}
impl Default for BackoffStrategy {
fn default() -> Self {
Self::Exponential { factor: 2.0 }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_config_exponential() {
let config = RetryConfig::exponential(3, Duration::from_millis(100));
assert_eq!(config.max_retries, 3);
assert_eq!(config.initial_delay, Duration::from_millis(100));
let config = RetryConfig::exponential(3, Duration::from_millis(100)).jitter(0.0);
assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
}
#[test]
fn test_retry_config_fixed() {
let config = RetryConfig::fixed(5, Duration::from_millis(50)).jitter(0.0);
assert_eq!(config.delay_for_attempt(0), Duration::from_millis(50));
assert_eq!(config.delay_for_attempt(1), Duration::from_millis(50));
assert_eq!(config.delay_for_attempt(5), Duration::from_millis(50));
}
#[test]
fn test_retry_config_linear() {
let config = RetryConfig::linear(3, Duration::from_millis(100)).jitter(0.0);
assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
assert_eq!(config.delay_for_attempt(2), Duration::from_millis(300));
}
#[test]
fn test_max_delay_cap() {
let config = RetryConfig::exponential(10, Duration::from_secs(1))
.max_delay(Duration::from_secs(5))
.jitter(0.0);
assert_eq!(config.delay_for_attempt(10), Duration::from_secs(5));
}
#[test]
fn test_recovery_policy() {
let policy = RecoveryPolicy::production();
assert!(policy.retry.is_some());
assert_eq!(policy.strategy, RecoveryStrategy::Retry);
}
#[test]
fn test_recovery_policy_none() {
let policy = RecoveryPolicy::none();
assert!(policy.retry.is_none());
assert_eq!(policy.strategy, RecoveryStrategy::FailFast);
}
}