use crate::error::Error;
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ToolErrorKind {
Transient,
RateLimit,
Quota,
Auth,
Permanent,
Validation,
Internal,
}
impl ToolErrorKind {
#[must_use]
pub fn classify(error: &Error) -> Self {
use crate::error::ProviderErrorKind;
match error {
Error::ToolErrorTerminal { kind, .. } => *kind,
Error::Provider {
kind: ProviderErrorKind::Network | ProviderErrorKind::Tls | ProviderErrorKind::Dns,
..
} => Self::Transient,
Error::Provider {
kind: ProviderErrorKind::Http(429),
retry_after,
..
} => {
if retry_after.is_some() {
Self::RateLimit
} else {
Self::Quota
}
}
Error::Provider {
kind: ProviderErrorKind::Http(status),
..
} => {
if *status == 401 || *status == 403 {
Self::Auth
} else if (500..600).contains(status) || *status == 408 || *status == 425 {
Self::Transient
} else {
Self::Permanent
}
}
Error::Auth(_) => Self::Auth,
Error::UsageLimitExceeded(_) => Self::Quota,
Error::InvalidRequest(_) | Error::Serde(_) => Self::Validation,
_ => Self::Internal,
}
}
#[must_use]
pub const fn is_retryable(self) -> bool {
matches!(self, Self::Transient | Self::RateLimit)
}
#[must_use]
pub const fn wire_id(self) -> &'static str {
match self {
Self::Transient => "transient",
Self::RateLimit => "rate_limit",
Self::Quota => "quota",
Self::Auth => "auth",
Self::Permanent => "permanent",
Self::Validation => "validation",
Self::Internal => "internal",
}
}
#[must_use]
pub(crate) const fn bit_index(self) -> u32 {
match self {
Self::Transient => 0,
Self::RateLimit => 1,
Self::Quota => 2,
Self::Auth => 3,
Self::Permanent => 4,
Self::Validation => 5,
Self::Internal => 6,
}
}
}
impl std::fmt::Display for ToolErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.wire_id())
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub struct ToolErrorKindSet(u16);
impl ToolErrorKindSet {
#[cfg(test)]
const CAPACITY_BITS: u32 = u16::BITS;
#[must_use]
pub const fn empty() -> Self {
Self(0)
}
#[must_use]
pub const fn singleton(kind: ToolErrorKind) -> Self {
Self(1u16 << kind.bit_index())
}
#[must_use]
pub const fn with(self, kind: ToolErrorKind) -> Self {
Self(self.0 | (1u16 << kind.bit_index()))
}
#[must_use]
pub const fn without(self, kind: ToolErrorKind) -> Self {
Self(self.0 & !(1u16 << kind.bit_index()))
}
#[must_use]
pub const fn union(self, other: Self) -> Self {
Self(self.0 | other.0)
}
#[must_use]
pub const fn contains(self, kind: ToolErrorKind) -> bool {
(self.0 >> kind.bit_index()) & 1 == 1
}
#[must_use]
pub const fn is_empty(self) -> bool {
self.0 == 0
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn provider_network_classifies_as_transient() {
let err = Error::provider_network("connect refused");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Transient);
assert!(ToolErrorKind::classify(&err).is_retryable());
}
#[test]
fn provider_dns_classifies_as_transient() {
let err = Error::provider_dns("no such host");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Transient);
}
#[test]
fn provider_5xx_classifies_as_transient() {
let err = Error::provider_http(503, "down");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Transient);
let err = Error::provider_http(502, "bad gateway");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Transient);
}
#[test]
fn http_408_and_425_classify_as_transient() {
let err = Error::provider_http(408, "timeout");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Transient);
let err = Error::provider_http(425, "too early");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Transient);
}
#[test]
fn http_429_with_retry_after_classifies_as_rate_limit() {
let err = Error::provider_http(429, "slow down").with_retry_after(Duration::from_secs(5));
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::RateLimit);
assert!(ToolErrorKind::classify(&err).is_retryable());
}
#[test]
fn http_429_without_retry_after_classifies_as_quota() {
let err = Error::provider_http(429, "monthly cap reached");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Quota);
assert!(!ToolErrorKind::classify(&err).is_retryable());
}
#[test]
fn http_401_403_classify_as_auth() {
let err = Error::provider_http(401, "unauthorized");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Auth);
let err = Error::provider_http(403, "forbidden");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Auth);
assert!(!ToolErrorKind::classify(&err).is_retryable());
}
#[test]
fn http_4xx_other_classifies_as_permanent() {
let err = Error::provider_http(404, "not found");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Permanent);
let err = Error::provider_http(422, "unprocessable");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Permanent);
assert!(!ToolErrorKind::classify(&err).is_retryable());
}
#[test]
fn invalid_request_and_serde_classify_as_validation() {
let err = Error::invalid_request("bad input");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Validation);
let serde_err: serde_json::Error = serde_json::from_str::<i32>("not-a-number").unwrap_err();
let err: Error = serde_err.into();
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Validation);
}
#[test]
fn config_classifies_as_internal() {
let err = Error::config("misconfigured");
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Internal);
}
#[test]
fn tool_error_terminal_unwraps_to_inner_kind() {
let inner = Error::provider_http(401, "unauthorized");
let wrapped = Error::tool_error_terminal(ToolErrorKind::Auth, "my_tool", inner);
assert_eq!(ToolErrorKind::classify(&wrapped), ToolErrorKind::Auth);
let twice = Error::tool_error_terminal(ToolErrorKind::Auth, "parent_tool", wrapped);
assert_eq!(ToolErrorKind::classify(&twice), ToolErrorKind::Auth);
}
#[test]
fn bit_indices_are_unique_and_fit_in_set_width() {
fn dispatch_bit_index(k: ToolErrorKind) -> u32 {
match k {
ToolErrorKind::Transient
| ToolErrorKind::RateLimit
| ToolErrorKind::Quota
| ToolErrorKind::Auth
| ToolErrorKind::Permanent
| ToolErrorKind::Validation
| ToolErrorKind::Internal => k.bit_index(),
}
}
let every_variant = [
ToolErrorKind::Transient,
ToolErrorKind::RateLimit,
ToolErrorKind::Quota,
ToolErrorKind::Auth,
ToolErrorKind::Permanent,
ToolErrorKind::Validation,
ToolErrorKind::Internal,
];
let mut seen = std::collections::HashSet::new();
for k in every_variant {
let bi = dispatch_bit_index(k);
assert!(
bi < ToolErrorKindSet::CAPACITY_BITS,
"{k:?}.bit_index() = {bi} exceeds ToolErrorKindSet capacity \
({cap} bits) — widen the backing integer in lockstep",
cap = ToolErrorKindSet::CAPACITY_BITS,
);
assert!(seen.insert(bi), "duplicate bit_index {bi} for {k:?}");
}
}
#[test]
fn tool_error_kind_set_const_construction() {
const SET: ToolErrorKindSet = ToolErrorKindSet::empty()
.with(ToolErrorKind::Auth)
.with(ToolErrorKind::Quota)
.with(ToolErrorKind::Permanent);
assert!(SET.contains(ToolErrorKind::Auth));
assert!(SET.contains(ToolErrorKind::Quota));
assert!(SET.contains(ToolErrorKind::Permanent));
assert!(!SET.contains(ToolErrorKind::Transient));
assert!(!SET.contains(ToolErrorKind::Internal));
assert!(!SET.is_empty());
assert!(ToolErrorKindSet::empty().is_empty());
}
#[test]
fn tool_error_kind_set_without_and_union() {
let a = ToolErrorKindSet::singleton(ToolErrorKind::Auth);
let b = ToolErrorKindSet::singleton(ToolErrorKind::Quota);
let both = a.union(b);
assert!(both.contains(ToolErrorKind::Auth));
assert!(both.contains(ToolErrorKind::Quota));
let removed = both.without(ToolErrorKind::Auth);
assert!(!removed.contains(ToolErrorKind::Auth));
assert!(removed.contains(ToolErrorKind::Quota));
}
#[test]
fn usage_limit_exceeded_classifies_as_quota() {
use crate::run_budget::UsageLimitBreach;
let err = Error::UsageLimitExceeded(UsageLimitBreach::Requests {
limit: 10,
observed: 11,
});
assert_eq!(ToolErrorKind::classify(&err), ToolErrorKind::Quota);
}
}