1use std::time::Duration;
54
55use serde::{Deserialize, Serialize};
56
57pub type Result<T, E = Error> = std::result::Result<T, E>;
59
60#[derive(Debug, thiserror::Error)]
67#[non_exhaustive]
68pub enum Error {
69 #[error("API error ({status}): {message}")]
71 #[non_exhaustive]
72 Api {
73 status: http::StatusCode,
75 request_id: Option<String>,
78 kind: ApiErrorKind,
80 message: String,
82 retry_after: Option<Duration>,
84 },
85 #[cfg(any(feature = "async", feature = "sync"))]
87 #[cfg_attr(docsrs, doc(cfg(any(feature = "async", feature = "sync"))))]
88 #[error("network error: {0}")]
89 Network(#[from] reqwest::Error),
90 #[error("decode error: {0}")]
92 Decode(#[from] serde_json::Error),
93 #[cfg(feature = "streaming")]
95 #[cfg_attr(docsrs, doc(cfg(feature = "streaming")))]
96 #[error("stream error: {0}")]
97 Stream(#[from] StreamError),
98 #[error("invalid configuration: {0}")]
100 InvalidConfig(String),
101 #[error("IO error: {0}")]
103 Io(#[from] std::io::Error),
104 #[error("agent loop exceeded max iterations ({max})")]
107 MaxIterationsExceeded {
108 max: u32,
110 },
111 #[error("agent loop exceeded cost budget: ${spent_usd:.4} > ${budget_usd:.4}")]
115 CostBudgetExceeded {
116 budget_usd: f64,
118 spent_usd: f64,
120 },
121 #[error("agent loop cancelled")]
123 Cancelled,
124 #[error("agent loop stopped by approval gate at tool '{tool_name}': {reason}")]
127 ToolApprovalStopped {
128 tool_name: String,
130 reason: String,
132 },
133 #[cfg(feature = "async")]
136 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
137 #[error("request signing failed: {0}")]
138 Signing(Box<dyn std::error::Error + Send + Sync + 'static>),
139}
140
141impl Error {
142 pub fn is_retryable(&self) -> bool {
147 match self {
148 Error::Api { status, .. } => {
149 matches!(
150 status.as_u16(),
151 408 | 425 | 429 | 500 | 502 | 503 | 504 | 529
152 )
153 }
154 #[cfg(any(feature = "async", feature = "sync"))]
155 Error::Network(e) => e.is_timeout() || e.is_connect(),
156 #[cfg(feature = "streaming")]
157 Error::Stream(_) => false,
158 Error::Decode(_)
159 | Error::InvalidConfig(_)
160 | Error::Io(_)
161 | Error::MaxIterationsExceeded { .. }
162 | Error::CostBudgetExceeded { .. }
163 | Error::Cancelled
164 | Error::ToolApprovalStopped { .. } => false,
165 #[cfg(feature = "async")]
166 Error::Signing(_) => false,
167 }
168 }
169
170 pub fn request_id(&self) -> Option<&str> {
172 match self {
173 Error::Api { request_id, .. } => request_id.as_deref(),
174 _ => None,
175 }
176 }
177
178 pub fn retry_after(&self) -> Option<Duration> {
180 match self {
181 Error::Api { retry_after, .. } => *retry_after,
182 _ => None,
183 }
184 }
185
186 pub fn status(&self) -> Option<http::StatusCode> {
188 match self {
189 Error::Api { status, .. } => Some(*status),
190 _ => None,
191 }
192 }
193
194 #[allow(dead_code)]
202 pub(crate) fn from_response(
203 status: http::StatusCode,
204 request_id: Option<String>,
205 retry_after_header: Option<&str>,
206 body: &[u8],
207 ) -> Error {
208 let retry_after = retry_after_header.and_then(parse_retry_after);
209 let payload = serde_json::from_slice::<ErrorEnvelope>(body).map_or_else(
210 |_| ApiErrorPayload {
211 kind: ApiErrorKind::ApiError,
212 message: String::from_utf8_lossy(body).into_owned(),
213 },
214 |e| e.error,
215 );
216 Error::Api {
217 status,
218 request_id,
219 kind: payload.kind,
220 message: payload.message,
221 retry_after,
222 }
223 }
224}
225
226#[allow(dead_code)]
231pub(crate) fn parse_retry_after(header: &str) -> Option<Duration> {
232 header.trim().parse::<u64>().ok().map(Duration::from_secs)
233}
234
235#[derive(Deserialize)]
238#[allow(dead_code)]
239struct ErrorEnvelope {
240 error: ApiErrorPayload,
241}
242
243#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
252#[non_exhaustive]
253pub struct ApiErrorPayload {
254 #[serde(rename = "type")]
256 pub kind: ApiErrorKind,
257 pub message: String,
259}
260
261#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
268#[serde(rename_all = "snake_case")]
269#[non_exhaustive]
270pub enum ApiErrorKind {
271 InvalidRequestError,
273 AuthenticationError,
275 PermissionError,
277 NotFoundError,
279 RateLimitError,
281 ApiError,
283 OverloadedError,
285 #[serde(other)]
287 Other,
288}
289
290#[cfg(feature = "streaming")]
295#[cfg_attr(docsrs, doc(cfg(feature = "streaming")))]
296#[derive(Debug, thiserror::Error)]
297#[non_exhaustive]
298pub enum StreamError {
299 #[error("stream parse error: {0}")]
301 Parse(String),
302 #[error("stream connection lost: {0}")]
304 Connection(String),
305 #[error("server emitted error event: {kind:?}: {message}")]
307 Server {
308 kind: ApiErrorKind,
310 message: String,
312 },
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use pretty_assertions::assert_eq;
319 use serde_json::json;
320
321 #[test]
322 fn api_error_payload_round_trips() {
323 let payload = ApiErrorPayload {
324 kind: ApiErrorKind::OverloadedError,
325 message: "server overloaded".into(),
326 };
327 let v = serde_json::to_value(&payload).unwrap();
328 assert_eq!(
329 v,
330 json!({"type": "overloaded_error", "message": "server overloaded"})
331 );
332 let parsed: ApiErrorPayload = serde_json::from_value(v).unwrap();
333 assert_eq!(parsed, payload);
334 }
335
336 #[test]
337 fn api_error_kind_round_trips_known_variants() {
338 for (variant, wire) in [
339 (ApiErrorKind::InvalidRequestError, "invalid_request_error"),
340 (ApiErrorKind::AuthenticationError, "authentication_error"),
341 (ApiErrorKind::PermissionError, "permission_error"),
342 (ApiErrorKind::NotFoundError, "not_found_error"),
343 (ApiErrorKind::RateLimitError, "rate_limit_error"),
344 (ApiErrorKind::ApiError, "api_error"),
345 (ApiErrorKind::OverloadedError, "overloaded_error"),
346 ] {
347 let v = serde_json::to_value(variant).unwrap();
348 assert_eq!(v, json!(wire));
349 let parsed: ApiErrorKind = serde_json::from_value(v).unwrap();
350 assert_eq!(parsed, variant);
351 }
352 }
353
354 #[test]
355 fn api_error_kind_unknown_falls_to_other() {
356 let parsed: ApiErrorKind = serde_json::from_str("\"future_error_type\"").unwrap();
357 assert_eq!(parsed, ApiErrorKind::Other);
358 }
359
360 fn api_error(status: u16) -> Error {
361 Error::Api {
362 status: http::StatusCode::from_u16(status).unwrap(),
363 request_id: None,
364 kind: ApiErrorKind::ApiError,
365 message: "x".into(),
366 retry_after: None,
367 }
368 }
369
370 #[test]
371 fn is_retryable_for_transient_statuses() {
372 for s in [408u16, 425, 429, 500, 502, 503, 504, 529] {
373 assert!(api_error(s).is_retryable(), "{s} should retry");
374 }
375 }
376
377 #[test]
378 fn is_not_retryable_for_client_errors() {
379 for s in [400u16, 401, 403, 404, 422] {
380 assert!(!api_error(s).is_retryable(), "{s} should not retry");
381 }
382 }
383
384 #[test]
385 fn is_not_retryable_for_decode_invalidconfig_io() {
386 let decode = Error::Decode(serde_json::from_str::<u32>("\"oops\"").unwrap_err());
387 assert!(!decode.is_retryable());
388
389 let cfg = Error::InvalidConfig("missing api key".into());
390 assert!(!cfg.is_retryable());
391
392 let io = Error::Io(std::io::Error::other("bad"));
393 assert!(!io.is_retryable());
394 }
395
396 #[test]
397 fn parse_retry_after_seconds() {
398 assert_eq!(parse_retry_after("120"), Some(Duration::from_secs(120)));
399 assert_eq!(parse_retry_after(" 5 "), Some(Duration::from_secs(5)));
400 assert_eq!(parse_retry_after("0"), Some(Duration::from_secs(0)));
401 }
402
403 #[test]
404 fn parse_retry_after_rejects_garbage() {
405 assert_eq!(parse_retry_after("not a number"), None);
406 assert_eq!(parse_retry_after("Wed, 21 Oct 2015 07:28:00 GMT"), None);
408 assert_eq!(parse_retry_after(""), None);
409 }
410
411 #[test]
412 fn from_response_decodes_typed_error_envelope() {
413 let body =
414 br#"{"type": "error", "error": {"type": "rate_limit_error", "message": "slow down"}}"#;
415 let err = Error::from_response(
416 http::StatusCode::TOO_MANY_REQUESTS,
417 Some("req_abc".into()),
418 Some("12"),
419 body,
420 );
421 match err {
422 Error::Api {
423 status,
424 request_id,
425 kind,
426 message,
427 retry_after,
428 } => {
429 assert_eq!(status, http::StatusCode::TOO_MANY_REQUESTS);
430 assert_eq!(request_id.as_deref(), Some("req_abc"));
431 assert_eq!(kind, ApiErrorKind::RateLimitError);
432 assert_eq!(message, "slow down");
433 assert_eq!(retry_after, Some(Duration::from_secs(12)));
434 }
435 other => panic!("expected Api, got {other:?}"),
436 }
437 }
438
439 #[test]
440 fn from_response_falls_back_for_non_json_body() {
441 let body = b"<html>oops</html>";
442 let err = Error::from_response(http::StatusCode::BAD_GATEWAY, None, None, body);
443 match err {
444 Error::Api {
445 status,
446 kind,
447 message,
448 retry_after,
449 ..
450 } => {
451 assert_eq!(status, http::StatusCode::BAD_GATEWAY);
452 assert_eq!(kind, ApiErrorKind::ApiError); assert_eq!(message, "<html>oops</html>");
454 assert_eq!(retry_after, None);
455 }
456 other => panic!("expected Api, got {other:?}"),
457 }
458 }
459
460 #[test]
461 fn accessors_return_request_id_and_retry_after() {
462 let err = Error::Api {
463 status: http::StatusCode::INTERNAL_SERVER_ERROR,
464 request_id: Some("rid".into()),
465 kind: ApiErrorKind::ApiError,
466 message: "boom".into(),
467 retry_after: Some(Duration::from_secs(3)),
468 };
469 assert_eq!(err.request_id(), Some("rid"));
470 assert_eq!(err.retry_after(), Some(Duration::from_secs(3)));
471 assert_eq!(err.status(), Some(http::StatusCode::INTERNAL_SERVER_ERROR));
472
473 let cfg = Error::InvalidConfig("nope".into());
474 assert_eq!(cfg.request_id(), None);
475 assert_eq!(cfg.retry_after(), None);
476 assert_eq!(cfg.status(), None);
477 }
478
479 #[test]
480 fn display_impl_includes_status_and_message() {
481 let err = api_error(503);
482 let s = format!("{err}");
483 assert!(s.contains("503"), "{s}");
484 assert!(s.contains('x'), "{s}");
485 }
486
487 #[cfg(feature = "streaming")]
488 #[test]
489 fn stream_errors_are_not_retryable() {
490 let err = Error::Stream(StreamError::Connection("dropped".into()));
491 assert!(!err.is_retryable());
492 }
493}