use std::future::Future;
use std::time::{Duration, Instant};
use rand_core::{OsRng, RngCore};
use crate::error::AnvilError;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RetryPolicy {
pub attempts: u32,
pub base: Duration,
pub factor: u32,
pub cap: Duration,
pub max_window: Duration,
pub connect_timeout: Option<Duration>,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
attempts: 3,
base: Duration::from_millis(250),
factor: 2,
cap: Duration::from_secs(8),
max_window: Duration::from_secs(30),
connect_timeout: None,
}
}
}
impl RetryPolicy {
#[must_use]
pub fn attempts(mut self, n: u32) -> Self {
self.attempts = n;
self
}
#[must_use]
pub fn base(mut self, d: Duration) -> Self {
self.base = d;
self
}
#[must_use]
pub fn factor(mut self, f: u32) -> Self {
self.factor = f;
self
}
#[must_use]
pub fn cap(mut self, d: Duration) -> Self {
self.cap = d;
self
}
#[must_use]
pub fn max_window(mut self, d: Duration) -> Self {
self.max_window = d;
self
}
#[must_use]
pub fn connect_timeout(mut self, d: Option<Duration>) -> Self {
self.connect_timeout = d;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Disposition {
Retry,
Fatal,
}
#[must_use]
pub fn classify(err: &AnvilError) -> Disposition {
if err.is_authentication_failed()
|| err.is_host_key_mismatch()
|| err.is_no_key_found()
|| err.is_key_encrypted()
{
return Disposition::Fatal;
}
if err.is_io() {
if let Some(kind) = err.io_kind() {
return classify_io_kind(kind);
}
}
Disposition::Fatal
}
fn classify_io_kind(kind: std::io::ErrorKind) -> Disposition {
use std::io::ErrorKind as K;
match kind {
K::ConnectionRefused
| K::TimedOut
| K::HostUnreachable
| K::NetworkUnreachable
| K::NotFound
| K::AddrNotAvailable => Disposition::Retry,
_ => Disposition::Fatal,
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RetryAttempt {
pub attempt: u32,
pub reason: String,
pub elapsed: Duration,
}
pub async fn run<F, Fut, T>(
policy: &RetryPolicy,
mut op: F,
) -> Result<(T, Vec<RetryAttempt>), AnvilError>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, AnvilError>>,
{
let started_at = Instant::now();
let mut history: Vec<RetryAttempt> = Vec::new();
let attempts = policy.attempts.max(1);
for attempt in 1..=attempts {
if attempt > 1 {
let delay = backoff_delay(policy, attempt - 1);
if started_at.elapsed() + delay > policy.max_window {
if let Some(last) = history.last() {
tracing::warn!(
target: crate::log::CAT_RETRY,
attempt = last.attempt,
reason = %last.reason,
elapsed_ms = u64::try_from(last.elapsed.as_millis()).unwrap_or(u64::MAX),
max_window_ms = u64::try_from(policy.max_window.as_millis()).unwrap_or(u64::MAX),
"retry max_window exhausted; giving up",
);
}
return Err(history_to_terminal_error(&history));
}
tokio::time::sleep(delay).await;
}
match op().await {
Ok(value) => return Ok((value, history)),
Err(e) => {
let reason = e.error_code().to_owned();
let elapsed = started_at.elapsed();
let disposition = classify(&e);
if disposition == Disposition::Fatal || attempt == attempts {
history.push(RetryAttempt {
attempt,
reason: reason.clone(),
elapsed,
});
tracing::warn!(
target: crate::log::CAT_RETRY,
attempt,
reason = %reason,
elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX),
disposition = if disposition == Disposition::Fatal { "fatal" } else { "exhausted" },
"retry loop terminating",
);
return Err(e);
}
history.push(RetryAttempt {
attempt,
reason: reason.clone(),
elapsed,
});
tracing::warn!(
target: crate::log::CAT_RETRY,
attempt,
reason = %reason,
elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX),
"retrying after transient error",
);
}
}
}
Err(history_to_terminal_error(&history))
}
fn backoff_delay(policy: &RetryPolicy, step: u32) -> Duration {
let base_ms = u64::try_from(policy.base.as_millis()).unwrap_or(u64::MAX);
let exponent_ms = base_ms.saturating_mul(u64::from(policy.factor).saturating_pow(step - 1));
let cap_ms = u64::try_from(policy.cap.as_millis()).unwrap_or(u64::MAX);
let core_ms = exponent_ms.min(cap_ms);
let jitter_max_ms = base_ms / 2;
let jitter_ms = if jitter_max_ms == 0 {
0
} else {
let mut buf = [0u8; 8];
OsRng.fill_bytes(&mut buf);
let raw = u64::from_le_bytes(buf);
raw % (jitter_max_ms + 1)
};
Duration::from_millis(core_ms.saturating_add(jitter_ms))
}
fn history_to_terminal_error(history: &[RetryAttempt]) -> AnvilError {
let last = history.last().map_or("unknown", |a| a.reason.as_str());
AnvilError::invalid_config(format!(
"retry exhausted (max_window reached); last error: {last}"
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_policy_matches_prd() {
let p = RetryPolicy::default();
assert_eq!(p.attempts, 3);
assert_eq!(p.base, Duration::from_millis(250));
assert_eq!(p.factor, 2);
assert_eq!(p.cap, Duration::from_secs(8));
assert_eq!(p.max_window, Duration::from_secs(30));
assert_eq!(p.connect_timeout, None);
}
#[test]
fn builder_setters_are_chainable() {
let p = RetryPolicy::default()
.attempts(5)
.base(Duration::from_millis(100))
.factor(3)
.cap(Duration::from_secs(2))
.max_window(Duration::from_secs(10))
.connect_timeout(Some(Duration::from_secs(5)));
assert_eq!(p.attempts, 5);
assert_eq!(p.base, Duration::from_millis(100));
assert_eq!(p.factor, 3);
assert_eq!(p.cap, Duration::from_secs(2));
assert_eq!(p.max_window, Duration::from_secs(10));
assert_eq!(p.connect_timeout, Some(Duration::from_secs(5)));
}
#[test]
fn auth_failure_is_fatal() {
let err = AnvilError::authentication_failed();
assert_eq!(classify(&err), Disposition::Fatal);
}
#[test]
fn host_key_mismatch_is_fatal() {
let err = AnvilError::host_key_mismatch("SHA256:abc");
assert_eq!(classify(&err), Disposition::Fatal);
}
#[test]
fn no_key_found_is_fatal() {
let err = AnvilError::no_key_found();
assert_eq!(classify(&err), Disposition::Fatal);
}
#[test]
fn io_connection_refused_is_retry() {
assert_eq!(
classify_io_kind(std::io::ErrorKind::ConnectionRefused),
Disposition::Retry,
);
}
#[test]
fn io_timed_out_is_retry() {
assert_eq!(
classify_io_kind(std::io::ErrorKind::TimedOut),
Disposition::Retry,
);
}
#[test]
fn io_not_found_is_retry_for_dns_nxdomain() {
assert_eq!(
classify_io_kind(std::io::ErrorKind::NotFound),
Disposition::Retry,
);
}
#[test]
fn io_permission_denied_is_fatal() {
assert_eq!(
classify_io_kind(std::io::ErrorKind::PermissionDenied),
Disposition::Fatal,
);
}
#[tokio::test]
async fn run_succeeds_on_first_try_with_empty_history() {
let p = RetryPolicy::default().attempts(3);
let (value, history) = run(&p, || async { Ok::<_, AnvilError>(42_u32) })
.await
.expect("must succeed");
assert_eq!(value, 42);
assert!(history.is_empty());
}
#[tokio::test]
async fn run_bails_immediately_on_fatal() {
let p = RetryPolicy::default().attempts(5);
let (err_count, _) = run_count_calls(&p, |_n| {
futures::future::ready::<Result<u32, AnvilError>>(Err(
AnvilError::authentication_failed(),
))
})
.await;
assert_eq!(err_count, 1);
}
#[tokio::test]
async fn run_retries_transient_errors_and_records_history() {
let p = RetryPolicy::default()
.attempts(3)
.base(Duration::from_millis(1))
.cap(Duration::from_millis(2))
.max_window(Duration::from_secs(60));
let calls = std::sync::atomic::AtomicU32::new(0);
let result: Result<(u32, Vec<RetryAttempt>), AnvilError> = run(&p, || async {
let n = calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if n < 2 {
Err(AnvilError::new(crate::error::AnvilErrorKind::Io(
std::io::Error::from(std::io::ErrorKind::ConnectionRefused),
)))
} else {
Ok::<_, AnvilError>(99)
}
})
.await;
let (value, history) = result.expect("third attempt must succeed");
assert_eq!(value, 99);
assert_eq!(history.len(), 2);
assert_eq!(history[0].attempt, 1);
assert_eq!(history[1].attempt, 2);
for entry in &history {
assert_eq!(
entry.reason, "GENERAL_ERROR",
"expected GENERAL_ERROR (io variant), got: {}",
entry.reason,
);
}
}
#[tokio::test]
async fn run_attempts_caps_after_exhausting_count() {
let p = RetryPolicy::default()
.attempts(2)
.base(Duration::from_millis(1))
.cap(Duration::from_millis(1))
.max_window(Duration::from_secs(60));
let result: Result<(u32, Vec<RetryAttempt>), AnvilError> = run(&p, || async {
Err(AnvilError::new(crate::error::AnvilErrorKind::Io(
std::io::Error::from(std::io::ErrorKind::TimedOut),
)))
})
.await;
let err = result.expect_err("must exhaust");
assert!(err.is_io());
}
async fn run_count_calls<F, Fut>(
policy: &RetryPolicy,
mut op: F,
) -> (u32, Result<u32, AnvilError>)
where
F: FnMut(u32) -> Fut,
Fut: Future<Output = Result<u32, AnvilError>>,
{
let calls = std::sync::atomic::AtomicU32::new(0);
let result: Result<(u32, Vec<RetryAttempt>), AnvilError> = run(policy, || {
let n = calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
op(n)
})
.await;
let count = calls.load(std::sync::atomic::Ordering::SeqCst);
let final_result = result.map(|(v, _)| v);
(count, final_result)
}
#[test]
fn backoff_delay_grows_exponentially_until_cap() {
let p = RetryPolicy::default()
.base(Duration::from_millis(10))
.factor(2)
.cap(Duration::from_millis(40));
let d1 = backoff_delay(&p, 1);
let d2 = backoff_delay(&p, 2);
let d3 = backoff_delay(&p, 3);
let d4 = backoff_delay(&p, 4);
assert!(d1.as_millis() >= 10 && d1.as_millis() <= 15);
assert!(d2.as_millis() >= 20 && d2.as_millis() <= 25);
assert!(d3.as_millis() >= 40 && d3.as_millis() <= 45);
assert!(d4.as_millis() >= 40 && d4.as_millis() <= 45);
}
#[test]
fn backoff_jitter_stays_within_documented_window() {
let p = RetryPolicy::default()
.base(Duration::from_millis(10))
.factor(1)
.cap(Duration::from_millis(10));
for _ in 0..1000 {
let d = backoff_delay(&p, 1);
let ms = d.as_millis();
assert!(
(10..=15).contains(&ms),
"delay {ms}ms outside [10,15]ms jitter window",
);
}
}
}