use std::borrow::Cow;
use std::time::Duration;
use crate::auth::AuthError;
pub type Result<T> = core::result::Result<T, Error>;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
#[error("invalid request: {0}")]
InvalidRequest(Cow<'static, str>),
#[error("config error: {0}")]
Config(Cow<'static, str>),
#[error("provider {kind}: {message}")]
Provider {
kind: ProviderErrorKind,
message: String,
#[allow(dead_code)]
retry_after: Option<Duration>,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
},
#[error("operation cancelled")]
Cancelled,
#[error("deadline exceeded")]
DeadlineExceeded,
#[error("dispatch interrupted for human review")]
Interrupted {
kind: crate::interruption::InterruptionKind,
payload: serde_json::Value,
},
#[error("model retry requested (attempt {attempt})")]
ModelRetry {
hint: crate::llm_facing::RenderedForLlm<String>,
attempt: u32,
},
#[error(transparent)]
Serde(#[from] serde_json::Error),
#[error(transparent)]
Auth(AuthError),
#[error("{0}")]
UsageLimitExceeded(crate::run_budget::UsageLimitBreach),
}
impl Error {
pub fn invalid_request(msg: impl Into<Cow<'static, str>>) -> Self {
Self::InvalidRequest(msg.into())
}
pub fn config(msg: impl Into<Cow<'static, str>>) -> Self {
Self::Config(msg.into())
}
pub const fn model_retry(
hint: crate::llm_facing::RenderedForLlm<String>,
attempt: u32,
) -> Self {
Self::ModelRetry { hint, attempt }
}
pub fn provider_http(status: u16, message: impl Into<String>) -> Self {
Self::Provider {
kind: http_or_network(status),
message: message.into(),
retry_after: None,
source: None,
}
}
pub fn provider_http_from<E>(status: u16, err: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Self::Provider {
kind: http_or_network(status),
message: err.to_string(),
retry_after: None,
source: Some(Box::new(err)),
}
}
pub fn provider_network(message: impl Into<String>) -> Self {
Self::Provider {
kind: ProviderErrorKind::Network,
message: message.into(),
retry_after: None,
source: None,
}
}
pub fn provider_network_from<E>(err: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Self::Provider {
kind: ProviderErrorKind::Network,
message: err.to_string(),
retry_after: None,
source: Some(Box::new(err)),
}
}
pub fn provider_tls(message: impl Into<String>) -> Self {
Self::Provider {
kind: ProviderErrorKind::Tls,
message: message.into(),
retry_after: None,
source: None,
}
}
pub fn provider_tls_from<E>(err: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Self::Provider {
kind: ProviderErrorKind::Tls,
message: err.to_string(),
retry_after: None,
source: Some(Box::new(err)),
}
}
pub fn provider_dns(message: impl Into<String>) -> Self {
Self::Provider {
kind: ProviderErrorKind::Dns,
message: message.into(),
retry_after: None,
source: None,
}
}
pub fn provider_dns_from<E>(err: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Self::Provider {
kind: ProviderErrorKind::Dns,
message: err.to_string(),
retry_after: None,
source: Some(Box::new(err)),
}
}
#[must_use]
pub fn with_retry_after(mut self, duration: Duration) -> Self {
if let Self::Provider {
ref mut retry_after,
..
} = self
{
*retry_after = Some(duration);
}
self
}
#[must_use]
pub fn with_source<E>(mut self, err: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
if let Self::Provider { ref mut source, .. } = self {
*source = Some(Box::new(err));
}
self
}
pub fn envelope(&self) -> ErrorEnvelope {
let (wire_code, wire_class) = self.wire_signal();
let (retry_after_secs, provider_status) = match self {
Self::Provider {
kind, retry_after, ..
} => (
retry_after.map(|d| d.as_secs()),
match kind {
ProviderErrorKind::Http(status) => Some(*status),
_ => None,
},
),
_ => (None, None),
};
ErrorEnvelope {
wire_code,
wire_class,
retry_after_secs,
provider_status,
}
}
fn wire_signal(&self) -> (&'static str, ErrorClass) {
match self {
Self::InvalidRequest(_) => ("invalid_request", ErrorClass::Client),
Self::Config(_) => ("config_error", ErrorClass::Server),
Self::Provider { kind, .. } => match kind {
ProviderErrorKind::Network => ("transport_failure", ErrorClass::Server),
ProviderErrorKind::Tls => ("tls_failure", ErrorClass::Server),
ProviderErrorKind::Dns => ("dns_failure", ErrorClass::Server),
ProviderErrorKind::Http(status) => match *status {
429 => ("rate_limited", ErrorClass::Client),
401 | 403 => ("upstream_unauthorized", ErrorClass::Client),
s if (400..500).contains(&s) => ("upstream_invalid", ErrorClass::Client),
s if (500..600).contains(&s) => ("upstream_unavailable", ErrorClass::Server),
_ => ("upstream_error", ErrorClass::Server),
},
},
Self::Auth(_) => ("auth_failed", ErrorClass::Client),
Self::Cancelled => ("cancelled", ErrorClass::Client),
Self::DeadlineExceeded => ("deadline_exceeded", ErrorClass::Server),
Self::Interrupted { .. } => ("interrupted", ErrorClass::Client),
Self::ModelRetry { .. } => ("model_retry_exhausted", ErrorClass::Client),
Self::Serde(_) => ("serde", ErrorClass::Server),
Self::UsageLimitExceeded(_) => ("quota_exceeded", ErrorClass::Client),
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, serde::Serialize)]
#[non_exhaustive]
pub struct ErrorEnvelope {
pub wire_code: &'static str,
pub wire_class: ErrorClass,
#[serde(skip_serializing_if = "Option::is_none")]
pub retry_after_secs: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub provider_status: Option<u16>,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, serde::Serialize)]
#[serde(rename_all = "lowercase")]
#[non_exhaustive]
pub enum ErrorClass {
Client,
Server,
}
impl std::fmt::Display for ErrorClass {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Client => f.write_str("client"),
Self::Server => f.write_str("server"),
}
}
}
const fn http_or_network(status: u16) -> ProviderErrorKind {
if status >= 400 && status < 600 {
ProviderErrorKind::Http(status)
} else {
ProviderErrorKind::Network
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
#[non_exhaustive]
pub enum ProviderErrorKind {
Network,
Tls,
Dns,
Http(u16),
}
impl std::fmt::Display for ProviderErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Network => f.write_str("network"),
Self::Tls => f.write_str("tls"),
Self::Dns => f.write_str("dns"),
Self::Http(status) => write!(f, "returned {status}"),
}
}
}