use std::time::Duration;
use aonyx_core::{AonyxError, Result};
#[derive(Debug, Clone, Copy)]
pub struct RetryPolicy {
pub max_retries: u32,
pub base_backoff_ms: u64,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_retries: 3,
base_backoff_ms: 500,
}
}
}
pub fn is_retriable_status(code: u16) -> bool {
code == 429 || (500..=599).contains(&code)
}
pub fn backoff_ms(policy: RetryPolicy, attempt: u32) -> u64 {
const CAP_MS: u64 = 30_000;
let shift = attempt.saturating_sub(1).min(16);
policy
.base_backoff_ms
.saturating_mul(1u64 << shift)
.min(CAP_MS)
}
pub async fn send_with_retry(
builder: reqwest::RequestBuilder,
policy: RetryPolicy,
label: &str,
) -> Result<reqwest::Response> {
let mut attempt = 0u32;
loop {
let this = builder.try_clone().ok_or_else(|| {
AonyxError::Provider(format!("{label}: request body is not retry-cloneable"))
})?;
match this.send().await {
Ok(resp) => {
if is_retriable_status(resp.status().as_u16()) && attempt < policy.max_retries {
attempt += 1;
tokio::time::sleep(Duration::from_millis(backoff_ms(policy, attempt))).await;
continue;
}
return Ok(resp);
}
Err(e) => {
if attempt < policy.max_retries {
attempt += 1;
tokio::time::sleep(Duration::from_millis(backoff_ms(policy, attempt))).await;
continue;
}
return Err(AonyxError::Provider(format!("{label} send: {e}")));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn retriable_status_covers_429_and_5xx_only() {
assert!(is_retriable_status(429));
assert!(is_retriable_status(500));
assert!(is_retriable_status(503));
assert!(is_retriable_status(599));
assert!(!is_retriable_status(200));
assert!(!is_retriable_status(400));
assert!(!is_retriable_status(401));
assert!(!is_retriable_status(404));
}
#[test]
fn backoff_is_exponential_and_capped() {
let p = RetryPolicy {
max_retries: 10,
base_backoff_ms: 500,
};
assert_eq!(backoff_ms(p, 1), 500); assert_eq!(backoff_ms(p, 2), 1000); assert_eq!(backoff_ms(p, 3), 2000); assert_eq!(backoff_ms(p, 4), 4000);
assert_eq!(backoff_ms(p, 20), 30_000);
}
#[test]
fn default_policy_is_three_retries() {
let p = RetryPolicy::default();
assert_eq!(p.max_retries, 3);
assert_eq!(p.base_backoff_ms, 500);
}
}