1use std::fmt;
2
3use thiserror::Error;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct ContextOverflowError {
7 pub provider: String,
8 pub model: Option<String>,
9 pub requested_tokens: Option<u32>,
10 pub max_tokens: Option<u32>,
11 pub message: String,
12}
13
14impl ContextOverflowError {
15 pub fn new(
16 provider: impl Into<String>,
17 model: Option<String>,
18 requested_tokens: Option<u32>,
19 max_tokens: Option<u32>,
20 message: impl Into<String>,
21 ) -> Self {
22 Self { provider: provider.into(), model, requested_tokens, max_tokens, message: message.into() }
23 }
24}
25
26impl fmt::Display for ContextOverflowError {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 let model = self.model.as_deref().unwrap_or("unknown-model");
29 match (self.requested_tokens, self.max_tokens) {
30 (Some(requested), Some(max)) => write!(
31 f,
32 "{} (provider={}, model={}, requested={}, max={})",
33 self.message, self.provider, model, requested, max
34 ),
35 _ => write!(f, "{} (provider={}, model={})", self.message, self.provider, model),
36 }
37 }
38}
39
40#[doc = include_str!("docs/llm_error.md")]
41#[derive(Debug, Error, Clone)]
42pub enum LlmError {
43 #[error("{0} environment variable not set")]
45 MissingApiKey(String),
46 #[error("Invalid API key: {0}")]
48 InvalidApiKey(String),
49 #[error("Failed to create HTTP client: {0}")]
51 HttpClientCreation(String),
52 #[error("API request failed: {0}")]
54 ApiRequest(String),
55 #[error("API error: {0}")]
57 ApiError(String),
58 #[error("Rate limited: {0}")]
60 RateLimited(String),
61 #[error("Server error (status {status:?}): {message}")]
65 ServerError { status: Option<u16>, message: String },
66 #[error("Request timed out: {0}")]
68 Timeout(String),
69 #[error("Network error: {0}")]
71 Network(String),
72 #[error("Stream interrupted: {0}")]
74 StreamInterrupted(String),
75 #[error("Context overflow: {0}")]
77 ContextOverflow(ContextOverflowError),
78 #[error("IO error reading stream: {0}")]
80 IoError(String),
81 #[error("JSON parsing error: {0}")]
83 JsonParsing(String),
84 #[error("Failed to parse tool parameters for {tool_name}: {error}")]
86 ToolParameterParsing { tool_name: String, error: String },
87 #[error("OAuth error: {0}")]
89 OAuthError(String),
90 #[error("Unsupported content: {0}")]
92 UnsupportedContent(String),
93 #[error("{0}")]
95 Other(String),
96}
97
98impl LlmError {
99 pub fn is_retryable(&self) -> bool {
103 matches!(
104 self,
105 LlmError::RateLimited(_)
106 | LlmError::ServerError { .. }
107 | LlmError::Timeout(_)
108 | LlmError::Network(_)
109 | LlmError::StreamInterrupted(_)
110 )
111 }
112}
113
114impl From<reqwest::Error> for LlmError {
115 fn from(error: reqwest::Error) -> Self {
116 if error.is_timeout() {
117 return LlmError::Timeout(error.to_string());
118 }
119 if error.is_connect() || error.is_request() {
120 return LlmError::Network(error.to_string());
121 }
122 match error.status().map(|s| s.as_u16()) {
123 Some(429) => LlmError::RateLimited(error.to_string()),
124 Some(s) if (500..600).contains(&s) => LlmError::ServerError { status: Some(s), message: error.to_string() },
125 _ => LlmError::ApiRequest(error.to_string()),
126 }
127 }
128}
129
130impl From<serde_json::Error> for LlmError {
131 fn from(error: serde_json::Error) -> Self {
132 LlmError::JsonParsing(error.to_string())
133 }
134}
135
136impl From<std::io::Error> for LlmError {
137 fn from(error: std::io::Error) -> Self {
138 LlmError::IoError(error.to_string())
139 }
140}
141
142impl From<reqwest::header::InvalidHeaderValue> for LlmError {
143 fn from(error: reqwest::header::InvalidHeaderValue) -> Self {
144 LlmError::InvalidApiKey(error.to_string())
145 }
146}
147
148impl From<async_openai::error::OpenAIError> for LlmError {
149 fn from(error: async_openai::error::OpenAIError) -> Self {
150 use async_openai::error::OpenAIError;
151 match error {
152 OpenAIError::Reqwest(e) => LlmError::from(e),
153 OpenAIError::StreamError(e) => LlmError::StreamInterrupted(e.to_string()),
154 OpenAIError::ApiError(api_err) => LlmError::ApiError(api_err.to_string()),
155 OpenAIError::JSONDeserialize(e, _) => LlmError::JsonParsing(e.to_string()),
156 OpenAIError::FileSaveError(s) | OpenAIError::FileReadError(s) => LlmError::IoError(s),
157 OpenAIError::InvalidArgument(s) => LlmError::Other(s),
158 }
159 }
160}
161
162#[cfg(feature = "oauth")]
163impl From<crate::oauth::OAuthError> for LlmError {
164 fn from(error: crate::oauth::OAuthError) -> Self {
165 LlmError::OAuthError(error.to_string())
166 }
167}
168
169pub type Result<T> = std::result::Result<T, LlmError>;
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn is_retryable() {
177 assert!(LlmError::RateLimited("rl".into()).is_retryable());
178 assert!(LlmError::ServerError { status: Some(503), message: "x".into() }.is_retryable());
179 assert!(LlmError::ServerError { status: None, message: "stream-level".into() }.is_retryable());
180 assert!(LlmError::Timeout("t".into()).is_retryable());
181 assert!(LlmError::Network("n".into()).is_retryable());
182 assert!(LlmError::StreamInterrupted("s".into()).is_retryable());
183
184 assert!(!LlmError::ApiError("x".into()).is_retryable());
185 assert!(!LlmError::ApiRequest("x".into()).is_retryable());
186 assert!(!LlmError::MissingApiKey("x".into()).is_retryable());
187 assert!(!LlmError::InvalidApiKey("x".into()).is_retryable());
188 assert!(!LlmError::HttpClientCreation("x".into()).is_retryable());
189 assert!(!LlmError::IoError("x".into()).is_retryable());
190 assert!(!LlmError::JsonParsing("x".into()).is_retryable());
191 assert!(!LlmError::ToolParameterParsing { tool_name: "t".into(), error: "e".into() }.is_retryable());
192 assert!(!LlmError::OAuthError("x".into()).is_retryable());
193 assert!(!LlmError::UnsupportedContent("x".into()).is_retryable());
194 assert!(!LlmError::Other("x".into()).is_retryable());
195 assert!(!LlmError::ContextOverflow(ContextOverflowError::new("p", None, None, None, "m")).is_retryable());
196 }
197}