use std::time::Duration;
use std::future::Future;
use tokio::time::sleep;
use tracing::{warn, debug};
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_attempts: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub exponential_base: f32,
pub jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
exponential_base: 2.0,
jitter: true,
}
}
}
pub async fn with_retry<F, Fut, T, E>(
config: &RetryConfig,
operation_name: &str,
mut f: F,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
let mut attempt = 0;
let mut delay = config.initial_delay;
loop {
attempt += 1;
match f().await {
Ok(result) => {
if attempt > 1 {
debug!(
"Operation '{}' succeeded after {} attempts",
operation_name, attempt
);
}
return Ok(result);
}
Err(err) => {
if attempt >= config.max_attempts {
warn!(
"Operation '{}' failed after {} attempts: {}",
operation_name, attempt, err
);
return Err(err);
}
warn!(
"Operation '{}' failed (attempt {}/{}): {}. Retrying in {:?}",
operation_name, attempt, config.max_attempts, err, delay
);
sleep(delay).await;
delay = calculate_next_delay(delay, config);
}
}
}
}
fn calculate_next_delay(current: Duration, config: &RetryConfig) -> Duration {
let mut next = current.mul_f32(config.exponential_base);
if config.jitter {
use rand::Rng;
let mut rng = rand::thread_rng();
let jitter_factor = rng.gen_range(0.5..1.5);
next = next.mul_f32(jitter_factor);
}
next.min(config.max_delay)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RetryPolicy {
Always,
Never,
Conditional,
}
pub trait Retryable {
fn is_retryable(&self) -> bool;
fn retry_policy(&self) -> RetryPolicy {
if self.is_retryable() {
RetryPolicy::Always
} else {
RetryPolicy::Never
}
}
}
impl Retryable for crate::UbiquityError {
fn is_retryable(&self) -> bool {
match self {
Self::SocketError(_) => true,
Self::MeshError(msg) => {
msg.contains("connect") || msg.contains("timeout") || msg.contains("refused")
}
Self::ConsciousnessTooLow(_) => false,
Self::CoherenceLoss(_) => false,
Self::AgentNotFound(_) => false,
Self::SerializationError(_) => false,
Self::ConfigError(_) => false,
Self::TaskExecutionError(_) => false,
Self::DatabaseError(_) => false,
Self::LLMError(msg) => {
msg.contains("rate limit") || msg.contains("timeout") || msg.contains("temporary")
}
Self::RateLimitError(_) => true, Self::AuthenticationError(_) => false, Self::Other(_) => false,
Self::CommandExecution(_) => false,
Self::Timeout(_) => true, Self::NotFound(_) => false,
Self::Internal(_) => false,
Self::Network(_) => true, Self::Serialization(_) => false,
Self::Configuration(_) => false,
Self::ResourceExhausted(_) => true, Self::CloudExecution(_) => true, }
}
}
pub struct RetryBuilder {
config: RetryConfig,
}
impl RetryBuilder {
pub fn new() -> Self {
Self {
config: RetryConfig::default(),
}
}
pub fn max_attempts(mut self, attempts: u32) -> Self {
self.config.max_attempts = attempts;
self
}
pub fn initial_delay(mut self, delay: Duration) -> Self {
self.config.initial_delay = delay;
self
}
pub fn max_delay(mut self, delay: Duration) -> Self {
self.config.max_delay = delay;
self
}
pub fn exponential_base(mut self, base: f32) -> Self {
self.config.exponential_base = base;
self
}
pub fn with_jitter(mut self, jitter: bool) -> Self {
self.config.jitter = jitter;
self
}
pub async fn run<F, Fut, T, E>(
self,
operation_name: &str,
f: F,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
with_retry(&self.config, operation_name, f).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[tokio::test]
async fn test_retry_success_first_attempt() {
let config = RetryConfig::default();
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = attempts.clone();
let result = with_retry(&config, "test", || async {
attempts_clone.fetch_add(1, Ordering::SeqCst);
Ok::<_, &'static str>("success")
}).await;
assert_eq!(result, Ok("success"));
assert_eq!(attempts.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_success_after_failures() {
let config = RetryConfig {
initial_delay: Duration::from_millis(10),
..Default::default()
};
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = attempts.clone();
let result = with_retry(&config, "test", || async {
let attempt = attempts_clone.fetch_add(1, Ordering::SeqCst);
if attempt < 2 {
Err("temporary failure")
} else {
Ok("success")
}
}).await;
assert_eq!(result, Ok("success"));
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_exhausted() {
let config = RetryConfig {
max_attempts: 2,
initial_delay: Duration::from_millis(10),
..Default::default()
};
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = attempts.clone();
let result = with_retry(&config, "test", || async {
attempts_clone.fetch_add(1, Ordering::SeqCst);
Err::<(), _>("persistent failure")
}).await;
assert_eq!(result, Err("persistent failure"));
assert_eq!(attempts.load(Ordering::SeqCst), 2);
}
}