use std::time::Duration;
use crate::retry::AttemptOutcome;
pub(crate) const MAX_RETRIES: u32 = 4;
#[derive(serde::Deserialize)]
pub(crate) struct ApiError {
#[serde(default)]
pub(crate) code: Option<serde_json::Value>,
#[serde(default)]
pub(crate) message: String,
}
impl ApiError {
pub(crate) fn code_string(&self) -> String {
match &self.code {
Some(serde_json::Value::String(s)) => s.clone(),
Some(other) => other.to_string(),
None => "unknown".to_string(),
}
}
}
pub(crate) async fn backoff(attempt: u32) {
let base_ms = 1000u64 * 2u64.pow(attempt);
let jitter = fastrand::u64(0..500);
let sleep_ms = base_ms + jitter;
tracing::debug!(attempt, sleep_ms, "exponential backoff");
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
}
pub(crate) fn status_retry_class(status: reqwest::StatusCode) -> AttemptOutcome {
match status.as_u16() {
400 | 401 | 403 | 404 => AttemptOutcome::HardFailure,
408 | 425 | 429 => AttemptOutcome::Transient,
_ if status.is_server_error() => AttemptOutcome::Transient,
_ => AttemptOutcome::HardFailure,
}
}
pub(crate) fn provider_error_retry_class(api_err: &ApiError) -> AttemptOutcome {
let code = api_err.code_string();
if let Ok(numeric) = code.parse::<u16>() {
return if numeric == 429 || (500..=599).contains(&numeric) {
AttemptOutcome::Transient
} else {
AttemptOutcome::HardFailure
};
}
match code.as_str() {
"rate_limit_exceeded" | "rate_limited" | "server_error" | "service_unavailable" => {
AttemptOutcome::Transient
}
_ => AttemptOutcome::HardFailure,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn status_retry_class_maps_client_errors_to_hard_failure() {
assert_eq!(
status_retry_class(reqwest::StatusCode::UNAUTHORIZED),
AttemptOutcome::HardFailure
);
assert_eq!(
status_retry_class(reqwest::StatusCode::BAD_REQUEST),
AttemptOutcome::HardFailure
);
assert_eq!(
status_retry_class(reqwest::StatusCode::NOT_FOUND),
AttemptOutcome::HardFailure
);
}
#[test]
fn status_retry_class_maps_rate_limit_and_server_errors_to_transient() {
assert_eq!(
status_retry_class(reqwest::StatusCode::TOO_MANY_REQUESTS),
AttemptOutcome::Transient
);
assert_eq!(
status_retry_class(reqwest::StatusCode::SERVICE_UNAVAILABLE),
AttemptOutcome::Transient
);
assert_eq!(
status_retry_class(reqwest::StatusCode::BAD_GATEWAY),
AttemptOutcome::Transient
);
}
#[test]
fn status_retry_class_treats_403_as_hard_failure() {
assert_eq!(
status_retry_class(reqwest::StatusCode::FORBIDDEN),
AttemptOutcome::HardFailure
);
}
#[test]
fn status_retry_class_treats_408_and_425_as_transient() {
assert_eq!(
status_retry_class(reqwest::StatusCode::REQUEST_TIMEOUT),
AttemptOutcome::Transient
);
assert_eq!(
status_retry_class(reqwest::StatusCode::from_u16(425).expect("425 is a valid status")),
AttemptOutcome::Transient
);
}
#[test]
fn status_retry_class_defaults_unrecognised_status_to_hard_failure() {
assert_eq!(
status_retry_class(reqwest::StatusCode::IM_A_TEAPOT),
AttemptOutcome::HardFailure
);
}
fn api_error(code: serde_json::Value) -> ApiError {
serde_json::from_value(serde_json::json!({ "code": code, "message": "x" }))
.expect("valid ApiError fixture")
}
#[test]
fn provider_error_retry_class_treats_numeric_429_and_5xx_as_transient() {
assert_eq!(
provider_error_retry_class(&api_error(serde_json::json!(429))),
AttemptOutcome::Transient
);
assert_eq!(
provider_error_retry_class(&api_error(serde_json::json!(503))),
AttemptOutcome::Transient
);
}
#[test]
fn provider_error_retry_class_treats_numeric_400_as_hard_failure() {
assert_eq!(
provider_error_retry_class(&api_error(serde_json::json!(400))),
AttemptOutcome::HardFailure
);
}
#[test]
fn provider_error_retry_class_treats_known_transient_codes_as_transient() {
assert_eq!(
provider_error_retry_class(&api_error(serde_json::json!("rate_limited"))),
AttemptOutcome::Transient
);
assert_eq!(
provider_error_retry_class(&api_error(serde_json::json!("server_error"))),
AttemptOutcome::Transient
);
}
#[test]
fn provider_error_retry_class_treats_context_length_exceeded_as_hard_failure() {
assert_eq!(
provider_error_retry_class(&api_error(serde_json::json!("context_length_exceeded"))),
AttemptOutcome::HardFailure
);
}
}