use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use reqwest::header::HeaderMap;
use reqwest::{Method, StatusCode};
pub const DEFAULT_BASE_MS: u64 = 500;
pub const DEFAULT_CAP_MS: u64 = 30_000;
pub const DEFAULT_MAX_ELAPSED_MS: u64 = 60_000;
pub fn compute_delay(attempt: u32, base_ms: u64, cap_ms: u64) -> Duration {
let shifted = base_ms.saturating_mul(1u64 << attempt.min(20));
let jitter = jitter_ms(base_ms);
let combined = shifted.saturating_add(jitter);
Duration::from_millis(combined.min(cap_ms))
}
static JITTER_STATE: AtomicU64 = AtomicU64::new(0);
#[inline]
fn xorshift64(state: &mut u64) -> u64 {
let mut x = *state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
*state = x;
x
}
fn seed_if_needed() -> u64 {
let current = JITTER_STATE.load(Ordering::Relaxed);
if current != 0 {
return current;
}
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| (d.subsec_nanos() as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ d.as_secs())
.unwrap_or(0x9E37_79B9_7F4A_7C15);
let seed = if nanos == 0 { 1 } else { nanos };
let _ = JITTER_STATE.compare_exchange(0, seed, Ordering::Relaxed, Ordering::Relaxed);
JITTER_STATE.load(Ordering::Relaxed)
}
fn jitter_ms(bound: u64) -> u64 {
if bound == 0 {
return 0;
}
let mut state = seed_if_needed();
let threshold = u64::MAX - (u64::MAX % bound);
loop {
let x = xorshift64(&mut state);
JITTER_STATE.store(state, Ordering::Relaxed);
if x < threshold {
return x % bound;
}
}
}
pub fn is_retriable(status: StatusCode) -> bool {
matches!(status.as_u16(), 429 | 500 | 502 | 503 | 504)
}
pub fn is_post_retriable(status: StatusCode, idempotent: bool) -> bool {
if idempotent {
return is_retriable(status);
}
matches!(status.as_u16(), 429 | 503)
}
pub fn should_retry(method: &Method, status: StatusCode, idempotent: bool) -> bool {
if !is_retriable(status) {
return false;
}
if *method == Method::POST {
return is_post_retriable(status, idempotent);
}
true
}
pub fn parse_retry_after(headers: &HeaderMap, body: &str) -> Option<Duration> {
if let Some(val) = headers.get(reqwest::header::RETRY_AFTER) {
if let Ok(s) = val.to_str() {
let trimmed = s.trim();
if let Ok(secs) = trimmed.parse::<u64>() {
return Some(Duration::from_secs(secs));
}
if let Some(delta) = parse_http_date_delta(trimmed) {
return Some(delta);
}
}
}
if !body.is_empty() {
#[derive(serde::Deserialize)]
struct Body {
#[serde(default)]
retry_after_us: Option<u64>,
#[serde(default)]
error: Option<String>,
}
if let Ok(b) = serde_json::from_str::<Body>(body) {
if let Some(us) = b.retry_after_us {
return Some(Duration::from_micros(us));
}
if let Some(msg) = b.error {
if let Some(secs) = parse_retry_after_english(&msg) {
return Some(Duration::from_secs(secs));
}
}
}
}
None
}
fn parse_retry_after_english(msg: &str) -> Option<u64> {
let lower = msg.to_ascii_lowercase();
let after = lower.split("retry after").nth(1)?.trim();
let num: String = after.chars().take_while(|c| c.is_ascii_digit()).collect();
num.parse().ok()
}
fn parse_http_date_delta(s: &str) -> Option<Duration> {
let parts: Vec<&str> = s.split_whitespace().collect();
if parts.len() < 5 {
return None;
}
None.filter(|_: &Duration| {
parts[0].ends_with(',') })
}
#[cfg(test)]
mod tests {
use super::*;
use reqwest::header::HeaderValue;
#[test]
fn compute_delay_is_bounded() {
let d = compute_delay(0, 500, 30_000);
assert!(d >= Duration::from_millis(500));
assert!(d < Duration::from_millis(1_000));
}
#[test]
fn compute_delay_caps_at_cap_ms() {
let d = compute_delay(30, 500, 30_000);
assert_eq!(d, Duration::from_millis(30_000));
}
#[test]
fn post_without_idempotency_not_retried_on_500() {
assert!(!should_retry(
&Method::POST,
StatusCode::INTERNAL_SERVER_ERROR,
false
));
}
#[test]
fn post_with_idempotency_retried_on_500() {
assert!(should_retry(
&Method::POST,
StatusCode::INTERNAL_SERVER_ERROR,
true
));
}
#[test]
fn post_always_retried_on_429() {
assert!(should_retry(
&Method::POST,
StatusCode::TOO_MANY_REQUESTS,
false
));
}
#[test]
fn get_retried_on_502_504() {
assert!(should_retry(&Method::GET, StatusCode::BAD_GATEWAY, false));
assert!(should_retry(
&Method::GET,
StatusCode::GATEWAY_TIMEOUT,
false
));
}
#[test]
fn auth_errors_not_retried() {
assert!(!should_retry(&Method::GET, StatusCode::UNAUTHORIZED, false));
assert!(!should_retry(&Method::GET, StatusCode::FORBIDDEN, false));
}
#[test]
fn parse_retry_after_from_header_seconds() {
let mut h = HeaderMap::new();
h.insert("retry-after", HeaderValue::from_static("7"));
assert_eq!(parse_retry_after(&h, ""), Some(Duration::from_secs(7)));
}
#[test]
fn parse_retry_after_from_body_us() {
let h = HeaderMap::new();
let body = r#"{"error":"rate limited","retry_after_us":1500000}"#;
assert_eq!(
parse_retry_after(&h, body),
Some(Duration::from_micros(1_500_000))
);
}
#[test]
fn parse_retry_after_from_english_body() {
let h = HeaderMap::new();
let body = r#"{"error":"rate limited — retry after 3s"}"#;
assert_eq!(parse_retry_after(&h, body), Some(Duration::from_secs(3)));
}
#[test]
fn parse_retry_after_returns_none_without_hint() {
let h = HeaderMap::new();
assert_eq!(parse_retry_after(&h, r#"{"error":"other"}"#), None);
}
#[test]
fn jitter_is_roughly_uniform_over_base() {
let base: u64 = 100;
let iters: usize = 1_000;
let mut samples: Vec<u64> = Vec::with_capacity(iters);
for _ in 0..iters {
samples.push(jitter_ms(base));
}
assert!(samples.iter().all(|&s| s < base), "sample out of range");
let sum: u64 = samples.iter().sum();
let mean = sum as f64 / iters as f64;
assert!(
(45.0..=55.0).contains(&mean),
"mean {mean} outside [45, 55]"
);
let distinct: std::collections::BTreeSet<_> = samples.iter().copied().collect();
assert!(distinct.len() >= 20, "only {} distinct", distinct.len());
}
#[test]
fn jitter_decorrelates_consecutive_calls() {
let base: u64 = 1_000;
let mut prev = jitter_ms(base);
let mut identical_pairs = 0usize;
for _ in 0..256 {
let next = jitter_ms(base);
if next == prev {
identical_pairs += 1;
}
prev = next;
}
assert!(
identical_pairs <= 20,
"too many identical consecutive samples: {identical_pairs}"
);
}
#[test]
fn jitter_zero_bound_returns_zero() {
assert_eq!(jitter_ms(0), 0);
}
}