use crate::error::Error;
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[non_exhaustive]
pub enum ToolErrorKind {
Transient,
RateLimit,
Quota,
Auth,
Permanent,
Validation,
Internal,
}
impl ToolErrorKind {
#[must_use]
pub fn classify(error: &Error) -> Self {
use crate::error::ProviderErrorKind;
match error {
Error::Provider {
kind: ProviderErrorKind::Network | ProviderErrorKind::Tls | ProviderErrorKind::Dns,
..
} => Self::Transient,
Error::Provider {
kind: ProviderErrorKind::Http(429),
retry_after,
..
} => {
if retry_after.is_some() {
Self::RateLimit
} else {
Self::Quota
}
}
Error::Provider {
kind: ProviderErrorKind::Http(status),
..
} => {
if *status == 401 || *status == 403 {
Self::Auth
} else if (500..600).contains(status) || *status == 408 || *status == 425 {
Self::Transient
} else {
Self::Permanent
}
}
Error::Auth(_) => Self::Auth,
Error::UsageLimitExceeded(_) => Self::Quota,
Error::InvalidRequest(_) | Error::Serde(_) => Self::Validation,
_ => Self::Internal,
}
}
#[must_use]
pub const fn is_retryable(self) -> bool {
matches!(self, Self::Transient | Self::RateLimit)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn provider_network_classifies_as_transient() {
let err = Error::provider_network("connect refused");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Transient);
assert!(ToolErrorKind::classify(&err).is_retryable());
}
#[test]
fn provider_dns_classifies_as_transient() {
let err = Error::provider_dns("no such host");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Transient);
}
#[test]
fn provider_5xx_classifies_as_transient() {
let err = Error::provider_http(503, "down");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Transient);
let err = Error::provider_http(502, "bad gateway");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Transient);
}
#[test]
fn http_408_and_425_classify_as_transient() {
let err = Error::provider_http(408, "timeout");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Transient);
let err = Error::provider_http(425, "too early");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Transient);
}
#[test]
fn http_429_with_retry_after_classifies_as_rate_limit() {
let err = Error::provider_http(429, "slow down").with_retry_after(Duration::from_secs(5));
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::RateLimit);
assert!(ToolErrorKind::classify(&err).is_retryable());
}
#[test]
fn http_429_without_retry_after_classifies_as_quota() {
let err = Error::provider_http(429, "monthly cap reached");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Quota);
assert!(!ToolErrorKind::classify(&err).is_retryable());
}
#[test]
fn http_401_403_classify_as_auth() {
let err = Error::provider_http(401, "unauthorized");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Auth);
let err = Error::provider_http(403, "forbidden");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Auth);
assert!(!ToolErrorKind::classify(&err).is_retryable());
}
#[test]
fn http_4xx_other_classifies_as_permanent() {
let err = Error::provider_http(404, "not found");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Permanent);
let err = Error::provider_http(422, "unprocessable");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Permanent);
assert!(!ToolErrorKind::classify(&err).is_retryable());
}
#[test]
fn invalid_request_and_serde_classify_as_validation() {
let err = Error::invalid_request("bad input");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Validation);
let serde_err: serde_json::Error = serde_json::from_str::<i32>("not-a-number").unwrap_err();
let err: Error = serde_err.into();
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Validation);
}
#[test]
fn config_classifies_as_internal() {
let err = Error::config("misconfigured");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Internal);
}
#[test]
fn usage_limit_exceeded_classifies_as_quota() {
use crate::run_budget::UsageLimitBreach;
let err = Error::UsageLimitExceeded(UsageLimitBreach::Requests {
limit: 10,
observed: 11,
});
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Quota);
}
}