use crate::config::RetryConfig;
use crate::error::Error;
use rand::Rng;
use std::future::Future;
use std::time::Duration;
pub trait IsRetryable {
fn is_retryable(&self) -> bool;
}
impl IsRetryable for Error {
fn is_retryable(&self) -> bool {
match self {
Error::Network(e) => {
e.is_timeout() || e.is_connect()
}
Error::Io(e) => matches!(
e.kind(),
std::io::ErrorKind::TimedOut
| std::io::ErrorKind::ConnectionRefused
| std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::NotConnected
| std::io::ErrorKind::BrokenPipe
| std::io::ErrorKind::Interrupted
),
Error::Nntp(msg) => {
msg.contains("timeout")
|| msg.contains("busy")
|| msg.contains("connection")
|| msg.contains("temporary")
|| msg.contains("503") || msg.contains("400") }
Error::Download(_) => false,
Error::PostProcess(_) => false,
Error::Database(_) | Error::Sqlx(_) => false,
Error::Config { .. } => false,
Error::InvalidNzb(_) => false,
Error::NotFound(_) => false,
Error::ShuttingDown => false,
Error::Serialization(_) => false,
Error::ApiServerError(_) => false,
Error::FolderWatch(_) => false,
Error::Duplicate(_) => false,
Error::InsufficientSpace { .. } => false,
Error::DiskSpaceCheckFailed(_) => false,
Error::ExternalTool(msg) => {
msg.contains("timeout") || msg.contains("busy") || msg.contains("temporary")
}
Error::NotSupported(_) => false,
Error::Other(_) => false,
}
}
}
pub async fn download_with_retry<F, Fut, T, E>(
config: &RetryConfig,
mut operation: F,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: IsRetryable + std::fmt::Display,
{
let mut attempt = 0;
let mut delay = config.initial_delay;
loop {
match operation().await {
Ok(result) => {
if attempt > 0 {
tracing::info!(attempts = attempt + 1, "Operation succeeded after retry");
}
return Ok(result);
}
Err(e) if e.is_retryable() && attempt < config.max_attempts => {
attempt += 1;
tracing::warn!(
error = %e,
attempt = attempt,
max_attempts = config.max_attempts,
delay_ms = delay.as_millis(),
"Operation failed, retrying"
);
let jittered_delay = if config.jitter {
add_jitter(delay)
} else {
delay
};
tokio::time::sleep(jittered_delay).await;
let next_delay =
Duration::from_secs_f64(delay.as_secs_f64() * config.backoff_multiplier);
delay = next_delay.min(config.max_delay);
}
Err(e) => {
if e.is_retryable() {
tracing::error!(
error = %e,
attempts = attempt + 1,
"Operation failed after all retry attempts exhausted"
);
} else {
tracing::error!(
error = %e,
"Operation failed with non-retryable error"
);
}
return Err(e);
}
}
}
}
fn add_jitter(delay: Duration) -> Duration {
let mut rng = rand::thread_rng();
let jitter_factor: f64 = rng.gen_range(0.0..=1.0);
let jittered_secs = delay.as_secs_f64() * (1.0 + jitter_factor);
Duration::from_secs_f64(jittered_secs)
}
#[allow(clippy::unwrap_used, clippy::expect_used)]
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
#[derive(Debug)]
enum TestError {
Transient,
Permanent,
}
impl std::fmt::Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TestError::Transient => write!(f, "transient error"),
TestError::Permanent => write!(f, "permanent error"),
}
}
}
impl IsRetryable for TestError {
fn is_retryable(&self) -> bool {
matches!(self, TestError::Transient)
}
}
#[tokio::test]
async fn test_success_no_retry() {
let config = RetryConfig::default();
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result = download_with_retry(&config, || {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok::<_, TestError>(42)
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(counter.load(Ordering::SeqCst), 1, "should only call once");
}
#[tokio::test]
async fn test_retry_transient_then_succeed() {
let config = RetryConfig {
max_attempts: 3,
initial_delay: Duration::from_millis(10),
max_delay: Duration::from_secs(1),
backoff_multiplier: 2.0,
jitter: false,
};
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result = download_with_retry(&config, || {
let counter = counter_clone.clone();
async move {
let count = counter.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err(TestError::Transient)
} else {
Ok(42)
}
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(
counter.load(Ordering::SeqCst),
3,
"should retry twice before success"
);
}
#[tokio::test]
async fn test_retry_exhausted() {
let config = RetryConfig {
max_attempts: 2,
initial_delay: Duration::from_millis(10),
max_delay: Duration::from_secs(1),
backoff_multiplier: 2.0,
jitter: false,
};
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result = download_with_retry(&config, || {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<i32, _>(TestError::Transient)
}
})
.await;
assert!(result.is_err());
assert_eq!(
counter.load(Ordering::SeqCst),
3,
"should try initial + 2 retries"
);
}
#[tokio::test]
async fn test_permanent_error_no_retry() {
let config = RetryConfig::default();
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result = download_with_retry(&config, || {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<i32, _>(TestError::Permanent)
}
})
.await;
assert!(result.is_err());
assert_eq!(
counter.load(Ordering::SeqCst),
1,
"should not retry permanent error"
);
}
#[tokio::test]
async fn test_exponential_backoff() {
let config = RetryConfig {
max_attempts: 3,
initial_delay: Duration::from_millis(10),
max_delay: Duration::from_secs(1),
backoff_multiplier: 2.0,
jitter: false,
};
let start = std::time::Instant::now();
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let _result = download_with_retry(&config, || {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<i32, _>(TestError::Transient)
}
})
.await;
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(70),
"should wait at least 70ms, waited {:?}",
elapsed
);
assert!(
elapsed < Duration::from_secs(2),
"should not wait too long, waited {:?}",
elapsed
);
}
#[tokio::test]
async fn test_jitter_adds_randomness() {
let delay = Duration::from_millis(100);
let jittered1 = add_jitter(delay);
let jittered2 = add_jitter(delay);
assert!(jittered1 >= delay);
assert!(jittered1 <= delay * 2);
assert!(jittered2 >= delay);
assert!(jittered2 <= delay * 2);
}
#[tokio::test]
async fn test_max_delay_cap() {
let config = RetryConfig {
max_attempts: 5,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(3),
backoff_multiplier: 10.0, jitter: false,
};
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let start = std::time::Instant::now();
let _result = download_with_retry(&config, || {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<i32, _>(TestError::Transient)
}
})
.await;
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_secs(13),
"should wait at least 13s with max_delay cap, waited {:?}",
elapsed
);
assert!(
elapsed < Duration::from_secs(15),
"should not exceed expected time significantly, waited {:?}",
elapsed
);
}
#[tokio::test]
async fn test_individual_retry_delays_never_exceed_max_delay() {
let config = RetryConfig {
max_attempts: 4,
initial_delay: Duration::from_millis(50),
max_delay: Duration::from_millis(200),
backoff_multiplier: 10.0,
jitter: false,
};
let timestamps = Arc::new(tokio::sync::Mutex::new(Vec::new()));
let ts_clone = timestamps.clone();
let _result = download_with_retry(&config, || {
let ts = ts_clone.clone();
async move {
ts.lock().await.push(std::time::Instant::now());
Err::<i32, _>(TestError::Transient)
}
})
.await;
let ts = timestamps.lock().await;
assert_eq!(ts.len(), 5, "should have initial + 4 retries = 5 calls");
let max_allowed = Duration::from_millis(350); for i in 1..ts.len() {
let gap = ts[i].duration_since(ts[i - 1]);
assert!(
gap <= max_allowed,
"delay between attempt {} and {} was {:?}, which exceeds max_delay (200ms) + tolerance ({:?})",
i,
i + 1,
gap,
max_allowed
);
}
let gap_3_to_4 = ts[3].duration_since(ts[2]);
let gap_4_to_5 = ts[4].duration_since(ts[3]);
assert!(
gap_3_to_4 >= Duration::from_millis(150),
"third delay should be ~200ms (capped), was {:?}",
gap_3_to_4
);
assert!(
gap_4_to_5 >= Duration::from_millis(150),
"fourth delay should be ~200ms (capped), was {:?}",
gap_4_to_5
);
}
#[test]
fn test_error_is_retryable_io() {
let timeout_err = Error::Io(std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout"));
assert!(timeout_err.is_retryable());
let connection_refused = Error::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionRefused,
"refused",
));
assert!(connection_refused.is_retryable());
let not_found = Error::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
"not found",
));
assert!(!not_found.is_retryable());
}
#[test]
fn test_error_is_retryable_nntp() {
let timeout = Error::Nntp("connection timeout".to_string());
assert!(timeout.is_retryable());
let busy = Error::Nntp("server busy (400)".to_string());
assert!(busy.is_retryable());
let auth_failed = Error::Nntp("authentication failed".to_string());
assert!(!auth_failed.is_retryable());
}
#[test]
fn test_error_is_retryable_permanent() {
use crate::error::{DatabaseError, DownloadError};
assert!(
!Error::Config {
message: "bad config".to_string(),
key: None,
}
.is_retryable()
);
assert!(
!Error::Database(DatabaseError::QueryFailed("db error".to_string())).is_retryable()
);
assert!(!Error::InvalidNzb("bad nzb".to_string()).is_retryable());
assert!(!Error::NotFound("not found".to_string()).is_retryable());
assert!(!Error::Download(DownloadError::NotFound { id: 123 }).is_retryable());
}
#[test]
fn add_jitter_stays_within_bounds_over_many_iterations() {
let delay = Duration::from_millis(50);
for i in 0..200 {
let jittered = add_jitter(delay);
assert!(
jittered >= delay,
"iteration {i}: jittered {jittered:?} < base delay {delay:?}"
);
assert!(
jittered <= delay * 2,
"iteration {i}: jittered {jittered:?} > 2x base delay {:?}",
delay * 2
);
}
}
#[test]
fn add_jitter_on_zero_delay_returns_zero() {
let jittered = add_jitter(Duration::ZERO);
assert_eq!(
jittered,
Duration::ZERO,
"jitter on zero delay should remain zero"
);
}
#[tokio::test]
async fn zero_max_attempts_fails_on_first_transient_error() {
let config = RetryConfig {
max_attempts: 0,
initial_delay: Duration::from_millis(1),
max_delay: Duration::from_secs(1),
backoff_multiplier: 2.0,
jitter: false,
};
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result = download_with_retry(&config, || {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<i32, _>(TestError::Transient)
}
})
.await;
assert!(
matches!(result, Err(TestError::Transient)),
"should return the transient error without retrying"
);
assert_eq!(
counter.load(Ordering::SeqCst),
1,
"should call the operation exactly once (no retries when max_attempts=0)"
);
}
#[tokio::test]
async fn backoff_delays_increase_exponentially() {
let config = RetryConfig {
max_attempts: 3,
initial_delay: Duration::from_millis(50),
max_delay: Duration::from_secs(10),
backoff_multiplier: 2.0,
jitter: false,
};
let timestamps = Arc::new(tokio::sync::Mutex::new(Vec::new()));
let ts_clone = timestamps.clone();
let _result = download_with_retry(&config, || {
let ts = ts_clone.clone();
async move {
ts.lock().await.push(std::time::Instant::now());
Err::<i32, _>(TestError::Transient)
}
})
.await;
let ts = timestamps.lock().await;
assert_eq!(ts.len(), 4, "initial + 3 retries = 4 calls");
let gap1 = ts[1].duration_since(ts[0]);
let gap2 = ts[2].duration_since(ts[1]);
let gap3 = ts[3].duration_since(ts[2]);
assert!(
gap1 >= Duration::from_millis(40),
"first delay should be ~50ms, was {:?}",
gap1
);
assert!(
gap2 >= Duration::from_millis(80),
"second delay should be ~100ms, was {:?}",
gap2
);
assert!(
gap3 >= Duration::from_millis(160),
"third delay should be ~200ms, was {:?}",
gap3
);
let ratio = gap2.as_secs_f64() / gap1.as_secs_f64();
assert!(
(1.5..=2.5).contains(&ratio),
"gap2/gap1 ratio should be ~2.0, was {ratio:.2}"
);
}
#[tokio::test]
async fn jitter_enabled_produces_delay_within_expected_range() {
let config = RetryConfig {
max_attempts: 1,
initial_delay: Duration::from_millis(50),
max_delay: Duration::from_secs(10),
backoff_multiplier: 2.0,
jitter: true,
};
let start = std::time::Instant::now();
let _result =
download_with_retry(&config, || async { Err::<i32, _>(TestError::Transient) }).await;
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(40),
"should wait at least the base delay, waited {:?}",
elapsed
);
assert!(
elapsed < Duration::from_secs(2),
"should not wait longer than expected, waited {:?}",
elapsed
);
}
#[test]
fn io_connection_reset_is_retryable() {
let err = Error::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"reset by peer",
));
assert!(
err.is_retryable(),
"ConnectionReset should be retryable for transient network glitches"
);
}
#[test]
fn io_connection_aborted_is_retryable() {
let err = Error::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
"aborted",
));
assert!(err.is_retryable());
}
#[test]
fn io_not_connected_is_retryable() {
let err = Error::Io(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"not connected",
));
assert!(err.is_retryable());
}
#[test]
fn io_broken_pipe_is_retryable() {
let err = Error::Io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"broken pipe",
));
assert!(err.is_retryable());
}
#[test]
fn io_interrupted_is_retryable() {
let err = Error::Io(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"interrupted",
));
assert!(err.is_retryable());
}
#[test]
fn io_permission_denied_is_not_retryable() {
let err = Error::Io(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"denied",
));
assert!(
!err.is_retryable(),
"PermissionDenied is permanent, not transient"
);
}
#[test]
fn nntp_503_service_unavailable_is_retryable() {
let err = Error::Nntp("503 service temporarily unavailable".to_string());
assert!(err.is_retryable());
}
#[test]
fn nntp_400_server_busy_is_retryable() {
let err = Error::Nntp("400 server too busy".to_string());
assert!(err.is_retryable());
}
#[test]
fn nntp_temporary_failure_is_retryable() {
let err = Error::Nntp("temporary failure, please retry".to_string());
assert!(err.is_retryable());
}
#[test]
fn nntp_unknown_error_without_keywords_is_not_retryable() {
let err = Error::Nntp("430 no such article".to_string());
assert!(
!err.is_retryable(),
"NNTP error without transient keywords should not be retried"
);
}
#[test]
fn external_tool_timeout_is_retryable() {
let err = Error::ExternalTool("timeout waiting for par2".to_string());
assert!(err.is_retryable());
}
#[test]
fn external_tool_busy_is_retryable() {
let err = Error::ExternalTool("process busy, try again".to_string());
assert!(err.is_retryable());
}
#[test]
fn external_tool_temporary_is_retryable() {
let err = Error::ExternalTool("temporary failure in unrar".to_string());
assert!(err.is_retryable());
}
#[test]
fn external_tool_not_found_is_not_retryable() {
let err = Error::ExternalTool("par2 not found in PATH".to_string());
assert!(
!err.is_retryable(),
"missing binary is permanent, not transient"
);
}
#[test]
fn post_process_error_is_never_retryable() {
use crate::error::PostProcessError;
let err = Error::PostProcess(PostProcessError::ExtractionFailed {
archive: std::path::PathBuf::from("test.rar"),
reason: "CRC error".to_string(),
});
assert!(!err.is_retryable(), "post-processing errors are permanent");
}
#[test]
fn shutting_down_is_not_retryable() {
assert!(
!Error::ShuttingDown.is_retryable(),
"shutdown should not trigger retries"
);
}
#[test]
fn serialization_error_is_not_retryable() {
let err = Error::Serialization(serde_json::from_str::<String>("bad json").unwrap_err());
assert!(!err.is_retryable());
}
#[test]
fn api_server_error_is_not_retryable() {
let err = Error::ApiServerError("bind failed".to_string());
assert!(!err.is_retryable());
}
#[test]
fn folder_watch_error_is_not_retryable() {
let err = Error::FolderWatch("inotify error".to_string());
assert!(!err.is_retryable());
}
#[test]
fn duplicate_error_is_not_retryable() {
let err = Error::Duplicate("already exists".to_string());
assert!(!err.is_retryable());
}
#[test]
fn insufficient_space_is_not_retryable() {
let err = Error::InsufficientSpace {
required: 1_000_000,
available: 500,
};
assert!(
!err.is_retryable(),
"disk space issues require user action, not retries"
);
}
#[test]
fn disk_space_check_failed_is_not_retryable() {
let err = Error::DiskSpaceCheckFailed("statvfs failed".to_string());
assert!(!err.is_retryable());
}
#[test]
fn not_supported_is_not_retryable() {
let err = Error::NotSupported("feature unavailable".to_string());
assert!(!err.is_retryable());
}
#[test]
fn other_error_is_not_retryable() {
let err = Error::Other("unknown problem".to_string());
assert!(!err.is_retryable());
}
}