use anyhow::Result;
use backoff::{future::retry, ExponentialBackoff, ExponentialBackoffBuilder};
use std::time::Duration;
const MAX_RETRY_TIMEOUT_SECS: u64 = 120;
pub fn is_retryable_error(error: &anyhow::Error) -> bool {
let error_msg = error.to_string().to_lowercase();
error_msg.contains("429") || error_msg.contains("rate_limit") ||
error_msg.contains("rate limit") ||
error_msg.contains("500") || error_msg.contains("502") || error_msg.contains("503") || error_msg.contains("504") || error_msg.contains("timeout") ||
error_msg.contains("connection") ||
error_msg.contains("network") ||
error_msg.contains("dns") ||
error_msg.contains("overloaded")
}
pub fn is_permanent_error(error: &anyhow::Error) -> bool {
let error_msg = error.to_string().to_lowercase();
error_msg.contains("401") || error_msg.contains("403") || error_msg.contains("invalid api key") ||
error_msg.contains("insufficient quota") ||
error_msg.contains("quota exceeded") ||
error_msg.contains("invalid request") ||
error_msg.contains("model not found") ||
error_msg.contains("400") }
pub fn create_backoff() -> ExponentialBackoff {
ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_millis(500))
.with_max_interval(Duration::from_secs(30))
.with_multiplier(2.0)
.with_max_elapsed_time(Some(Duration::from_secs(MAX_RETRY_TIMEOUT_SECS)))
.build()
}
pub async fn retry_async<F, Fut, T>(operation: F) -> Result<T>
where
F: Fn() -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<T>> + Send,
{
let backoff = create_backoff();
retry(backoff, || async {
match operation().await {
Ok(result) => Ok(result),
Err(error) => {
if is_permanent_error(&error) {
Err(backoff::Error::permanent(error))
} else if is_retryable_error(&error) {
Err(backoff::Error::transient(error))
} else {
Err(backoff::Error::permanent(error))
}
}
}
})
.await
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::anyhow;
#[test]
fn test_is_retryable_error() {
assert!(is_retryable_error(&anyhow!("429 Rate limit exceeded")));
assert!(is_retryable_error(&anyhow!("500 Internal server error")));
assert!(is_retryable_error(&anyhow!("Connection timeout")));
assert!(is_retryable_error(&anyhow!("Network error")));
assert!(is_retryable_error(&anyhow!("Model overloaded")));
assert!(!is_retryable_error(&anyhow!("401 Unauthorized")));
assert!(!is_retryable_error(&anyhow!("Invalid API key")));
}
#[test]
fn test_is_permanent_error() {
assert!(is_permanent_error(&anyhow!("401 Unauthorized")));
assert!(is_permanent_error(&anyhow!("Invalid API key")));
assert!(is_permanent_error(&anyhow!("Insufficient quota")));
assert!(is_permanent_error(&anyhow!("400 Bad request")));
assert!(!is_permanent_error(&anyhow!("429 Rate limit")));
assert!(!is_permanent_error(&anyhow!("500 Server error")));
}
#[tokio::test]
async fn test_retry_success() {
use std::sync::{Arc, Mutex};
let attempts = Arc::new(Mutex::new(0));
let result = retry_async(|| {
let attempts = attempts.clone();
async move {
let mut attempts_lock = attempts.lock().unwrap();
*attempts_lock += 1;
if *attempts_lock < 3 {
Err(anyhow!("429 Rate limit"))
} else {
Ok("success".to_string())
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(*attempts.lock().unwrap(), 3);
}
#[tokio::test]
async fn test_retry_permanent_error() {
use std::sync::{Arc, Mutex};
let attempts = Arc::new(Mutex::new(0));
let result: Result<String, _> = retry_async(|| {
let attempts = attempts.clone();
async move {
let mut attempts_lock = attempts.lock().unwrap();
*attempts_lock += 1;
Err(anyhow!("401 Unauthorized"))
}
})
.await;
assert!(result.is_err());
assert_eq!(*attempts.lock().unwrap(), 1); }
}