use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
pub max_retries: u32,
pub base_delay_ms: u64,
pub max_delay_ms: u64,
pub jitter_factor: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
base_delay_ms: 500,
max_delay_ms: 30_000,
jitter_factor: 0.5,
}
}
}
pub struct RetryManager {
config: RetryConfig,
}
impl RetryManager {
#[must_use]
pub fn new(config: RetryConfig) -> Self {
Self { config }
}
pub async fn with_retry<F, Fut, T>(&self, mut f: F) -> anyhow::Result<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = anyhow::Result<T>>,
{
let mut attempt = 0u32;
loop {
match f().await {
Ok(value) => return Ok(value),
Err(e) => {
let is_retryable = e
.downcast_ref::<crate::error::HooshError>()
.map(|he| he.is_retryable())
.unwrap_or(false);
if !is_retryable || attempt >= self.config.max_retries {
if attempt > 0 {
tracing::warn!(
attempt,
max_retries = self.config.max_retries,
"retry exhausted or permanent error"
);
}
return Err(e);
}
let delay = self.compute_delay(attempt);
tracing::warn!(
attempt = attempt + 1,
max_retries = self.config.max_retries,
delay_ms = delay.as_millis() as u64,
error = %e,
"retrying after transient error"
);
tokio::time::sleep(delay).await;
attempt += 1;
}
}
}
}
#[must_use]
fn compute_delay(&self, attempt: u32) -> Duration {
let base = self.config.base_delay_ms as f64;
let exp_delay = base * 2.0f64.powi(attempt as i32);
let capped = exp_delay.min(self.config.max_delay_ms as f64);
let jitter = if self.config.jitter_factor > 0.0 {
let j: f64 = rand::random::<f64>() * self.config.jitter_factor;
1.0 + j
} else {
1.0
};
Duration::from_millis((capped * jitter) as u64)
}
#[must_use]
#[inline]
pub fn is_enabled(&self) -> bool {
self.config.max_retries > 0
}
}
pub async fn retry_with<F, T>(manager: &RetryManager, f: F) -> anyhow::Result<T>
where
F: FnMut() -> Pin<Box<dyn Future<Output = anyhow::Result<T>> + Send>>,
{
manager.with_retry(f).await
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
#[test]
fn default_config() {
let cfg = RetryConfig::default();
assert_eq!(cfg.max_retries, 3);
assert_eq!(cfg.base_delay_ms, 500);
assert_eq!(cfg.max_delay_ms, 30_000);
assert!((cfg.jitter_factor - 0.5).abs() < f64::EPSILON);
}
#[test]
fn delay_exponential() {
let manager = RetryManager::new(RetryConfig {
jitter_factor: 0.0, ..Default::default()
});
let d0 = manager.compute_delay(0);
let d1 = manager.compute_delay(1);
let d2 = manager.compute_delay(2);
assert_eq!(d0.as_millis(), 500);
assert_eq!(d1.as_millis(), 1000);
assert_eq!(d2.as_millis(), 2000);
}
#[test]
fn delay_capped_at_max() {
let manager = RetryManager::new(RetryConfig {
base_delay_ms: 10_000,
max_delay_ms: 15_000,
jitter_factor: 0.0,
..Default::default()
});
let d = manager.compute_delay(2);
assert_eq!(d.as_millis(), 15_000);
}
#[test]
fn is_enabled() {
let enabled = RetryManager::new(RetryConfig::default());
assert!(enabled.is_enabled());
let disabled = RetryManager::new(RetryConfig {
max_retries: 0,
..Default::default()
});
assert!(!disabled.is_enabled());
}
#[tokio::test]
async fn retry_succeeds_first_try() {
let manager = RetryManager::new(RetryConfig::default());
let result = manager
.with_retry(|| async { Ok::<_, anyhow::Error>(42) })
.await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn retry_succeeds_after_failures() {
let manager = RetryManager::new(RetryConfig {
max_retries: 3,
base_delay_ms: 1, max_delay_ms: 10,
jitter_factor: 0.0,
});
let attempts = AtomicU32::new(0);
let result = manager
.with_retry(|| {
let count = attempts.fetch_add(1, Ordering::SeqCst);
async move {
if count < 2 {
Err(anyhow::anyhow!(crate::error::HooshError::Timeout(1000)))
} else {
Ok(99)
}
}
})
.await;
assert_eq!(result.unwrap(), 99);
assert_eq!(attempts.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn retry_stops_on_permanent_error() {
let manager = RetryManager::new(RetryConfig {
max_retries: 3,
base_delay_ms: 1,
max_delay_ms: 10,
jitter_factor: 0.0,
});
let attempts = AtomicU32::new(0);
let result: anyhow::Result<i32> = manager
.with_retry(|| {
attempts.fetch_add(1, Ordering::SeqCst);
async {
Err(anyhow::anyhow!(crate::error::HooshError::ModelNotFound(
"no".into()
)))
}
})
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn retry_exhausts_max_retries() {
let manager = RetryManager::new(RetryConfig {
max_retries: 2,
base_delay_ms: 1,
max_delay_ms: 10,
jitter_factor: 0.0,
});
let attempts = AtomicU32::new(0);
let result: anyhow::Result<i32> = manager
.with_retry(|| {
attempts.fetch_add(1, Ordering::SeqCst);
async { Err(anyhow::anyhow!(crate::error::HooshError::Timeout(1000))) }
})
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 3); }
#[test]
fn compute_delay_with_jitter() {
let manager = RetryManager::new(RetryConfig {
base_delay_ms: 100,
max_delay_ms: 10_000,
jitter_factor: 0.5,
..Default::default()
});
let delay = manager.compute_delay(0);
assert!(delay.as_millis() >= 100);
assert!(delay.as_millis() <= 150);
}
#[test]
fn compute_delay_zero_jitter() {
let manager = RetryManager::new(RetryConfig {
base_delay_ms: 200,
max_delay_ms: 10_000,
jitter_factor: 0.0,
..Default::default()
});
let d0 = manager.compute_delay(0);
let d1 = manager.compute_delay(1);
assert_eq!(d0.as_millis(), 200);
assert_eq!(d1.as_millis(), 400);
}
#[test]
fn compute_delay_large_attempt_capped() {
let manager = RetryManager::new(RetryConfig {
base_delay_ms: 500,
max_delay_ms: 5_000,
jitter_factor: 0.0,
..Default::default()
});
let d = manager.compute_delay(10);
assert_eq!(d.as_millis(), 5000);
}
#[tokio::test]
async fn retry_with_convenience_fn() {
let manager = RetryManager::new(RetryConfig {
max_retries: 1,
base_delay_ms: 1,
max_delay_ms: 10,
jitter_factor: 0.0,
});
let result = retry_with(&manager, || Box::pin(async { Ok::<_, anyhow::Error>(42) })).await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn retry_unknown_error_not_retried() {
let manager = RetryManager::new(RetryConfig {
max_retries: 3,
base_delay_ms: 1,
max_delay_ms: 10,
jitter_factor: 0.0,
});
let attempts = AtomicU32::new(0);
let result: anyhow::Result<i32> = manager
.with_retry(|| {
attempts.fetch_add(1, Ordering::SeqCst);
async { Err(anyhow::anyhow!("generic error, not HooshError")) }
})
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 1); }
#[test]
fn retry_config_serde_roundtrip() {
let cfg = RetryConfig {
max_retries: 5,
base_delay_ms: 1000,
max_delay_ms: 60_000,
jitter_factor: 0.3,
};
let json = serde_json::to_string(&cfg).unwrap();
let back: RetryConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.max_retries, 5);
assert_eq!(back.base_delay_ms, 1000);
assert_eq!(back.max_delay_ms, 60_000);
assert!((back.jitter_factor - 0.3).abs() < f64::EPSILON);
}
}