use std::time::Duration;
use camel_api::CamelError;
#[derive(Debug, Clone, thiserror::Error)]
#[non_exhaustive]
pub enum LlmError {
#[error("rate limited by provider")]
RateLimit {
retry_after: Option<Duration>,
},
#[error("quota exceeded: {detail}")]
QuotaExceeded {
detail: String,
},
#[error("context window exceeded: {max_tokens} tokens max")]
ContextLengthExceeded {
max_tokens: u32,
},
#[error("authentication failed: {detail}")]
AuthFailed {
detail: String,
},
#[error("model not found: {0}")]
ModelNotFound(String),
#[error("model unavailable: {0}")]
ModelUnavailable(String),
#[error("provider unavailable: {0}")]
ProviderUnavailable(String),
#[error("content filtered by safety policy: {detail}")]
ContentFiltered {
detail: String,
},
#[error("network error: {0}")]
Network(String),
#[error("timeout after {0:?}")]
Timeout(Duration),
#[error("invalid request: {0}")]
InvalidRequest(String),
#[error("capability not supported: {0}")]
UnsupportedCapability(String),
#[error("malformed provider response: {0}")]
Protocol(String),
#[error("stream interrupted: {0}")]
StreamInterrupted(String),
#[error("provider error: {0}")]
Provider(String),
}
const MAX_PROVIDER_ERROR_BYTES: usize = 200;
fn truncate_for_display(msg: &str) -> String {
if msg.len() <= MAX_PROVIDER_ERROR_BYTES {
msg.to_string()
} else {
let cut = msg.floor_char_boundary(MAX_PROVIDER_ERROR_BYTES);
format!("{}...[truncated]", &msg[..cut])
}
}
impl LlmError {
pub fn provider(msg: impl Into<String>) -> Self {
LlmError::Provider(truncate_for_display(&msg.into()))
}
}
pub fn is_retryable(err: &LlmError) -> bool {
matches!(
err,
LlmError::RateLimit { .. }
| LlmError::Network(_)
| LlmError::ProviderUnavailable(_)
| LlmError::ModelUnavailable(_)
| LlmError::Timeout(_)
)
}
impl From<LlmError> for CamelError {
fn from(e: LlmError) -> Self {
match &e {
LlmError::AuthFailed { .. } => CamelError::Unauthenticated(e.to_string()),
LlmError::Network(_) | LlmError::ProviderUnavailable(_) => {
CamelError::Io(e.to_string())
}
_ => CamelError::ProcessorError(e.to_string()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn retryable_errors() {
assert!(is_retryable(&LlmError::Network("conn reset".into())));
assert!(is_retryable(&LlmError::Timeout(Duration::from_secs(30))));
assert!(is_retryable(&LlmError::RateLimit { retry_after: None }));
assert!(is_retryable(&LlmError::ProviderUnavailable("503".into())));
assert!(is_retryable(&LlmError::ModelUnavailable(
"overloaded".into()
)));
}
#[test]
fn non_retryable_errors() {
assert!(!is_retryable(&LlmError::AuthFailed {
detail: "bad key".into()
}));
assert!(!is_retryable(&LlmError::QuotaExceeded {
detail: "billing".into()
}));
assert!(!is_retryable(&LlmError::ContextLengthExceeded {
max_tokens: 4096
}));
assert!(!is_retryable(&LlmError::ModelNotFound("gpt-99".into())));
assert!(!is_retryable(&LlmError::ContentFiltered {
detail: "safety".into()
}));
assert!(!is_retryable(&LlmError::InvalidRequest("bad json".into())));
assert!(!is_retryable(&LlmError::Protocol("decode".into())));
assert!(!is_retryable(&LlmError::StreamInterrupted(
"dropped".into()
)));
assert!(!is_retryable(&LlmError::UnsupportedCapability(
"embed".into()
)));
assert!(!is_retryable(&LlmError::Provider("generic error".into())));
}
#[test]
fn converts_to_camel_error() {
let err: camel_api::CamelError = LlmError::Timeout(Duration::from_secs(5)).into();
assert!(err.to_string().contains("timeout"));
assert!(matches!(
CamelError::from(LlmError::Network("conn".into())),
CamelError::Io(_)
));
assert!(matches!(
CamelError::from(LlmError::ProviderUnavailable("503".into())),
CamelError::Io(_)
));
assert!(matches!(
CamelError::from(LlmError::AuthFailed {
detail: "bad key".into()
}),
CamelError::Unauthenticated(_)
));
assert!(matches!(
CamelError::from(LlmError::InvalidRequest("bad json".into())),
CamelError::ProcessorError(_)
));
}
#[test]
fn provider_constructor_truncates_long_messages() {
let long = "x".repeat(300);
let err = LlmError::provider(long);
let display = err.to_string();
assert!(display.contains("provider error"));
assert!(display.contains("[truncated]"));
assert!(
display.len() <= 200 + 40,
"display too long: {} bytes",
display.len()
);
}
}