use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
pub max_retries: u32,
pub retry_delays: Vec<f64>,
pub retryable_status_codes: Vec<u16>,
#[serde(default = "default_initial_delay_ms")]
pub initial_delay_ms: u64,
#[serde(default = "default_backoff_factor")]
pub backoff_factor: f64,
#[serde(default = "default_max_delay_ms")]
pub max_delay_ms: u64,
}
fn default_initial_delay_ms() -> u64 {
2000
}
fn default_backoff_factor() -> f64 {
2.0
}
fn default_max_delay_ms() -> u64 {
30000
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
retry_delays: vec![1.0, 2.0, 4.0],
retryable_status_codes: vec![429, 500, 502, 503, 504],
initial_delay_ms: 2000,
backoff_factor: 2.0,
max_delay_ms: 30000,
}
}
}
impl RetryConfig {
pub fn delay_for_attempt(&self, attempt: u32) -> std::time::Duration {
if self.initial_delay_ms > 0 {
let delay_ms = self.initial_delay_ms as f64 * self.backoff_factor.powi(attempt as i32);
let capped_ms = delay_ms.min(self.max_delay_ms as f64);
let jitter_factor = 0.75 + fastrand::f64() * 0.5; let final_ms = (capped_ms * jitter_factor) as u64;
return std::time::Duration::from_millis(final_ms);
}
let idx = (attempt as usize).min(self.retry_delays.len().saturating_sub(1));
let secs = self.retry_delays.get(idx).copied().unwrap_or(4.0);
std::time::Duration::from_secs_f64(secs)
}
pub fn is_retryable_status(&self, status: u16) -> bool {
self.retryable_status_codes.contains(&status)
}
}
pub fn classify_retryable_error(
status: Option<u16>,
body: Option<&serde_json::Value>,
) -> Option<String> {
match status {
Some(429) => {
if let Some(body) = body
&& let Some(msg) = extract_error_message(body)
{
if msg.contains("rate_limit") || msg.contains("Rate") {
return Some("Rate Limited".to_string());
}
if msg.contains("too_many_requests") || msg.contains("Too Many") {
return Some("Too Many Requests".to_string());
}
}
return Some("Rate Limited".to_string());
}
Some(529) => return Some("Provider is overloaded".to_string()),
Some(503) => {
if let Some(body) = body
&& let Some(msg) = extract_error_message(body)
{
if msg.contains("overloaded") || msg.contains("Overloaded") {
return Some("Provider is overloaded".to_string());
}
if msg.contains("unavailable") || msg.contains("exhausted") {
return Some("Provider is overloaded".to_string());
}
}
return Some("Service Unavailable".to_string());
}
Some(500) => return Some("Internal Server Error".to_string()),
Some(502) => return Some("Bad Gateway".to_string()),
Some(504) => return Some("Gateway Timeout".to_string()),
_ => {}
}
if let Some(body) = body
&& let Some(msg) = extract_error_message(body)
&& (msg.contains("overloaded") || msg.contains("Overloaded"))
{
return Some("Provider is overloaded".to_string());
}
None
}
pub fn parse_retry_after(
retry_after: Option<&str>,
retry_after_ms: Option<&str>,
) -> Option<std::time::Duration> {
if let Some(ms_str) = retry_after_ms
&& let Ok(ms) = ms_str.parse::<u64>()
{
return Some(std::time::Duration::from_millis(ms));
}
let val = retry_after?;
if let Ok(secs) = val.parse::<f64>()
&& secs > 0.0
{
return Some(std::time::Duration::from_secs_f64(secs));
}
if val.contains(',')
&& val.contains("GMT")
&& let Ok(date) = httpdate::parse_http_date(val)
&& let Ok(duration) = date.duration_since(std::time::SystemTime::now())
{
return Some(duration);
}
None
}
pub(super) fn extract_error_message(body: &serde_json::Value) -> Option<String> {
if let Some(err) = body.get("error") {
if let Some(msg) = err.get("message").and_then(|v| v.as_str()) {
return Some(msg.to_string());
}
if let Some(code) = err.get("code").and_then(|v| v.as_str()) {
return Some(code.to_string());
}
if let Some(err_type) = err.get("type").and_then(|v| v.as_str()) {
return Some(err_type.to_string());
}
if let Some(msg) = err.as_str() {
return Some(msg.to_string());
}
}
if body.get("type").and_then(|v| v.as_str()) == Some("error")
&& let Some(err) = body.get("error")
&& let Some(msg) = err.get("message").and_then(|v| v.as_str())
{
return Some(msg.to_string());
}
body.get("message")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
}
#[cfg(test)]
#[path = "retry_tests.rs"]
mod tests;