use crate::config::RetryConfig;
use crate::error::LlmConnectorError;
use std::future::Future;
use std::time::Duration;
#[derive(Debug, Clone, Default)]
pub struct RetryMiddleware {
config: RetryConfig,
}
impl RetryMiddleware {
pub fn new(config: RetryConfig) -> Self {
Self { config }
}
pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T, LlmConnectorError>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, LlmConnectorError>>,
{
let mut attempts = 0;
let mut backoff_ms = self.config.initial_backoff_ms;
loop {
match operation().await {
Ok(result) => {
if attempts > 0 {
log::info!("Request succeeded after {} retries", attempts);
}
return Ok(result);
}
Err(e) => {
attempts += 1;
if !self.is_retriable(&e) {
log::debug!("Error is not retriable: {:?}", e);
return Err(e);
}
if attempts > self.config.max_retries {
log::warn!(
"Max retries ({}) exhausted for retriable error",
self.config.max_retries
);
return Err(LlmConnectorError::MaxRetriesExceeded(format!(
"Failed after {} attempts: {}",
attempts, e
)));
}
let jitter = (rand::random::<f32>() * 0.3 - 0.15) * backoff_ms as f32;
let actual_backoff = (backoff_ms as f32 + jitter).max(0.0) as u64;
let capped_backoff = actual_backoff.min(self.config.max_backoff_ms);
log::info!(
"Retry attempt {}/{} after {}ms (error: {})",
attempts,
self.config.max_retries,
capped_backoff,
e
);
tokio::time::sleep(Duration::from_millis(capped_backoff)).await;
backoff_ms = ((backoff_ms as f32 * self.config.backoff_multiplier) as u64)
.min(self.config.max_backoff_ms);
}
}
}
}
fn is_retriable(&self, error: &LlmConnectorError) -> bool {
matches!(
error,
LlmConnectorError::RateLimitError(_)
| LlmConnectorError::ServerError(_)
| LlmConnectorError::TimeoutError(_)
| LlmConnectorError::ConnectionError(_)
| LlmConnectorError::NetworkError(_)
)
}
pub fn config(&self) -> &RetryConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct RetryPolicyBuilder {
max_retries: u32,
initial_backoff_ms: u64,
backoff_multiplier: f32,
max_backoff_ms: u64,
}
impl RetryPolicyBuilder {
pub fn new() -> Self {
let default = RetryConfig::default();
Self {
max_retries: default.max_retries,
initial_backoff_ms: default.initial_backoff_ms,
backoff_multiplier: default.backoff_multiplier,
max_backoff_ms: default.max_backoff_ms,
}
}
pub fn max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn initial_backoff_ms(mut self, ms: u64) -> Self {
self.initial_backoff_ms = ms;
self
}
pub fn backoff_multiplier(mut self, multiplier: f32) -> Self {
self.backoff_multiplier = multiplier;
self
}
pub fn max_backoff_ms(mut self, ms: u64) -> Self {
self.max_backoff_ms = ms;
self
}
pub fn build(self) -> RetryConfig {
RetryConfig {
max_retries: self.max_retries,
initial_backoff_ms: self.initial_backoff_ms,
backoff_multiplier: self.backoff_multiplier,
max_backoff_ms: self.max_backoff_ms,
}
}
pub fn build_middleware(self) -> RetryMiddleware {
RetryMiddleware::new(self.build())
}
}
impl Default for RetryPolicyBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[tokio::test]
async fn test_retry_success_on_first_attempt() {
let retry = RetryMiddleware::default();
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result = retry
.execute(|| {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok::<_, LlmConnectorError>("success")
}
})
.await;
assert!(result.is_ok());
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_success_after_failures() {
let config = RetryConfig {
max_retries: 3,
initial_backoff_ms: 10,
backoff_multiplier: 2.0,
max_backoff_ms: 1000,
};
let retry = RetryMiddleware::new(config);
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result = retry
.execute(|| {
let counter = counter_clone.clone();
async move {
let count = counter.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err(LlmConnectorError::ServerError(
"temporary error".to_string(),
))
} else {
Ok("success")
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_max_retries_exceeded() {
let config = RetryConfig {
max_retries: 2,
initial_backoff_ms: 10,
backoff_multiplier: 2.0,
max_backoff_ms: 1000,
};
let retry = RetryMiddleware::new(config);
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result = retry
.execute(|| {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<(), _>(LlmConnectorError::ServerError("always fails".to_string()))
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn test_non_retriable_error() {
let retry = RetryMiddleware::default();
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result = retry
.execute(|| {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<(), _>(LlmConnectorError::InvalidRequest("bad request".to_string()))
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 1); }
#[test]
fn test_retry_policy_builder() {
let config = RetryPolicyBuilder::new()
.max_retries(5)
.initial_backoff_ms(500)
.backoff_multiplier(3.0)
.max_backoff_ms(10000)
.build();
assert_eq!(config.max_retries, 5);
assert_eq!(config.initial_backoff_ms, 500);
assert_eq!(config.backoff_multiplier, 3.0);
assert_eq!(config.max_backoff_ms, 10000);
}
}