1use thiserror::Error;
4
5pub type Result<T> = std::result::Result<T, Error>;
7
8#[derive(Error, Debug)]
10pub enum Error {
11 #[error("Connection error: {0}")]
13 Connection(String),
14
15 #[error("Query error: {code} - {message}")]
17 Query { code: String, message: String },
18
19 #[error("Authentication error: {0}")]
21 Auth(String),
22
23 #[error("I/O error: {0}")]
25 Io(#[from] std::io::Error),
26
27 #[error("JSON error: {0}")]
29 Json(#[from] serde_json::Error),
30
31 #[error("QUIC error: {0}")]
33 Quic(String),
34
35 #[error("TLS error: {0}")]
37 Tls(String),
38
39 #[error("Invalid DSN: {0}")]
41 InvalidDsn(String),
42
43 #[error("Type error: {0}")]
45 Type(String),
46
47 #[error("Operation timed out")]
49 Timeout,
50
51 #[error("Pool error: {0}")]
53 Pool(String),
54
55 #[error("Validation error: {0}")]
57 Validation(String),
58
59 #[error("{0}")]
61 Other(String),
62}
63
64impl Error {
65 pub fn connection<S: Into<String>>(msg: S) -> Self {
67 Error::Connection(msg.into())
68 }
69
70 pub fn query<S: Into<String>>(msg: S) -> Self {
72 Error::Query {
73 code: "QUERY_ERROR".to_string(),
74 message: msg.into(),
75 }
76 }
77
78 pub fn protocol<S: Into<String>>(msg: S) -> Self {
80 Error::Connection(format!("Protocol error: {}", msg.into()))
81 }
82
83 pub fn transaction<S: Into<String>>(msg: S) -> Self {
85 Error::Connection(format!("Transaction error: {}", msg.into()))
86 }
87
88 pub fn timeout() -> Self {
90 Error::Timeout
91 }
92
93 pub fn auth<S: Into<String>>(msg: S) -> Self {
95 Error::Auth(msg.into())
96 }
97
98 pub fn quic<S: Into<String>>(msg: S) -> Self {
100 Error::Quic(msg.into())
101 }
102
103 pub fn tls<S: Into<String>>(msg: S) -> Self {
105 Error::Tls(msg.into())
106 }
107
108 pub fn type_error<S: Into<String>>(msg: S) -> Self {
110 Error::Type(msg.into())
111 }
112
113 pub fn pool<S: Into<String>>(msg: S) -> Self {
115 Error::Pool(msg.into())
116 }
117
118 pub fn validation<S: Into<String>>(msg: S) -> Self {
120 Error::Validation(msg.into())
121 }
122
123 pub fn invalid_dsn<S: Into<String>>(msg: S) -> Self {
125 Error::InvalidDsn(msg.into())
126 }
127
128 pub fn is_retryable(&self) -> bool {
136 match self {
137 Error::Connection(_) => true,
138 Error::Timeout => true,
139 Error::Quic(_) => true,
140 Error::Query { code, .. } => {
141 code == "40001" || code == "40P01" || code == "40502"
143 }
144 Error::Pool(_) => true,
145 _ => false,
146 }
147 }
148
149 pub fn code(&self) -> Option<&str> {
151 match self {
152 Error::Query { code, .. } => Some(code),
153 _ => None,
154 }
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use std::io;
162
163 #[test]
164 fn test_error_display_connection() {
165 let err = Error::Connection("connection refused".to_string());
166 assert_eq!(err.to_string(), "Connection error: connection refused");
167 }
168
169 #[test]
170 fn test_error_display_query() {
171 let err = Error::Query {
172 code: "42000".to_string(),
173 message: "syntax error".to_string(),
174 };
175 assert_eq!(err.to_string(), "Query error: 42000 - syntax error");
176 }
177
178 #[test]
179 fn test_error_display_auth() {
180 let err = Error::Auth("invalid credentials".to_string());
181 assert_eq!(err.to_string(), "Authentication error: invalid credentials");
182 }
183
184 #[test]
185 fn test_error_display_io() {
186 let io_err = io::Error::new(io::ErrorKind::NotFound, "file not found");
187 let err = Error::Io(io_err);
188 assert!(err.to_string().starts_with("I/O error:"));
189 }
190
191 #[test]
192 fn test_error_display_json() {
193 let json_err: serde_json::Error = serde_json::from_str::<i32>("invalid").unwrap_err();
194 let err = Error::Json(json_err);
195 assert!(err.to_string().starts_with("JSON error:"));
196 }
197
198 #[test]
199 fn test_error_display_quic() {
200 let err = Error::Quic("connection reset".to_string());
201 assert_eq!(err.to_string(), "QUIC error: connection reset");
202 }
203
204 #[test]
205 fn test_error_display_tls() {
206 let err = Error::Tls("certificate expired".to_string());
207 assert_eq!(err.to_string(), "TLS error: certificate expired");
208 }
209
210 #[test]
211 fn test_error_display_invalid_dsn() {
212 let err = Error::InvalidDsn("missing host".to_string());
213 assert_eq!(err.to_string(), "Invalid DSN: missing host");
214 }
215
216 #[test]
217 fn test_error_display_type() {
218 let err = Error::Type("cannot convert int to string".to_string());
219 assert_eq!(err.to_string(), "Type error: cannot convert int to string");
220 }
221
222 #[test]
223 fn test_error_display_timeout() {
224 let err = Error::Timeout;
225 assert_eq!(err.to_string(), "Operation timed out");
226 }
227
228 #[test]
229 fn test_error_display_pool() {
230 let err = Error::Pool("pool exhausted".to_string());
231 assert_eq!(err.to_string(), "Pool error: pool exhausted");
232 }
233
234 #[test]
235 fn test_error_display_other() {
236 let err = Error::Other("unknown error".to_string());
237 assert_eq!(err.to_string(), "unknown error");
238 }
239
240 #[test]
241 fn test_error_from_io() {
242 let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "refused");
243 let err: Error = io_err.into();
244 assert!(matches!(err, Error::Io(_)));
245 }
246
247 #[test]
248 fn test_error_from_json() {
249 let json_err: serde_json::Error = serde_json::from_str::<i32>("not_a_number").unwrap_err();
250 let err: Error = json_err.into();
251 assert!(matches!(err, Error::Json(_)));
252 }
253
254 #[test]
255 fn test_error_helper_connection() {
256 let err = Error::connection("test connection error");
257 assert!(matches!(err, Error::Connection(msg) if msg == "test connection error"));
258 }
259
260 #[test]
261 fn test_error_helper_query() {
262 let err = Error::query("test query error");
263 assert!(matches!(err, Error::Query { code, message }
264 if code == "QUERY_ERROR" && message == "test query error"));
265 }
266
267 #[test]
268 fn test_error_helper_protocol() {
269 let err = Error::protocol("invalid frame");
270 assert!(matches!(err, Error::Connection(msg) if msg.contains("Protocol error")));
271 }
272
273 #[test]
274 fn test_error_helper_transaction() {
275 let err = Error::transaction("rollback failed");
276 assert!(matches!(err, Error::Connection(msg) if msg.contains("Transaction error")));
277 }
278
279 #[test]
280 fn test_error_helper_timeout() {
281 let err = Error::timeout();
282 assert!(matches!(err, Error::Timeout));
283 }
284
285 #[test]
286 fn test_error_helper_auth() {
287 let err = Error::auth("bad token");
288 assert!(matches!(err, Error::Auth(msg) if msg == "bad token"));
289 }
290
291 #[test]
292 fn test_error_helper_quic() {
293 let err = Error::quic("stream closed");
294 assert!(matches!(err, Error::Quic(msg) if msg == "stream closed"));
295 }
296
297 #[test]
298 fn test_error_helper_tls() {
299 let err = Error::tls("handshake failed");
300 assert!(matches!(err, Error::Tls(msg) if msg == "handshake failed"));
301 }
302
303 #[test]
304 fn test_error_helper_type_error() {
305 let err = Error::type_error("invalid cast");
306 assert!(matches!(err, Error::Type(msg) if msg == "invalid cast"));
307 }
308
309 #[test]
310 fn test_error_helper_pool() {
311 let err = Error::pool("no connections available");
312 assert!(matches!(err, Error::Pool(msg) if msg == "no connections available"));
313 }
314
315 #[test]
316 fn test_error_is_retryable_connection() {
317 let err = Error::Connection("network error".to_string());
318 assert!(err.is_retryable());
319 }
320
321 #[test]
322 fn test_error_is_retryable_timeout() {
323 let err = Error::Timeout;
324 assert!(err.is_retryable());
325 }
326
327 #[test]
328 fn test_error_is_retryable_quic() {
329 let err = Error::Quic("reset".to_string());
330 assert!(err.is_retryable());
331 }
332
333 #[test]
334 fn test_error_is_retryable_pool() {
335 let err = Error::Pool("exhausted".to_string());
336 assert!(err.is_retryable());
337 }
338
339 #[test]
340 fn test_error_is_retryable_serialization_failure() {
341 let err = Error::Query {
342 code: "40001".to_string(),
343 message: "serialization failure".to_string(),
344 };
345 assert!(err.is_retryable());
346 }
347
348 #[test]
349 fn test_error_is_retryable_deadlock() {
350 let err = Error::Query {
351 code: "40P01".to_string(),
352 message: "deadlock detected".to_string(),
353 };
354 assert!(err.is_retryable());
355 }
356
357 #[test]
358 fn test_error_is_retryable_transaction_deadlock() {
359 let err = Error::Query {
360 code: "40502".to_string(),
361 message: "transaction deadlock".to_string(),
362 };
363 assert!(err.is_retryable());
364 }
365
366 #[test]
367 fn test_error_not_retryable_syntax() {
368 let err = Error::Query {
369 code: "42000".to_string(),
370 message: "syntax error".to_string(),
371 };
372 assert!(!err.is_retryable());
373 }
374
375 #[test]
376 fn test_error_not_retryable_auth() {
377 let err = Error::Auth("invalid".to_string());
378 assert!(!err.is_retryable());
379 }
380
381 #[test]
382 fn test_error_not_retryable_tls() {
383 let err = Error::Tls("cert error".to_string());
384 assert!(!err.is_retryable());
385 }
386
387 #[test]
388 fn test_error_not_retryable_dsn() {
389 let err = Error::InvalidDsn("bad format".to_string());
390 assert!(!err.is_retryable());
391 }
392
393 #[test]
394 fn test_error_not_retryable_type() {
395 let err = Error::Type("cast failed".to_string());
396 assert!(!err.is_retryable());
397 }
398
399 #[test]
400 fn test_error_code_query() {
401 let err = Error::Query {
402 code: "42000".to_string(),
403 message: "syntax error".to_string(),
404 };
405 assert_eq!(err.code(), Some("42000"));
406 }
407
408 #[test]
409 fn test_error_code_non_query() {
410 let err = Error::Connection("test".to_string());
411 assert_eq!(err.code(), None);
412 }
413
414 #[test]
415 fn test_result_type_alias() {
416 fn returns_result() -> Result<i32> {
417 Ok(42)
418 }
419 assert_eq!(returns_result().unwrap(), 42);
420 }
421
422 #[test]
423 fn test_result_type_alias_error() {
424 fn returns_error() -> Result<i32> {
425 Err(Error::Other("test".to_string()))
426 }
427 assert!(returns_error().is_err());
428 }
429
430 #[test]
431 fn test_error_string_conversion() {
432 let err = Error::connection("test");
434 assert!(matches!(err, Error::Connection(_)));
435
436 let err = Error::connection(String::from("test"));
438 assert!(matches!(err, Error::Connection(_)));
439 }
440
441 #[test]
442 fn test_error_debug() {
443 let err = Error::Connection("test".to_string());
444 let debug_str = format!("{:?}", err);
445 assert!(debug_str.contains("Connection"));
446 assert!(debug_str.contains("test"));
447 }
448}