1#[derive(Debug)]
3pub enum PgError {
4 Io(std::io::Error),
6 Protocol(String),
8 Auth(String),
10 Server(Box<ServerError>),
12 ConnectionClosed,
14 NoRows,
16 TypeConversion(String),
18 StatementNotCached,
20 BufferOverflow,
22 WouldBlock,
24 Timeout,
26 PoolTimeout,
28 PoolExhausted,
30 PoolValidationFailed,
32}
33
34#[derive(Debug)]
36pub struct ServerError {
37 pub severity: String,
38 pub code: String,
39 pub message: String,
40 pub detail: Option<String>,
41 pub hint: Option<String>,
42 pub position: Option<i32>,
43 pub internal_position: Option<i32>,
44 pub internal_query: Option<String>,
45 pub where_: Option<String>,
46 pub schema_name: Option<String>,
47 pub table_name: Option<String>,
48 pub column_name: Option<String>,
49 pub data_type_name: Option<String>,
50 pub constraint_name: Option<String>,
51 pub file: Option<String>,
52 pub line: Option<String>,
53 pub routine: Option<String>,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum ErrorClass {
59 Transient,
61 Permanent,
63 Client,
65 Pool,
67}
68
69impl PgError {
70 pub fn classify(&self) -> ErrorClass {
72 match self {
73 PgError::Io(_) | PgError::ConnectionClosed | PgError::Timeout => ErrorClass::Transient,
74 PgError::WouldBlock => ErrorClass::Client,
77 PgError::Server(err) => classify_sql_state(&err.code),
78 PgError::PoolTimeout | PgError::PoolExhausted | PgError::PoolValidationFailed => {
79 ErrorClass::Pool
80 }
81 PgError::TypeConversion(_)
82 | PgError::BufferOverflow
83 | PgError::StatementNotCached
84 | PgError::NoRows => ErrorClass::Client,
85 PgError::Protocol(_) | PgError::Auth(_) => ErrorClass::Permanent,
86 }
87 }
88
89 pub fn is_transient(&self) -> bool {
91 self.classify() == ErrorClass::Transient
92 }
93
94 pub fn sql_state(&self) -> Option<&str> {
96 match self {
97 PgError::Server(err) => Some(&err.code),
98 _ => None,
99 }
100 }
101
102 pub fn hint(&self) -> Option<&str> {
104 match self {
105 PgError::Server(err) => err.hint.as_deref(),
106 _ => None,
107 }
108 }
109
110 pub fn detail(&self) -> Option<&str> {
112 match self {
113 PgError::Server(err) => err.detail.as_deref(),
114 _ => None,
115 }
116 }
117
118 pub fn from_fields(fields: &[(u8, String)]) -> Self {
120 let mut severity = String::new();
121 let mut code = String::new();
122 let mut message = String::new();
123 let mut detail = None;
124 let mut hint = None;
125 let mut position = None;
126 let mut internal_position = None;
127 let mut internal_query = None;
128 let mut where_ = None;
129 let mut schema_name = None;
130 let mut table_name = None;
131 let mut column_name = None;
132 let mut data_type_name = None;
133 let mut constraint_name = None;
134 let mut file = None;
135 let mut line = None;
136 let mut routine = None;
137
138 for (field_type, value) in fields {
139 match field_type {
140 b'S' => severity = value.clone(),
141 b'C' => code = value.clone(),
142 b'M' => message = value.clone(),
143 b'D' => detail = Some(value.clone()),
144 b'H' => hint = Some(value.clone()),
145 b'P' => position = value.parse().ok(),
146 b'p' => internal_position = value.parse().ok(),
147 b'q' => internal_query = Some(value.clone()),
148 b'W' => where_ = Some(value.clone()),
149 b's' => schema_name = Some(value.clone()),
150 b't' => table_name = Some(value.clone()),
151 b'c' => column_name = Some(value.clone()),
152 b'd' => data_type_name = Some(value.clone()),
153 b'n' => constraint_name = Some(value.clone()),
154 b'F' => file = Some(value.clone()),
155 b'L' => line = Some(value.clone()),
156 b'R' => routine = Some(value.clone()),
157 _ => {}
158 }
159 }
160
161 PgError::Server(Box::new(ServerError {
162 severity,
163 code,
164 message,
165 detail,
166 hint,
167 position,
168 internal_position,
169 internal_query,
170 where_,
171 schema_name,
172 table_name,
173 column_name,
174 data_type_name,
175 constraint_name,
176 file,
177 line,
178 routine,
179 }))
180 }
181}
182
183fn classify_sql_state(code: &str) -> ErrorClass {
185 match code {
186 c if c.starts_with("40") => ErrorClass::Transient,
188 c if c.starts_with("08") => ErrorClass::Transient,
190 c if c.starts_with("53") => ErrorClass::Transient,
192 c if c.starts_with("57") => ErrorClass::Transient,
194 c if c.starts_with("42") => ErrorClass::Permanent,
196 c if c.starts_with("23") => ErrorClass::Permanent,
198 c if c.starts_with("28") => ErrorClass::Permanent,
200 _ => ErrorClass::Permanent,
202 }
203}
204
205impl From<std::io::Error> for PgError {
206 fn from(e: std::io::Error) -> Self {
207 if e.kind() == std::io::ErrorKind::WouldBlock {
208 PgError::WouldBlock
209 } else {
210 PgError::Io(e)
211 }
212 }
213}
214
215impl std::fmt::Display for PgError {
216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 match self {
218 PgError::Io(e) => write!(f, "I/O error: {}", e),
219 PgError::Protocol(msg) => write!(f, "Protocol error: {}", msg),
220 PgError::Auth(msg) => write!(f, "Auth error: {}", msg),
221 PgError::Server(err) => {
222 write!(f, "PG {}: {} ({})", err.severity, err.message, err.code)?;
223 if let Some(d) = &err.detail {
224 write!(f, "\n Detail: {}", d)?;
225 }
226 if let Some(h) = &err.hint {
227 write!(f, "\n Hint: {}", h)?;
228 }
229 Ok(())
230 }
231 PgError::ConnectionClosed => write!(f, "Connection closed"),
232 PgError::NoRows => write!(f, "No rows returned"),
233 PgError::TypeConversion(msg) => write!(f, "Type conversion: {}", msg),
234 PgError::StatementNotCached => write!(f, "Statement not in cache"),
235 PgError::BufferOverflow => write!(f, "Buffer overflow"),
236 PgError::WouldBlock => write!(f, "Would block"),
237 PgError::Timeout => write!(f, "I/O operation timed out"),
238 PgError::PoolTimeout => write!(f, "Pool: connection checkout timed out"),
239 PgError::PoolExhausted => write!(f, "Pool: all connections are in use"),
240 PgError::PoolValidationFailed => write!(f, "Pool: connection failed validation"),
241 }
242 }
243}
244
245impl std::error::Error for PgError {}
246
247pub type PgResult<T> = Result<T, PgError>;
248
249pub fn retry<F, T>(max_retries: u32, mut f: F) -> PgResult<T>
251where
252 F: FnMut() -> PgResult<T>,
253{
254 let mut attempts = 0;
255 loop {
256 match f() {
257 Ok(val) => return Ok(val),
258 Err(e) if e.is_transient() && attempts < max_retries => {
259 attempts += 1;
260 let delay = std::time::Duration::from_millis(1 << attempts.min(10));
262 std::thread::sleep(delay);
263 }
264 Err(e) => return Err(e),
265 }
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
276 fn test_io_error_is_transient() {
277 let e = PgError::Io(std::io::Error::new(
278 std::io::ErrorKind::BrokenPipe,
279 "broken pipe",
280 ));
281 assert_eq!(e.classify(), ErrorClass::Transient);
282 assert!(e.is_transient());
283 }
284
285 #[test]
286 fn test_connection_closed_is_transient() {
287 assert_eq!(PgError::ConnectionClosed.classify(), ErrorClass::Transient);
288 assert!(PgError::ConnectionClosed.is_transient());
289 }
290
291 #[test]
292 fn test_timeout_is_transient() {
293 assert_eq!(PgError::Timeout.classify(), ErrorClass::Transient);
294 assert!(PgError::Timeout.is_transient());
295 }
296
297 #[test]
301 fn test_wouldblock_is_client_not_transient() {
302 assert_eq!(PgError::WouldBlock.classify(), ErrorClass::Client);
303 assert!(!PgError::WouldBlock.is_transient());
304 }
305
306 #[test]
307 fn test_type_conversion_is_client() {
308 assert_eq!(
309 PgError::TypeConversion("bad".to_string()).classify(),
310 ErrorClass::Client
311 );
312 assert!(!PgError::TypeConversion("bad".to_string()).is_transient());
313 }
314
315 #[test]
316 fn test_buffer_overflow_is_client() {
317 assert_eq!(PgError::BufferOverflow.classify(), ErrorClass::Client);
318 }
319
320 #[test]
321 fn test_no_rows_is_client() {
322 assert_eq!(PgError::NoRows.classify(), ErrorClass::Client);
323 }
324
325 #[test]
326 fn test_statement_not_cached_is_client() {
327 assert_eq!(PgError::StatementNotCached.classify(), ErrorClass::Client);
328 }
329
330 #[test]
331 fn test_pool_timeout_is_pool_class() {
332 assert_eq!(PgError::PoolTimeout.classify(), ErrorClass::Pool);
333 assert!(!PgError::PoolTimeout.is_transient());
334 }
335
336 #[test]
337 fn test_pool_exhausted_is_pool_class() {
338 assert_eq!(PgError::PoolExhausted.classify(), ErrorClass::Pool);
339 assert!(!PgError::PoolExhausted.is_transient());
340 }
341
342 #[test]
343 fn test_pool_validation_failed_is_pool_class() {
344 assert_eq!(PgError::PoolValidationFailed.classify(), ErrorClass::Pool);
345 assert!(!PgError::PoolValidationFailed.is_transient());
346 }
347
348 #[test]
349 fn test_protocol_error_is_permanent() {
350 assert_eq!(
351 PgError::Protocol("bad".to_string()).classify(),
352 ErrorClass::Permanent
353 );
354 assert!(!PgError::Protocol("bad".to_string()).is_transient());
355 }
356
357 #[test]
358 fn test_auth_error_is_permanent() {
359 assert_eq!(
360 PgError::Auth("denied".to_string()).classify(),
361 ErrorClass::Permanent
362 );
363 assert!(!PgError::Auth("denied".to_string()).is_transient());
364 }
365
366 fn server_err(code: &str) -> PgError {
369 PgError::Server(Box::new(ServerError {
370 severity: "ERROR".to_string(),
371 code: code.to_string(),
372 message: "test".to_string(),
373 detail: None,
374 hint: None,
375 position: None,
376 internal_position: None,
377 internal_query: None,
378 where_: None,
379 schema_name: None,
380 table_name: None,
381 column_name: None,
382 data_type_name: None,
383 constraint_name: None,
384 file: None,
385 line: None,
386 routine: None,
387 }))
388 }
389
390 #[test]
391 fn test_sqlstate_40001_serialization_failure_transient() {
392 assert!(server_err("40001").is_transient());
393 }
394
395 #[test]
396 fn test_sqlstate_40p01_deadlock_transient() {
397 assert!(server_err("40P01").is_transient());
398 }
399
400 #[test]
401 fn test_sqlstate_08006_connection_failure_transient() {
402 assert!(server_err("08006").is_transient());
403 }
404
405 #[test]
406 fn test_sqlstate_53300_too_many_connections_transient() {
407 assert!(server_err("53300").is_transient());
409 }
410
411 #[test]
412 fn test_sqlstate_57014_query_canceled_transient() {
413 assert!(server_err("57014").is_transient());
415 }
416
417 #[test]
418 fn test_sqlstate_42601_syntax_error_permanent() {
419 assert_eq!(server_err("42601").classify(), ErrorClass::Permanent);
420 assert!(!server_err("42601").is_transient());
421 }
422
423 #[test]
424 fn test_sqlstate_23505_unique_violation_permanent() {
425 assert_eq!(server_err("23505").classify(), ErrorClass::Permanent);
426 }
427
428 #[test]
429 fn test_sqlstate_28000_invalid_authorization_permanent() {
430 assert_eq!(server_err("28000").classify(), ErrorClass::Permanent);
431 }
432
433 #[test]
434 fn test_sqlstate_unknown_default_permanent() {
435 assert_eq!(server_err("99999").classify(), ErrorClass::Permanent);
437 }
438
439 #[test]
442 fn test_sql_state_returns_code() {
443 assert_eq!(server_err("42601").sql_state(), Some("42601"));
444 }
445
446 #[test]
447 fn test_sql_state_non_server_is_none() {
448 assert_eq!(PgError::WouldBlock.sql_state(), None);
449 assert_eq!(PgError::Timeout.sql_state(), None);
450 assert_eq!(PgError::PoolExhausted.sql_state(), None);
451 assert_eq!(PgError::ConnectionClosed.sql_state(), None);
452 }
453
454 #[test]
457 fn test_from_fields_complete() {
458 let fields = vec![
459 (b'S', "ERROR".to_string()),
460 (b'C', "42601".to_string()),
461 (b'M', "syntax error at position 5".to_string()),
462 (b'D', "near SELECT".to_string()),
463 (b'H', "check your query".to_string()),
464 (b'P', "5".to_string()),
465 (b's', "public".to_string()),
466 (b't', "users".to_string()),
467 (b'n', "users_pkey".to_string()),
468 ];
469 let e = PgError::from_fields(&fields);
470 if let PgError::Server(err) = e {
471 assert_eq!(err.severity, "ERROR");
472 assert_eq!(err.code, "42601");
473 assert_eq!(err.message, "syntax error at position 5");
474 assert_eq!(err.detail, Some("near SELECT".to_string()));
475 assert_eq!(err.hint, Some("check your query".to_string()));
476 assert_eq!(err.position, Some(5));
477 assert_eq!(err.schema_name, Some("public".to_string()));
478 assert_eq!(err.table_name, Some("users".to_string()));
479 assert_eq!(err.constraint_name, Some("users_pkey".to_string()));
480 } else {
481 panic!("Expected Server variant");
482 }
483 }
484
485 #[test]
486 fn test_from_fields_minimal() {
487 let fields = vec![
488 (b'S', "ERROR".to_string()),
489 (b'C', "99999".to_string()),
490 (b'M', "unknown error".to_string()),
491 ];
492 let e = PgError::from_fields(&fields);
493 if let PgError::Server(err) = e {
494 assert!(err.detail.is_none());
495 assert!(err.hint.is_none());
496 assert!(err.position.is_none());
497 } else {
498 panic!("Expected Server variant");
499 }
500 }
501
502 #[test]
503 fn test_from_fields_unknown_field_ignored() {
504 let fields = vec![
506 (b'S', "ERROR".to_string()),
507 (b'C', "00000".to_string()),
508 (b'M', "ok".to_string()),
509 (b'Z', "ignored".to_string()),
510 ];
511 let e = PgError::from_fields(&fields);
512 assert!(matches!(e, PgError::Server(_)));
513 }
514
515 #[test]
518 fn test_display_server_includes_message_code_detail() {
519 let e = PgError::Server(Box::new(ServerError {
520 severity: "ERROR".to_string(),
521 code: "42601".to_string(),
522 message: "syntax error here".to_string(),
523 detail: Some("bad token".to_string()),
524 hint: None,
525 position: None,
526 internal_position: None,
527 internal_query: None,
528 where_: None,
529 schema_name: None,
530 table_name: None,
531 column_name: None,
532 data_type_name: None,
533 constraint_name: None,
534 file: None,
535 line: None,
536 routine: None,
537 }));
538 let s = format!("{}", e);
539 assert!(s.contains("syntax error here"), "missing message: {}", s);
540 assert!(s.contains("42601"), "missing code: {}", s);
541 assert!(s.contains("bad token"), "missing detail: {}", s);
542 }
543
544 #[test]
545 fn test_display_server_no_detail_or_hint() {
546 let e = server_err("42601");
547 let s = format!("{}", e);
548 assert!(!s.is_empty());
550 }
551
552 #[test]
553 fn test_display_all_non_server_variants() {
554 let _ = format!("{}", PgError::ConnectionClosed);
556 let _ = format!("{}", PgError::NoRows);
557 let _ = format!("{}", PgError::BufferOverflow);
558 let _ = format!("{}", PgError::WouldBlock);
559 let _ = format!("{}", PgError::Timeout);
560 let _ = format!("{}", PgError::PoolTimeout);
561 let _ = format!("{}", PgError::PoolExhausted);
562 let _ = format!("{}", PgError::PoolValidationFailed);
563 let _ = format!("{}", PgError::StatementNotCached);
564 let _ = format!("{}", PgError::TypeConversion("type error".to_string()));
565 let _ = format!("{}", PgError::Protocol("protocol error".to_string()));
566 let _ = format!("{}", PgError::Auth("auth error".to_string()));
567 }
568
569 #[test]
572 fn test_from_io_wouldblock_becomes_wouldblock() {
573 let io_err = std::io::Error::new(std::io::ErrorKind::WouldBlock, "would block");
574 let pg_err = PgError::from(io_err);
575 assert!(matches!(pg_err, PgError::WouldBlock));
576 }
577
578 #[test]
579 fn test_from_io_other_becomes_io_variant() {
580 let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "reset");
581 let pg_err = PgError::from(io_err);
582 assert!(matches!(pg_err, PgError::Io(_)));
583 }
584
585 #[test]
586 fn test_from_io_broken_pipe_is_not_wouldblock() {
587 let io_err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "pipe");
588 let pg_err = PgError::from(io_err);
589 assert!(!matches!(pg_err, PgError::WouldBlock));
590 }
591
592 #[test]
595 fn test_retry_succeeds_immediately() {
596 let result = retry(3, || Ok::<i32, PgError>(42));
597 assert_eq!(result.unwrap(), 42);
598 }
599
600 #[test]
601 fn test_retry_no_retries_on_success() {
602 let mut calls = 0;
603 let result = retry(3, || {
604 calls += 1;
605 Ok::<i32, PgError>(1)
606 });
607 assert_eq!(result.unwrap(), 1);
608 assert_eq!(calls, 1);
609 }
610
611 #[test]
612 fn test_retry_permanent_error_not_retried() {
613 let mut calls = 0;
615 let result = retry(5, || {
616 calls += 1;
617 Err::<i32, PgError>(PgError::Protocol("bad".to_string()))
618 });
619 assert!(result.is_err());
620 assert_eq!(calls, 1, "Permanent errors must not be retried");
621 }
622
623 #[test]
624 fn test_retry_client_error_not_retried() {
625 let mut calls = 0;
627 let result = retry(5, || {
628 calls += 1;
629 Err::<i32, PgError>(PgError::WouldBlock)
630 });
631 assert!(result.is_err());
632 assert_eq!(calls, 1, "WouldBlock must not be retried");
633 }
634
635 #[test]
636 fn test_retry_zero_max_retries_no_sleep_no_retry() {
637 let mut calls = 0;
638 let result = retry(0, || {
639 calls += 1;
640 Err::<i32, PgError>(PgError::Io(std::io::Error::new(
641 std::io::ErrorKind::ConnectionReset,
642 "reset",
643 )))
644 });
645 assert!(result.is_err());
646 assert_eq!(calls, 1);
647 }
648
649 #[test]
650 fn test_retry_transient_error_retried_up_to_limit() {
651 let mut calls = 0;
652 let result = retry(2, || {
653 calls += 1;
654 Err::<i32, PgError>(PgError::Io(std::io::Error::new(
655 std::io::ErrorKind::ConnectionReset,
656 "reset",
657 )))
658 });
659 assert!(result.is_err());
660 assert_eq!(calls, 3);
662 }
663
664 #[test]
665 fn test_retry_succeeds_on_second_attempt() {
666 let mut calls = 0;
667 let result = retry(3, || {
668 calls += 1;
669 if calls < 2 {
670 Err(PgError::Io(std::io::Error::new(
671 std::io::ErrorKind::ConnectionReset,
672 "reset",
673 )))
674 } else {
675 Ok::<i32, PgError>(99)
676 }
677 });
678 assert_eq!(result.unwrap(), 99);
679 assert_eq!(calls, 2);
680 }
681
682 #[test]
683 fn test_pool_errors_not_retried() {
684 let mut calls = 0;
685 let _ = retry(5, || {
686 calls += 1;
687 Err::<(), PgError>(PgError::PoolTimeout)
688 });
689 assert_eq!(calls, 1, "PoolTimeout must not be retried");
690
691 calls = 0;
692 let _ = retry(5, || {
693 calls += 1;
694 Err::<(), PgError>(PgError::PoolExhausted)
695 });
696 assert_eq!(calls, 1, "PoolExhausted must not be retried");
697
698 calls = 0;
699 let _ = retry(5, || {
700 calls += 1;
701 Err::<(), PgError>(PgError::PoolValidationFailed)
702 });
703 assert_eq!(calls, 1, "PoolValidationFailed must not be retried");
704 }
705}