use serde::{Deserialize, Serialize};
use std::fmt;
use thiserror::Error;
#[derive(Debug, Error, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum AIError {
#[error("LLM API error: {message}")]
LLMError {
provider: String,
message: String,
status_code: Option<u16>,
retry_after: Option<u64>,
},
#[error("Rate limit exceeded: {message}")]
RateLimitError {
limit_type: String, current: usize,
limit: usize,
reset_at: Option<chrono::DateTime<chrono::Utc>>,
message: String,
},
#[error("Timeout: {operation} took longer than {timeout_secs}s")]
TimeoutError {
operation: String,
timeout_secs: u64,
elapsed_secs: f64,
},
#[error("Authentication error: {message}")]
AuthError { provider: String, message: String },
#[error("Validation error: {message}")]
ValidationError {
field: String,
message: String,
expected: Option<String>,
actual: Option<String>,
},
#[error("Network error: {message}")]
NetworkError {
url: Option<String>,
message: String,
retryable: bool,
},
#[error("Circuit breaker is open: {reason}")]
CircuitBreakerError {
service: String,
reason: String,
retry_after_secs: u64,
},
#[error("Resource exhausted: {resource}")]
ResourceExhausted {
resource: String, current: usize,
limit: usize,
},
#[error("Model error: {message}")]
ModelError {
model_name: String,
message: String,
error_type: String, },
#[error("Parse error: {message}")]
ParseError {
format: String, message: String,
line: Option<usize>,
column: Option<usize>,
},
#[error("Internal error: {message}")]
InternalError {
component: String,
message: String,
stacktrace: Option<String>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AIErrorKind {
Transient,
Permanent,
RateLimit,
Client,
Server,
Auth,
CircuitBreaker,
}
impl AIError {
pub fn kind(&self) -> AIErrorKind {
match self {
Self::LLMError { status_code, .. } => match status_code {
Some(429) => AIErrorKind::RateLimit,
Some(code) if *code >= 500 => AIErrorKind::Server,
Some(code) if *code >= 400 => AIErrorKind::Client,
_ => AIErrorKind::Transient,
},
Self::RateLimitError { .. } => AIErrorKind::RateLimit,
Self::TimeoutError { .. } => AIErrorKind::Transient,
Self::AuthError { .. } => AIErrorKind::Auth,
Self::ValidationError { .. } => AIErrorKind::Client,
Self::NetworkError { retryable, .. } => {
if *retryable {
AIErrorKind::Transient
} else {
AIErrorKind::Permanent
}
}
Self::CircuitBreakerError { .. } => AIErrorKind::CircuitBreaker,
Self::ResourceExhausted { .. } => AIErrorKind::Server,
Self::ModelError { error_type, .. } => match error_type.as_str() {
"not_found" | "unsupported" => AIErrorKind::Permanent,
"loading_failed" => AIErrorKind::Transient,
_ => AIErrorKind::Server,
},
Self::ParseError { .. } => AIErrorKind::Client,
Self::InternalError { .. } => AIErrorKind::Server,
}
}
pub fn is_retryable(&self) -> bool {
matches!(
self.kind(),
AIErrorKind::Transient | AIErrorKind::RateLimit | AIErrorKind::Server
)
}
pub fn retry_delay_secs(&self) -> Option<u64> {
match self {
Self::RateLimitError { reset_at, .. } => reset_at.map(|reset| {
let now = chrono::Utc::now();
(reset - now).num_seconds().max(0) as u64
}),
Self::LLMError { retry_after, .. } => *retry_after,
Self::CircuitBreakerError {
retry_after_secs, ..
} => Some(*retry_after_secs),
Self::TimeoutError { .. } => Some(5), Self::NetworkError { retryable, .. } if *retryable => Some(3),
_ => None,
}
}
pub fn is_service_unavailable(&self) -> bool {
match self {
Self::LLMError { status_code, .. } => matches!(status_code, Some(503)),
Self::CircuitBreakerError { .. } => true,
Self::NetworkError { .. } => true,
_ => false,
}
}
pub fn status_code(&self) -> Option<u16> {
match self {
Self::LLMError { status_code, .. } => *status_code,
Self::RateLimitError { .. } => Some(429),
Self::AuthError { .. } => Some(401),
Self::ValidationError { .. } => Some(400),
Self::TimeoutError { .. } => Some(504),
Self::CircuitBreakerError { .. } => Some(503),
Self::ResourceExhausted { .. } => Some(507),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorContext {
pub request_id: Option<String>,
pub user_id: Option<String>,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub component: String,
pub operation: String,
pub additional_info: std::collections::HashMap<String, String>,
}
impl ErrorContext {
pub fn new(component: &str, operation: &str) -> Self {
Self {
request_id: None,
user_id: None,
timestamp: chrono::Utc::now(),
component: component.to_string(),
operation: operation.to_string(),
additional_info: std::collections::HashMap::new(),
}
}
pub fn with_request_id(mut self, request_id: String) -> Self {
self.request_id = Some(request_id);
self
}
pub fn with_user_id(mut self, user_id: String) -> Self {
self.user_id = Some(user_id);
self
}
pub fn add_info(mut self, key: String, value: String) -> Self {
self.additional_info.insert(key, value);
self
}
}
#[derive(Debug, Clone)]
pub struct ContextualError {
pub error: AIError,
pub context: ErrorContext,
}
impl fmt::Display for ContextualError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} (component: {}, operation: {}, timestamp: {})",
self.error, self.context.component, self.context.operation, self.context.timestamp
)
}
}
impl std::error::Error for ContextualError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_kind_classification() {
let error = AIError::RateLimitError {
limit_type: "requests".to_string(),
current: 100,
limit: 60,
reset_at: None,
message: "Too many requests".to_string(),
};
assert_eq!(error.kind(), AIErrorKind::RateLimit);
assert!(error.is_retryable());
}
#[test]
fn test_llm_error_status_codes() {
let server_error = AIError::LLMError {
provider: "openai".to_string(),
message: "Internal error".to_string(),
status_code: Some(500),
retry_after: None,
};
assert_eq!(server_error.kind(), AIErrorKind::Server);
assert!(server_error.is_retryable());
let client_error = AIError::LLMError {
provider: "openai".to_string(),
message: "Bad request".to_string(),
status_code: Some(400),
retry_after: None,
};
assert_eq!(client_error.kind(), AIErrorKind::Client);
assert!(!client_error.is_retryable());
}
#[test]
fn test_retry_delay() {
let error = AIError::CircuitBreakerError {
service: "openai".to_string(),
reason: "Too many failures".to_string(),
retry_after_secs: 60,
};
assert_eq!(error.retry_delay_secs(), Some(60));
}
#[test]
fn test_error_context() {
let context = ErrorContext::new("chat", "generate_response")
.with_request_id("req-123".to_string())
.with_user_id("user-456".to_string())
.add_info("model".to_string(), "gpt-4".to_string());
assert_eq!(context.component, "chat");
assert_eq!(context.operation, "generate_response");
assert_eq!(context.request_id, Some("req-123".to_string()));
assert_eq!(context.user_id, Some("user-456".to_string()));
assert_eq!(
context.additional_info.get("model"),
Some(&"gpt-4".to_string())
);
}
}