use thiserror::Error;
#[derive(Error, Debug)]
pub enum KeyComputeError {
#[error("authentication failed: {0}")]
AuthError(String),
#[error("permission denied: {0}")]
PermissionDenied(String),
#[error("rate limit exceeded: {0}")]
RateLimitExceeded(String),
#[error("routing failed: no available provider for model {0}")]
RoutingFailed(String),
#[error("upstream provider error: {0}")]
ProviderError(String),
#[error("provider timeout after {0}ms: {1}")]
ProviderTimeout(u64, String),
#[error("database error: {0}")]
DatabaseError(String),
#[error("configuration error: {0}")]
ConfigError(String),
#[error("internal error: {0}")]
Internal(String),
#[error("serialization error: {0}")]
SerializationError(String),
#[error("validation error: {0}")]
ValidationError(String),
#[error("not found: {0}")]
NotFound(String),
#[error("invalid request: {0}")]
InvalidRequest(String),
#[error("network error: {0}")]
NetworkError(String),
#[error("request timeout: {0}")]
Timeout(String),
}
pub type Result<T> = std::result::Result<T, KeyComputeError>;
impl From<serde_json::Error> for KeyComputeError {
fn from(err: serde_json::Error) -> Self {
KeyComputeError::SerializationError(err.to_string())
}
}
impl From<std::io::Error> for KeyComputeError {
fn from(err: std::io::Error) -> Self {
KeyComputeError::Internal(err.to_string())
}
}
impl From<uuid::Error> for KeyComputeError {
fn from(err: uuid::Error) -> Self {
KeyComputeError::InvalidRequest(format!("Invalid UUID: {}", err))
}
}
impl From<chrono::ParseError> for KeyComputeError {
fn from(err: chrono::ParseError) -> Self {
KeyComputeError::InvalidRequest(format!("Invalid datetime format: {}", err))
}
}
impl KeyComputeError {
pub fn is_retryable(&self) -> bool {
matches!(
self,
KeyComputeError::ProviderError(_)
| KeyComputeError::ProviderTimeout(_, _)
| KeyComputeError::NetworkError(_)
| KeyComputeError::Timeout(_)
| KeyComputeError::DatabaseError(_)
)
}
pub fn category(&self) -> ErrorCategory {
match self {
KeyComputeError::AuthError(_) | KeyComputeError::PermissionDenied(_) => {
ErrorCategory::Auth
}
KeyComputeError::RateLimitExceeded(_) => ErrorCategory::RateLimit,
KeyComputeError::RoutingFailed(_) => ErrorCategory::Routing,
KeyComputeError::ProviderError(_) | KeyComputeError::ProviderTimeout(_, _) => {
ErrorCategory::Provider
}
KeyComputeError::DatabaseError(_) => ErrorCategory::Database,
KeyComputeError::ConfigError(_) => ErrorCategory::Config,
KeyComputeError::ValidationError(_) | KeyComputeError::InvalidRequest(_) => {
ErrorCategory::Validation
}
KeyComputeError::NotFound(_) => ErrorCategory::NotFound,
KeyComputeError::NetworkError(_) | KeyComputeError::Timeout(_) => {
ErrorCategory::Network
}
KeyComputeError::Internal(_) | KeyComputeError::SerializationError(_) => {
ErrorCategory::Internal
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorCategory {
Auth,
RateLimit,
Routing,
Provider,
Database,
Config,
Validation,
NotFound,
Network,
Internal,
}
impl std::fmt::Display for ErrorCategory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErrorCategory::Auth => write!(f, "authentication_error"),
ErrorCategory::RateLimit => write!(f, "rate_limit_error"),
ErrorCategory::Routing => write!(f, "routing_error"),
ErrorCategory::Provider => write!(f, "provider_error"),
ErrorCategory::Database => write!(f, "database_error"),
ErrorCategory::Config => write!(f, "config_error"),
ErrorCategory::Validation => write!(f, "validation_error"),
ErrorCategory::NotFound => write!(f, "not_found_error"),
ErrorCategory::Network => write!(f, "network_error"),
ErrorCategory::Internal => write!(f, "internal_error"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display() {
let err = KeyComputeError::AuthError("invalid token".to_string());
assert!(err.to_string().contains("authentication failed"));
let err = KeyComputeError::RoutingFailed("gpt-4".to_string());
assert!(err.to_string().contains("gpt-4"));
}
#[test]
fn test_is_retryable() {
assert!(KeyComputeError::ProviderError("timeout".into()).is_retryable());
assert!(KeyComputeError::NetworkError("connection reset".into()).is_retryable());
assert!(KeyComputeError::DatabaseError("deadlock".into()).is_retryable());
assert!(!KeyComputeError::AuthError("invalid".into()).is_retryable());
assert!(!KeyComputeError::ValidationError("bad input".into()).is_retryable());
assert!(!KeyComputeError::NotFound("missing".into()).is_retryable());
}
#[test]
fn test_error_category() {
assert_eq!(
KeyComputeError::AuthError("test".into()).category(),
ErrorCategory::Auth
);
assert_eq!(
KeyComputeError::RateLimitExceeded("test".into()).category(),
ErrorCategory::RateLimit
);
assert_eq!(
KeyComputeError::ProviderError("test".into()).category(),
ErrorCategory::Provider
);
}
#[test]
fn test_category_display() {
assert_eq!(ErrorCategory::Auth.to_string(), "authentication_error");
assert_eq!(ErrorCategory::RateLimit.to_string(), "rate_limit_error");
}
#[test]
fn test_from_serde_json_error() {
let json_err = serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
let err: KeyComputeError = json_err.into();
assert!(matches!(err, KeyComputeError::SerializationError(_)));
}
#[test]
fn test_from_uuid_error() {
let uuid_err = uuid::Uuid::parse_str("not-a-uuid").unwrap_err();
let err: KeyComputeError = uuid_err.into();
assert!(matches!(err, KeyComputeError::InvalidRequest(_)));
}
}