use backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
use std::time::Duration;
use crate::error::LlmError;
use crate::types::ProviderType;
pub struct BackoffRetryExecutor {
backoff: ExponentialBackoff,
}
impl BackoffRetryExecutor {
pub fn new() -> Self {
Self {
backoff: ExponentialBackoff::default(),
}
}
pub fn with_backoff(backoff: ExponentialBackoff) -> Self {
Self { backoff }
}
pub fn for_provider(provider: &ProviderType) -> Self {
let backoff = match provider {
ProviderType::OpenAi => Self::openai_backoff(),
ProviderType::Anthropic => Self::anthropic_backoff(),
ProviderType::Gemini => Self::google_backoff(),
ProviderType::Ollama => Self::ollama_backoff(),
ProviderType::XAI => Self::openai_backoff(), ProviderType::Groq => Self::openai_backoff(), ProviderType::Custom(_) => Self::default_backoff(),
};
Self { backoff }
}
pub async fn execute<F, Fut, T>(&self, operation: F) -> Result<T, LlmError>
where
F: Fn() -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<T, LlmError>> + Send,
T: Send,
{
backoff::future::retry(self.backoff.clone(), || async {
match operation().await {
Ok(result) => Ok(result),
Err(error) => {
if Self::is_retryable(&error) {
Err(backoff::Error::Transient {
err: error,
retry_after: None,
})
} else {
Err(backoff::Error::Permanent(error))
}
}
}
})
.await
}
fn is_retryable(error: &LlmError) -> bool {
match error {
LlmError::ApiError { code, .. } => {
matches!(*code, 429 | 500..=599)
}
LlmError::RateLimitError(_) => true,
LlmError::TimeoutError(_) => true,
LlmError::ConnectionError(_) => true,
LlmError::HttpError(_) => true,
_ => false,
}
}
fn openai_backoff() -> ExponentialBackoff {
ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_millis(1000))
.with_max_interval(Duration::from_secs(60))
.with_multiplier(2.0)
.with_max_elapsed_time(Some(Duration::from_secs(300))) .build()
}
fn anthropic_backoff() -> ExponentialBackoff {
ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_millis(1000))
.with_max_interval(Duration::from_secs(60))
.with_multiplier(1.5)
.with_max_elapsed_time(Some(Duration::from_secs(300)))
.build()
}
fn google_backoff() -> ExponentialBackoff {
ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_millis(1000))
.with_max_interval(Duration::from_secs(60))
.with_multiplier(1.5)
.with_max_elapsed_time(Some(Duration::from_secs(300)))
.build()
}
fn ollama_backoff() -> ExponentialBackoff {
ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_millis(500))
.with_max_interval(Duration::from_secs(30))
.with_multiplier(1.5)
.with_max_elapsed_time(Some(Duration::from_secs(180))) .build()
}
fn default_backoff() -> ExponentialBackoff {
ExponentialBackoff::default()
}
}
impl Default for BackoffRetryExecutor {
fn default() -> Self {
Self::new()
}
}
pub async fn retry_with_backoff<F, Fut, T>(operation: F) -> Result<T, LlmError>
where
F: Fn() -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<T, LlmError>> + Send,
T: Send,
{
let executor = BackoffRetryExecutor::new();
executor.execute(operation).await
}
pub async fn retry_for_provider_backoff<F, Fut, T>(
provider: &ProviderType,
operation: F,
) -> Result<T, LlmError>
where
F: Fn() -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<T, LlmError>> + Send,
T: Send,
{
let executor = BackoffRetryExecutor::for_provider(provider);
executor.execute(operation).await
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
#[tokio::test]
async fn test_retry_success_on_second_attempt() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let executor = BackoffRetryExecutor::new();
let result: Result<String, LlmError> = executor
.execute(|| async {
let count = counter_clone.fetch_add(1, Ordering::SeqCst);
if count == 0 {
Err(LlmError::RateLimitError("Rate limited".to_string()))
} else {
Ok("Success".to_string())
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Success");
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_permanent_error_no_retry() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let executor = BackoffRetryExecutor::new();
let result: Result<String, LlmError> = executor
.execute(|| async {
counter_clone.fetch_add(1, Ordering::SeqCst);
Err(LlmError::InvalidInput("Bad input".to_string()))
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn test_provider_specific_backoff() {
let executor = BackoffRetryExecutor::for_provider(&ProviderType::OpenAi);
let result: Result<String, LlmError> = executor
.execute(|| async { Ok("Success".to_string()) })
.await;
assert!(result.is_ok());
}
}