use thiserror::Error;
#[derive(Debug, Error)]
pub enum ProviderError {
#[error("HTTP request failed: {0}")]
HttpError(#[from] reqwest::Error),
#[error(
"API error ({status}){}: {message}",
error_type
.as_ref()
.filter(|t| !t.is_empty())
.map(|t| format!(" [{}]", t))
.unwrap_or_default()
)]
ApiError {
status: u16,
message: String,
error_type: Option<String>,
},
#[error("Invalid API key")]
InvalidApiKey,
#[error("Rate limit exceeded: {0}")]
RateLimitExceeded(String),
#[error("Invalid request: {0}")]
InvalidRequest(String),
#[error("Model not found: {0}")]
ModelNotFound(String),
#[error("Context length exceeded: {0} tokens")]
ContextLengthExceeded(u32),
#[error("Streaming not supported by this provider")]
StreamingNotSupported,
#[error("Tools not supported by this provider")]
ToolsNotSupported,
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Streaming error: {0}")]
StreamError(String),
#[error("Request timed out after {0}s")]
Timeout(u64),
#[error("Internal error: {0}")]
Internal(String),
}
impl ProviderError {
pub fn is_retryable(&self) -> bool {
match self {
ProviderError::HttpError(_)
| ProviderError::RateLimitExceeded(_)
| ProviderError::Timeout(_) => true,
ProviderError::ApiError { status, .. } if *status >= 500 => true,
ProviderError::ApiError {
status, message, ..
} if (400..500).contains(status) && is_html_error_body(message) => true,
ProviderError::ApiError {
status: 400,
message,
error_type,
} => is_transient_proxy_400(message, error_type.as_deref()),
_ => false,
}
}
pub fn status_code(&self) -> Option<u16> {
match self {
ProviderError::ApiError { status, .. } => Some(*status),
_ => None,
}
}
pub fn is_model_unsupported(&self) -> bool {
match self {
ProviderError::ModelNotFound(_) => true,
ProviderError::ApiError {
error_type,
message,
..
} => {
let type_hit = error_type.as_ref().is_some_and(|t| {
let t = t.to_ascii_lowercase();
t == "modelerror"
|| t == "model_error"
|| t == "model_not_found"
|| t == "invalid_model"
});
let msg = message.to_ascii_lowercase();
let msg_hit = msg.contains("model")
&& (msg.contains("not supported")
|| msg.contains("not found")
|| msg.contains("unsupported"));
type_hit || msg_hit
}
_ => false,
}
}
}
pub(crate) fn is_html_error_body(message: &str) -> bool {
let head: String = message
.trim_start()
.chars()
.take(256)
.collect::<String>()
.to_ascii_lowercase();
head.contains("<!doctype")
|| head.contains("<html")
|| head.contains("<head")
|| head.contains("<body")
}
pub(crate) fn is_transient_proxy_400(message: &str, error_type: Option<&str>) -> bool {
if error_type.is_some_and(|t| !t.is_empty()) {
return false;
}
let m = message.trim().to_ascii_lowercase();
if m.is_empty() {
return true;
}
const TRANSIENT_HINTS: &[&str] = &[
"provider returned error",
"upstream error",
"internal error",
"temporary",
"try again",
"bad gateway",
];
TRANSIENT_HINTS.iter().any(|h| m.contains(h))
}
pub type Result<T> = std::result::Result<T, ProviderError>;
impl crate::utils::retry::RetryableError for ProviderError {
fn is_retryable(&self) -> bool {
provider_error_is_retryable(self)
}
fn retry_after(&self) -> Option<std::time::Duration> {
let msg = match self {
ProviderError::RateLimitExceeded(m) => m.as_str(),
ProviderError::ApiError {
status, message, ..
} if *status == 429 => message.as_str(),
_ => return None,
};
parse_retry_seconds(msg).map(|secs| std::time::Duration::from_secs(secs.min(30)))
}
fn is_hard_down(&self) -> bool {
let ProviderError::HttpError(e) = self else {
return false;
};
if e.is_connect() {
return true;
}
let mut source: Option<&(dyn std::error::Error + 'static)> = std::error::Error::source(e);
while let Some(err) = source {
if looks_like_connection_failure(&err.to_string()) {
return true;
}
source = err.source();
}
false
}
}
pub(crate) fn looks_like_connection_failure(msg: &str) -> bool {
let m = msg.to_ascii_lowercase();
const NEEDLES: &[&str] = &[
"dns error",
"failed to lookup address",
"name or service not known", "nodename nor servname", "no such host",
"could not resolve",
"name resolution",
"connection refused",
"network is unreachable",
"no route to host",
"connection reset",
];
NEEDLES.iter().any(|n| m.contains(n))
}
fn provider_error_is_retryable(e: &ProviderError) -> bool {
e.is_retryable()
}
fn parse_retry_seconds(msg: &str) -> Option<u64> {
use regex::Regex;
let patterns = [
r"(\d+)\s*seconds?",
r"(\d+)\s*s\b",
r"retry in (\d+)",
r"wait (\d+)",
];
for pattern in patterns {
if let Ok(re) = Regex::new(pattern)
&& let Some(captures) = re.captures(msg)
&& let Some(num_str) = captures.get(1)
&& let Ok(secs) = num_str.as_str().parse::<u64>()
{
return Some(secs);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_retryable() {
let rate_limit = ProviderError::RateLimitExceeded("Try again later".to_string());
assert!(rate_limit.is_retryable());
let invalid_key = ProviderError::InvalidApiKey;
assert!(!invalid_key.is_retryable());
let server_error = ProviderError::ApiError {
status: 500,
message: "Internal Server Error".to_string(),
error_type: None,
};
assert!(server_error.is_retryable());
let client_error = ProviderError::ApiError {
status: 400,
message: "Bad Request".to_string(),
error_type: None,
};
assert!(!client_error.is_retryable());
}
#[test]
fn test_status_code() {
let error = ProviderError::ApiError {
status: 429,
message: "Too many requests".to_string(),
error_type: Some("rate_limit_error".to_string()),
};
assert_eq!(error.status_code(), Some(429));
let invalid_key = ProviderError::InvalidApiKey;
assert_eq!(invalid_key.status_code(), None);
}
}