1use std::time::Duration;
4
5use serde::{Deserialize, Serialize};
6
7pub type Result<T, E = Error> = std::result::Result<T, E>;
9
10#[derive(Debug, thiserror::Error)]
17#[non_exhaustive]
18pub enum Error {
19 #[error("API error ({status}): {message}")]
21 #[non_exhaustive]
22 Api {
23 status: http::StatusCode,
25 request_id: Option<String>,
28 kind: ApiErrorKind,
30 message: String,
32 retry_after: Option<Duration>,
34 },
35 #[cfg(any(feature = "async", feature = "sync"))]
37 #[cfg_attr(docsrs, doc(cfg(any(feature = "async", feature = "sync"))))]
38 #[error("network error: {0}")]
39 Network(#[from] reqwest::Error),
40 #[error("decode error: {0}")]
42 Decode(#[from] serde_json::Error),
43 #[cfg(feature = "streaming")]
45 #[cfg_attr(docsrs, doc(cfg(feature = "streaming")))]
46 #[error("stream error: {0}")]
47 Stream(#[from] StreamError),
48 #[error("invalid configuration: {0}")]
50 InvalidConfig(String),
51 #[error("IO error: {0}")]
53 Io(#[from] std::io::Error),
54 #[error("agent loop exceeded max iterations ({max})")]
57 MaxIterationsExceeded {
58 max: u32,
60 },
61 #[error("agent loop exceeded cost budget: ${spent_usd:.4} > ${budget_usd:.4}")]
65 CostBudgetExceeded {
66 budget_usd: f64,
68 spent_usd: f64,
70 },
71 #[error("agent loop cancelled")]
73 Cancelled,
74 #[error("agent loop stopped by approval gate at tool '{tool_name}': {reason}")]
77 ToolApprovalStopped {
78 tool_name: String,
80 reason: String,
82 },
83 #[cfg(feature = "async")]
86 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
87 #[error("request signing failed: {0}")]
88 Signing(Box<dyn std::error::Error + Send + Sync + 'static>),
89}
90
91impl Error {
92 pub fn is_retryable(&self) -> bool {
97 match self {
98 Error::Api { status, .. } => {
99 matches!(
100 status.as_u16(),
101 408 | 425 | 429 | 500 | 502 | 503 | 504 | 529
102 )
103 }
104 #[cfg(any(feature = "async", feature = "sync"))]
105 Error::Network(e) => e.is_timeout() || e.is_connect(),
106 #[cfg(feature = "streaming")]
107 Error::Stream(_) => false,
108 Error::Decode(_)
109 | Error::InvalidConfig(_)
110 | Error::Io(_)
111 | Error::MaxIterationsExceeded { .. }
112 | Error::CostBudgetExceeded { .. }
113 | Error::Cancelled
114 | Error::ToolApprovalStopped { .. } => false,
115 #[cfg(feature = "async")]
116 Error::Signing(_) => false,
117 }
118 }
119
120 pub fn request_id(&self) -> Option<&str> {
122 match self {
123 Error::Api { request_id, .. } => request_id.as_deref(),
124 _ => None,
125 }
126 }
127
128 pub fn retry_after(&self) -> Option<Duration> {
130 match self {
131 Error::Api { retry_after, .. } => *retry_after,
132 _ => None,
133 }
134 }
135
136 pub fn status(&self) -> Option<http::StatusCode> {
138 match self {
139 Error::Api { status, .. } => Some(*status),
140 _ => None,
141 }
142 }
143
144 #[allow(dead_code)]
152 pub(crate) fn from_response(
153 status: http::StatusCode,
154 request_id: Option<String>,
155 retry_after_header: Option<&str>,
156 body: &[u8],
157 ) -> Error {
158 let retry_after = retry_after_header.and_then(parse_retry_after);
159 let payload = serde_json::from_slice::<ErrorEnvelope>(body).map_or_else(
160 |_| ApiErrorPayload {
161 kind: ApiErrorKind::ApiError,
162 message: String::from_utf8_lossy(body).into_owned(),
163 },
164 |e| e.error,
165 );
166 Error::Api {
167 status,
168 request_id,
169 kind: payload.kind,
170 message: payload.message,
171 retry_after,
172 }
173 }
174}
175
176#[allow(dead_code)]
181pub(crate) fn parse_retry_after(header: &str) -> Option<Duration> {
182 header.trim().parse::<u64>().ok().map(Duration::from_secs)
183}
184
185#[derive(Deserialize)]
188#[allow(dead_code)]
189struct ErrorEnvelope {
190 error: ApiErrorPayload,
191}
192
193#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
202#[non_exhaustive]
203pub struct ApiErrorPayload {
204 #[serde(rename = "type")]
206 pub kind: ApiErrorKind,
207 pub message: String,
209}
210
211#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
218#[serde(rename_all = "snake_case")]
219#[non_exhaustive]
220pub enum ApiErrorKind {
221 InvalidRequestError,
223 AuthenticationError,
225 PermissionError,
227 NotFoundError,
229 RateLimitError,
231 ApiError,
233 OverloadedError,
235 #[serde(other)]
237 Other,
238}
239
240#[cfg(feature = "streaming")]
245#[cfg_attr(docsrs, doc(cfg(feature = "streaming")))]
246#[derive(Debug, thiserror::Error)]
247#[non_exhaustive]
248pub enum StreamError {
249 #[error("stream parse error: {0}")]
251 Parse(String),
252 #[error("stream connection lost: {0}")]
254 Connection(String),
255 #[error("server emitted error event: {kind:?}: {message}")]
257 Server {
258 kind: ApiErrorKind,
260 message: String,
262 },
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use pretty_assertions::assert_eq;
269 use serde_json::json;
270
271 #[test]
272 fn api_error_payload_round_trips() {
273 let payload = ApiErrorPayload {
274 kind: ApiErrorKind::OverloadedError,
275 message: "server overloaded".into(),
276 };
277 let v = serde_json::to_value(&payload).unwrap();
278 assert_eq!(
279 v,
280 json!({"type": "overloaded_error", "message": "server overloaded"})
281 );
282 let parsed: ApiErrorPayload = serde_json::from_value(v).unwrap();
283 assert_eq!(parsed, payload);
284 }
285
286 #[test]
287 fn api_error_kind_round_trips_known_variants() {
288 for (variant, wire) in [
289 (ApiErrorKind::InvalidRequestError, "invalid_request_error"),
290 (ApiErrorKind::AuthenticationError, "authentication_error"),
291 (ApiErrorKind::PermissionError, "permission_error"),
292 (ApiErrorKind::NotFoundError, "not_found_error"),
293 (ApiErrorKind::RateLimitError, "rate_limit_error"),
294 (ApiErrorKind::ApiError, "api_error"),
295 (ApiErrorKind::OverloadedError, "overloaded_error"),
296 ] {
297 let v = serde_json::to_value(variant).unwrap();
298 assert_eq!(v, json!(wire));
299 let parsed: ApiErrorKind = serde_json::from_value(v).unwrap();
300 assert_eq!(parsed, variant);
301 }
302 }
303
304 #[test]
305 fn api_error_kind_unknown_falls_to_other() {
306 let parsed: ApiErrorKind = serde_json::from_str("\"future_error_type\"").unwrap();
307 assert_eq!(parsed, ApiErrorKind::Other);
308 }
309
310 fn api_error(status: u16) -> Error {
311 Error::Api {
312 status: http::StatusCode::from_u16(status).unwrap(),
313 request_id: None,
314 kind: ApiErrorKind::ApiError,
315 message: "x".into(),
316 retry_after: None,
317 }
318 }
319
320 #[test]
321 fn is_retryable_for_transient_statuses() {
322 for s in [408u16, 425, 429, 500, 502, 503, 504, 529] {
323 assert!(api_error(s).is_retryable(), "{s} should retry");
324 }
325 }
326
327 #[test]
328 fn is_not_retryable_for_client_errors() {
329 for s in [400u16, 401, 403, 404, 422] {
330 assert!(!api_error(s).is_retryable(), "{s} should not retry");
331 }
332 }
333
334 #[test]
335 fn is_not_retryable_for_decode_invalidconfig_io() {
336 let decode = Error::Decode(serde_json::from_str::<u32>("\"oops\"").unwrap_err());
337 assert!(!decode.is_retryable());
338
339 let cfg = Error::InvalidConfig("missing api key".into());
340 assert!(!cfg.is_retryable());
341
342 let io = Error::Io(std::io::Error::other("bad"));
343 assert!(!io.is_retryable());
344 }
345
346 #[test]
347 fn parse_retry_after_seconds() {
348 assert_eq!(parse_retry_after("120"), Some(Duration::from_secs(120)));
349 assert_eq!(parse_retry_after(" 5 "), Some(Duration::from_secs(5)));
350 assert_eq!(parse_retry_after("0"), Some(Duration::from_secs(0)));
351 }
352
353 #[test]
354 fn parse_retry_after_rejects_garbage() {
355 assert_eq!(parse_retry_after("not a number"), None);
356 assert_eq!(parse_retry_after("Wed, 21 Oct 2015 07:28:00 GMT"), None);
358 assert_eq!(parse_retry_after(""), None);
359 }
360
361 #[test]
362 fn from_response_decodes_typed_error_envelope() {
363 let body =
364 br#"{"type": "error", "error": {"type": "rate_limit_error", "message": "slow down"}}"#;
365 let err = Error::from_response(
366 http::StatusCode::TOO_MANY_REQUESTS,
367 Some("req_abc".into()),
368 Some("12"),
369 body,
370 );
371 match err {
372 Error::Api {
373 status,
374 request_id,
375 kind,
376 message,
377 retry_after,
378 } => {
379 assert_eq!(status, http::StatusCode::TOO_MANY_REQUESTS);
380 assert_eq!(request_id.as_deref(), Some("req_abc"));
381 assert_eq!(kind, ApiErrorKind::RateLimitError);
382 assert_eq!(message, "slow down");
383 assert_eq!(retry_after, Some(Duration::from_secs(12)));
384 }
385 other => panic!("expected Api, got {other:?}"),
386 }
387 }
388
389 #[test]
390 fn from_response_falls_back_for_non_json_body() {
391 let body = b"<html>oops</html>";
392 let err = Error::from_response(http::StatusCode::BAD_GATEWAY, None, None, body);
393 match err {
394 Error::Api {
395 status,
396 kind,
397 message,
398 retry_after,
399 ..
400 } => {
401 assert_eq!(status, http::StatusCode::BAD_GATEWAY);
402 assert_eq!(kind, ApiErrorKind::ApiError); assert_eq!(message, "<html>oops</html>");
404 assert_eq!(retry_after, None);
405 }
406 other => panic!("expected Api, got {other:?}"),
407 }
408 }
409
410 #[test]
411 fn accessors_return_request_id_and_retry_after() {
412 let err = Error::Api {
413 status: http::StatusCode::INTERNAL_SERVER_ERROR,
414 request_id: Some("rid".into()),
415 kind: ApiErrorKind::ApiError,
416 message: "boom".into(),
417 retry_after: Some(Duration::from_secs(3)),
418 };
419 assert_eq!(err.request_id(), Some("rid"));
420 assert_eq!(err.retry_after(), Some(Duration::from_secs(3)));
421 assert_eq!(err.status(), Some(http::StatusCode::INTERNAL_SERVER_ERROR));
422
423 let cfg = Error::InvalidConfig("nope".into());
424 assert_eq!(cfg.request_id(), None);
425 assert_eq!(cfg.retry_after(), None);
426 assert_eq!(cfg.status(), None);
427 }
428
429 #[test]
430 fn display_impl_includes_status_and_message() {
431 let err = api_error(503);
432 let s = format!("{err}");
433 assert!(s.contains("503"), "{s}");
434 assert!(s.contains('x'), "{s}");
435 }
436
437 #[cfg(feature = "streaming")]
438 #[test]
439 fn stream_errors_are_not_retryable() {
440 let err = Error::Stream(StreamError::Connection("dropped".into()));
441 assert!(!err.is_retryable());
442 }
443}