use crate::Result;
use crate::error::SubXError;
use tokio::time::{Duration, sleep};
pub struct RetryConfig {
pub max_attempts: usize,
pub base_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
base_delay: Duration::from_millis(1000),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
}
}
}
pub async fn retry_with_backoff<F, Fut, T>(operation: F, config: &RetryConfig) -> Result<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
if config.max_attempts == 0 {
return Err(SubXError::AiService(
"Retry configuration invalid: max_attempts must be at least 1".to_string(),
));
}
let mut last_error = None;
for attempt in 0..config.max_attempts {
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = Some(e);
if attempt < config.max_attempts - 1 {
let delay = std::cmp::min(
Duration::from_millis(
(config.base_delay.as_millis() as f64
* config.backoff_multiplier.powi(attempt as i32))
as u64,
),
config.max_delay,
);
sleep(delay).await;
}
}
}
}
Err(last_error
.unwrap_or_else(|| SubXError::AiService("Retry loop produced no error state".to_string())))
}
#[allow(async_fn_in_trait)]
pub trait HttpRetryClient {
fn retry_attempts(&self) -> u32;
fn retry_delay_ms(&self) -> u64;
async fn make_request_with_retry(
&self,
request: reqwest::RequestBuilder,
) -> Result<reqwest::Response> {
make_http_request_with_retry_impl(request, self.retry_attempts(), self.retry_delay_ms())
.await
}
}
async fn make_http_request_with_retry_impl(
request: reqwest::RequestBuilder,
retry_attempts: u32,
retry_delay_ms: u64,
) -> Result<reqwest::Response> {
let mut attempts = 0;
loop {
let cloned = request.try_clone().ok_or_else(|| {
SubXError::AiService("Request body cannot be cloned for retry".to_string())
})?;
match cloned.send().await {
Ok(resp) => match resp.error_for_status() {
Ok(success) => return Ok(success),
Err(err) if attempts + 1 >= retry_attempts => return Err(err.into()),
Err(_) => {}
},
Err(err) if attempts + 1 >= retry_attempts => return Err(err.into()),
Err(_) => {}
}
attempts += 1;
sleep(Duration::from_millis(retry_delay_ms)).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::SubXError;
use std::sync::{Arc, Mutex};
use std::time::Instant;
#[tokio::test]
async fn test_retry_success_on_second_attempt() {
let config = RetryConfig {
max_attempts: 3,
base_delay: Duration::from_millis(10),
max_delay: Duration::from_secs(1),
backoff_multiplier: 2.0,
};
let attempt_count = Arc::new(Mutex::new(0));
let attempt_count_clone = attempt_count.clone();
let operation = || async {
let mut count = attempt_count_clone.lock().unwrap();
*count += 1;
if *count == 1 {
Err(SubXError::AiService("First attempt fails".to_string()))
} else {
Ok("Success on second attempt".to_string())
}
};
let result = retry_with_backoff(operation, &config).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Success on second attempt");
assert_eq!(*attempt_count.lock().unwrap(), 2);
}
#[tokio::test]
async fn test_retry_exhaust_max_attempts() {
let config = RetryConfig {
max_attempts: 2,
base_delay: Duration::from_millis(10),
max_delay: Duration::from_secs(1),
backoff_multiplier: 2.0,
};
let attempt_count = Arc::new(Mutex::new(0));
let attempt_count_clone = attempt_count.clone();
let operation = || async {
let mut count = attempt_count_clone.lock().unwrap();
*count += 1;
Err(SubXError::AiService("Always fails".to_string()))
};
let result: Result<String> = retry_with_backoff(operation, &config).await;
assert!(result.is_err());
assert_eq!(*attempt_count.lock().unwrap(), 2);
}
#[tokio::test]
async fn test_exponential_backoff_timing() {
let config = RetryConfig {
max_attempts: 3,
base_delay: Duration::from_millis(50),
max_delay: Duration::from_millis(200),
backoff_multiplier: 2.0,
};
let attempt_times = Arc::new(Mutex::new(Vec::new()));
let attempt_times_clone = attempt_times.clone();
let operation = || async {
let start_time = Instant::now();
attempt_times_clone.lock().unwrap().push(start_time);
Err(SubXError::AiService(
"Always fails for timing test".to_string(),
))
};
let _overall_start = Instant::now();
let _result: Result<String> = retry_with_backoff(operation, &config).await;
let times = attempt_times.lock().unwrap();
assert_eq!(times.len(), 3);
if times.len() >= 2 {
let delay1 = times[1].duration_since(times[0]);
assert!(delay1 >= Duration::from_millis(30));
assert!(delay1 <= Duration::from_millis(100));
}
}
#[tokio::test]
async fn test_max_delay_cap() {
let config = RetryConfig {
max_attempts: 5,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_millis(200), backoff_multiplier: 3.0, };
let attempt_times = Arc::new(Mutex::new(Vec::new()));
let attempt_times_clone = attempt_times.clone();
let operation = || async {
attempt_times_clone.lock().unwrap().push(Instant::now());
Err(SubXError::AiService("Always fails".to_string()))
};
let _result: Result<String> = retry_with_backoff(operation, &config).await;
let times = attempt_times.lock().unwrap();
if times.len() >= 3 {
let delay2 = times[2].duration_since(times[1]);
assert!(delay2 <= Duration::from_millis(250));
}
}
#[tokio::test]
async fn test_retry_rejects_zero_max_attempts() {
let config = RetryConfig {
max_attempts: 0,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(1),
backoff_multiplier: 2.0,
};
let called = Arc::new(Mutex::new(false));
let called_clone = called.clone();
let operation = || {
let called = called_clone.clone();
async move {
*called.lock().unwrap() = true;
Ok::<_, SubXError>("should not run".to_string())
}
};
let result: Result<String> = retry_with_backoff(operation, &config).await;
assert!(result.is_err());
assert!(!*called.lock().unwrap(), "operation must not be invoked");
match result {
Err(SubXError::AiService(msg)) => assert!(msg.contains("max_attempts")),
other => panic!("unexpected result: {:?}", other),
}
}
#[test]
fn test_retry_config_validation() {
let valid_config = RetryConfig {
max_attempts: 3,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(1),
backoff_multiplier: 2.0,
};
assert!(valid_config.base_delay <= valid_config.max_delay);
assert!(valid_config.max_attempts > 0);
assert!(valid_config.backoff_multiplier > 1.0);
}
#[tokio::test]
async fn test_ai_service_integration_simulation() {
let config = RetryConfig {
max_attempts: 3,
base_delay: Duration::from_millis(10),
max_delay: Duration::from_secs(1),
backoff_multiplier: 2.0,
};
let request_count = Arc::new(Mutex::new(0));
let request_count_clone = request_count.clone();
let mock_ai_request = || async {
let mut count = request_count_clone.lock().unwrap();
*count += 1;
match *count {
1 => Err(SubXError::AiService("Network timeout".to_string())),
2 => Err(SubXError::AiService("Rate limit exceeded".to_string())),
3 => Ok("AI analysis complete".to_string()),
_ => unreachable!(),
}
};
let result = retry_with_backoff(mock_ai_request, &config).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "AI analysis complete");
assert_eq!(*request_count.lock().unwrap(), 3);
}
}