use std::time::Duration;
use serde::{Deserialize, Serialize};
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
#[error("API error ({status}): {message}")]
#[non_exhaustive]
Api {
status: http::StatusCode,
request_id: Option<String>,
kind: ApiErrorKind,
message: String,
retry_after: Option<Duration>,
},
#[cfg(any(feature = "async", feature = "sync"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "async", feature = "sync"))))]
#[error("network error: {0}")]
Network(#[from] reqwest::Error),
#[error("decode error: {0}")]
Decode(#[from] serde_json::Error),
#[cfg(feature = "streaming")]
#[cfg_attr(docsrs, doc(cfg(feature = "streaming")))]
#[error("stream error: {0}")]
Stream(#[from] StreamError),
#[error("invalid configuration: {0}")]
InvalidConfig(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("agent loop exceeded max iterations ({max})")]
MaxIterationsExceeded {
max: u32,
},
#[error("agent loop exceeded cost budget: ${spent_usd:.4} > ${budget_usd:.4}")]
CostBudgetExceeded {
budget_usd: f64,
spent_usd: f64,
},
#[error("agent loop cancelled")]
Cancelled,
#[error("agent loop stopped by approval gate at tool '{tool_name}': {reason}")]
ToolApprovalStopped {
tool_name: String,
reason: String,
},
#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
#[error("request signing failed: {0}")]
Signing(Box<dyn std::error::Error + Send + Sync + 'static>),
}
impl Error {
pub fn is_retryable(&self) -> bool {
match self {
Error::Api { status, .. } => {
matches!(
status.as_u16(),
408 | 425 | 429 | 500 | 502 | 503 | 504 | 529
)
}
#[cfg(any(feature = "async", feature = "sync"))]
Error::Network(e) => e.is_timeout() || e.is_connect(),
#[cfg(feature = "streaming")]
Error::Stream(_) => false,
Error::Decode(_)
| Error::InvalidConfig(_)
| Error::Io(_)
| Error::MaxIterationsExceeded { .. }
| Error::CostBudgetExceeded { .. }
| Error::Cancelled
| Error::ToolApprovalStopped { .. } => false,
#[cfg(feature = "async")]
Error::Signing(_) => false,
}
}
pub fn request_id(&self) -> Option<&str> {
match self {
Error::Api { request_id, .. } => request_id.as_deref(),
_ => None,
}
}
pub fn retry_after(&self) -> Option<Duration> {
match self {
Error::Api { retry_after, .. } => *retry_after,
_ => None,
}
}
pub fn status(&self) -> Option<http::StatusCode> {
match self {
Error::Api { status, .. } => Some(*status),
_ => None,
}
}
#[allow(dead_code)]
pub(crate) fn from_response(
status: http::StatusCode,
request_id: Option<String>,
retry_after_header: Option<&str>,
body: &[u8],
) -> Error {
let retry_after = retry_after_header.and_then(parse_retry_after);
let payload = serde_json::from_slice::<ErrorEnvelope>(body).map_or_else(
|_| ApiErrorPayload {
kind: ApiErrorKind::ApiError,
message: String::from_utf8_lossy(body).into_owned(),
},
|e| e.error,
);
Error::Api {
status,
request_id,
kind: payload.kind,
message: payload.message,
retry_after,
}
}
}
#[allow(dead_code)]
pub(crate) fn parse_retry_after(header: &str) -> Option<Duration> {
header.trim().parse::<u64>().ok().map(Duration::from_secs)
}
#[derive(Deserialize)]
#[allow(dead_code)]
struct ErrorEnvelope {
error: ApiErrorPayload,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ApiErrorPayload {
#[serde(rename = "type")]
pub kind: ApiErrorKind,
pub message: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ApiErrorKind {
InvalidRequestError,
AuthenticationError,
PermissionError,
NotFoundError,
RateLimitError,
ApiError,
OverloadedError,
#[serde(other)]
Other,
}
#[cfg(feature = "streaming")]
#[cfg_attr(docsrs, doc(cfg(feature = "streaming")))]
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum StreamError {
#[error("stream parse error: {0}")]
Parse(String),
#[error("stream connection lost: {0}")]
Connection(String),
#[error("server emitted error event: {kind:?}: {message}")]
Server {
kind: ApiErrorKind,
message: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use serde_json::json;
#[test]
fn api_error_payload_round_trips() {
let payload = ApiErrorPayload {
kind: ApiErrorKind::OverloadedError,
message: "server overloaded".into(),
};
let v = serde_json::to_value(&payload).unwrap();
assert_eq!(
v,
json!({"type": "overloaded_error", "message": "server overloaded"})
);
let parsed: ApiErrorPayload = serde_json::from_value(v).unwrap();
assert_eq!(parsed, payload);
}
#[test]
fn api_error_kind_round_trips_known_variants() {
for (variant, wire) in [
(ApiErrorKind::InvalidRequestError, "invalid_request_error"),
(ApiErrorKind::AuthenticationError, "authentication_error"),
(ApiErrorKind::PermissionError, "permission_error"),
(ApiErrorKind::NotFoundError, "not_found_error"),
(ApiErrorKind::RateLimitError, "rate_limit_error"),
(ApiErrorKind::ApiError, "api_error"),
(ApiErrorKind::OverloadedError, "overloaded_error"),
] {
let v = serde_json::to_value(variant).unwrap();
assert_eq!(v, json!(wire));
let parsed: ApiErrorKind = serde_json::from_value(v).unwrap();
assert_eq!(parsed, variant);
}
}
#[test]
fn api_error_kind_unknown_falls_to_other() {
let parsed: ApiErrorKind = serde_json::from_str("\"future_error_type\"").unwrap();
assert_eq!(parsed, ApiErrorKind::Other);
}
fn api_error(status: u16) -> Error {
Error::Api {
status: http::StatusCode::from_u16(status).unwrap(),
request_id: None,
kind: ApiErrorKind::ApiError,
message: "x".into(),
retry_after: None,
}
}
#[test]
fn is_retryable_for_transient_statuses() {
for s in [408u16, 425, 429, 500, 502, 503, 504, 529] {
assert!(api_error(s).is_retryable(), "{s} should retry");
}
}
#[test]
fn is_not_retryable_for_client_errors() {
for s in [400u16, 401, 403, 404, 422] {
assert!(!api_error(s).is_retryable(), "{s} should not retry");
}
}
#[test]
fn is_not_retryable_for_decode_invalidconfig_io() {
let decode = Error::Decode(serde_json::from_str::<u32>("\"oops\"").unwrap_err());
assert!(!decode.is_retryable());
let cfg = Error::InvalidConfig("missing api key".into());
assert!(!cfg.is_retryable());
let io = Error::Io(std::io::Error::other("bad"));
assert!(!io.is_retryable());
}
#[test]
fn parse_retry_after_seconds() {
assert_eq!(parse_retry_after("120"), Some(Duration::from_secs(120)));
assert_eq!(parse_retry_after(" 5 "), Some(Duration::from_secs(5)));
assert_eq!(parse_retry_after("0"), Some(Duration::from_secs(0)));
}
#[test]
fn parse_retry_after_rejects_garbage() {
assert_eq!(parse_retry_after("not a number"), None);
assert_eq!(parse_retry_after("Wed, 21 Oct 2015 07:28:00 GMT"), None);
assert_eq!(parse_retry_after(""), None);
}
#[test]
fn from_response_decodes_typed_error_envelope() {
let body =
br#"{"type": "error", "error": {"type": "rate_limit_error", "message": "slow down"}}"#;
let err = Error::from_response(
http::StatusCode::TOO_MANY_REQUESTS,
Some("req_abc".into()),
Some("12"),
body,
);
match err {
Error::Api {
status,
request_id,
kind,
message,
retry_after,
} => {
assert_eq!(status, http::StatusCode::TOO_MANY_REQUESTS);
assert_eq!(request_id.as_deref(), Some("req_abc"));
assert_eq!(kind, ApiErrorKind::RateLimitError);
assert_eq!(message, "slow down");
assert_eq!(retry_after, Some(Duration::from_secs(12)));
}
other => panic!("expected Api, got {other:?}"),
}
}
#[test]
fn from_response_falls_back_for_non_json_body() {
let body = b"<html>oops</html>";
let err = Error::from_response(http::StatusCode::BAD_GATEWAY, None, None, body);
match err {
Error::Api {
status,
kind,
message,
retry_after,
..
} => {
assert_eq!(status, http::StatusCode::BAD_GATEWAY);
assert_eq!(kind, ApiErrorKind::ApiError); assert_eq!(message, "<html>oops</html>");
assert_eq!(retry_after, None);
}
other => panic!("expected Api, got {other:?}"),
}
}
#[test]
fn accessors_return_request_id_and_retry_after() {
let err = Error::Api {
status: http::StatusCode::INTERNAL_SERVER_ERROR,
request_id: Some("rid".into()),
kind: ApiErrorKind::ApiError,
message: "boom".into(),
retry_after: Some(Duration::from_secs(3)),
};
assert_eq!(err.request_id(), Some("rid"));
assert_eq!(err.retry_after(), Some(Duration::from_secs(3)));
assert_eq!(err.status(), Some(http::StatusCode::INTERNAL_SERVER_ERROR));
let cfg = Error::InvalidConfig("nope".into());
assert_eq!(cfg.request_id(), None);
assert_eq!(cfg.retry_after(), None);
assert_eq!(cfg.status(), None);
}
#[test]
fn display_impl_includes_status_and_message() {
let err = api_error(503);
let s = format!("{err}");
assert!(s.contains("503"), "{s}");
assert!(s.contains('x'), "{s}");
}
#[cfg(feature = "streaming")]
#[test]
fn stream_errors_are_not_retryable() {
let err = Error::Stream(StreamError::Connection("dropped".into()));
assert!(!err.is_retryable());
}
}