use std::time::Duration;
#[derive(Debug, Clone)]
pub enum NetworkError {
Retryable { message: String },
NonRetryable { message: String },
}
impl NetworkError {
pub fn message(&self) -> &str {
match self {
Self::Retryable { message } => message,
Self::NonRetryable { message } => message,
}
}
pub fn is_retryable(&self) -> bool {
matches!(self, Self::Retryable { .. })
}
}
impl std::fmt::Display for NetworkError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message())
}
}
impl std::error::Error for NetworkError {}
pub fn classify_response_status(status: reqwest::StatusCode, body: Option<String>) -> NetworkError {
let code = status.as_u16();
let message = if let Some(body_text) = body {
parse_error_body(status, &body_text)
} else {
format!(
"HTTP error {}: {}",
code,
status.canonical_reason().unwrap_or("Unknown")
)
};
match code {
408 | 429 | 500 | 502 | 503 | 504 => NetworkError::Retryable { message },
401 | 403 | 404 | 413 => NetworkError::NonRetryable { message },
_ => NetworkError::NonRetryable { message },
}
}
fn parse_error_body(status: reqwest::StatusCode, body: &str) -> String {
if let Ok(val) = serde_json::from_str::<serde_json::Value>(body) {
let error_obj = val
.get("error")
.and_then(|e| e.as_object())
.or_else(|| val.as_object());
if let Some(obj) = error_obj {
let msg = obj
.get("message")
.and_then(|m| m.as_str())
.or_else(|| obj.get("error").and_then(|e| e.as_str()))
.or_else(|| val.get("error").and_then(|e| e.as_str()))
.unwrap_or_else(|| status.canonical_reason().unwrap_or("Unknown"));
let mut detail = obj
.get("detail")
.and_then(|d| d.as_str())
.or_else(|| obj.get("cause").and_then(|c| c.as_str()))
.map(|s| s.to_string());
if let Some(d) = detail.as_ref()
&& let Ok(inner_val) = serde_json::from_str::<serde_json::Value>(d)
&& let Some(inner_detail) = inner_val.get("detail").and_then(|m| m.as_str())
{
detail = Some(inner_detail.to_string());
}
let mut result = format!("HTTP {} {}", status.as_u16(), msg);
if let Some(d) = detail {
result.push_str(&format!(" ({})", d));
}
return result;
}
}
if !body.is_empty() && body.len() < 200 {
format!(
"HTTP {} {}: {}",
status.as_u16(),
status.canonical_reason().unwrap_or("Unknown"),
body
)
} else {
format!(
"HTTP {} {}",
status.as_u16(),
status.canonical_reason().unwrap_or("Unknown")
)
}
}
fn classify_reqwest_error(error: &reqwest::Error) -> NetworkError {
if error.is_timeout() {
return NetworkError::Retryable {
message: "Request timed out".to_string(),
};
}
if error.is_connect() {
return NetworkError::Retryable {
message: "Connection failed".to_string(),
};
}
if let Some(status) = error.status() {
return classify_response_status(status, None);
}
NetworkError::NonRetryable {
message: error.to_string(),
}
}
pub fn classify_anyhow_error(error: anyhow::Error) -> NetworkError {
if let Some(reqwest_error) = error.downcast_ref::<reqwest::Error>() {
return classify_reqwest_error(reqwest_error);
}
NetworkError::NonRetryable {
message: error.to_string(),
}
}
pub fn backoff_delay(attempt: u32) -> Duration {
let delay_secs = 2u64.saturating_pow(attempt).min(30);
Duration::from_secs(delay_secs)
}
pub async fn backoff_sleep(attempt: u32) {
tokio::time::sleep(backoff_delay(attempt)).await;
}
pub const MAX_RETRIES: u32 = 3;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backoff_delay() {
assert_eq!(backoff_delay(1), Duration::from_secs(2));
assert_eq!(backoff_delay(2), Duration::from_secs(4));
assert_eq!(backoff_delay(3), Duration::from_secs(8));
assert_eq!(backoff_delay(4), Duration::from_secs(16));
assert_eq!(backoff_delay(5), Duration::from_secs(30)); }
#[test]
fn test_classify_status() {
let retryable = classify_response_status(reqwest::StatusCode::INTERNAL_SERVER_ERROR, None);
assert!(retryable.is_retryable());
let non_retryable = classify_response_status(reqwest::StatusCode::UNAUTHORIZED, None);
assert!(!non_retryable.is_retryable());
let rate_limit = classify_response_status(reqwest::StatusCode::TOO_MANY_REQUESTS, None);
assert!(rate_limit.is_retryable());
}
}