Skip to main content

llm/
error.rs

1use std::fmt;
2
3use thiserror::Error;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct ContextOverflowError {
7    pub provider: String,
8    pub model: Option<String>,
9    pub requested_tokens: Option<u32>,
10    pub max_tokens: Option<u32>,
11    pub message: String,
12}
13
14impl ContextOverflowError {
15    pub fn new(
16        provider: impl Into<String>,
17        model: Option<String>,
18        requested_tokens: Option<u32>,
19        max_tokens: Option<u32>,
20        message: impl Into<String>,
21    ) -> Self {
22        Self { provider: provider.into(), model, requested_tokens, max_tokens, message: message.into() }
23    }
24}
25
26impl fmt::Display for ContextOverflowError {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        let model = self.model.as_deref().unwrap_or("unknown-model");
29        match (self.requested_tokens, self.max_tokens) {
30            (Some(requested), Some(max)) => write!(
31                f,
32                "{} (provider={}, model={}, requested={}, max={})",
33                self.message, self.provider, model, requested, max
34            ),
35            _ => write!(f, "{} (provider={}, model={})", self.message, self.provider, model),
36        }
37    }
38}
39
40#[doc = include_str!("docs/llm_error.md")]
41#[derive(Debug, Error, Clone)]
42pub enum LlmError {
43    /// Environment variable not set or invalid
44    #[error("{0} environment variable not set")]
45    MissingApiKey(String),
46    /// Invalid API key format
47    #[error("Invalid API key: {0}")]
48    InvalidApiKey(String),
49    /// HTTP client creation failed
50    #[error("Failed to create HTTP client: {0}")]
51    HttpClientCreation(String),
52    /// API request failed
53    #[error("API request failed: {0}")]
54    ApiRequest(String),
55    /// API returned an error response
56    #[error("API error: {0}")]
57    ApiError(String),
58    /// HTTP 429 / provider-flagged rate limit. Retryable.
59    #[error("Rate limited: {0}")]
60    RateLimited(String),
61    /// HTTP 5xx or provider-flagged server error. Retryable. `status` is
62    /// `None` when the signal originates from a stream-level event (e.g.
63    /// Anthropic SSE `overloaded_error`) rather than an HTTP response.
64    #[error("Server error (status {status:?}): {message}")]
65    ServerError { status: Option<u16>, message: String },
66    /// Request timeout (no bytes received within client deadline). Retryable.
67    #[error("Request timed out: {0}")]
68    Timeout(String),
69    /// Transport-level connection failure (DNS, TCP reset, TLS, request build). Retryable.
70    #[error("Network error: {0}")]
71    Network(String),
72    /// Stream began but errored or terminated prematurely. Retryable.
73    #[error("Stream interrupted: {0}")]
74    StreamInterrupted(String),
75    /// API rejected the request because the prompt exceeded the model's context window.
76    #[error("Context overflow: {0}")]
77    ContextOverflow(ContextOverflowError),
78    /// IO error while reading stream
79    #[error("IO error reading stream: {0}")]
80    IoError(String),
81    /// JSON parsing/serialization error
82    #[error("JSON parsing error: {0}")]
83    JsonParsing(String),
84    /// Tool parameter parsing error
85    #[error("Failed to parse tool parameters for {tool_name}: {error}")]
86    ToolParameterParsing { tool_name: String, error: String },
87    /// OAuth authentication error
88    #[error("OAuth error: {0}")]
89    OAuthError(String),
90    /// The message contained only content types this provider doesn't support
91    #[error("Unsupported content: {0}")]
92    UnsupportedContent(String),
93    /// Generic error for other cases
94    #[error("{0}")]
95    Other(String),
96}
97
98impl LlmError {
99    /// Whether this error class is worth retrying. Transient transport / server
100    /// failures return `true`; permanent failures (auth, schema, context size)
101    /// return `false`.
102    pub fn is_retryable(&self) -> bool {
103        matches!(
104            self,
105            LlmError::RateLimited(_)
106                | LlmError::ServerError { .. }
107                | LlmError::Timeout(_)
108                | LlmError::Network(_)
109                | LlmError::StreamInterrupted(_)
110        )
111    }
112}
113
114impl From<reqwest::Error> for LlmError {
115    fn from(error: reqwest::Error) -> Self {
116        if error.is_timeout() {
117            return LlmError::Timeout(error.to_string());
118        }
119        if error.is_connect() || error.is_request() {
120            return LlmError::Network(error.to_string());
121        }
122        match error.status().map(|s| s.as_u16()) {
123            Some(429) => LlmError::RateLimited(error.to_string()),
124            Some(s) if (500..600).contains(&s) => LlmError::ServerError { status: Some(s), message: error.to_string() },
125            _ => LlmError::ApiRequest(error.to_string()),
126        }
127    }
128}
129
130impl From<serde_json::Error> for LlmError {
131    fn from(error: serde_json::Error) -> Self {
132        LlmError::JsonParsing(error.to_string())
133    }
134}
135
136impl From<std::io::Error> for LlmError {
137    fn from(error: std::io::Error) -> Self {
138        LlmError::IoError(error.to_string())
139    }
140}
141
142impl From<reqwest::header::InvalidHeaderValue> for LlmError {
143    fn from(error: reqwest::header::InvalidHeaderValue) -> Self {
144        LlmError::InvalidApiKey(error.to_string())
145    }
146}
147
148impl From<async_openai::error::OpenAIError> for LlmError {
149    fn from(error: async_openai::error::OpenAIError) -> Self {
150        use async_openai::error::OpenAIError;
151        match error {
152            OpenAIError::Reqwest(e) => LlmError::from(e),
153            OpenAIError::StreamError(e) => LlmError::StreamInterrupted(e.to_string()),
154            OpenAIError::ApiError(api_err) => LlmError::ApiError(api_err.to_string()),
155            OpenAIError::JSONDeserialize(e, _) => LlmError::JsonParsing(e.to_string()),
156            OpenAIError::FileSaveError(s) | OpenAIError::FileReadError(s) => LlmError::IoError(s),
157            OpenAIError::InvalidArgument(s) => LlmError::Other(s),
158        }
159    }
160}
161
162#[cfg(feature = "oauth")]
163impl From<crate::oauth::OAuthError> for LlmError {
164    fn from(error: crate::oauth::OAuthError) -> Self {
165        LlmError::OAuthError(error.to_string())
166    }
167}
168
169pub type Result<T> = std::result::Result<T, LlmError>;
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn is_retryable() {
177        assert!(LlmError::RateLimited("rl".into()).is_retryable());
178        assert!(LlmError::ServerError { status: Some(503), message: "x".into() }.is_retryable());
179        assert!(LlmError::ServerError { status: None, message: "stream-level".into() }.is_retryable());
180        assert!(LlmError::Timeout("t".into()).is_retryable());
181        assert!(LlmError::Network("n".into()).is_retryable());
182        assert!(LlmError::StreamInterrupted("s".into()).is_retryable());
183
184        assert!(!LlmError::ApiError("x".into()).is_retryable());
185        assert!(!LlmError::ApiRequest("x".into()).is_retryable());
186        assert!(!LlmError::MissingApiKey("x".into()).is_retryable());
187        assert!(!LlmError::InvalidApiKey("x".into()).is_retryable());
188        assert!(!LlmError::HttpClientCreation("x".into()).is_retryable());
189        assert!(!LlmError::IoError("x".into()).is_retryable());
190        assert!(!LlmError::JsonParsing("x".into()).is_retryable());
191        assert!(!LlmError::ToolParameterParsing { tool_name: "t".into(), error: "e".into() }.is_retryable());
192        assert!(!LlmError::OAuthError("x".into()).is_retryable());
193        assert!(!LlmError::UnsupportedContent("x".into()).is_retryable());
194        assert!(!LlmError::Other("x".into()).is_retryable());
195        assert!(!LlmError::ContextOverflow(ContextOverflowError::new("p", None, None, None, "m")).is_retryable());
196    }
197}