use std::{fmt, time::Duration};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RetryMetadata {
pub attempts: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub last_status: Option<u16>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub last_error: Option<String>,
}
use thiserror::Error;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct FieldError {
pub field: Option<String>,
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct WorkflowValidationIssue {
pub code: String,
pub path: String,
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct WorkflowValidationError {
pub issues: Vec<WorkflowValidationIssue>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub request_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub retries: Option<RetryMetadata>,
}
impl fmt::Display for WorkflowValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.issues.is_empty() {
return write!(f, "workflow validation error");
}
let first = &self.issues[0];
if self.issues.len() == 1 {
return write!(f, "{}: {}", first.path, first.message);
}
write!(
f,
"{}: {} (and {} more)",
first.path,
first.message,
self.issues.len() - 1
)
}
}
impl std::error::Error for WorkflowValidationError {}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ValidationError {
pub message: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub field: Option<String>,
}
impl ValidationError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
field: None,
}
}
pub fn with_field(mut self, field: impl Into<String>) -> Self {
self.field = Some(field.into());
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validation_error_formats_with_field() {
let err = ValidationError::new("is required").with_field("model");
assert_eq!(err.to_string(), "model: is required");
}
#[test]
fn api_error_keeps_status_and_body() {
let api_err = APIError {
status: 429,
code: Some("rate_limit".into()),
message: "too many requests".into(),
request_id: Some("req_123".into()),
fields: Vec::new(),
retries: Some(RetryMetadata {
attempts: 2,
last_status: Some(429),
last_error: None,
}),
raw_body: Some("{\"error\":\"rate limit\"}".into()),
};
assert_eq!(api_err.to_string(), "rate_limit (429): too many requests");
assert_eq!(api_err.status, 429);
assert!(api_err.raw_body.is_some());
}
}
impl fmt::Display for ValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(field) = &self.field {
write!(f, "{}: {}", field, self.message)
} else {
write!(f, "{}", self.message)
}
}
}
impl std::error::Error for ValidationError {}
impl From<String> for ValidationError {
fn from(message: String) -> Self {
Self::new(message)
}
}
impl From<&str> for ValidationError {
fn from(message: &str) -> Self {
Self::new(message)
}
}
mod error_codes {
pub const NOT_FOUND: &str = "NOT_FOUND";
pub const VALIDATION_ERROR: &str = "VALIDATION_ERROR";
pub const RATE_LIMIT: &str = "RATE_LIMIT";
pub const UNAUTHORIZED: &str = "UNAUTHORIZED";
pub const FORBIDDEN: &str = "FORBIDDEN";
pub const SERVICE_UNAVAILABLE: &str = "SERVICE_UNAVAILABLE";
pub const INVALID_INPUT: &str = "INVALID_INPUT";
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct APIError {
pub status: u16,
pub code: Option<String>,
pub message: String,
pub request_id: Option<String>,
#[serde(default)]
pub fields: Vec<FieldError>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub retries: Option<RetryMetadata>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub raw_body: Option<String>,
}
impl APIError {
pub fn new(status: u16, message: impl Into<String>) -> Self {
Self {
status,
code: None,
message: message.into(),
request_id: None,
fields: Vec::new(),
retries: None,
raw_body: None,
}
}
pub fn is_not_found(&self) -> bool {
self.code.as_deref() == Some(error_codes::NOT_FOUND)
}
pub fn is_validation(&self) -> bool {
matches!(
self.code.as_deref(),
Some(error_codes::VALIDATION_ERROR) | Some(error_codes::INVALID_INPUT)
)
}
pub fn is_rate_limit(&self) -> bool {
self.code.as_deref() == Some(error_codes::RATE_LIMIT)
}
pub fn is_unauthorized(&self) -> bool {
self.code.as_deref() == Some(error_codes::UNAUTHORIZED)
}
pub fn is_forbidden(&self) -> bool {
self.code.as_deref() == Some(error_codes::FORBIDDEN)
}
pub fn is_unavailable(&self) -> bool {
self.code.as_deref() == Some(error_codes::SERVICE_UNAVAILABLE)
}
}
impl fmt::Display for APIError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(code) = &self.code {
write!(f, "{} ({}): {}", code, self.status, self.message)
} else {
write!(f, "{}: {}", self.status, self.message)
}
}
}
impl std::error::Error for APIError {}
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, Error)]
#[error("{kind}: {message}")]
pub struct TransportError {
pub kind: TransportErrorKind,
pub message: String,
#[source]
pub source: Option<reqwest::Error>,
pub retries: Option<RetryMetadata>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum TransportErrorKind {
Timeout,
Connect,
Request,
EmptyResponse,
Other,
}
impl TransportError {
pub fn connect(context: impl Into<String>, source: reqwest::Error) -> Error {
Error::Transport(Self {
kind: TransportErrorKind::Connect,
message: format!("{}: {}", context.into(), source),
source: Some(source),
retries: None,
})
}
pub fn other(message: impl Into<String>) -> Error {
Error::Transport(Self {
kind: TransportErrorKind::Other,
message: message.into(),
source: None,
retries: None,
})
}
pub fn http_failure(
context: impl Into<String>,
status: reqwest::StatusCode,
body: String,
) -> Error {
Error::Transport(Self {
kind: TransportErrorKind::Other,
message: format!("{} ({}): {}", context.into(), status, body),
source: None,
retries: None,
})
}
pub fn parse_response(context: impl Into<String>, source: reqwest::Error) -> Error {
Error::Transport(Self {
kind: TransportErrorKind::Other,
message: format!("failed to parse {}: {}", context.into(), source),
source: Some(source),
retries: None,
})
}
pub fn timeout(message: impl Into<String>) -> Error {
Error::Transport(Self {
kind: TransportErrorKind::Timeout,
message: message.into(),
source: None,
retries: None,
})
}
}
impl fmt::Display for TransportErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let label = match self {
TransportErrorKind::Timeout => "timeout",
TransportErrorKind::Connect => "connect",
TransportErrorKind::Request => "request",
TransportErrorKind::EmptyResponse => "empty response",
TransportErrorKind::Other => "transport",
};
write!(f, "{label}")
}
}
#[derive(Debug, Error)]
pub enum Error {
#[error("{0}")]
Validation(#[from] ValidationError),
#[error("{0}")]
WorkflowValidation(#[from] WorkflowValidationError),
#[error("serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("{0}")]
Api(#[from] APIError),
#[error("{0}")]
Transport(#[from] TransportError),
#[error("stream backpressure: dropped {dropped} events")]
StreamBackpressure { dropped: usize },
#[error("stream closed")]
StreamClosed,
#[cfg(feature = "streaming")]
#[error("stream protocol error: {message}")]
StreamProtocol {
message: String,
raw_data: Option<String>,
},
#[cfg(feature = "streaming")]
#[error("expected NDJSON stream ({expected}), got Content-Type {received}")]
StreamContentType {
expected: &'static str,
received: String,
status: u16,
},
#[cfg(feature = "streaming")]
#[error("{0}")]
StreamTimeout(#[from] StreamTimeoutError),
#[error("agent exceeded maximum turns ({max_turns}) without completing")]
AgentMaxTurns { max_turns: usize },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamTimeoutKind {
Ttft,
Idle,
Total,
}
impl fmt::Display for StreamTimeoutKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StreamTimeoutKind::Ttft => write!(f, "ttft"),
StreamTimeoutKind::Idle => write!(f, "idle"),
StreamTimeoutKind::Total => write!(f, "total"),
}
}
}
#[derive(Debug, Error, Clone)]
#[error("stream {kind} timeout after {timeout:?}")]
pub struct StreamTimeoutError {
pub kind: StreamTimeoutKind,
pub timeout: Duration,
}