use crate::error::{LettaError, LettaResult};
use std::future::Future;
use std::time::Duration;
use tokio::time::sleep;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_attempts: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: f64,
pub jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
initial_backoff: Duration::from_millis(500),
max_backoff: Duration::from_secs(30),
backoff_multiplier: 2.0,
jitter: true,
}
}
}
impl RetryConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_attempts(mut self, attempts: u32) -> Self {
self.max_attempts = attempts;
self
}
pub fn with_initial_backoff(mut self, duration: Duration) -> Self {
self.initial_backoff = duration;
self
}
pub fn with_max_backoff(mut self, duration: Duration) -> Self {
self.max_backoff = duration;
self
}
pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
self.backoff_multiplier = multiplier;
self
}
pub fn with_jitter(mut self, jitter: bool) -> Self {
self.jitter = jitter;
self
}
fn calculate_backoff(&self, attempt: u32) -> Duration {
let base_backoff =
self.initial_backoff.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
let mut backoff_ms = base_backoff.min(self.max_backoff.as_millis() as f64) as u64;
if self.jitter {
use rand::Rng;
let mut rng = rand::thread_rng();
let jitter_factor = rng.gen_range(0.75..1.25);
backoff_ms = (backoff_ms as f64 * jitter_factor) as u64;
}
Duration::from_millis(backoff_ms)
}
}
pub trait Retryable {
fn is_retryable(&self) -> bool;
fn retry_after(&self) -> Option<Duration> {
None
}
}
impl Retryable for LettaError {
fn is_retryable(&self) -> bool {
match self {
LettaError::RateLimit { .. } => true,
LettaError::RequestTimeout { .. } => true,
LettaError::Http(err) => {
err.is_timeout() || err.is_connect() || err.is_request()
}
LettaError::Api { status, .. } => {
matches!(
*status,
408 |
429 |
500 |
502 |
503 |
504
)
}
_ => false,
}
}
fn retry_after(&self) -> Option<Duration> {
match self {
LettaError::RateLimit { retry_after, .. } => retry_after.map(Duration::from_secs),
_ => None,
}
}
}
pub async fn retry_with_config<T, F, Fut>(config: &RetryConfig, operation: F) -> LettaResult<T>
where
F: Fn() -> Fut,
Fut: Future<Output = LettaResult<T>>,
{
let mut last_error = None;
for attempt in 0..config.max_attempts {
match operation().await {
Ok(result) => return Ok(result),
Err(error) => {
if !error.is_retryable() || attempt == config.max_attempts - 1 {
return Err(error);
}
let backoff = if let Some(retry_after) = error.retry_after() {
retry_after
} else {
config.calculate_backoff(attempt)
};
eprintln!(
"Retry attempt {} after {:?} due to: {:?}",
attempt + 1,
backoff,
error
);
last_error = Some(error);
sleep(backoff).await;
}
}
}
Err(last_error.unwrap_or_else(|| LettaError::Config {
message: "Retry logic failed unexpectedly".to_string(),
}))
}
pub async fn retry<T, F, Fut>(operation: F) -> LettaResult<T>
where
F: Fn() -> Fut,
Fut: Future<Output = LettaResult<T>>,
{
retry_with_config(&RetryConfig::default(), operation).await
}
#[cfg(test)]
mod tests {
use super::*;
use tokio;
use url::Url;
#[test]
fn test_retry_config_builder() {
let config = RetryConfig::new()
.with_max_attempts(5)
.with_initial_backoff(Duration::from_secs(1))
.with_max_backoff(Duration::from_secs(60))
.with_backoff_multiplier(3.0)
.with_jitter(false);
assert_eq!(config.max_attempts, 5);
assert_eq!(config.initial_backoff, Duration::from_secs(1));
assert_eq!(config.max_backoff, Duration::from_secs(60));
assert_eq!(config.backoff_multiplier, 3.0);
assert!(!config.jitter);
}
#[test]
fn test_backoff_calculation() {
let config = RetryConfig::new()
.with_initial_backoff(Duration::from_millis(100))
.with_backoff_multiplier(2.0)
.with_max_backoff(Duration::from_millis(1000))
.with_jitter(false);
assert_eq!(config.calculate_backoff(0), Duration::from_millis(100));
assert_eq!(config.calculate_backoff(1), Duration::from_millis(200));
assert_eq!(config.calculate_backoff(2), Duration::from_millis(400));
assert_eq!(config.calculate_backoff(3), Duration::from_millis(800));
assert_eq!(config.calculate_backoff(4), Duration::from_millis(1000));
}
#[test]
fn test_error_retryability() {
let error = LettaError::RateLimit {
retry_after: Some(60),
};
assert!(error.is_retryable());
assert_eq!(error.retry_after(), Some(Duration::from_secs(60)));
let error = LettaError::Api {
status: 503,
message: "Service unavailable".to_string(),
code: None,
body: crate::error::ErrorBody::Text(String::new()),
url: Some(Url::parse("http://example.com/path").unwrap()),
method: Some("GET".to_string()),
};
assert!(error.is_retryable());
let error = LettaError::NotFound {
resource_type: "agent".to_string(),
id: "123".to_string(),
};
assert!(!error.is_retryable());
}
#[tokio::test]
async fn test_retry_success_on_second_attempt() {
let attempt_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let attempt_count_clone = attempt_count.clone();
let result = retry(|| {
let count_clone = attempt_count_clone.clone();
async move {
let attempt_count = count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if attempt_count == 0 {
Err(LettaError::RequestTimeout { seconds: 60 })
} else {
Ok("Success".to_string())
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Success");
assert_eq!(attempt_count.load(std::sync::atomic::Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_retry_non_retryable_error() {
let attempt_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let attempt_count_clone = attempt_count.clone();
let result = retry(|| {
let count_clone = attempt_count_clone.clone();
async move {
count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Err::<String, _>(LettaError::Auth {
message: "Invalid API key".to_string(),
})
}
})
.await;
assert!(result.is_err());
assert_eq!(attempt_count.load(std::sync::atomic::Ordering::SeqCst), 1); }
#[tokio::test]
async fn test_retry_exhausted_attempts() {
let config = RetryConfig::new()
.with_max_attempts(2)
.with_initial_backoff(Duration::from_millis(10));
let attempt_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let attempt_count_clone = attempt_count.clone();
let result = retry_with_config(&config, || {
let count_clone = attempt_count_clone.clone();
async move {
count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Err::<String, _>(LettaError::RequestTimeout { seconds: 60 })
}
})
.await;
assert!(result.is_err());
assert_eq!(attempt_count.load(std::sync::atomic::Ordering::SeqCst), 2); }
}