use crate::error::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorClass {
ContextOverflow,
RateLimited,
AuthError,
ServerError,
InvalidRequest,
Network,
Unknown,
}
pub fn classify(error: &Error) -> ErrorClass {
let inner = match error {
Error::WithPartialUsage { source, .. } => source.as_ref(),
other => other,
};
match inner {
Error::Api { status, message } => classify_api(*status, message),
Error::Http(_) => ErrorClass::Network,
_ => ErrorClass::Unknown,
}
}
fn classify_api(status: u16, message: &str) -> ErrorClass {
match status {
401 | 403 => ErrorClass::AuthError,
429 => ErrorClass::RateLimited,
500 | 502 | 503 | 529 => ErrorClass::ServerError,
400 => {
if is_context_overflow(message) {
ErrorClass::ContextOverflow
} else {
ErrorClass::InvalidRequest
}
}
_ => ErrorClass::Unknown,
}
}
fn is_context_overflow(message: &str) -> bool {
const PATTERNS: &[&str] = &[
"prompt is too long",
"maximum context length",
"context_length_exceeded",
"context window",
"too many tokens",
"input is too long",
"exceeds the model's maximum context",
"request too large",
"content too large",
];
let lower = message.to_lowercase();
PATTERNS.iter().any(|p| lower.contains(p))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn classify_401_as_auth_error() {
let err = Error::Api {
status: 401,
message: "Unauthorized".into(),
};
assert_eq!(classify(&err), ErrorClass::AuthError);
}
#[test]
fn classify_403_as_auth_error() {
let err = Error::Api {
status: 403,
message: "Forbidden".into(),
};
assert_eq!(classify(&err), ErrorClass::AuthError);
}
#[test]
fn classify_429_as_rate_limited() {
let err = Error::Api {
status: 429,
message: "Too Many Requests".into(),
};
assert_eq!(classify(&err), ErrorClass::RateLimited);
}
#[test]
fn classify_500_as_server_error() {
let err = Error::Api {
status: 500,
message: "Internal Server Error".into(),
};
assert_eq!(classify(&err), ErrorClass::ServerError);
}
#[test]
fn classify_502_as_server_error() {
let err = Error::Api {
status: 502,
message: "Bad Gateway".into(),
};
assert_eq!(classify(&err), ErrorClass::ServerError);
}
#[test]
fn classify_503_as_server_error() {
let err = Error::Api {
status: 503,
message: "Service Unavailable".into(),
};
assert_eq!(classify(&err), ErrorClass::ServerError);
}
#[test]
fn classify_529_as_server_error() {
let err = Error::Api {
status: 529,
message: "Overloaded".into(),
};
assert_eq!(classify(&err), ErrorClass::ServerError);
}
#[test]
fn classify_400_prompt_too_long() {
let err = Error::Api {
status: 400,
message: "prompt is too long".into(),
};
assert_eq!(classify(&err), ErrorClass::ContextOverflow);
}
#[test]
fn classify_400_maximum_context_length() {
let err = Error::Api {
status: 400,
message: "This request exceeds the maximum context length".into(),
};
assert_eq!(classify(&err), ErrorClass::ContextOverflow);
}
#[test]
fn classify_400_context_length_exceeded() {
let err = Error::Api {
status: 400,
message: "context_length_exceeded".into(),
};
assert_eq!(classify(&err), ErrorClass::ContextOverflow);
}
#[test]
fn classify_400_request_too_large() {
let err = Error::Api {
status: 400,
message: "request too large for this model".into(),
};
assert_eq!(classify(&err), ErrorClass::ContextOverflow);
}
#[test]
fn classify_400_content_too_large() {
let err = Error::Api {
status: 400,
message: "content too large".into(),
};
assert_eq!(classify(&err), ErrorClass::ContextOverflow);
}
#[test]
fn classify_400_max_tokens_parameter_is_not_overflow() {
let err = Error::Api {
status: 400,
message: "max_tokens: 4096 must be less than 2048".into(),
};
assert_eq!(classify(&err), ErrorClass::InvalidRequest);
}
#[test]
fn classify_400_context_window() {
let err = Error::Api {
status: 400,
message: "exceeds the context window".into(),
};
assert_eq!(classify(&err), ErrorClass::ContextOverflow);
}
#[test]
fn classify_400_too_many_tokens() {
let err = Error::Api {
status: 400,
message: "too many tokens in the request".into(),
};
assert_eq!(classify(&err), ErrorClass::ContextOverflow);
}
#[test]
fn classify_400_input_too_long() {
let err = Error::Api {
status: 400,
message: "input is too long for model".into(),
};
assert_eq!(classify(&err), ErrorClass::ContextOverflow);
}
#[test]
fn classify_400_exceeds_model_maximum_context() {
let err = Error::Api {
status: 400,
message: "exceeds the model's maximum context length".into(),
};
assert_eq!(classify(&err), ErrorClass::ContextOverflow);
}
#[test]
fn classify_400_case_insensitive() {
let err = Error::Api {
status: 400,
message: "PROMPT IS TOO LONG".into(),
};
assert_eq!(classify(&err), ErrorClass::ContextOverflow);
}
#[test]
fn classify_400_generic_as_invalid_request() {
let err = Error::Api {
status: 400,
message: "invalid parameter: temperature must be between 0 and 1".into(),
};
assert_eq!(classify(&err), ErrorClass::InvalidRequest);
}
#[test]
fn classify_http_error_as_network() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("test runtime");
let http_err = rt
.block_on(reqwest::get("http://[::0]:1"))
.expect_err("should fail");
let err = Error::Http(http_err);
assert_eq!(classify(&err), ErrorClass::Network);
}
#[test]
fn classify_agent_error_as_unknown() {
let err = Error::Agent("something went wrong".into());
assert_eq!(classify(&err), ErrorClass::Unknown);
}
#[test]
fn classify_max_turns_exceeded_as_unknown() {
let err = Error::MaxTurnsExceeded(10);
assert_eq!(classify(&err), ErrorClass::Unknown);
}
#[test]
fn classify_truncated_as_unknown() {
let err = Error::Truncated;
assert_eq!(classify(&err), ErrorClass::Unknown);
}
#[test]
fn classify_config_error_as_unknown() {
let err = Error::Config("bad config".into());
assert_eq!(classify(&err), ErrorClass::Unknown);
}
#[test]
fn classify_mcp_error_as_unknown() {
let err = Error::Mcp("connection refused".into());
assert_eq!(classify(&err), ErrorClass::Unknown);
}
#[test]
fn classify_unwraps_with_partial_usage() {
use crate::llm::types::TokenUsage;
let inner = Error::Api {
status: 429,
message: "rate limited".into(),
};
let wrapped = inner.with_partial_usage(TokenUsage {
input_tokens: 100,
output_tokens: 50,
..Default::default()
});
assert_eq!(classify(&wrapped), ErrorClass::RateLimited);
}
#[test]
fn classify_unwraps_partial_usage_context_overflow() {
use crate::llm::types::TokenUsage;
let inner = Error::Api {
status: 400,
message: "prompt is too long".into(),
};
let wrapped = inner.with_partial_usage(TokenUsage::default());
assert_eq!(classify(&wrapped), ErrorClass::ContextOverflow);
}
#[test]
fn classify_unknown_status_as_unknown() {
let err = Error::Api {
status: 418,
message: "I'm a teapot".into(),
};
assert_eq!(classify(&err), ErrorClass::Unknown);
}
}