use std::time::Duration;
use tokio::time::{sleep, timeout};
use crate::sources::SourceError;
#[derive(Debug, Clone, Copy)]
pub struct RetryConfig {
pub max_attempts: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
pub max_total_time: Duration,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
backoff_multiplier: 2.0,
max_total_time: Duration::from_secs(120),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum TransientError {
Network,
RateLimit(Option<u64>),
ServerError,
ServiceUnavailable,
GatewayTimeout,
TooManyRequests,
Timeout,
}
impl TransientError {
pub fn from_reqwest_error(err: &reqwest::Error) -> Option<Self> {
if err.is_timeout() {
return Some(TransientError::Timeout);
}
if err.is_connect() {
return Some(TransientError::Network);
}
if let Some(status) = err.status() {
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
return Some(TransientError::TooManyRequests);
}
if status == reqwest::StatusCode::SERVICE_UNAVAILABLE {
return Some(TransientError::ServiceUnavailable);
}
if status == reqwest::StatusCode::GATEWAY_TIMEOUT {
return Some(TransientError::GatewayTimeout);
}
if status.is_server_error() {
return Some(TransientError::ServerError);
}
}
None
}
pub fn from_source_error(err: &SourceError) -> Option<Self> {
match err {
SourceError::RateLimit => Some(TransientError::RateLimit(None)),
SourceError::Network(_) => Some(TransientError::Network),
SourceError::Api(msg) => {
let msg_lower = msg.to_lowercase();
if msg_lower.contains("timeout") {
Some(TransientError::Timeout)
} else if msg_lower.contains("service unavailable")
|| msg_lower.contains("temporarily unavailable")
{
Some(TransientError::ServiceUnavailable)
} else {
None
}
}
_ => None,
}
}
pub fn recommended_delay(&self) -> Duration {
match self {
TransientError::RateLimit(Some(seconds)) => Duration::from_secs(*seconds + 1),
TransientError::RateLimit(None) => Duration::from_secs(61),
TransientError::TooManyRequests => Duration::from_secs(61),
TransientError::ServiceUnavailable => Duration::from_secs(10),
TransientError::GatewayTimeout => Duration::from_secs(5),
TransientError::Timeout => Duration::from_secs(2),
TransientError::Network => Duration::from_secs(2),
TransientError::ServerError => Duration::from_secs(2),
}
}
}
pub enum RetryResult<T> {
Success(T),
TransientFailure(SourceError, TransientError, u32),
PermanentFailure(SourceError),
}
pub async fn with_retry<T, F, Fut>(config: RetryConfig, operation: F) -> Result<T, SourceError>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, SourceError>>,
{
let mut attempts = 0;
let mut total_elapsed = Duration::ZERO;
let mut operation = operation;
loop {
attempts += 1;
match timeout(config.max_total_time, operation()).await {
Ok(Ok(result)) => {
if attempts > 1 {
tracing::info!(
"Operation succeeded on attempt {} after {} transient failures",
attempts,
attempts - 1
);
}
return Ok(result);
}
Ok(Err(error)) => {
if let Some(transient) = TransientError::from_source_error(&error) {
let delay = if attempts == 1 {
config.initial_delay
} else {
let exp_delay = config.initial_delay.as_secs_f64()
* config.backoff_multiplier.powf(attempts as f64 - 1.0);
let delay_secs = exp_delay.min(config.max_delay.as_secs_f64());
Duration::from_secs_f64(delay_secs)
};
let delay = std::cmp::max(delay, transient.recommended_delay());
total_elapsed += delay;
if attempts >= config.max_attempts || total_elapsed >= config.max_total_time {
tracing::warn!(
"Operation failed after {} attempts (total elapsed: {:?}): {}",
attempts,
total_elapsed,
error
);
return Err(error);
}
tracing::debug!(
"Transient error on attempt {}: {:?}, retrying in {:?}",
attempts,
transient,
delay
);
sleep(delay).await;
continue;
} else {
return Err(error);
}
}
Err(_) => {
let error = SourceError::Network("Operation timed out".to_string());
if attempts >= config.max_attempts {
return Err(error);
}
let delay = config.initial_delay;
total_elapsed += delay;
tracing::debug!(
"Operation timed out, attempt {}/{}",
attempts,
config.max_attempts
);
sleep(delay).await;
}
}
}
}
pub async fn with_retry_detailed<T, F, Fut>(config: RetryConfig, operation: F) -> RetryResult<T>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, SourceError>>,
{
let mut attempts = 0;
let mut total_elapsed = Duration::ZERO;
let mut operation = operation;
loop {
attempts += 1;
match timeout(config.max_total_time, operation()).await {
Ok(Ok(result)) => {
return RetryResult::Success(result);
}
Ok(Err(error)) => {
if let Some(transient) = TransientError::from_source_error(&error) {
let delay = if attempts == 1 {
config.initial_delay
} else {
let exp_delay = config.initial_delay.as_secs_f64()
* config.backoff_multiplier.powf(attempts as f64 - 1.0);
Duration::from_secs_f64(exp_delay.min(config.max_delay.as_secs_f64()))
};
let delay = std::cmp::max(delay, transient.recommended_delay());
total_elapsed += delay;
if attempts >= config.max_attempts || total_elapsed >= config.max_total_time {
return RetryResult::TransientFailure(error, transient, attempts);
}
sleep(delay).await;
continue;
} else {
return RetryResult::PermanentFailure(error);
}
}
Err(_) => {
let error = SourceError::Network("Operation timed out".to_string());
if attempts >= config.max_attempts {
return RetryResult::TransientFailure(error, TransientError::Timeout, attempts);
}
let delay = config.initial_delay;
total_elapsed += delay;
sleep(delay).await;
}
}
}
}
pub fn api_retry_config() -> RetryConfig {
RetryConfig {
max_attempts: 5,
initial_delay: Duration::from_secs(2),
max_delay: Duration::from_secs(120),
backoff_multiplier: 2.0,
max_total_time: Duration::from_secs(300),
}
}
pub fn strict_rate_limit_retry_config() -> RetryConfig {
RetryConfig {
max_attempts: 3,
initial_delay: Duration::from_secs(2),
max_delay: Duration::from_secs(120),
backoff_multiplier: 2.0,
max_total_time: Duration::from_secs(180),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::RefCell;
use std::rc::Rc;
#[tokio::test]
async fn test_retry_success_first_try() {
let config = RetryConfig::default();
let call_count = Rc::new(RefCell::new(0));
let result = {
let call_count = call_count.clone();
with_retry(config, move || {
let call_count = call_count.clone();
async move {
*call_count.borrow_mut() += 1;
Ok("success")
}
})
}
.await;
assert_eq!(result.unwrap(), "success");
assert_eq!(*call_count.borrow(), 1);
}
#[tokio::test]
async fn test_retry_success_after_failures() {
let config = RetryConfig {
max_attempts: 4, initial_delay: Duration::from_millis(10),
max_delay: Duration::from_millis(100),
backoff_multiplier: 2.0,
max_total_time: Duration::from_secs(10),
};
let call_count = Rc::new(RefCell::new(0));
let result = {
let call_count = call_count.clone();
with_retry(config, move || {
let call_count = call_count.clone();
async move {
*call_count.borrow_mut() += 1;
let count = *call_count.borrow();
if count < 3 {
Err(SourceError::Network("temporary error".to_string()))
} else {
Ok("success")
}
}
})
}
.await;
assert_eq!(result.unwrap(), "success");
assert_eq!(*call_count.borrow(), 3);
}
#[tokio::test]
async fn test_retry_returns_permanent_error() {
let config = RetryConfig {
max_attempts: 5,
initial_delay: Duration::from_millis(10),
max_delay: Duration::from_millis(50),
backoff_multiplier: 2.0,
max_total_time: Duration::from_secs(5),
};
let call_count = Rc::new(RefCell::new(0));
let result: Result<&str, SourceError> = {
let call_count = call_count.clone();
with_retry(config, move || {
let call_count = call_count.clone();
async move {
*call_count.borrow_mut() += 1;
Err(SourceError::NotFound("not found".to_string()))
}
})
}
.await;
assert!(result.is_err());
if let Err(e) = result {
match e {
SourceError::NotFound(_) => {} _ => panic!("Expected NotFound error"),
}
}
assert_eq!(*call_count.borrow(), 1); }
#[test]
fn test_transient_error_detection() {
let rate_limit_error = SourceError::RateLimit;
assert!(TransientError::from_source_error(&rate_limit_error).is_some());
let network_error = SourceError::Network("connection refused".to_string());
assert!(TransientError::from_source_error(&network_error).is_some());
let parse_error = SourceError::Parse("invalid json".to_string());
assert!(TransientError::from_source_error(&parse_error).is_none());
}
#[test]
fn test_recommended_delay() {
assert_eq!(
TransientError::RateLimit(Some(30)).recommended_delay(),
Duration::from_secs(31)
);
assert_eq!(
TransientError::RateLimit(None).recommended_delay(),
Duration::from_secs(61)
);
assert_eq!(
TransientError::ServiceUnavailable.recommended_delay(),
Duration::from_secs(10)
);
assert_eq!(
TransientError::Network.recommended_delay(),
Duration::from_secs(2)
);
}
}