use std::fmt;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::types::ToolCallId;
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, Error)]
pub enum Error {
#[error("provider error: {kind}")]
Provider {
kind: ProviderErrorKind,
suggestion: Option<String>,
},
#[error("tool `{name}` failed: {message}")]
Tool {
name: String,
call_id: ToolCallId,
message: String,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync>>,
},
#[error("context overflow: {used}/{limit} tokens")]
ContextOverflow { used: usize, limit: usize },
#[error("config error: {message}")]
Config { message: String },
#[error("cancelled")]
Cancelled,
#[error("{0}")]
Internal(#[from] Box<dyn std::error::Error + Send + Sync>),
}
impl Error {
pub fn tool(
name: impl Into<String>,
call_id: ToolCallId,
message: impl Into<String>,
) -> Self {
Self::Tool {
name: name.into(),
call_id,
message: message.into(),
source: None,
}
}
pub fn tool_with_source(
name: impl Into<String>,
call_id: ToolCallId,
message: impl Into<String>,
source: impl std::error::Error + Send + Sync + 'static,
) -> Self {
Self::Tool {
name: name.into(),
call_id,
message: message.into(),
source: Some(Box::new(source)),
}
}
pub fn provider(kind: ProviderErrorKind, suggestion: impl Into<String>) -> Self {
Self::Provider {
kind,
suggestion: Some(suggestion.into()),
}
}
pub fn retryable(&self) -> bool {
match self {
Self::Provider { kind, .. } => kind.retryable(),
_ => false,
}
}
pub fn retry_after(&self) -> Option<Duration> {
match self {
Self::Provider { kind, .. } => kind.retry_after(),
_ => None,
}
}
}
#[derive(Debug, Error, Clone, Serialize, Deserialize)]
pub enum ProviderErrorKind {
#[error("authentication failed: {message} ({kind})")]
Authentication {
message: String,
kind: AuthErrorKind,
},
#[error("rate limited: {message}")]
RateLimit {
message: String,
#[serde(with = "option_duration_millis")]
retry_after: Option<Duration>,
},
#[error("quota exceeded: {message}")]
QuotaExceeded { message: String },
#[error("invalid request: {message}")]
InvalidRequest { message: String },
#[error("content filtered: {message}")]
ContentFiltered { message: String },
#[error("provider internal error ({status}): {message}")]
ServerError {
message: String,
status: u16,
#[serde(with = "option_duration_millis")]
retry_after: Option<Duration>,
},
#[error("transport error: {message}")]
Transport { message: String },
#[error("invalid response: {message}")]
InvalidResponse { message: String },
#[error("unknown provider error: {message}")]
Unknown { message: String },
}
impl ProviderErrorKind {
pub fn retryable(&self) -> bool {
matches!(self, Self::RateLimit { .. } | Self::ServerError { .. })
}
pub fn retry_after(&self) -> Option<Duration> {
match self {
Self::RateLimit { retry_after, .. } => {
retry_after.or(Some(Duration::from_secs(30)))
}
Self::ServerError { retry_after, .. } => {
retry_after.or(Some(Duration::from_secs(20)))
}
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthErrorKind {
Missing,
Invalid,
Expired,
InsufficientPermissions,
Unknown,
}
impl fmt::Display for AuthErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Missing => write!(f, "missing"),
Self::Invalid => write!(f, "invalid"),
Self::Expired => write!(f, "expired"),
Self::InsufficientPermissions => write!(f, "insufficient_permissions"),
Self::Unknown => write!(f, "unknown"),
}
}
}
mod option_duration_millis {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::Duration;
pub fn serialize<S>(value: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match value {
Some(d) => d.as_millis().serialize(serializer),
None => serializer.serialize_none(),
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
let opt: Option<u64> = Option::deserialize(deserializer)?;
Ok(opt.map(Duration::from_millis))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_tool_construction() {
let err = Error::tool("read_file", ToolCallId::new("call_1"), "not found");
match &err {
Error::Tool {
name,
call_id,
message,
source,
} => {
assert_eq!(name, "read_file");
assert_eq!(call_id.as_str(), "call_1");
assert_eq!(message, "not found");
assert!(source.is_none());
}
_ => panic!("expected Error::Tool"),
}
}
#[test]
fn test_error_tool_with_source() {
let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "no such file");
let err = Error::tool_with_source(
"read_file",
ToolCallId::new("call_2"),
"failed",
io_err,
);
match &err {
Error::Tool { source, .. } => {
assert!(source.is_some());
}
_ => panic!("expected Error::Tool"),
}
}
#[test]
fn test_error_provider_construction() {
let err = Error::provider(
ProviderErrorKind::Authentication {
message: "bad key".into(),
kind: AuthErrorKind::Invalid,
},
"run /login",
);
match &err {
Error::Provider { kind, suggestion } => {
assert!(matches!(kind, ProviderErrorKind::Authentication { .. }));
assert_eq!(suggestion.as_deref(), Some("run /login"));
}
_ => panic!("expected Error::Provider"),
}
}
#[test]
fn test_cancelled_not_retryable() {
let err = Error::Cancelled;
assert!(!err.retryable());
assert!(err.retry_after().is_none());
}
#[test]
fn test_context_overflow_display() {
let err = Error::ContextOverflow {
used: 130000,
limit: 128000,
};
assert_eq!(err.to_string(), "context overflow: 130000/128000 tokens");
}
#[test]
fn test_config_error_display() {
let err = Error::Config {
message: "missing api_key".into(),
};
assert!(err.to_string().contains("missing api_key"));
}
#[test]
fn test_rate_limit_retryable() {
let kind = ProviderErrorKind::RateLimit {
message: "429".into(),
retry_after: Some(Duration::from_secs(60)),
};
assert!(kind.retryable());
assert_eq!(kind.retry_after(), Some(Duration::from_secs(60)));
}
#[test]
fn test_rate_limit_default_retry_after() {
let kind = ProviderErrorKind::RateLimit {
message: "slow down".into(),
retry_after: None,
};
assert_eq!(kind.retry_after(), Some(Duration::from_secs(30)));
}
#[test]
fn test_server_error_retryable() {
let kind = ProviderErrorKind::ServerError {
message: "internal".into(),
status: 500,
retry_after: None,
};
assert!(kind.retryable());
assert_eq!(kind.retry_after(), Some(Duration::from_secs(20)));
}
#[test]
fn test_authentication_not_retryable() {
let kind = ProviderErrorKind::Authentication {
message: "invalid".into(),
kind: AuthErrorKind::Invalid,
};
assert!(!kind.retryable());
assert!(kind.retry_after().is_none());
}
#[test]
fn test_quota_exceeded_not_retryable() {
let kind = ProviderErrorKind::QuotaExceeded {
message: "out of credits".into(),
};
assert!(!kind.retryable());
}
#[test]
fn test_invalid_request_not_retryable() {
let kind = ProviderErrorKind::InvalidRequest {
message: "model not found".into(),
};
assert!(!kind.retryable());
}
#[test]
fn test_content_filtered_not_retryable() {
let kind = ProviderErrorKind::ContentFiltered {
message: "blocked".into(),
};
assert!(!kind.retryable());
}
#[test]
fn test_transport_not_retryable() {
let kind = ProviderErrorKind::Transport {
message: "dns failed".into(),
};
assert!(!kind.retryable());
}
#[test]
fn test_invalid_response_not_retryable() {
let kind = ProviderErrorKind::InvalidResponse {
message: "bad json".into(),
};
assert!(!kind.retryable());
}
#[test]
fn test_unknown_not_retryable() {
let kind = ProviderErrorKind::Unknown {
message: "???".into(),
};
assert!(!kind.retryable());
}
#[test]
fn test_provider_error_kind_serde_roundtrip() {
let kind = ProviderErrorKind::RateLimit {
message: "429 too many".into(),
retry_after: Some(Duration::from_millis(5000)),
};
let json = serde_json::to_string(&kind).unwrap();
let restored: ProviderErrorKind = serde_json::from_str(&json).unwrap();
assert!(matches!(
restored,
ProviderErrorKind::RateLimit {
retry_after: Some(d),
..
} if d == Duration::from_millis(5000)
));
}
#[test]
fn test_server_error_serde_roundtrip() {
let kind = ProviderErrorKind::ServerError {
message: "overloaded".into(),
status: 529,
retry_after: None,
};
let json = serde_json::to_string(&kind).unwrap();
let restored: ProviderErrorKind = serde_json::from_str(&json).unwrap();
assert!(matches!(
restored,
ProviderErrorKind::ServerError { status: 529, retry_after: None, .. }
));
}
#[test]
fn test_auth_error_kind_display() {
assert_eq!(AuthErrorKind::Missing.to_string(), "missing");
assert_eq!(AuthErrorKind::Invalid.to_string(), "invalid");
assert_eq!(AuthErrorKind::Expired.to_string(), "expired");
assert_eq!(
AuthErrorKind::InsufficientPermissions.to_string(),
"insufficient_permissions"
);
assert_eq!(AuthErrorKind::Unknown.to_string(), "unknown");
}
#[test]
fn test_auth_error_kind_serde_roundtrip() {
for kind in [
AuthErrorKind::Missing,
AuthErrorKind::Invalid,
AuthErrorKind::Expired,
AuthErrorKind::InsufficientPermissions,
AuthErrorKind::Unknown,
] {
let json = serde_json::to_string(&kind).unwrap();
let restored: AuthErrorKind = serde_json::from_str(&json).unwrap();
assert_eq!(kind, restored);
}
}
#[test]
fn test_error_retryable_delegates() {
let retryable_err = Error::Provider {
kind: ProviderErrorKind::RateLimit {
message: "wait".into(),
retry_after: Some(Duration::from_secs(10)),
},
suggestion: None,
};
assert!(retryable_err.retryable());
assert_eq!(retryable_err.retry_after(), Some(Duration::from_secs(10)));
let non_retryable_err = Error::Provider {
kind: ProviderErrorKind::Authentication {
message: "bad".into(),
kind: AuthErrorKind::Invalid,
},
suggestion: None,
};
assert!(!non_retryable_err.retryable());
assert!(non_retryable_err.retry_after().is_none());
}
#[test]
fn test_non_provider_errors_not_retryable() {
assert!(!Error::Cancelled.retryable());
assert!(!Error::Config { message: "x".into() }.retryable());
assert!(!Error::ContextOverflow { used: 1, limit: 1 }.retryable());
assert!(!Error::tool("t", ToolCallId::new("c"), "m").retryable());
}
#[test]
fn test_error_display_formats() {
let tool_err = Error::tool("bash", ToolCallId::new("c1"), "permission denied");
assert_eq!(
tool_err.to_string(),
"tool `bash` failed: permission denied"
);
let provider_err = Error::Provider {
kind: ProviderErrorKind::Transport {
message: "connection reset".into(),
},
suggestion: Some("check your network".into()),
};
assert_eq!(
provider_err.to_string(),
"provider error: transport error: connection reset"
);
}
}