use std::fmt;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ContextOverflowError {
pub provider: String,
pub model: Option<String>,
pub requested_tokens: Option<u32>,
pub max_tokens: Option<u32>,
pub message: String,
}
impl ContextOverflowError {
pub fn new(
provider: impl Into<String>,
model: Option<String>,
requested_tokens: Option<u32>,
max_tokens: Option<u32>,
message: impl Into<String>,
) -> Self {
Self { provider: provider.into(), model, requested_tokens, max_tokens, message: message.into() }
}
}
impl fmt::Display for ContextOverflowError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let model = self.model.as_deref().unwrap_or("unknown-model");
match (self.requested_tokens, self.max_tokens) {
(Some(requested), Some(max)) => write!(
f,
"{} (provider={}, model={}, requested={}, max={})",
self.message, self.provider, model, requested, max
),
_ => write!(f, "{} (provider={}, model={})", self.message, self.provider, model),
}
}
}
#[doc = include_str!("docs/llm_error.md")]
#[derive(Debug, Error, Clone)]
pub enum LlmError {
#[error("{0} environment variable not set")]
MissingApiKey(String),
#[error("Invalid API key: {0}")]
InvalidApiKey(String),
#[error("Failed to create HTTP client: {0}")]
HttpClientCreation(String),
#[error("API request failed: {0}")]
ApiRequest(String),
#[error("API error: {0}")]
ApiError(String),
#[error("Rate limited: {0}")]
RateLimited(String),
#[error("Server error (status {status:?}): {message}")]
ServerError { status: Option<u16>, message: String },
#[error("Request timed out: {0}")]
Timeout(String),
#[error("Network error: {0}")]
Network(String),
#[error("Stream interrupted: {0}")]
StreamInterrupted(String),
#[error("Context overflow: {0}")]
ContextOverflow(ContextOverflowError),
#[error("IO error reading stream: {0}")]
IoError(String),
#[error("JSON parsing error: {0}")]
JsonParsing(String),
#[error("Failed to parse tool parameters for {tool_name}: {error}")]
ToolParameterParsing { tool_name: String, error: String },
#[error("OAuth error: {0}")]
OAuthError(String),
#[error("Unsupported content: {0}")]
UnsupportedContent(String),
#[error("{0}")]
Other(String),
}
impl LlmError {
pub fn is_retryable(&self) -> bool {
matches!(
self,
LlmError::RateLimited(_)
| LlmError::ServerError { .. }
| LlmError::Timeout(_)
| LlmError::Network(_)
| LlmError::StreamInterrupted(_)
)
}
}
impl From<reqwest::Error> for LlmError {
fn from(error: reqwest::Error) -> Self {
if error.is_timeout() {
return LlmError::Timeout(error.to_string());
}
if error.is_connect() || error.is_request() {
return LlmError::Network(error.to_string());
}
match error.status().map(|s| s.as_u16()) {
Some(429) => LlmError::RateLimited(error.to_string()),
Some(s) if (500..600).contains(&s) => LlmError::ServerError { status: Some(s), message: error.to_string() },
_ => LlmError::ApiRequest(error.to_string()),
}
}
}
impl From<serde_json::Error> for LlmError {
fn from(error: serde_json::Error) -> Self {
LlmError::JsonParsing(error.to_string())
}
}
impl From<std::io::Error> for LlmError {
fn from(error: std::io::Error) -> Self {
LlmError::IoError(error.to_string())
}
}
impl From<reqwest::header::InvalidHeaderValue> for LlmError {
fn from(error: reqwest::header::InvalidHeaderValue) -> Self {
LlmError::InvalidApiKey(error.to_string())
}
}
impl From<async_openai::error::OpenAIError> for LlmError {
fn from(error: async_openai::error::OpenAIError) -> Self {
use async_openai::error::OpenAIError;
match error {
OpenAIError::Reqwest(e) => LlmError::from(e),
OpenAIError::StreamError(e) => LlmError::StreamInterrupted(e.to_string()),
OpenAIError::ApiError(api_err) => LlmError::ApiError(api_err.to_string()),
OpenAIError::JSONDeserialize(e, _) => LlmError::JsonParsing(e.to_string()),
OpenAIError::FileSaveError(s) | OpenAIError::FileReadError(s) => LlmError::IoError(s),
OpenAIError::InvalidArgument(s) => LlmError::Other(s),
}
}
}
#[cfg(feature = "oauth")]
impl From<crate::oauth::OAuthError> for LlmError {
fn from(error: crate::oauth::OAuthError) -> Self {
LlmError::OAuthError(error.to_string())
}
}
pub type Result<T> = std::result::Result<T, LlmError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn is_retryable() {
assert!(LlmError::RateLimited("rl".into()).is_retryable());
assert!(LlmError::ServerError { status: Some(503), message: "x".into() }.is_retryable());
assert!(LlmError::ServerError { status: None, message: "stream-level".into() }.is_retryable());
assert!(LlmError::Timeout("t".into()).is_retryable());
assert!(LlmError::Network("n".into()).is_retryable());
assert!(LlmError::StreamInterrupted("s".into()).is_retryable());
assert!(!LlmError::ApiError("x".into()).is_retryable());
assert!(!LlmError::ApiRequest("x".into()).is_retryable());
assert!(!LlmError::MissingApiKey("x".into()).is_retryable());
assert!(!LlmError::InvalidApiKey("x".into()).is_retryable());
assert!(!LlmError::HttpClientCreation("x".into()).is_retryable());
assert!(!LlmError::IoError("x".into()).is_retryable());
assert!(!LlmError::JsonParsing("x".into()).is_retryable());
assert!(!LlmError::ToolParameterParsing { tool_name: "t".into(), error: "e".into() }.is_retryable());
assert!(!LlmError::OAuthError("x".into()).is_retryable());
assert!(!LlmError::UnsupportedContent("x".into()).is_retryable());
assert!(!LlmError::Other("x".into()).is_retryable());
assert!(!LlmError::ContextOverflow(ContextOverflowError::new("p", None, None, None, "m")).is_retryable());
}
}