use thiserror::Error;
#[derive(Debug, Error)]
pub enum ProviderError {
#[error("HTTP request failed: {0}")]
HttpError(#[from] reqwest::Error),
#[error(
"API error ({status}){}: {message}",
error_type
.as_ref()
.filter(|t| !t.is_empty())
.map(|t| format!(" [{}]", t))
.unwrap_or_default()
)]
ApiError {
status: u16,
message: String,
error_type: Option<String>,
},
#[error("Invalid API key")]
InvalidApiKey,
#[error("Rate limit exceeded: {0}")]
RateLimitExceeded(String),
#[error("Invalid request: {0}")]
InvalidRequest(String),
#[error("Model not found: {0}")]
ModelNotFound(String),
#[error("Context length exceeded: {0} tokens")]
ContextLengthExceeded(u32),
#[error("Streaming not supported by this provider")]
StreamingNotSupported,
#[error("Tools not supported by this provider")]
ToolsNotSupported,
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Streaming error: {0}")]
StreamError(String),
#[error("Request timed out after {0}s")]
Timeout(u64),
#[error("Internal error: {0}")]
Internal(String),
}
impl ProviderError {
pub fn is_retryable(&self) -> bool {
match self {
ProviderError::HttpError(_)
| ProviderError::RateLimitExceeded(_)
| ProviderError::Timeout(_)
| ProviderError::StreamError(_) => true,
ProviderError::ApiError { status, .. } if *status >= 500 => true,
ProviderError::ApiError {
status, message, ..
} if (400..500).contains(status) && is_html_error_body(message) => true,
ProviderError::ApiError {
status: 400,
message,
error_type,
} => is_transient_proxy_400(message, error_type.as_deref()),
_ => false,
}
}
pub fn status_code(&self) -> Option<u16> {
match self {
ProviderError::ApiError { status, .. } => Some(*status),
_ => None,
}
}
pub fn is_model_unsupported(&self) -> bool {
match self {
ProviderError::ModelNotFound(_) => true,
ProviderError::ApiError {
error_type,
message,
..
} => {
let type_hit = error_type.as_ref().is_some_and(|t| {
let t = t.to_ascii_lowercase();
t == "modelerror"
|| t == "model_error"
|| t == "model_not_found"
|| t == "invalid_model"
});
let msg = message.to_ascii_lowercase();
let msg_hit = msg.contains("model")
&& (msg.contains("not supported")
|| msg.contains("not found")
|| msg.contains("unsupported"));
type_hit || msg_hit
}
_ => false,
}
}
}
pub(crate) fn is_html_error_body(message: &str) -> bool {
let head: String = message
.trim_start()
.chars()
.take(256)
.collect::<String>()
.to_ascii_lowercase();
head.contains("<!doctype")
|| head.contains("<html")
|| head.contains("<head")
|| head.contains("<body")
}
pub(crate) fn is_transient_proxy_400(message: &str, error_type: Option<&str>) -> bool {
if error_type.is_some_and(|t| !t.is_empty()) {
return false;
}
let m = message.trim().to_ascii_lowercase();
if m.is_empty() {
return true;
}
const TRANSIENT_HINTS: &[&str] = &[
"provider returned error",
"upstream error",
"internal error",
"temporary",
"try again",
"bad gateway",
];
TRANSIENT_HINTS.iter().any(|h| m.contains(h))
}
pub type Result<T> = std::result::Result<T, ProviderError>;
impl crate::utils::retry::RetryableError for ProviderError {
fn is_retryable(&self) -> bool {
provider_error_is_retryable(self)
}
fn retry_after(&self) -> Option<std::time::Duration> {
let msg = match self {
ProviderError::RateLimitExceeded(m) => m.as_str(),
ProviderError::ApiError {
status, message, ..
} if *status == 429 => message.as_str(),
_ => return None,
};
parse_retry_seconds(msg).map(|secs| std::time::Duration::from_secs(secs.min(30)))
}
}
fn provider_error_is_retryable(e: &ProviderError) -> bool {
e.is_retryable()
}
pub fn user_facing_reason(err: &ProviderError) -> String {
match err {
ProviderError::HttpError(e) => describe_reqwest_error(e),
other => other.to_string(),
}
}
pub(crate) fn describe_reqwest_error(e: &reqwest::Error) -> String {
let host_suffix = e
.url()
.and_then(|u| u.host_str())
.map(|h| format!(" ({h})"))
.unwrap_or_default();
if e.is_timeout() {
return format!("request timed out{host_suffix}");
}
let mut deepest: Option<String> = None;
let mut src: Option<&(dyn std::error::Error + 'static)> = std::error::Error::source(e);
while let Some(s) = src {
deepest = Some(s.to_string());
src = s.source();
}
let detail = deepest.unwrap_or_else(|| e.to_string());
let low = detail.to_ascii_lowercase();
let label = if low.contains("dns")
|| low.contains("lookup address")
|| low.contains("nodename nor servname")
|| low.contains("name or service not known")
|| low.contains("no such host")
|| low.contains("failed to resolve")
|| low.contains("could not resolve")
{
"DNS lookup failed"
} else if low.contains("connection refused") {
"connection refused"
} else if low.contains("connection reset") {
"connection reset by peer"
} else if low.contains("network is unreachable") {
"network unreachable"
} else if low.contains("no route to host") {
"no route to host"
} else if low.contains("timed out") || low.contains("timeout") {
"timed out"
} else if low.contains("certificate")
|| low.contains("tls")
|| low.contains("ssl")
|| low.contains("handshake")
{
"TLS/certificate error"
} else {
let trimmed: String = detail.chars().take(140).collect();
return format!("{trimmed}{host_suffix}");
};
format!("{label}{host_suffix}")
}
fn parse_retry_seconds(msg: &str) -> Option<u64> {
use regex::Regex;
let patterns = [
r"(\d+)\s*seconds?",
r"(\d+)\s*s\b",
r"retry in (\d+)",
r"wait (\d+)",
];
for pattern in patterns {
if let Ok(re) = Regex::new(pattern)
&& let Some(captures) = re.captures(msg)
&& let Some(num_str) = captures.get(1)
&& let Ok(secs) = num_str.as_str().parse::<u64>()
{
return Some(secs);
}
}
None
}