use std::time::Duration;
use crate::error::{MktError, Result};
const MAX_HINT_SECS: u64 = 120;
#[must_use]
pub fn retry_after_secs(headers: &reqwest::header::HeaderMap) -> Option<u64> {
headers
.get(reqwest::header::RETRY_AFTER)?
.to_str()
.ok()?
.trim()
.parse()
.ok()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OpKind {
Read,
Write,
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub min_delay: Duration,
pub max_delay: Duration,
}
impl RetryPolicy {
#[must_use]
pub const fn standard() -> Self {
Self {
max_attempts: 4,
min_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(30),
}
}
#[must_use]
pub const fn none() -> Self {
Self {
max_attempts: 1,
min_delay: Duration::ZERO,
max_delay: Duration::ZERO,
}
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self::standard()
}
}
fn is_retryable(kind: OpKind, error: &MktError) -> bool {
match kind {
OpKind::Read => error.is_transient(),
OpKind::Write => match error {
MktError::RateLimited { .. } => true,
MktError::Http(e) => e.is_connect(),
_ => false,
},
}
}
fn retry_hint(error: &MktError) -> Option<Duration> {
let secs = match error {
MktError::RateLimited {
retry_after_secs, ..
} => Some(*retry_after_secs),
MktError::ApiError {
retry_after: Some(secs),
..
} => Some(*secs),
_ => None,
}?;
Some(Duration::from_secs(secs.min(MAX_HINT_SECS)))
}
fn with_jitter(delay: Duration) -> Duration {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.subsec_nanos());
delay + delay.mul_f64(f64::from(nanos % 21) / 100.0)
}
pub async fn retry<T, F, Fut>(policy: &RetryPolicy, kind: OpKind, mut op: F) -> Result<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T>>,
{
let mut attempt: u32 = 0;
loop {
attempt += 1;
let error = match op().await {
Ok(value) => return Ok(value),
Err(error) => error,
};
if attempt >= policy.max_attempts || !is_retryable(kind, &error) {
return Err(error);
}
let backoff = policy
.min_delay
.saturating_mul(2_u32.saturating_pow(attempt - 1))
.min(policy.max_delay);
let delay = retry_hint(&error).unwrap_or_else(|| with_jitter(backoff));
tracing::warn!(
attempt,
delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX),
error = %error,
"transient provider error; retrying"
);
tokio::time::sleep(delay).await;
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use std::sync::atomic::{AtomicU32, Ordering};
use super::*;
fn transient_error() -> MktError {
MktError::ApiError {
provider: "test".into(),
status: 503,
message: "unavailable".into(),
retry_after: None,
}
}
fn rate_limited(secs: u64) -> MktError {
MktError::RateLimited {
provider: "test".into(),
retry_after_secs: secs,
}
}
fn validation_error() -> MktError {
MktError::ValidationError {
field: "f".into(),
message: "bad".into(),
}
}
#[allow(clippy::future_not_send)] async fn run_counting(
policy: &RetryPolicy,
kind: OpKind,
failures: u32,
error_fn: impl Fn() -> MktError,
) -> (Result<u32>, u32) {
let calls = AtomicU32::new(0);
let result = retry(policy, kind, || {
let n = calls.fetch_add(1, Ordering::SeqCst) + 1;
let error = (n <= failures).then(&error_fn);
async move { error.map_or_else(|| Ok(n), Err) }
})
.await;
(result, calls.load(Ordering::SeqCst))
}
#[tokio::test(start_paused = true)]
async fn read_retries_transient_until_success() {
let (result, calls) =
run_counting(&RetryPolicy::standard(), OpKind::Read, 2, transient_error).await;
assert_eq!(result.unwrap(), 3);
assert_eq!(calls, 3);
}
#[tokio::test(start_paused = true)]
async fn exhausted_attempts_return_last_error() {
let (result, calls) =
run_counting(&RetryPolicy::standard(), OpKind::Read, 99, transient_error).await;
assert!(result.unwrap_err().is_transient());
assert_eq!(calls, 4, "standard policy makes 4 attempts");
}
#[tokio::test(start_paused = true)]
async fn non_transient_errors_never_retry() {
let (result, calls) =
run_counting(&RetryPolicy::standard(), OpKind::Read, 99, validation_error).await;
assert!(matches!(
result.unwrap_err(),
MktError::ValidationError { .. }
));
assert_eq!(calls, 1);
}
#[tokio::test(start_paused = true)]
async fn policy_none_makes_a_single_attempt() {
let (result, calls) =
run_counting(&RetryPolicy::none(), OpKind::Read, 99, transient_error).await;
assert!(result.is_err());
assert_eq!(calls, 1);
}
#[tokio::test(start_paused = true)]
async fn writes_do_not_retry_server_errors() {
let (result, calls) =
run_counting(&RetryPolicy::standard(), OpKind::Write, 99, transient_error).await;
assert!(result.is_err());
assert_eq!(calls, 1, "a 503 may have executed the write");
}
#[tokio::test(start_paused = true)]
async fn writes_retry_rate_limits() {
let (result, calls) = run_counting(&RetryPolicy::standard(), OpKind::Write, 1, || {
rate_limited(7)
})
.await;
assert_eq!(result.unwrap(), 2);
assert_eq!(calls, 2);
}
#[tokio::test(start_paused = true)]
async fn server_hint_overrides_backoff() {
let start = tokio::time::Instant::now();
let (result, _) = run_counting(&RetryPolicy::standard(), OpKind::Read, 1, || {
rate_limited(7)
})
.await;
assert!(result.is_ok());
let waited = start.elapsed();
assert!(
waited >= Duration::from_secs(7) && waited < Duration::from_secs(8),
"should sleep the hinted 7s, slept {waited:?}"
);
}
#[tokio::test(start_paused = true)]
async fn absurd_hints_are_clamped() {
let start = tokio::time::Instant::now();
let (result, _) = run_counting(&RetryPolicy::standard(), OpKind::Read, 1, || {
rate_limited(86_400)
})
.await;
assert!(result.is_ok());
assert!(
start.elapsed() <= Duration::from_secs(MAX_HINT_SECS + 1),
"hints are clamped to {MAX_HINT_SECS}s"
);
}
}