1use crate::error::Error;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum ErrorClass {
8 ContextOverflow,
10 RateLimited,
12 AuthError,
14 ServerError,
16 InvalidRequest,
18 Network,
21 Unknown,
23}
24
25pub fn classify(error: &Error) -> ErrorClass {
30 let inner = match error {
32 Error::WithPartialUsage { source, .. } => source.as_ref(),
33 other => other,
34 };
35
36 match inner {
37 Error::Api { status, message } => classify_api(*status, message),
38 Error::Http(_) => ErrorClass::Network,
39 _ => ErrorClass::Unknown,
40 }
41}
42
43fn classify_api(status: u16, message: &str) -> ErrorClass {
44 match status {
45 401 | 403 => ErrorClass::AuthError,
46 429 => ErrorClass::RateLimited,
47 500 | 502 | 503 | 529 => ErrorClass::ServerError,
48 400 => {
49 if is_context_overflow(message) {
50 ErrorClass::ContextOverflow
51 } else {
52 ErrorClass::InvalidRequest
53 }
54 }
55 _ => ErrorClass::Unknown,
56 }
57}
58
59fn is_context_overflow(message: &str) -> bool {
63 const PATTERNS: &[&str] = &[
64 "prompt is too long",
65 "maximum context length",
66 "context_length_exceeded",
67 "context window",
68 "too many tokens",
69 "input is too long",
70 "exceeds the model's maximum context",
71 "request too large",
72 "content too large",
73 ];
74
75 let lower = message.to_lowercase();
76 PATTERNS.iter().any(|p| lower.contains(p))
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82
83 #[test]
86 fn classify_401_as_auth_error() {
87 let err = Error::Api {
88 status: 401,
89 message: "Unauthorized".into(),
90 };
91 assert_eq!(classify(&err), ErrorClass::AuthError);
92 }
93
94 #[test]
95 fn classify_403_as_auth_error() {
96 let err = Error::Api {
97 status: 403,
98 message: "Forbidden".into(),
99 };
100 assert_eq!(classify(&err), ErrorClass::AuthError);
101 }
102
103 #[test]
106 fn classify_429_as_rate_limited() {
107 let err = Error::Api {
108 status: 429,
109 message: "Too Many Requests".into(),
110 };
111 assert_eq!(classify(&err), ErrorClass::RateLimited);
112 }
113
114 #[test]
117 fn classify_500_as_server_error() {
118 let err = Error::Api {
119 status: 500,
120 message: "Internal Server Error".into(),
121 };
122 assert_eq!(classify(&err), ErrorClass::ServerError);
123 }
124
125 #[test]
126 fn classify_502_as_server_error() {
127 let err = Error::Api {
128 status: 502,
129 message: "Bad Gateway".into(),
130 };
131 assert_eq!(classify(&err), ErrorClass::ServerError);
132 }
133
134 #[test]
135 fn classify_503_as_server_error() {
136 let err = Error::Api {
137 status: 503,
138 message: "Service Unavailable".into(),
139 };
140 assert_eq!(classify(&err), ErrorClass::ServerError);
141 }
142
143 #[test]
144 fn classify_529_as_server_error() {
145 let err = Error::Api {
146 status: 529,
147 message: "Overloaded".into(),
148 };
149 assert_eq!(classify(&err), ErrorClass::ServerError);
150 }
151
152 #[test]
155 fn classify_400_prompt_too_long() {
156 let err = Error::Api {
157 status: 400,
158 message: "prompt is too long".into(),
159 };
160 assert_eq!(classify(&err), ErrorClass::ContextOverflow);
161 }
162
163 #[test]
164 fn classify_400_maximum_context_length() {
165 let err = Error::Api {
166 status: 400,
167 message: "This request exceeds the maximum context length".into(),
168 };
169 assert_eq!(classify(&err), ErrorClass::ContextOverflow);
170 }
171
172 #[test]
173 fn classify_400_context_length_exceeded() {
174 let err = Error::Api {
175 status: 400,
176 message: "context_length_exceeded".into(),
177 };
178 assert_eq!(classify(&err), ErrorClass::ContextOverflow);
179 }
180
181 #[test]
182 fn classify_400_request_too_large() {
183 let err = Error::Api {
184 status: 400,
185 message: "request too large for this model".into(),
186 };
187 assert_eq!(classify(&err), ErrorClass::ContextOverflow);
188 }
189
190 #[test]
191 fn classify_400_content_too_large() {
192 let err = Error::Api {
193 status: 400,
194 message: "content too large".into(),
195 };
196 assert_eq!(classify(&err), ErrorClass::ContextOverflow);
197 }
198
199 #[test]
203 fn classify_400_max_tokens_parameter_is_not_overflow() {
204 let err = Error::Api {
205 status: 400,
206 message: "max_tokens: 4096 must be less than 2048".into(),
207 };
208 assert_eq!(classify(&err), ErrorClass::InvalidRequest);
209 }
210
211 #[test]
212 fn classify_400_context_window() {
213 let err = Error::Api {
214 status: 400,
215 message: "exceeds the context window".into(),
216 };
217 assert_eq!(classify(&err), ErrorClass::ContextOverflow);
218 }
219
220 #[test]
221 fn classify_400_too_many_tokens() {
222 let err = Error::Api {
223 status: 400,
224 message: "too many tokens in the request".into(),
225 };
226 assert_eq!(classify(&err), ErrorClass::ContextOverflow);
227 }
228
229 #[test]
230 fn classify_400_input_too_long() {
231 let err = Error::Api {
232 status: 400,
233 message: "input is too long for model".into(),
234 };
235 assert_eq!(classify(&err), ErrorClass::ContextOverflow);
236 }
237
238 #[test]
239 fn classify_400_exceeds_model_maximum_context() {
240 let err = Error::Api {
241 status: 400,
242 message: "exceeds the model's maximum context length".into(),
243 };
244 assert_eq!(classify(&err), ErrorClass::ContextOverflow);
245 }
246
247 #[test]
248 fn classify_400_case_insensitive() {
249 let err = Error::Api {
250 status: 400,
251 message: "PROMPT IS TOO LONG".into(),
252 };
253 assert_eq!(classify(&err), ErrorClass::ContextOverflow);
254 }
255
256 #[test]
259 fn classify_400_generic_as_invalid_request() {
260 let err = Error::Api {
261 status: 400,
262 message: "invalid parameter: temperature must be between 0 and 1".into(),
263 };
264 assert_eq!(classify(&err), ErrorClass::InvalidRequest);
265 }
266
267 #[test]
270 fn classify_http_error_as_network() {
271 let rt = tokio::runtime::Builder::new_current_thread()
273 .enable_all()
274 .build()
275 .expect("test runtime");
276 let http_err = rt
277 .block_on(reqwest::get("http://[::0]:1"))
278 .expect_err("should fail");
279 let err = Error::Http(http_err);
280 assert_eq!(classify(&err), ErrorClass::Network);
281 }
282
283 #[test]
286 fn classify_agent_error_as_unknown() {
287 let err = Error::Agent("something went wrong".into());
288 assert_eq!(classify(&err), ErrorClass::Unknown);
289 }
290
291 #[test]
292 fn classify_max_turns_exceeded_as_unknown() {
293 let err = Error::MaxTurnsExceeded(10);
294 assert_eq!(classify(&err), ErrorClass::Unknown);
295 }
296
297 #[test]
298 fn classify_truncated_as_unknown() {
299 let err = Error::Truncated;
300 assert_eq!(classify(&err), ErrorClass::Unknown);
301 }
302
303 #[test]
304 fn classify_config_error_as_unknown() {
305 let err = Error::Config("bad config".into());
306 assert_eq!(classify(&err), ErrorClass::Unknown);
307 }
308
309 #[test]
310 fn classify_mcp_error_as_unknown() {
311 let err = Error::Mcp("connection refused".into());
312 assert_eq!(classify(&err), ErrorClass::Unknown);
313 }
314
315 #[test]
318 fn classify_unwraps_with_partial_usage() {
319 use crate::llm::types::TokenUsage;
320
321 let inner = Error::Api {
322 status: 429,
323 message: "rate limited".into(),
324 };
325 let wrapped = inner.with_partial_usage(TokenUsage {
326 input_tokens: 100,
327 output_tokens: 50,
328 ..Default::default()
329 });
330 assert_eq!(classify(&wrapped), ErrorClass::RateLimited);
331 }
332
333 #[test]
334 fn classify_unwraps_partial_usage_context_overflow() {
335 use crate::llm::types::TokenUsage;
336
337 let inner = Error::Api {
338 status: 400,
339 message: "prompt is too long".into(),
340 };
341 let wrapped = inner.with_partial_usage(TokenUsage::default());
342 assert_eq!(classify(&wrapped), ErrorClass::ContextOverflow);
343 }
344
345 #[test]
348 fn classify_unknown_status_as_unknown() {
349 let err = Error::Api {
350 status: 418,
351 message: "I'm a teapot".into(),
352 };
353 assert_eq!(classify(&err), ErrorClass::Unknown);
354 }
355}