Skip to main content

chopin_pg/
error.rs

1/// Errors returned by chopin-pg operations.
2#[derive(Debug)]
3pub enum PgError {
4    /// I/O error from the underlying socket.
5    Io(std::io::Error),
6    /// Protocol violation or unexpected message from server.
7    Protocol(String),
8    /// Authentication failure.
9    Auth(String),
10    /// Server-sent error response with rich diagnostic fields.
11    Server(Box<ServerError>),
12    /// Connection is closed or in an invalid state.
13    ConnectionClosed,
14    /// Query returned no rows when one was expected.
15    NoRows,
16    /// Type conversion error.
17    TypeConversion(String),
18    /// Statement not found in cache.
19    StatementNotCached,
20    /// Buffer overflow — message too large.
21    BufferOverflow,
22    /// Would block — operation cannot complete without blocking.
23    WouldBlock,
24    /// I/O operation timed out (application-level timeout).
25    Timeout,
26    /// Pool: timed out waiting for a connection.
27    PoolTimeout,
28    /// Pool: all connections are in use.
29    PoolExhausted,
30    /// Pool: connection failed validation.
31    PoolValidationFailed,
32}
33
34/// Server-sent error response with rich diagnostic fields.
35#[derive(Debug)]
36pub struct ServerError {
37    pub severity: String,
38    pub code: String,
39    pub message: String,
40    pub detail: Option<String>,
41    pub hint: Option<String>,
42    pub position: Option<i32>,
43    pub internal_position: Option<i32>,
44    pub internal_query: Option<String>,
45    pub where_: Option<String>,
46    pub schema_name: Option<String>,
47    pub table_name: Option<String>,
48    pub column_name: Option<String>,
49    pub data_type_name: Option<String>,
50    pub constraint_name: Option<String>,
51    pub file: Option<String>,
52    pub line: Option<String>,
53    pub routine: Option<String>,
54}
55
56/// Error classification for retry logic.
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum ErrorClass {
59    /// Transient error — safe to retry (deadlock, serialization failure, connection reset).
60    Transient,
61    /// Permanent error — do not retry (syntax error, permission denied).
62    Permanent,
63    /// Client-side error (invalid parameters, type conversion).
64    Client,
65    /// Pool-related error (timeout, exhaustion).
66    Pool,
67}
68
69impl PgError {
70    /// Classify this error for retry decisions.
71    pub fn classify(&self) -> ErrorClass {
72        match self {
73            PgError::Io(_) | PgError::ConnectionClosed | PgError::Timeout => ErrorClass::Transient,
74            // WouldBlock is a flow-control signal, not a transient failure.
75            // It should not trigger retry with backoff.
76            PgError::WouldBlock => ErrorClass::Client,
77            PgError::Server(err) => classify_sql_state(&err.code),
78            PgError::PoolTimeout | PgError::PoolExhausted | PgError::PoolValidationFailed => {
79                ErrorClass::Pool
80            }
81            PgError::TypeConversion(_)
82            | PgError::BufferOverflow
83            | PgError::StatementNotCached
84            | PgError::NoRows => ErrorClass::Client,
85            PgError::Protocol(_) | PgError::Auth(_) => ErrorClass::Permanent,
86        }
87    }
88
89    /// Returns true if this error is transient and the operation can be retried.
90    pub fn is_transient(&self) -> bool {
91        self.classify() == ErrorClass::Transient
92    }
93
94    /// Get the SQLSTATE code, if this is a server error.
95    pub fn sql_state(&self) -> Option<&str> {
96        match self {
97            PgError::Server(err) => Some(&err.code),
98            _ => None,
99        }
100    }
101
102    /// Get the hint from the server, if available.
103    pub fn hint(&self) -> Option<&str> {
104        match self {
105            PgError::Server(err) => err.hint.as_deref(),
106            _ => None,
107        }
108    }
109
110    /// Get the detail from the server, if available.
111    pub fn detail(&self) -> Option<&str> {
112        match self {
113            PgError::Server(err) => err.detail.as_deref(),
114            _ => None,
115        }
116    }
117
118    /// Build a Server error from parsed error/notice fields.
119    pub fn from_fields(fields: &[(u8, String)]) -> Self {
120        let mut severity = String::new();
121        let mut code = String::new();
122        let mut message = String::new();
123        let mut detail = None;
124        let mut hint = None;
125        let mut position = None;
126        let mut internal_position = None;
127        let mut internal_query = None;
128        let mut where_ = None;
129        let mut schema_name = None;
130        let mut table_name = None;
131        let mut column_name = None;
132        let mut data_type_name = None;
133        let mut constraint_name = None;
134        let mut file = None;
135        let mut line = None;
136        let mut routine = None;
137
138        for (field_type, value) in fields {
139            match field_type {
140                b'S' => severity = value.clone(),
141                b'C' => code = value.clone(),
142                b'M' => message = value.clone(),
143                b'D' => detail = Some(value.clone()),
144                b'H' => hint = Some(value.clone()),
145                b'P' => position = value.parse().ok(),
146                b'p' => internal_position = value.parse().ok(),
147                b'q' => internal_query = Some(value.clone()),
148                b'W' => where_ = Some(value.clone()),
149                b's' => schema_name = Some(value.clone()),
150                b't' => table_name = Some(value.clone()),
151                b'c' => column_name = Some(value.clone()),
152                b'd' => data_type_name = Some(value.clone()),
153                b'n' => constraint_name = Some(value.clone()),
154                b'F' => file = Some(value.clone()),
155                b'L' => line = Some(value.clone()),
156                b'R' => routine = Some(value.clone()),
157                _ => {}
158            }
159        }
160
161        PgError::Server(Box::new(ServerError {
162            severity,
163            code,
164            message,
165            detail,
166            hint,
167            position,
168            internal_position,
169            internal_query,
170            where_,
171            schema_name,
172            table_name,
173            column_name,
174            data_type_name,
175            constraint_name,
176            file,
177            line,
178            routine,
179        }))
180    }
181}
182
183/// Classify a SQLSTATE code.
184fn classify_sql_state(code: &str) -> ErrorClass {
185    match code {
186        // Class 40 — Transaction Rollback (deadlock, serialization failure)
187        c if c.starts_with("40") => ErrorClass::Transient,
188        // Class 08 — Connection Exception
189        c if c.starts_with("08") => ErrorClass::Transient,
190        // Class 53 — Insufficient Resources
191        c if c.starts_with("53") => ErrorClass::Transient,
192        // Class 57 — Operator Intervention (crash recovery, etc.)
193        c if c.starts_with("57") => ErrorClass::Transient,
194        // Class 42 — Syntax Error / Access Rule Violation
195        c if c.starts_with("42") => ErrorClass::Permanent,
196        // Class 23 — Integrity Constraint Violation
197        c if c.starts_with("23") => ErrorClass::Permanent,
198        // Class 28 — Invalid Authorization
199        c if c.starts_with("28") => ErrorClass::Permanent,
200        // Default to permanent
201        _ => ErrorClass::Permanent,
202    }
203}
204
205impl From<std::io::Error> for PgError {
206    fn from(e: std::io::Error) -> Self {
207        if e.kind() == std::io::ErrorKind::WouldBlock {
208            PgError::WouldBlock
209        } else {
210            PgError::Io(e)
211        }
212    }
213}
214
215impl std::fmt::Display for PgError {
216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        match self {
218            PgError::Io(e) => write!(f, "I/O error: {}", e),
219            PgError::Protocol(msg) => write!(f, "Protocol error: {}", msg),
220            PgError::Auth(msg) => write!(f, "Auth error: {}", msg),
221            PgError::Server(err) => {
222                write!(f, "PG {}: {} ({})", err.severity, err.message, err.code)?;
223                if let Some(d) = &err.detail {
224                    write!(f, "\n  Detail: {}", d)?;
225                }
226                if let Some(h) = &err.hint {
227                    write!(f, "\n  Hint: {}", h)?;
228                }
229                Ok(())
230            }
231            PgError::ConnectionClosed => write!(f, "Connection closed"),
232            PgError::NoRows => write!(f, "No rows returned"),
233            PgError::TypeConversion(msg) => write!(f, "Type conversion: {}", msg),
234            PgError::StatementNotCached => write!(f, "Statement not in cache"),
235            PgError::BufferOverflow => write!(f, "Buffer overflow"),
236            PgError::WouldBlock => write!(f, "Would block"),
237            PgError::Timeout => write!(f, "I/O operation timed out"),
238            PgError::PoolTimeout => write!(f, "Pool: connection checkout timed out"),
239            PgError::PoolExhausted => write!(f, "Pool: all connections are in use"),
240            PgError::PoolValidationFailed => write!(f, "Pool: connection failed validation"),
241        }
242    }
243}
244
245impl std::error::Error for PgError {}
246
247pub type PgResult<T> = Result<T, PgError>;
248
249/// Retry helper: executes an operation with exponential backoff on transient errors.
250pub fn retry<F, T>(max_retries: u32, mut f: F) -> PgResult<T>
251where
252    F: FnMut() -> PgResult<T>,
253{
254    let mut attempts = 0;
255    loop {
256        match f() {
257            Ok(val) => return Ok(val),
258            Err(e) if e.is_transient() && attempts < max_retries => {
259                attempts += 1;
260                // Exponential backoff: 1ms, 2ms, 4ms, 8ms, ...
261                let delay = std::time::Duration::from_millis(1 << attempts.min(10));
262                std::thread::sleep(delay);
263            }
264            Err(e) => return Err(e),
265        }
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    // ─── Variant Classification ───────────────────────────────────────────────
274
275    #[test]
276    fn test_io_error_is_transient() {
277        let e = PgError::Io(std::io::Error::new(
278            std::io::ErrorKind::BrokenPipe,
279            "broken pipe",
280        ));
281        assert_eq!(e.classify(), ErrorClass::Transient);
282        assert!(e.is_transient());
283    }
284
285    #[test]
286    fn test_connection_closed_is_transient() {
287        assert_eq!(PgError::ConnectionClosed.classify(), ErrorClass::Transient);
288        assert!(PgError::ConnectionClosed.is_transient());
289    }
290
291    #[test]
292    fn test_timeout_is_transient() {
293        assert_eq!(PgError::Timeout.classify(), ErrorClass::Transient);
294        assert!(PgError::Timeout.is_transient());
295    }
296
297    // ─── WouldBlock: must be Client, not Transient ────────────────────────────
298    // If WouldBlock were Transient, retry() would sleep-loop forever on non-blocking I/O.
299
300    #[test]
301    fn test_wouldblock_is_client_not_transient() {
302        assert_eq!(PgError::WouldBlock.classify(), ErrorClass::Client);
303        assert!(!PgError::WouldBlock.is_transient());
304    }
305
306    #[test]
307    fn test_type_conversion_is_client() {
308        assert_eq!(
309            PgError::TypeConversion("bad".to_string()).classify(),
310            ErrorClass::Client
311        );
312        assert!(!PgError::TypeConversion("bad".to_string()).is_transient());
313    }
314
315    #[test]
316    fn test_buffer_overflow_is_client() {
317        assert_eq!(PgError::BufferOverflow.classify(), ErrorClass::Client);
318    }
319
320    #[test]
321    fn test_no_rows_is_client() {
322        assert_eq!(PgError::NoRows.classify(), ErrorClass::Client);
323    }
324
325    #[test]
326    fn test_statement_not_cached_is_client() {
327        assert_eq!(PgError::StatementNotCached.classify(), ErrorClass::Client);
328    }
329
330    #[test]
331    fn test_pool_timeout_is_pool_class() {
332        assert_eq!(PgError::PoolTimeout.classify(), ErrorClass::Pool);
333        assert!(!PgError::PoolTimeout.is_transient());
334    }
335
336    #[test]
337    fn test_pool_exhausted_is_pool_class() {
338        assert_eq!(PgError::PoolExhausted.classify(), ErrorClass::Pool);
339        assert!(!PgError::PoolExhausted.is_transient());
340    }
341
342    #[test]
343    fn test_pool_validation_failed_is_pool_class() {
344        assert_eq!(PgError::PoolValidationFailed.classify(), ErrorClass::Pool);
345        assert!(!PgError::PoolValidationFailed.is_transient());
346    }
347
348    #[test]
349    fn test_protocol_error_is_permanent() {
350        assert_eq!(
351            PgError::Protocol("bad".to_string()).classify(),
352            ErrorClass::Permanent
353        );
354        assert!(!PgError::Protocol("bad".to_string()).is_transient());
355    }
356
357    #[test]
358    fn test_auth_error_is_permanent() {
359        assert_eq!(
360            PgError::Auth("denied".to_string()).classify(),
361            ErrorClass::Permanent
362        );
363        assert!(!PgError::Auth("denied".to_string()).is_transient());
364    }
365
366    // ─── SQLSTATE Classification ──────────────────────────────────────────────
367
368    fn server_err(code: &str) -> PgError {
369        PgError::Server(Box::new(ServerError {
370            severity: "ERROR".to_string(),
371            code: code.to_string(),
372            message: "test".to_string(),
373            detail: None,
374            hint: None,
375            position: None,
376            internal_position: None,
377            internal_query: None,
378            where_: None,
379            schema_name: None,
380            table_name: None,
381            column_name: None,
382            data_type_name: None,
383            constraint_name: None,
384            file: None,
385            line: None,
386            routine: None,
387        }))
388    }
389
390    #[test]
391    fn test_sqlstate_40001_serialization_failure_transient() {
392        assert!(server_err("40001").is_transient());
393    }
394
395    #[test]
396    fn test_sqlstate_40p01_deadlock_transient() {
397        assert!(server_err("40P01").is_transient());
398    }
399
400    #[test]
401    fn test_sqlstate_08006_connection_failure_transient() {
402        assert!(server_err("08006").is_transient());
403    }
404
405    #[test]
406    fn test_sqlstate_53300_too_many_connections_transient() {
407        // Class 53 = Insufficient Resources
408        assert!(server_err("53300").is_transient());
409    }
410
411    #[test]
412    fn test_sqlstate_57014_query_canceled_transient() {
413        // Class 57 = Operator Intervention
414        assert!(server_err("57014").is_transient());
415    }
416
417    #[test]
418    fn test_sqlstate_42601_syntax_error_permanent() {
419        assert_eq!(server_err("42601").classify(), ErrorClass::Permanent);
420        assert!(!server_err("42601").is_transient());
421    }
422
423    #[test]
424    fn test_sqlstate_23505_unique_violation_permanent() {
425        assert_eq!(server_err("23505").classify(), ErrorClass::Permanent);
426    }
427
428    #[test]
429    fn test_sqlstate_28000_invalid_authorization_permanent() {
430        assert_eq!(server_err("28000").classify(), ErrorClass::Permanent);
431    }
432
433    #[test]
434    fn test_sqlstate_unknown_default_permanent() {
435        // Unknown codes default to Permanent
436        assert_eq!(server_err("99999").classify(), ErrorClass::Permanent);
437    }
438
439    // ─── sql_state() Accessor ─────────────────────────────────────────────────
440
441    #[test]
442    fn test_sql_state_returns_code() {
443        assert_eq!(server_err("42601").sql_state(), Some("42601"));
444    }
445
446    #[test]
447    fn test_sql_state_non_server_is_none() {
448        assert_eq!(PgError::WouldBlock.sql_state(), None);
449        assert_eq!(PgError::Timeout.sql_state(), None);
450        assert_eq!(PgError::PoolExhausted.sql_state(), None);
451        assert_eq!(PgError::ConnectionClosed.sql_state(), None);
452    }
453
454    // ─── from_fields() ────────────────────────────────────────────────────────
455
456    #[test]
457    fn test_from_fields_complete() {
458        let fields = vec![
459            (b'S', "ERROR".to_string()),
460            (b'C', "42601".to_string()),
461            (b'M', "syntax error at position 5".to_string()),
462            (b'D', "near SELECT".to_string()),
463            (b'H', "check your query".to_string()),
464            (b'P', "5".to_string()),
465            (b's', "public".to_string()),
466            (b't', "users".to_string()),
467            (b'n', "users_pkey".to_string()),
468        ];
469        let e = PgError::from_fields(&fields);
470        if let PgError::Server(err) = e {
471            assert_eq!(err.severity, "ERROR");
472            assert_eq!(err.code, "42601");
473            assert_eq!(err.message, "syntax error at position 5");
474            assert_eq!(err.detail, Some("near SELECT".to_string()));
475            assert_eq!(err.hint, Some("check your query".to_string()));
476            assert_eq!(err.position, Some(5));
477            assert_eq!(err.schema_name, Some("public".to_string()));
478            assert_eq!(err.table_name, Some("users".to_string()));
479            assert_eq!(err.constraint_name, Some("users_pkey".to_string()));
480        } else {
481            panic!("Expected Server variant");
482        }
483    }
484
485    #[test]
486    fn test_from_fields_minimal() {
487        let fields = vec![
488            (b'S', "ERROR".to_string()),
489            (b'C', "99999".to_string()),
490            (b'M', "unknown error".to_string()),
491        ];
492        let e = PgError::from_fields(&fields);
493        if let PgError::Server(err) = e {
494            assert!(err.detail.is_none());
495            assert!(err.hint.is_none());
496            assert!(err.position.is_none());
497        } else {
498            panic!("Expected Server variant");
499        }
500    }
501
502    #[test]
503    fn test_from_fields_unknown_field_ignored() {
504        // Unknown field byte 'Z' should be silently ignored
505        let fields = vec![
506            (b'S', "ERROR".to_string()),
507            (b'C', "00000".to_string()),
508            (b'M', "ok".to_string()),
509            (b'Z', "ignored".to_string()),
510        ];
511        let e = PgError::from_fields(&fields);
512        assert!(matches!(e, PgError::Server(_)));
513    }
514
515    // ─── Display Format ───────────────────────────────────────────────────────
516
517    #[test]
518    fn test_display_server_includes_message_code_detail() {
519        let e = PgError::Server(Box::new(ServerError {
520            severity: "ERROR".to_string(),
521            code: "42601".to_string(),
522            message: "syntax error here".to_string(),
523            detail: Some("bad token".to_string()),
524            hint: None,
525            position: None,
526            internal_position: None,
527            internal_query: None,
528            where_: None,
529            schema_name: None,
530            table_name: None,
531            column_name: None,
532            data_type_name: None,
533            constraint_name: None,
534            file: None,
535            line: None,
536            routine: None,
537        }));
538        let s = format!("{}", e);
539        assert!(s.contains("syntax error here"), "missing message: {}", s);
540        assert!(s.contains("42601"), "missing code: {}", s);
541        assert!(s.contains("bad token"), "missing detail: {}", s);
542    }
543
544    #[test]
545    fn test_display_server_no_detail_or_hint() {
546        let e = server_err("42601");
547        let s = format!("{}", e);
548        // Just confirms Display works without panicking
549        assert!(!s.is_empty());
550    }
551
552    #[test]
553    fn test_display_all_non_server_variants() {
554        // Ensure Display is implemented and doesn't panic for every variant
555        let _ = format!("{}", PgError::ConnectionClosed);
556        let _ = format!("{}", PgError::NoRows);
557        let _ = format!("{}", PgError::BufferOverflow);
558        let _ = format!("{}", PgError::WouldBlock);
559        let _ = format!("{}", PgError::Timeout);
560        let _ = format!("{}", PgError::PoolTimeout);
561        let _ = format!("{}", PgError::PoolExhausted);
562        let _ = format!("{}", PgError::PoolValidationFailed);
563        let _ = format!("{}", PgError::StatementNotCached);
564        let _ = format!("{}", PgError::TypeConversion("type error".to_string()));
565        let _ = format!("{}", PgError::Protocol("protocol error".to_string()));
566        let _ = format!("{}", PgError::Auth("auth error".to_string()));
567    }
568
569    // ─── From<io::Error> ─────────────────────────────────────────────────────
570
571    #[test]
572    fn test_from_io_wouldblock_becomes_wouldblock() {
573        let io_err = std::io::Error::new(std::io::ErrorKind::WouldBlock, "would block");
574        let pg_err = PgError::from(io_err);
575        assert!(matches!(pg_err, PgError::WouldBlock));
576    }
577
578    #[test]
579    fn test_from_io_other_becomes_io_variant() {
580        let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "reset");
581        let pg_err = PgError::from(io_err);
582        assert!(matches!(pg_err, PgError::Io(_)));
583    }
584
585    #[test]
586    fn test_from_io_broken_pipe_is_not_wouldblock() {
587        let io_err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "pipe");
588        let pg_err = PgError::from(io_err);
589        assert!(!matches!(pg_err, PgError::WouldBlock));
590    }
591
592    // ─── retry() ─────────────────────────────────────────────────────────────
593
594    #[test]
595    fn test_retry_succeeds_immediately() {
596        let result = retry(3, || Ok::<i32, PgError>(42));
597        assert_eq!(result.unwrap(), 42);
598    }
599
600    #[test]
601    fn test_retry_no_retries_on_success() {
602        let mut calls = 0;
603        let result = retry(3, || {
604            calls += 1;
605            Ok::<i32, PgError>(1)
606        });
607        assert_eq!(result.unwrap(), 1);
608        assert_eq!(calls, 1);
609    }
610
611    #[test]
612    fn test_retry_permanent_error_not_retried() {
613        // Protocol error must NOT trigger retry — ensures retry() doesn't waste time
614        let mut calls = 0;
615        let result = retry(5, || {
616            calls += 1;
617            Err::<i32, PgError>(PgError::Protocol("bad".to_string()))
618        });
619        assert!(result.is_err());
620        assert_eq!(calls, 1, "Permanent errors must not be retried");
621    }
622
623    #[test]
624    fn test_retry_client_error_not_retried() {
625        // WouldBlock must NOT trigger retry (regression test)
626        let mut calls = 0;
627        let result = retry(5, || {
628            calls += 1;
629            Err::<i32, PgError>(PgError::WouldBlock)
630        });
631        assert!(result.is_err());
632        assert_eq!(calls, 1, "WouldBlock must not be retried");
633    }
634
635    #[test]
636    fn test_retry_zero_max_retries_no_sleep_no_retry() {
637        let mut calls = 0;
638        let result = retry(0, || {
639            calls += 1;
640            Err::<i32, PgError>(PgError::Io(std::io::Error::new(
641                std::io::ErrorKind::ConnectionReset,
642                "reset",
643            )))
644        });
645        assert!(result.is_err());
646        assert_eq!(calls, 1);
647    }
648
649    #[test]
650    fn test_retry_transient_error_retried_up_to_limit() {
651        let mut calls = 0;
652        let result = retry(2, || {
653            calls += 1;
654            Err::<i32, PgError>(PgError::Io(std::io::Error::new(
655                std::io::ErrorKind::ConnectionReset,
656                "reset",
657            )))
658        });
659        assert!(result.is_err());
660        // 1 initial + 2 retries = 3 total
661        assert_eq!(calls, 3);
662    }
663
664    #[test]
665    fn test_retry_succeeds_on_second_attempt() {
666        let mut calls = 0;
667        let result = retry(3, || {
668            calls += 1;
669            if calls < 2 {
670                Err(PgError::Io(std::io::Error::new(
671                    std::io::ErrorKind::ConnectionReset,
672                    "reset",
673                )))
674            } else {
675                Ok::<i32, PgError>(99)
676            }
677        });
678        assert_eq!(result.unwrap(), 99);
679        assert_eq!(calls, 2);
680    }
681
682    #[test]
683    fn test_pool_errors_not_retried() {
684        let mut calls = 0;
685        let _ = retry(5, || {
686            calls += 1;
687            Err::<(), PgError>(PgError::PoolTimeout)
688        });
689        assert_eq!(calls, 1, "PoolTimeout must not be retried");
690
691        calls = 0;
692        let _ = retry(5, || {
693            calls += 1;
694            Err::<(), PgError>(PgError::PoolExhausted)
695        });
696        assert_eq!(calls, 1, "PoolExhausted must not be retried");
697
698        calls = 0;
699        let _ = retry(5, || {
700            calls += 1;
701            Err::<(), PgError>(PgError::PoolValidationFailed)
702        });
703        assert_eq!(calls, 1, "PoolValidationFailed must not be retried");
704    }
705}