1use thiserror::Error;
6
7#[derive(Error, Debug)]
12pub enum KeyComputeError {
13 #[error("authentication failed: {0}")]
16 AuthError(String),
17
18 #[error("permission denied: {0}")]
20 PermissionDenied(String),
21
22 #[error("rate limit exceeded: {0}")]
25 RateLimitExceeded(String),
26
27 #[error("routing failed: no available provider for model {0}")]
30 RoutingFailed(String),
31
32 #[error("upstream provider error: {0}")]
35 ProviderError(String),
36
37 #[error("provider timeout after {0}ms: {1}")]
39 ProviderTimeout(u64, String),
40
41 #[error("database error: {0}")]
44 DatabaseError(String),
45
46 #[error("configuration error: {0}")]
49 ConfigError(String),
50
51 #[error("internal error: {0}")]
54 Internal(String),
55
56 #[error("serialization error: {0}")]
58 SerializationError(String),
59
60 #[error("validation error: {0}")]
62 ValidationError(String),
63
64 #[error("not found: {0}")]
66 NotFound(String),
67
68 #[error("invalid request: {0}")]
70 InvalidRequest(String),
71
72 #[error("network error: {0}")]
75 NetworkError(String),
76
77 #[error("request timeout: {0}")]
79 Timeout(String),
80}
81
82pub type Result<T> = std::result::Result<T, KeyComputeError>;
84
85impl From<serde_json::Error> for KeyComputeError {
88 fn from(err: serde_json::Error) -> Self {
89 KeyComputeError::SerializationError(err.to_string())
90 }
91}
92
93impl From<std::io::Error> for KeyComputeError {
94 fn from(err: std::io::Error) -> Self {
95 KeyComputeError::Internal(err.to_string())
96 }
97}
98
99impl From<uuid::Error> for KeyComputeError {
100 fn from(err: uuid::Error) -> Self {
101 KeyComputeError::InvalidRequest(format!("Invalid UUID: {}", err))
102 }
103}
104
105impl From<chrono::ParseError> for KeyComputeError {
106 fn from(err: chrono::ParseError) -> Self {
107 KeyComputeError::InvalidRequest(format!("Invalid datetime format: {}", err))
108 }
109}
110
111impl KeyComputeError {
114 pub fn is_retryable(&self) -> bool {
118 matches!(
119 self,
120 KeyComputeError::ProviderError(_)
121 | KeyComputeError::ProviderTimeout(_, _)
122 | KeyComputeError::NetworkError(_)
123 | KeyComputeError::Timeout(_)
124 | KeyComputeError::DatabaseError(_)
125 )
126 }
127
128 pub fn category(&self) -> ErrorCategory {
130 match self {
131 KeyComputeError::AuthError(_) | KeyComputeError::PermissionDenied(_) => {
132 ErrorCategory::Auth
133 }
134 KeyComputeError::RateLimitExceeded(_) => ErrorCategory::RateLimit,
135 KeyComputeError::RoutingFailed(_) => ErrorCategory::Routing,
136 KeyComputeError::ProviderError(_) | KeyComputeError::ProviderTimeout(_, _) => {
137 ErrorCategory::Provider
138 }
139 KeyComputeError::DatabaseError(_) => ErrorCategory::Database,
140 KeyComputeError::ConfigError(_) => ErrorCategory::Config,
141 KeyComputeError::ValidationError(_) | KeyComputeError::InvalidRequest(_) => {
142 ErrorCategory::Validation
143 }
144 KeyComputeError::NotFound(_) => ErrorCategory::NotFound,
145 KeyComputeError::NetworkError(_) | KeyComputeError::Timeout(_) => {
146 ErrorCategory::Network
147 }
148 KeyComputeError::Internal(_) | KeyComputeError::SerializationError(_) => {
149 ErrorCategory::Internal
150 }
151 }
152 }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum ErrorCategory {
158 Auth,
159 RateLimit,
160 Routing,
161 Provider,
162 Database,
163 Config,
164 Validation,
165 NotFound,
166 Network,
167 Internal,
168}
169
170impl std::fmt::Display for ErrorCategory {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 match self {
173 ErrorCategory::Auth => write!(f, "authentication_error"),
174 ErrorCategory::RateLimit => write!(f, "rate_limit_error"),
175 ErrorCategory::Routing => write!(f, "routing_error"),
176 ErrorCategory::Provider => write!(f, "provider_error"),
177 ErrorCategory::Database => write!(f, "database_error"),
178 ErrorCategory::Config => write!(f, "config_error"),
179 ErrorCategory::Validation => write!(f, "validation_error"),
180 ErrorCategory::NotFound => write!(f, "not_found_error"),
181 ErrorCategory::Network => write!(f, "network_error"),
182 ErrorCategory::Internal => write!(f, "internal_error"),
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_error_display() {
193 let err = KeyComputeError::AuthError("invalid token".to_string());
194 assert!(err.to_string().contains("authentication failed"));
195
196 let err = KeyComputeError::RoutingFailed("gpt-4".to_string());
197 assert!(err.to_string().contains("gpt-4"));
198 }
199
200 #[test]
201 fn test_is_retryable() {
202 assert!(KeyComputeError::ProviderError("timeout".into()).is_retryable());
204 assert!(KeyComputeError::NetworkError("connection reset".into()).is_retryable());
205 assert!(KeyComputeError::DatabaseError("deadlock".into()).is_retryable());
206
207 assert!(!KeyComputeError::AuthError("invalid".into()).is_retryable());
209 assert!(!KeyComputeError::ValidationError("bad input".into()).is_retryable());
210 assert!(!KeyComputeError::NotFound("missing".into()).is_retryable());
211 }
212
213 #[test]
214 fn test_error_category() {
215 assert_eq!(
216 KeyComputeError::AuthError("test".into()).category(),
217 ErrorCategory::Auth
218 );
219 assert_eq!(
220 KeyComputeError::RateLimitExceeded("test".into()).category(),
221 ErrorCategory::RateLimit
222 );
223 assert_eq!(
224 KeyComputeError::ProviderError("test".into()).category(),
225 ErrorCategory::Provider
226 );
227 }
228
229 #[test]
230 fn test_category_display() {
231 assert_eq!(ErrorCategory::Auth.to_string(), "authentication_error");
232 assert_eq!(ErrorCategory::RateLimit.to_string(), "rate_limit_error");
233 }
234
235 #[test]
236 fn test_from_serde_json_error() {
237 let json_err = serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
238 let err: KeyComputeError = json_err.into();
239 assert!(matches!(err, KeyComputeError::SerializationError(_)));
240 }
241
242 #[test]
243 fn test_from_uuid_error() {
244 let uuid_err = uuid::Uuid::parse_str("not-a-uuid").unwrap_err();
245 let err: KeyComputeError = uuid_err.into();
246 assert!(matches!(err, KeyComputeError::InvalidRequest(_)));
247 }
248}