Skip to main content

llm_stack/
error.rs

1//! Unified error type for all LLM operations.
2//!
3//! Every provider maps its native errors into [`LlmError`], giving
4//! callers a single type to match against regardless of which backend
5//! is in use. Variants carry enough context for retry logic, user-facing
6//! messages, and diagnostics.
7//!
8//! # Retryability
9//!
10//! Several variants include a `retryable` flag that providers set based
11//! on the upstream response (e.g. HTTP 429 or 503). Middleware layers
12//! can inspect this flag to decide whether to retry automatically:
13//!
14//! ```rust
15//! use llm_stack::LlmError;
16//!
17//! fn should_retry(err: &LlmError) -> bool {
18//!     match err {
19//!         LlmError::Http { retryable, .. } => *retryable,
20//!         LlmError::Provider { retryable, .. } => *retryable,
21//!         LlmError::Timeout { .. } => true,
22//!         _ => false,
23//!     }
24//! }
25//! ```
26
27use serde_json::Value;
28
29/// The unified error type returned by all provider operations.
30///
31/// Variants are `#[non_exhaustive]` — new error kinds may be added in
32/// minor releases without breaking downstream matches (always include a
33/// wildcard arm).
34#[derive(Debug, thiserror::Error)]
35#[non_exhaustive]
36pub enum LlmError {
37    /// An HTTP-level failure (transport error, unexpected status code).
38    ///
39    /// `status` is `None` when the request never received a response
40    /// (e.g. DNS failure, connection reset).
41    #[error("HTTP error (status={status:?}): {message}")]
42    Http {
43        /// The HTTP status code, if one was received.
44        status: Option<http::StatusCode>,
45        /// A human-readable description of the failure.
46        message: String,
47        /// Whether the caller should retry this request.
48        retryable: bool,
49    },
50
51    /// The API key or token was rejected.
52    #[error("Authentication error: {0}")]
53    Auth(String),
54
55    /// The request was malformed (missing fields, invalid parameters).
56    #[error("Invalid request: {0}")]
57    InvalidRequest(String),
58
59    /// A provider-specific error that doesn't map to another variant.
60    #[error("Provider error ({code}): {message}")]
61    Provider {
62        /// Provider-defined error code (e.g. `"overloaded"`).
63        code: String,
64        /// Human-readable error description.
65        message: String,
66        /// Whether the caller should retry this request.
67        retryable: bool,
68    },
69
70    /// The response body could not be parsed.
71    #[error("Response format error: {message}")]
72    ResponseFormat {
73        /// What went wrong during parsing.
74        message: String,
75        /// The raw response body, for diagnostics.
76        raw: String,
77    },
78
79    /// A structured-output response failed JSON Schema validation.
80    #[error("Schema validation error: {message}")]
81    SchemaValidation {
82        /// Concatenated validation error messages.
83        message: String,
84        /// The schema the value was validated against.
85        schema: Value,
86        /// The value that failed validation.
87        actual: Value,
88    },
89
90    /// A tool invocation raised an error.
91    #[error("Tool execution error ({tool_name}): {source}")]
92    ToolExecution {
93        /// The name of the tool that failed.
94        tool_name: String,
95        /// The underlying error.
96        source: Box<dyn std::error::Error + Send + Sync>,
97    },
98
99    /// A retry policy exhausted its budget without a successful response.
100    #[error("Retry exhausted after {attempts} attempts: {last_error}")]
101    RetryExhausted {
102        /// How many attempts were made.
103        attempts: u32,
104        /// The error from the final attempt.
105        #[source]
106        last_error: Box<LlmError>,
107    },
108
109    /// The operation exceeded its deadline.
110    #[error("Operation timed out after {elapsed_ms}ms")]
111    Timeout {
112        /// Milliseconds elapsed before the timeout fired.
113        elapsed_ms: u64,
114    },
115
116    /// A nested tool loop exceeded the maximum allowed depth.
117    ///
118    /// This occurs when `tool_loop` is called recursively (e.g., a tool
119    /// spawning a sub-agent) and the nesting depth exceeds `max_depth`
120    /// in [`ToolLoopConfig`](crate::tool::ToolLoopConfig).
121    #[error("max nesting depth exceeded (current: {current}, limit: {limit})")]
122    MaxDepthExceeded {
123        /// The depth at which the error was raised.
124        current: u32,
125        /// The configured maximum depth.
126        limit: u32,
127    },
128}
129
130impl LlmError {
131    /// Returns `true` if the error is transient and the request may succeed on retry.
132    ///
133    /// Useful for retry interceptors. This checks the `retryable` flag
134    /// on applicable variants and treats timeouts as always retryable.
135    ///
136    /// # Example
137    ///
138    /// ```rust
139    /// use llm_stack::LlmError;
140    ///
141    /// let err = LlmError::Timeout { elapsed_ms: 5000 };
142    /// assert!(err.is_retryable());
143    ///
144    /// let err = LlmError::Auth("bad key".into());
145    /// assert!(!err.is_retryable());
146    /// ```
147    pub fn is_retryable(&self) -> bool {
148        match self {
149            Self::Http { retryable, .. } | Self::Provider { retryable, .. } => *retryable,
150            Self::Timeout { .. } => true,
151            _ => false,
152        }
153    }
154}
155
156impl From<serde_json::Error> for LlmError {
157    fn from(err: serde_json::Error) -> Self {
158        Self::ResponseFormat {
159            message: err.to_string(),
160            raw: String::new(),
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_error_display_http() {
171        let err = LlmError::Http {
172            status: Some(http::StatusCode::TOO_MANY_REQUESTS),
173            message: "rate limited".into(),
174            retryable: true,
175        };
176        let display = format!("{err}");
177        assert!(display.contains("429"));
178        assert!(display.contains("rate limited"));
179    }
180
181    #[test]
182    fn test_error_display_auth() {
183        let err = LlmError::Auth("bad key".into());
184        assert!(format!("{err}").contains("bad key"));
185    }
186
187    #[test]
188    fn test_error_display_invalid_request() {
189        let err = LlmError::InvalidRequest("missing model".into());
190        assert!(format!("{err}").contains("missing model"));
191    }
192
193    #[test]
194    fn test_error_display_provider() {
195        let err = LlmError::Provider {
196            code: "overloaded".into(),
197            message: "server busy".into(),
198            retryable: true,
199        };
200        let display = format!("{err}");
201        assert!(display.contains("overloaded"));
202        assert!(display.contains("server busy"));
203    }
204
205    #[test]
206    fn test_error_display_response_format() {
207        let err = LlmError::ResponseFormat {
208            message: "not json".into(),
209            raw: "hello".into(),
210        };
211        assert!(format!("{err}").contains("not json"));
212    }
213
214    #[test]
215    fn test_error_display_schema_validation() {
216        let err = LlmError::SchemaValidation {
217            message: "missing field".into(),
218            schema: serde_json::json!({"type": "object"}),
219            actual: serde_json::json!({}),
220        };
221        assert!(format!("{err}").contains("missing field"));
222    }
223
224    #[test]
225    fn test_error_display_tool_execution() {
226        let err = LlmError::ToolExecution {
227            tool_name: "calculator".into(),
228            source: Box::new(std::io::Error::other("boom")),
229        };
230        let display = format!("{err}");
231        assert!(display.contains("calculator"));
232        assert!(display.contains("boom"));
233    }
234
235    #[test]
236    fn test_error_display_retry_exhausted() {
237        let inner = LlmError::Http {
238            status: Some(http::StatusCode::INTERNAL_SERVER_ERROR),
239            message: "server error".into(),
240            retryable: true,
241        };
242        let err = LlmError::RetryExhausted {
243            attempts: 3,
244            last_error: Box::new(inner),
245        };
246        let display = format!("{err}");
247        assert!(display.contains('3'));
248        assert!(display.contains("server error"));
249    }
250
251    #[test]
252    fn test_error_display_timeout() {
253        let err = LlmError::Timeout { elapsed_ms: 5000 };
254        assert!(format!("{err}").contains("5000"));
255    }
256
257    #[test]
258    fn test_error_is_send_sync() {
259        fn assert_send_sync<T: Send + Sync>() {}
260        assert_send_sync::<LlmError>();
261    }
262
263    #[test]
264    fn test_error_retryable_http() {
265        let err = LlmError::Http {
266            status: Some(http::StatusCode::TOO_MANY_REQUESTS),
267            message: "rate limited".into(),
268            retryable: true,
269        };
270        assert!(matches!(
271            err,
272            LlmError::Http {
273                retryable: true,
274                ..
275            }
276        ));
277    }
278
279    #[test]
280    fn test_error_retryable_provider() {
281        let err = LlmError::Provider {
282            code: "bad_request".into(),
283            message: "invalid".into(),
284            retryable: false,
285        };
286        assert!(matches!(
287            err,
288            LlmError::Provider {
289                retryable: false,
290                ..
291            }
292        ));
293    }
294
295    #[test]
296    fn test_error_retry_exhausted_nests() {
297        let inner = LlmError::Auth("expired".into());
298        let err = LlmError::RetryExhausted {
299            attempts: 2,
300            last_error: Box::new(inner),
301        };
302        assert!(matches!(
303            &err,
304            LlmError::RetryExhausted { last_error, .. }
305                if matches!(last_error.as_ref(), LlmError::Auth(_))
306        ));
307    }
308
309    #[test]
310    fn test_error_retry_exhausted_source_chain() {
311        use std::error::Error;
312        let inner = LlmError::Auth("expired".into());
313        let err = LlmError::RetryExhausted {
314            attempts: 3,
315            last_error: Box::new(inner),
316        };
317        let source = err.source().expect("RetryExhausted should have a source");
318        assert!(format!("{source}").contains("expired"));
319    }
320
321    #[test]
322    fn test_error_source_trait() {
323        use std::error::Error;
324        let err = LlmError::ToolExecution {
325            tool_name: "test".into(),
326            source: Box::new(std::io::Error::new(std::io::ErrorKind::NotFound, "gone")),
327        };
328        assert!(err.source().is_some());
329    }
330
331    #[test]
332    fn test_from_serde_json_error() {
333        let json_err = serde_json::from_str::<serde_json::Value>("not valid json").unwrap_err();
334        let llm_err: LlmError = json_err.into();
335        assert!(matches!(llm_err, LlmError::ResponseFormat { .. }));
336    }
337
338    #[test]
339    fn test_error_display_max_depth_exceeded() {
340        let err = LlmError::MaxDepthExceeded {
341            current: 3,
342            limit: 3,
343        };
344        let display = format!("{err}");
345        assert!(display.contains("max nesting depth exceeded"));
346        assert!(display.contains("current: 3"));
347        assert!(display.contains("limit: 3"));
348    }
349
350    #[test]
351    fn test_error_max_depth_not_retryable() {
352        let err = LlmError::MaxDepthExceeded {
353            current: 2,
354            limit: 2,
355        };
356        assert!(!err.is_retryable());
357    }
358}