use crate::telemetry::LlmErrorClass;
#[derive(Debug, thiserror::Error)]
pub enum LlmError {
#[error("rate limited")]
RateLimit {
retry_after_ms: Option<u64>,
status: u16,
},
#[error("payment required")]
PaymentRequired { status: u16 },
#[error("server error (status {status})")]
ServerError { status: u16 },
#[error("context overflow ({tokens} input tokens exceeded {limit}-token model limit)")]
ContextOverflow { tokens: u32, limit: u32 },
#[error("transport")]
Transport(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("parse")]
Parse(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("other")]
Other(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
}
impl LlmError {
pub fn display_chain(&self) -> String {
use std::error::Error as _;
let mut out = self.to_string();
let mut cursor: Option<&dyn std::error::Error> = self.source();
while let Some(layer) = cursor {
out.push_str("\n caused by: ");
out.push_str(&layer.to_string());
cursor = layer.source();
}
out
}
pub fn classify(&self) -> (LlmErrorClass, Option<u16>) {
match self {
LlmError::RateLimit {
retry_after_ms: _,
status,
} => (LlmErrorClass::RateLimit, Some(*status)),
LlmError::PaymentRequired { status } => (LlmErrorClass::PaymentRequired, Some(*status)),
LlmError::ServerError { status } => (LlmErrorClass::ServerError, Some(*status)),
LlmError::ContextOverflow { .. } => (LlmErrorClass::ContextOverflow, None),
LlmError::Transport(_) => (LlmErrorClass::Transport, None),
LlmError::Parse(_) => (LlmErrorClass::Parse, None),
LlmError::Other(_) => (LlmErrorClass::Other, None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn classify_covers_every_variant_with_correct_status() {
let cases: Vec<(LlmError, LlmErrorClass, Option<u16>)> = vec![
(
LlmError::RateLimit {
retry_after_ms: Some(1500),
status: 429,
},
LlmErrorClass::RateLimit,
Some(429),
),
(
LlmError::PaymentRequired { status: 402 },
LlmErrorClass::PaymentRequired,
Some(402),
),
(
LlmError::ServerError { status: 503 },
LlmErrorClass::ServerError,
Some(503),
),
(
LlmError::ContextOverflow {
tokens: 9000,
limit: 8192,
},
LlmErrorClass::ContextOverflow,
None,
),
(
LlmError::Transport(Box::new(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"reset",
))),
LlmErrorClass::Transport,
None,
),
(
LlmError::Parse(Box::new(std::io::Error::other("bad json"))),
LlmErrorClass::Parse,
None,
),
(
LlmError::Other(Box::new(std::io::Error::other("misc"))),
LlmErrorClass::Other,
None,
),
];
for (err, want_class, want_status) in cases {
let (got_class, got_status) = err.classify();
assert_eq!(
got_class, want_class,
"variant {err:?} classified wrong: got {got_class:?}, want {want_class:?}"
);
assert_eq!(
got_status, want_status,
"variant {err:?} status wrong: got {got_status:?}, want {want_status:?}"
);
}
}
#[test]
fn context_overflow_carries_tokens_and_limit() {
let err = LlmError::ContextOverflow {
tokens: 12_500,
limit: 8_192,
};
let display = err.to_string();
assert!(display.contains("12500"), "display omits tokens: {display}");
assert!(display.contains("8192"), "display omits limit: {display}");
if let LlmError::ContextOverflow { tokens, limit } = err {
assert_eq!(tokens, 12_500);
assert_eq!(limit, 8_192);
} else {
unreachable!("matched variant must extract fields");
}
}
#[test]
fn anyhow_wrap_preserves_typed_error() {
let err: anyhow::Error = LlmError::ServerError { status: 502 }.into();
let downcast = err.downcast_ref::<LlmError>();
assert!(matches!(
downcast,
Some(LlmError::ServerError { status: 502 })
));
}
#[test]
fn display_chain_single_layer_emits_one_line() {
let chain = LlmError::ServerError { status: 503 }.display_chain();
assert_eq!(chain, "server error (status 503)");
assert!(
!chain.contains("caused by"),
"single-layer variant must not include a `caused by` line: {chain}"
);
}
#[test]
fn display_chain_two_layers_renders_caused_by() {
let inner = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "reset by peer");
let err = LlmError::Transport(Box::new(inner));
let chain = err.display_chain();
assert_eq!(chain, "transport\n caused by: reset by peer");
}
#[test]
fn display_chain_three_layers_walks_full_source_tree() {
#[derive(Debug)]
struct MidLayer(Box<dyn std::error::Error + Send + Sync + 'static>);
impl std::fmt::Display for MidLayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "mid layer")
}
}
impl std::error::Error for MidLayer {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(self.0.as_ref())
}
}
let root = std::io::Error::other("bad json byte");
let mid = MidLayer(Box::new(root));
let err = LlmError::Parse(Box::new(mid));
let chain = err.display_chain();
assert_eq!(
chain,
"parse\n caused by: mid layer\n caused by: bad json byte"
);
}
#[test]
fn display_chain_preserves_inline_detail_on_structured_variants() {
let chain = LlmError::ContextOverflow {
tokens: 9000,
limit: 8192,
}
.display_chain();
assert!(chain.contains("9000"), "tokens missing: {chain}");
assert!(chain.contains("8192"), "limit missing: {chain}");
}
}