Skip to main content

heartbit_core/llm/
error_class.rs

1//! Error classification for LLM API errors — distinguishes retryable from fatal conditions.
2
3use crate::error::Error;
4
5/// Actionable classification of LLM provider errors.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum ErrorClass {
8    /// The conversation context exceeds the model's context window.
9    ContextOverflow,
10    /// Rate limited (HTTP 429). Already retried by `RetryingProvider`.
11    RateLimited,
12    /// Authentication failure (HTTP 401/403).
13    AuthError,
14    /// Server-side failure (HTTP 500/502/503/529).
15    ServerError,
16    /// Client error that is not overflow (other HTTP 400).
17    InvalidRequest,
18    /// Transport-level failure (`Error::Http`): TCP/DNS/TLS/timeout.
19    /// Treated as transient — the same signal used by `RetryingProvider`.
20    Network,
21    /// Unrecognized error — no actionable recovery.
22    Unknown,
23}
24
25/// Classify an [`Error`] into an actionable [`ErrorClass`].
26///
27/// Primarily useful for `Error::Api` errors where the HTTP status code and
28/// message body determine recovery strategy.
29pub fn classify(error: &Error) -> ErrorClass {
30    // Unwrap WithPartialUsage to classify the inner error.
31    let inner = match error {
32        Error::WithPartialUsage { source, .. } => source.as_ref(),
33        other => other,
34    };
35
36    match inner {
37        Error::Api { status, message } => classify_api(*status, message),
38        Error::Http(_) => ErrorClass::Network,
39        _ => ErrorClass::Unknown,
40    }
41}
42
43fn classify_api(status: u16, message: &str) -> ErrorClass {
44    match status {
45        401 | 403 => ErrorClass::AuthError,
46        429 => ErrorClass::RateLimited,
47        500 | 502 | 503 | 529 => ErrorClass::ServerError,
48        400 => {
49            if is_context_overflow(message) {
50                ErrorClass::ContextOverflow
51            } else {
52                ErrorClass::InvalidRequest
53            }
54        }
55        _ => ErrorClass::Unknown,
56    }
57}
58
59/// Check if an error message indicates context overflow.
60///
61/// Uses case-insensitive substring matching (no regex dependency).
62fn is_context_overflow(message: &str) -> bool {
63    const PATTERNS: &[&str] = &[
64        "prompt is too long",
65        "maximum context length",
66        "context_length_exceeded",
67        "context window",
68        "too many tokens",
69        "input is too long",
70        "exceeds the model's maximum context",
71        "request too large",
72        "content too large",
73    ];
74
75    let lower = message.to_lowercase();
76    PATTERNS.iter().any(|p| lower.contains(p))
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    // --- Auth errors ---
84
85    #[test]
86    fn classify_401_as_auth_error() {
87        let err = Error::Api {
88            status: 401,
89            message: "Unauthorized".into(),
90        };
91        assert_eq!(classify(&err), ErrorClass::AuthError);
92    }
93
94    #[test]
95    fn classify_403_as_auth_error() {
96        let err = Error::Api {
97            status: 403,
98            message: "Forbidden".into(),
99        };
100        assert_eq!(classify(&err), ErrorClass::AuthError);
101    }
102
103    // --- Rate limited ---
104
105    #[test]
106    fn classify_429_as_rate_limited() {
107        let err = Error::Api {
108            status: 429,
109            message: "Too Many Requests".into(),
110        };
111        assert_eq!(classify(&err), ErrorClass::RateLimited);
112    }
113
114    // --- Server errors ---
115
116    #[test]
117    fn classify_500_as_server_error() {
118        let err = Error::Api {
119            status: 500,
120            message: "Internal Server Error".into(),
121        };
122        assert_eq!(classify(&err), ErrorClass::ServerError);
123    }
124
125    #[test]
126    fn classify_502_as_server_error() {
127        let err = Error::Api {
128            status: 502,
129            message: "Bad Gateway".into(),
130        };
131        assert_eq!(classify(&err), ErrorClass::ServerError);
132    }
133
134    #[test]
135    fn classify_503_as_server_error() {
136        let err = Error::Api {
137            status: 503,
138            message: "Service Unavailable".into(),
139        };
140        assert_eq!(classify(&err), ErrorClass::ServerError);
141    }
142
143    #[test]
144    fn classify_529_as_server_error() {
145        let err = Error::Api {
146            status: 529,
147            message: "Overloaded".into(),
148        };
149        assert_eq!(classify(&err), ErrorClass::ServerError);
150    }
151
152    // --- Context overflow (400 with overflow patterns) ---
153
154    #[test]
155    fn classify_400_prompt_too_long() {
156        let err = Error::Api {
157            status: 400,
158            message: "prompt is too long".into(),
159        };
160        assert_eq!(classify(&err), ErrorClass::ContextOverflow);
161    }
162
163    #[test]
164    fn classify_400_maximum_context_length() {
165        let err = Error::Api {
166            status: 400,
167            message: "This request exceeds the maximum context length".into(),
168        };
169        assert_eq!(classify(&err), ErrorClass::ContextOverflow);
170    }
171
172    #[test]
173    fn classify_400_context_length_exceeded() {
174        let err = Error::Api {
175            status: 400,
176            message: "context_length_exceeded".into(),
177        };
178        assert_eq!(classify(&err), ErrorClass::ContextOverflow);
179    }
180
181    #[test]
182    fn classify_400_request_too_large() {
183        let err = Error::Api {
184            status: 400,
185            message: "request too large for this model".into(),
186        };
187        assert_eq!(classify(&err), ErrorClass::ContextOverflow);
188    }
189
190    #[test]
191    fn classify_400_content_too_large() {
192        let err = Error::Api {
193            status: 400,
194            message: "content too large".into(),
195        };
196        assert_eq!(classify(&err), ErrorClass::ContextOverflow);
197    }
198
199    /// `max_tokens` in a 400 message can mean parameter validation (e.g.,
200    /// "max_tokens: 4096 must be less than ..."), not context overflow.
201    /// We should NOT classify it as ContextOverflow.
202    #[test]
203    fn classify_400_max_tokens_parameter_is_not_overflow() {
204        let err = Error::Api {
205            status: 400,
206            message: "max_tokens: 4096 must be less than 2048".into(),
207        };
208        assert_eq!(classify(&err), ErrorClass::InvalidRequest);
209    }
210
211    #[test]
212    fn classify_400_context_window() {
213        let err = Error::Api {
214            status: 400,
215            message: "exceeds the context window".into(),
216        };
217        assert_eq!(classify(&err), ErrorClass::ContextOverflow);
218    }
219
220    #[test]
221    fn classify_400_too_many_tokens() {
222        let err = Error::Api {
223            status: 400,
224            message: "too many tokens in the request".into(),
225        };
226        assert_eq!(classify(&err), ErrorClass::ContextOverflow);
227    }
228
229    #[test]
230    fn classify_400_input_too_long() {
231        let err = Error::Api {
232            status: 400,
233            message: "input is too long for model".into(),
234        };
235        assert_eq!(classify(&err), ErrorClass::ContextOverflow);
236    }
237
238    #[test]
239    fn classify_400_exceeds_model_maximum_context() {
240        let err = Error::Api {
241            status: 400,
242            message: "exceeds the model's maximum context length".into(),
243        };
244        assert_eq!(classify(&err), ErrorClass::ContextOverflow);
245    }
246
247    #[test]
248    fn classify_400_case_insensitive() {
249        let err = Error::Api {
250            status: 400,
251            message: "PROMPT IS TOO LONG".into(),
252        };
253        assert_eq!(classify(&err), ErrorClass::ContextOverflow);
254    }
255
256    // --- Invalid request (400 without overflow pattern) ---
257
258    #[test]
259    fn classify_400_generic_as_invalid_request() {
260        let err = Error::Api {
261            status: 400,
262            message: "invalid parameter: temperature must be between 0 and 1".into(),
263        };
264        assert_eq!(classify(&err), ErrorClass::InvalidRequest);
265    }
266
267    // --- HTTP / network errors ---
268
269    #[test]
270    fn classify_http_error_as_network() {
271        // Build a reqwest error by making a request to an invalid URL.
272        let rt = tokio::runtime::Builder::new_current_thread()
273            .enable_all()
274            .build()
275            .expect("test runtime");
276        let http_err = rt
277            .block_on(reqwest::get("http://[::0]:1"))
278            .expect_err("should fail");
279        let err = Error::Http(http_err);
280        assert_eq!(classify(&err), ErrorClass::Network);
281    }
282
283    // --- Other error variants ---
284
285    #[test]
286    fn classify_agent_error_as_unknown() {
287        let err = Error::Agent("something went wrong".into());
288        assert_eq!(classify(&err), ErrorClass::Unknown);
289    }
290
291    #[test]
292    fn classify_max_turns_exceeded_as_unknown() {
293        let err = Error::MaxTurnsExceeded(10);
294        assert_eq!(classify(&err), ErrorClass::Unknown);
295    }
296
297    #[test]
298    fn classify_truncated_as_unknown() {
299        let err = Error::Truncated;
300        assert_eq!(classify(&err), ErrorClass::Unknown);
301    }
302
303    #[test]
304    fn classify_config_error_as_unknown() {
305        let err = Error::Config("bad config".into());
306        assert_eq!(classify(&err), ErrorClass::Unknown);
307    }
308
309    #[test]
310    fn classify_mcp_error_as_unknown() {
311        let err = Error::Mcp("connection refused".into());
312        assert_eq!(classify(&err), ErrorClass::Unknown);
313    }
314
315    // --- WithPartialUsage unwrapping ---
316
317    #[test]
318    fn classify_unwraps_with_partial_usage() {
319        use crate::llm::types::TokenUsage;
320
321        let inner = Error::Api {
322            status: 429,
323            message: "rate limited".into(),
324        };
325        let wrapped = inner.with_partial_usage(TokenUsage {
326            input_tokens: 100,
327            output_tokens: 50,
328            ..Default::default()
329        });
330        assert_eq!(classify(&wrapped), ErrorClass::RateLimited);
331    }
332
333    #[test]
334    fn classify_unwraps_partial_usage_context_overflow() {
335        use crate::llm::types::TokenUsage;
336
337        let inner = Error::Api {
338            status: 400,
339            message: "prompt is too long".into(),
340        };
341        let wrapped = inner.with_partial_usage(TokenUsage::default());
342        assert_eq!(classify(&wrapped), ErrorClass::ContextOverflow);
343    }
344
345    // --- Unknown status codes ---
346
347    #[test]
348    fn classify_unknown_status_as_unknown() {
349        let err = Error::Api {
350            status: 418,
351            message: "I'm a teapot".into(),
352        };
353        assert_eq!(classify(&err), ErrorClass::Unknown);
354    }
355}