#![deny(missing_docs)]
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ErrorClass {
RateLimit,
Overloaded,
Server,
Timeout,
Auth,
ContextWindow,
ContentPolicy,
Malformed,
NotFound,
BillingQuota,
Unknown,
}
impl ErrorClass {
pub fn is_retriable(self) -> bool {
matches!(
self,
ErrorClass::RateLimit
| ErrorClass::Overloaded
| ErrorClass::Server
| ErrorClass::Timeout
)
}
}
pub fn classify(status: u16, body: &str) -> ErrorClass {
let lower = body.to_ascii_lowercase();
if has(&lower, "rate_limit") || has(&lower, "throttling") || has(&lower, "too many requests") {
return ErrorClass::RateLimit;
}
if has(&lower, "overloaded") || has(&lower, "serviceunavailable") {
return ErrorClass::Overloaded;
}
if has(&lower, "context_length_exceeded")
|| has(&lower, "maximum context length")
|| has(&lower, "exceeds the maximum")
|| has(&lower, "context window")
{
return ErrorClass::ContextWindow;
}
if has(&lower, "content_policy") || has(&lower, "content_filter") || has(&lower, "safety") {
return ErrorClass::ContentPolicy;
}
if has(&lower, "insufficient_quota") || has(&lower, "billing") || has(&lower, "credit") {
return ErrorClass::BillingQuota;
}
if has(&lower, "timeout") || has(&lower, "timed out") {
return ErrorClass::Timeout;
}
if has(&lower, "authentication")
|| has(&lower, "invalid api key")
|| has(&lower, "unauthorized")
|| has(&lower, "forbidden")
{
return ErrorClass::Auth;
}
if has(&lower, "not found") || has(&lower, "model not found") {
return ErrorClass::NotFound;
}
if has(&lower, "invalid_request")
|| has(&lower, "validationexception")
|| has(&lower, "bad request")
|| has(&lower, "malformed")
{
return ErrorClass::Malformed;
}
match status {
401 | 403 => ErrorClass::Auth,
402 => ErrorClass::BillingQuota,
404 => ErrorClass::NotFound,
408 => ErrorClass::Timeout,
429 => ErrorClass::RateLimit,
503 => ErrorClass::Overloaded,
500..=599 => ErrorClass::Server,
400 => ErrorClass::Malformed,
_ => ErrorClass::Unknown,
}
}
fn has(haystack: &str, needle: &str) -> bool {
haystack.contains(needle)
}