use anyhow::Result;
use chrono::{DateTime, Utc};
use std::future::Future;
use std::time::{Duration, Instant};
use tracing::{debug, info, warn};
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub base_delay_ms: u64,
pub max_delay_ms: u64,
pub backoff_multiplier: f64,
pub jitter: f64,
pub total_timeout_ms: Option<u64>,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
base_delay_ms: 1000,
max_delay_ms: 30000,
backoff_multiplier: 2.0,
jitter: 0.1,
total_timeout_ms: None,
}
}
}
impl RetryConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_retries(mut self, retries: u32) -> Self {
self.max_retries = retries;
self
}
pub fn with_base_delay(mut self, ms: u64) -> Self {
self.base_delay_ms = ms;
self
}
pub fn with_max_delay(mut self, ms: u64) -> Self {
self.max_delay_ms = ms;
self
}
pub fn with_backoff_multiplier(mut self, mult: f64) -> Self {
self.backoff_multiplier = mult;
self
}
pub fn with_jitter(mut self, jitter: f64) -> Self {
self.jitter = jitter.clamp(0.0, 1.0);
self
}
pub fn with_total_timeout(mut self, ms: u64) -> Self {
self.total_timeout_ms = Some(ms);
self
}
pub fn should_retry(&self, error: &RetryableError, attempt: u32) -> bool {
if attempt >= self.max_retries {
return false;
}
if let Some(timeout_ms) = self.total_timeout_ms {
if attempt > 0 {
let estimated_time = self.estimate_total_time(attempt);
if estimated_time > Duration::from_millis(timeout_ms) {
return false;
}
}
}
matches!(
error,
RetryableError::NetworkError | RetryableError::Timeout
) || matches!(error, RetryableError::RateLimitError { .. })
|| matches!(error, RetryableError::ServerError { .. })
}
pub fn next_delay(&self, attempt: u32) -> Duration {
if attempt == 0 {
return Duration::from_millis(self.base_delay_ms);
}
let exponential = self.base_delay_ms as f64 * self.backoff_multiplier.powi(attempt as i32);
let capped = exponential.min(self.max_delay_ms as f64);
let jitter_range = capped * self.jitter;
let jitter = (rand_simple(attempt) * 2.0 - 1.0) * jitter_range;
let final_delay = (capped + jitter).max(0.0);
Duration::from_millis(final_delay as u64)
}
fn estimate_total_time(&self, attempts: u32) -> Duration {
let mut total_ms = 0u64;
for i in 0..attempts {
total_ms += self.next_delay(i).as_millis() as u64;
}
Duration::from_millis(total_ms)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RetryableError {
NetworkError,
RateLimitError {
retry_after: u64,
},
ServerError {
code: u16,
},
Timeout,
ServiceUnavailable {
retry_after: u64,
},
TemporaryFailure,
}
impl RetryableError {
pub fn from_status_code(code: u16) -> Option<Self> {
match code {
408 | 429 => Some(Self::RateLimitError { retry_after: 0 }),
500 | 502 | 504 => Some(Self::ServerError { code }),
503 => Some(Self::ServiceUnavailable { retry_after: 0 }),
_ if code >= 500 => Some(Self::ServerError { code }),
_ => None,
}
}
pub fn from_network_error(msg: &str) -> Self {
let msg_lower = msg.to_lowercase();
if msg_lower.contains("timeout") {
RetryableError::Timeout
} else if msg_lower.contains("connection") || msg_lower.contains("network") {
RetryableError::NetworkError
} else {
RetryableError::NetworkError
}
}
pub fn suggested_delay(&self) -> Option<u64> {
match self {
RetryableError::RateLimitError { retry_after } => {
if *retry_after > 0 {
Some(*retry_after)
} else {
None
}
}
RetryableError::ServiceUnavailable { retry_after } => {
if *retry_after > 0 {
Some(*retry_after)
} else {
None
}
}
_ => None,
}
}
pub fn has_required_wait(&self) -> bool {
match self {
RetryableError::RateLimitError { retry_after } => *retry_after > 0,
RetryableError::ServiceUnavailable { retry_after } => *retry_after > 0,
_ => false,
}
}
}
impl std::fmt::Display for RetryableError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RetryableError::NetworkError => write!(f, "Network error"),
RetryableError::RateLimitError { retry_after } => {
if *retry_after > 0 {
write!(f, "Rate limit exceeded (retry after {}s)", retry_after)
} else {
write!(f, "Rate limit exceeded")
}
}
RetryableError::ServerError { code } => write!(f, "Server error ({})", code),
RetryableError::Timeout => write!(f, "Request timeout"),
RetryableError::ServiceUnavailable { retry_after } => {
if *retry_after > 0 {
write!(f, "Service unavailable (retry after {}s)", retry_after)
} else {
write!(f, "Service unavailable")
}
}
RetryableError::TemporaryFailure => write!(f, "Temporary failure"),
}
}
}
#[derive(Debug, Clone)]
pub struct RetryState {
pub attempt: u32,
pub next_retry: DateTime<Utc>,
pub error: RetryableError,
pub elapsed_ms: u64,
pub aborted: bool,
}
impl RetryState {
pub fn new(attempt: u32, error: RetryableError) -> Self {
Self {
attempt,
next_retry: Utc::now(),
error,
elapsed_ms: 0,
aborted: false,
}
}
pub fn set_next_retry(&mut self, delay: Duration) {
self.next_retry = Utc::now() + chrono::Duration::from_std(delay).unwrap_or_default();
}
pub fn abort(&mut self) {
self.aborted = true;
}
pub fn seconds_until_retry(&self) -> i64 {
let now = Utc::now();
let diff = self.next_retry.signed_duration_since(now);
diff.num_seconds().max(0)
}
pub fn is_ready(&self) -> bool {
self.seconds_until_retry() <= 0 && !self.aborted
}
}
pub enum RetryResult<T> {
Success(T),
Exhausted {
attempts: u32,
last_error: RetryableError,
},
Aborted {
attempts: u32,
},
TimedOut {
attempts: u32,
elapsed_ms: u64,
},
}
pub async fn with_retry<R, F, Fut>(
config: &RetryConfig,
f: F,
) -> Result<RetryResult<R>, RetryableError>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<R, RetryableError>>,
{
let start = Instant::now();
let mut attempt = 0u32;
loop {
debug!("Attempt {} of {}", attempt + 1, config.max_retries + 1);
match f().await {
Ok(result) => {
return Ok(RetryResult::Success(result));
}
Err(error) => {
if !config.should_retry(&error, attempt) {
info!("Retry not possible for error: {}", error);
return Ok(RetryResult::Exhausted {
attempts: attempt + 1,
last_error: error,
});
}
let delay = if error.has_required_wait() {
let wait = error.suggested_delay().unwrap_or(0);
Duration::from_secs(wait)
} else {
config.next_delay(attempt)
};
info!("Retrying in {:?} due to: {}", delay, error);
tokio::time::sleep(delay).await;
if let Some(timeout_ms) = config.total_timeout_ms {
let elapsed = start.elapsed().as_millis() as u64;
if elapsed >= timeout_ms {
return Ok(RetryResult::TimedOut {
attempts: attempt + 1,
elapsed_ms: elapsed,
});
}
}
attempt += 1;
}
}
}
}
pub async fn retry<R, F, Fut>(config: &RetryConfig, f: F) -> Result<R, RetryableError>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<R, RetryableError>>,
{
match with_retry(config, f).await? {
RetryResult::Success(r) => Ok(r),
RetryResult::Exhausted {
attempts,
last_error,
} => {
warn!(
"Retry exhausted after {} attempts: {}",
attempts, last_error
);
Err(last_error)
}
RetryResult::Aborted { attempts } => {
warn!("Retry aborted after {} attempts", attempts);
Err(RetryableError::TemporaryFailure) }
RetryResult::TimedOut {
attempts,
elapsed_ms,
} => {
warn!(
"Retry timed out after {} attempts ({}ms)",
attempts, elapsed_ms
);
Err(RetryableError::Timeout)
}
}
}
pub struct CountdownTimer {
remaining_ms: RwLock<u64>,
start_time: Instant,
total_ms: u64,
on_tick: Option<Box<dyn Fn(u64) + Send + Sync>>,
on_complete: Option<Box<dyn Fn() + Send + Sync>>,
}
use parking_lot::RwLock;
impl CountdownTimer {
pub fn new(
duration_ms: u64,
on_tick: Option<Box<dyn Fn(u64) + Send + Sync>>,
on_complete: Option<Box<dyn Fn() + Send + Sync>>,
) -> Self {
let now = Instant::now();
Self {
remaining_ms: RwLock::new(duration_ms),
start_time: now,
total_ms: duration_ms,
on_tick,
on_complete,
}
}
pub fn update(&self) -> u64 {
let elapsed = self.start_time.elapsed().as_millis() as u64;
let remaining = self.total_ms.saturating_sub(elapsed);
*self.remaining_ms.write() = remaining;
let seconds = (remaining / 1000) as u64;
if remaining % 1000 < 50 {
if let Some(ref cb) = self.on_tick {
cb(seconds);
}
}
if remaining == 0 {
if let Some(ref cb) = self.on_complete {
cb();
}
}
seconds
}
pub fn remaining_ms(&self) -> u64 {
*self.remaining_ms.read()
}
pub fn is_complete(&self) -> bool {
*self.remaining_ms.read() == 0
}
pub fn progress(&self) -> f64 {
let remaining = *self.remaining_ms.read();
if self.total_ms == 0 {
1.0
} else {
1.0 - (remaining as f64 / self.total_ms as f64)
}
}
pub fn formatted(&self) -> String {
let remaining = *self.remaining_ms.read();
let seconds = remaining / 1000;
if seconds > 0 {
format!("{}s", seconds)
} else {
"0s".to_string()
}
}
pub fn cancel(&self) {
*self.remaining_ms.write() = 0;
}
}
pub fn format_retry_message(
attempt: u32,
max_attempts: u32,
delay_seconds: u64,
error: &RetryableError,
can_cancel: bool,
) -> String {
let error_str: String = match error {
RetryableError::NetworkError => "Network error".to_string(),
RetryableError::RateLimitError { .. } => "Rate limit exceeded".to_string(),
RetryableError::ServerError { code } => format!("Server error ({})", code),
RetryableError::Timeout => "Request timeout".to_string(),
RetryableError::ServiceUnavailable { .. } => "Service unavailable".to_string(),
RetryableError::TemporaryFailure => "Temporary failure".to_string(),
};
let cancel_hint = if can_cancel { " (Esc to cancel)" } else { "" };
format!(
"Retrying ({}/{}) in {}s: {}{}",
attempt, max_attempts, delay_seconds, error_str, cancel_hint
)
}
pub fn format_error_message(error: &RetryableError) -> String {
match error {
RetryableError::NetworkError => {
"Connection failed. Check your network and try again.".to_string()
}
RetryableError::RateLimitError { retry_after } => {
if *retry_after > 0 {
format!(
"Rate limit hit. Wait {}s or reduce request frequency.",
retry_after
)
} else {
"Rate limit exceeded. Try again in a few moments.".to_string()
}
}
RetryableError::ServerError { code } => {
format!("Server error ({}). Try again later.", code)
}
RetryableError::Timeout => {
"Request timed out. Try again or use a shorter request.".to_string()
}
RetryableError::ServiceUnavailable { retry_after } => {
if *retry_after > 0 {
format!(
"Service temporarily unavailable. Retry in {}s.",
retry_after
)
} else {
"Service temporarily unavailable. Try again later.".to_string()
}
}
RetryableError::TemporaryFailure => "Temporary failure. Try again.".to_string(),
}
}
fn rand_simple(seed: u32) -> f64 {
let mut x = seed.wrapping_mul(1103515245).wrapping_add(12345);
x = x.wrapping_mul(1103515245).wrapping_add(12345);
((x as u64) % 1000000) as f64 / 1000000.0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_config_defaults() {
let config = RetryConfig::default();
assert_eq!(config.max_retries, 3);
assert_eq!(config.base_delay_ms, 1000);
assert_eq!(config.max_delay_ms, 30000);
assert!((config.backoff_multiplier - 2.0).abs() < 0.001);
}
#[test]
fn test_retry_config_builder() {
let config = RetryConfig::new()
.with_max_retries(5)
.with_base_delay(500)
.with_max_delay(60000)
.with_backoff_multiplier(1.5);
assert_eq!(config.max_retries, 5);
assert_eq!(config.base_delay_ms, 500);
assert_eq!(config.max_delay_ms, 60000);
assert!((config.backoff_multiplier - 1.5).abs() < 0.001);
}
#[test]
fn test_should_retry_within_limit() {
let config = RetryConfig::default();
assert!(config.should_retry(&RetryableError::NetworkError, 0));
assert!(config.should_retry(&RetryableError::NetworkError, 1));
assert!(config.should_retry(&RetryableError::NetworkError, 2));
}
#[test]
fn test_should_retry_at_limit() {
let config = RetryConfig::default();
assert!(!config.should_retry(&RetryableError::NetworkError, 3));
}
#[test]
fn test_should_retry_rate_limit() {
let config = RetryConfig::default();
assert!(config.should_retry(&RetryableError::RateLimitError { retry_after: 0 }, 0));
assert!(config.should_retry(&RetryableError::RateLimitError { retry_after: 30 }, 0));
}
#[test]
fn test_should_retry_server_error() {
let config = RetryConfig::default();
assert!(config.should_retry(&RetryableError::ServerError { code: 500 }, 0));
assert!(config.should_retry(&RetryableError::ServerError { code: 502 }, 0));
assert!(config.should_retry(&RetryableError::ServerError { code: 503 }, 0));
}
#[test]
fn test_should_retry_timeout() {
let config = RetryConfig::default();
assert!(config.should_retry(&RetryableError::Timeout, 0));
}
#[test]
fn test_next_delay_exponential() {
let config = RetryConfig::default();
let delay0 = config.next_delay(0);
assert!((delay0.as_millis() as u64 - config.base_delay_ms) <= 100);
let delay1 = config.next_delay(1);
assert!(delay1 > delay0);
let delay2 = config.next_delay(2);
assert!(delay2 > delay1);
}
#[test]
fn test_next_delay_capped_at_max() {
let config = RetryConfig::new()
.with_base_delay(1000)
.with_max_delay(5000)
.with_backoff_multiplier(10.0);
let delay = config.next_delay(10);
assert!(delay.as_millis() as u64 <= config.max_delay_ms);
}
#[test]
fn test_retryable_error_from_status_code() {
assert_eq!(
RetryableError::from_status_code(429),
Some(RetryableError::RateLimitError { retry_after: 0 })
);
assert_eq!(
RetryableError::from_status_code(500),
Some(RetryableError::ServerError { code: 500 })
);
assert_eq!(
RetryableError::from_status_code(502),
Some(RetryableError::ServerError { code: 502 })
);
assert_eq!(
RetryableError::from_status_code(503),
Some(RetryableError::ServiceUnavailable { retry_after: 0 })
);
assert_eq!(RetryableError::from_status_code(400), None);
assert_eq!(RetryableError::from_status_code(404), None);
}
#[test]
fn test_retryable_error_suggested_delay() {
let rate_limit = RetryableError::RateLimitError { retry_after: 30 };
assert_eq!(rate_limit.suggested_delay(), Some(30));
let service = RetryableError::ServiceUnavailable { retry_after: 60 };
assert_eq!(service.suggested_delay(), Some(60));
let network = RetryableError::NetworkError;
assert_eq!(network.suggested_delay(), None);
let timeout = RetryableError::Timeout;
assert_eq!(timeout.suggested_delay(), None);
}
#[test]
fn test_retryable_error_display() {
assert_eq!(RetryableError::NetworkError.to_string(), "Network error");
assert_eq!(
RetryableError::RateLimitError { retry_after: 30 }.to_string(),
"Rate limit exceeded (retry after 30s)"
);
assert_eq!(
RetryableError::ServerError { code: 500 }.to_string(),
"Server error (500)"
);
assert_eq!(RetryableError::Timeout.to_string(), "Request timeout");
}
#[test]
fn test_retry_state_creation() {
let state = RetryState::new(0, RetryableError::NetworkError);
assert_eq!(state.attempt, 0);
assert!(!state.aborted);
}
#[test]
fn test_retry_state_update_next_retry() {
let mut state = RetryState::new(0, RetryableError::NetworkError);
state.set_next_retry(Duration::from_secs(5));
assert!(state.next_retry > Utc::now());
}
#[test]
fn test_retry_state_abort() {
let mut state = RetryState::new(0, RetryableError::NetworkError);
state.abort();
assert!(state.aborted);
}
#[test]
fn test_countdown_timer_basic() {
let timer = CountdownTimer::new(5000, None, None);
assert!(!timer.is_complete());
assert_eq!(timer.remaining_ms(), 5000);
}
#[test]
fn test_countdown_timer_progress() {
let timer = CountdownTimer::new(10000, None, None);
assert!((timer.progress() - 0.0).abs() < 0.01);
}
#[test]
fn test_countdown_timer_formatted() {
let timer = CountdownTimer::new(3500, None, None);
assert_eq!(timer.formatted(), "3s");
}
#[test]
fn test_countdown_timer_cancel() {
let timer = CountdownTimer::new(5000, None, None);
timer.cancel();
assert!(timer.is_complete());
assert_eq!(timer.remaining_ms(), 0);
}
#[test]
fn test_format_retry_message() {
let error = RetryableError::NetworkError;
let msg = format_retry_message(1, 3, 5, &error, true);
assert!(msg.contains("1/3"));
assert!(msg.contains("5s"));
assert!(msg.contains("Network error"));
assert!(msg.contains("Esc to cancel"));
}
#[test]
fn test_format_retry_message_no_cancel() {
let error = RetryableError::Timeout;
let msg = format_retry_message(2, 3, 10, &error, false);
assert!(msg.contains("2/3"));
assert!(msg.contains("10s"));
assert!(msg.contains("Request timeout"));
assert!(!msg.contains("Esc"));
}
#[test]
fn test_format_error_message() {
assert_eq!(
format_error_message(&RetryableError::NetworkError),
"Connection failed. Check your network and try again."
);
assert_eq!(
format_error_message(&RetryableError::RateLimitError { retry_after: 30 }),
"Rate limit hit. Wait 30s or reduce request frequency."
);
assert_eq!(
format_error_message(&RetryableError::ServerError { code: 500 }),
"Server error (500). Try again later."
);
}
#[test]
fn test_retry_result_success() {
let result = RetryResult::Success(42);
match result {
RetryResult::Success(val) => assert_eq!(val, 42),
_ => panic!("Expected Success"),
}
}
#[test]
fn test_retry_result_exhausted() {
let result: RetryResult<()> = RetryResult::Exhausted {
attempts: 3,
last_error: RetryableError::NetworkError,
};
match result {
RetryResult::Exhausted {
attempts,
last_error,
} => {
assert_eq!(attempts, 3);
assert!(matches!(last_error, RetryableError::NetworkError));
}
_ => panic!("Expected Exhausted"),
}
}
#[test]
fn test_retry_result_aborted() {
let result: RetryResult<()> = RetryResult::Aborted { attempts: 2 };
match result {
RetryResult::Aborted { attempts } => assert_eq!(attempts, 2),
_ => panic!("Expected Aborted"),
}
}
#[test]
fn test_retry_result_timed_out() {
let result: RetryResult<()> = RetryResult::TimedOut {
attempts: 3,
elapsed_ms: 30000,
};
match result {
RetryResult::TimedOut {
attempts,
elapsed_ms,
} => {
assert_eq!(attempts, 3);
assert_eq!(elapsed_ms, 30000);
}
_ => panic!("Expected TimedOut"),
}
}
#[test]
fn test_jitter_variance() {
let config = RetryConfig::default();
let delay1 = config.next_delay(1);
let delay2 = config.next_delay(1);
let diff = (delay1.as_millis() as i64 - delay2.as_millis() as i64).abs();
let base_expected = (config.base_delay_ms as f64 * config.backoff_multiplier) as i64;
assert!(diff < base_expected / 2); }
#[test]
fn test_total_timeout_check() {
let config = RetryConfig::new()
.with_max_retries(10)
.with_total_timeout(5000);
assert!(!config.should_retry(&RetryableError::NetworkError, 20));
}
#[test]
fn test_zero_jitter() {
let config = RetryConfig::new().with_jitter(0.0);
let delay1 = config.next_delay(1);
let delay2 = config.next_delay(1);
let diff = (delay1.as_millis() as i64 - delay2.as_millis() as i64).abs();
assert!(diff <= 1); }
#[test]
fn test_has_required_wait() {
let rate_limit = RetryableError::RateLimitError { retry_after: 30 };
assert!(rate_limit.has_required_wait());
let rate_limit_no_wait = RetryableError::RateLimitError { retry_after: 0 };
assert!(!rate_limit_no_wait.has_required_wait());
let service = RetryableError::ServiceUnavailable { retry_after: 60 };
assert!(service.has_required_wait());
let network = RetryableError::NetworkError;
assert!(!network.has_required_wait());
}
#[test]
fn test_retryable_error_from_network_error() {
let timeout_err = RetryableError::from_network_error("connection timeout");
assert_eq!(timeout_err, RetryableError::Timeout);
let network_err = RetryableError::from_network_error("connection refused");
assert_eq!(network_err, RetryableError::NetworkError);
}
}