Skip to main content

rivet_cli/pipeline/
retry.rs

1/// Classifies transient errors into retry categories.
2/// Returns (is_transient, needs_reconnect, extra_delay_ms)
3///
4/// Checks Postgres SQLSTATE codes and MySQL error codes first, then falls back
5/// to string matching for errors that don't carry structured codes (e.g. IO,
6/// cloud credential errors).
7pub fn classify_error(err: &anyhow::Error) -> (bool, bool, u64) {
8    // --- Postgres: check SQLSTATE via the `postgres::Error` downcasted type ---
9    if let Some(pg) = err.downcast_ref::<postgres::Error>() {
10        if let Some(db) = pg.as_db_error() {
11            return classify_pg_sqlstate(db.code());
12        }
13        // Connection-level (non-DB) postgres errors → reconnect
14        if pg.is_closed() {
15            return (true, true, 0);
16        }
17    }
18
19    // --- MySQL: check numeric error code ---
20    if let Some(result) = err
21        .downcast_ref::<mysql::Error>()
22        .and_then(classify_mysql_error)
23    {
24        return result;
25    }
26
27    // --- Fallback: string-based classification ---
28    let msg = format!("{:#}", err).to_lowercase();
29
30    // Auth / credential errors are never transient — fix config, not retry
31    if msg.contains("loading credential")
32        || msg.contains("loadcredential")
33        || msg.contains("metadata.google.internal")
34        || msg.contains("permission denied")
35        || msg.contains("access denied")
36        || msg.contains("invalid_grant")
37        || msg.contains("token has been expired or revoked")
38    {
39        return (false, false, 0);
40    }
41
42    // Network errors -- need reconnect
43    if msg.contains("connection reset")
44        || msg.contains("broken pipe")
45        || msg.contains("connection refused")
46        || msg.contains("no route to host")
47        || msg.contains("network is unreachable")
48        || msg.contains("name resolution")
49        || msg.contains("dns")
50        || msg.contains("ssl handshake")
51        || msg.contains("i/o timeout")
52        || msg.contains("unexpected eof")
53        || msg.contains("closed the connection unexpectedly")
54        || msg.contains("got an error reading communication packets")
55    {
56        return (true, true, 0);
57    }
58
59    // MySQL specific -- need reconnect
60    if msg.contains("gone away")
61        || msg.contains("lost connection")
62        || msg.contains("the server closed the connection")
63        || msg.contains("can't connect to mysql server")
64    {
65        return (true, true, 0);
66    }
67
68    // Timeout errors -- retry on same connection
69    if msg.contains("timed out")
70        || msg.contains("timeout")
71        || msg.contains("canceling statement")
72        || msg.contains("lock wait timeout")
73        || msg.contains("execution time exceeded")
74    {
75        return (true, false, 0);
76    }
77
78    // Capacity errors -- retry with longer delay
79    if msg.contains("too many connections")
80        || msg.contains("the database system is starting up")
81        || msg.contains("the database system is shutting down")
82    {
83        return (true, true, 15_000);
84    }
85
86    // Deadlock/serialization -- retry once, same connection
87    if msg.contains("deadlock") || msg.contains("could not serialize access") {
88        return (true, false, 1_000);
89    }
90
91    // Not transient
92    (false, false, 0)
93}
94
95/// Classify a Postgres SQLSTATE code.
96/// Reference: <https://www.postgresql.org/docs/current/errcodes-appendix.html>
97fn classify_pg_sqlstate(code: &postgres::error::SqlState) -> (bool, bool, u64) {
98    use postgres::error::SqlState;
99
100    // Class 08 — Connection Exception → reconnect
101    if *code == SqlState::CONNECTION_EXCEPTION
102        || *code == SqlState::CONNECTION_DOES_NOT_EXIST
103        || *code == SqlState::CONNECTION_FAILURE
104        || *code == SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION
105        || *code == SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION
106        || code.code().starts_with("08")
107    {
108        return (true, true, 0);
109    }
110
111    // 57P01 admin_shutdown, 57P02 crash_shutdown, 57P03 cannot_connect_now
112    if *code == SqlState::ADMIN_SHUTDOWN
113        || *code == SqlState::CRASH_SHUTDOWN
114        || *code == SqlState::CANNOT_CONNECT_NOW
115    {
116        return (true, true, 15_000);
117    }
118
119    // 53300 too_many_connections
120    if *code == SqlState::TOO_MANY_CONNECTIONS {
121        return (true, true, 15_000);
122    }
123
124    // 40001 serialization_failure, 40P01 deadlock_detected
125    if *code == SqlState::T_R_SERIALIZATION_FAILURE {
126        return (true, false, 1_000);
127    }
128    if *code == SqlState::T_R_DEADLOCK_DETECTED {
129        return (true, false, 1_000);
130    }
131
132    // 57014 query_canceled (statement_timeout)
133    if *code == SqlState::QUERY_CANCELED {
134        return (true, false, 0);
135    }
136
137    // Class 53 — Insufficient Resources (disk full, out of memory)
138    if code.code().starts_with("53") {
139        return (true, false, 5_000);
140    }
141
142    // 28xxx — Invalid Authorization → permanent
143    if code.code().starts_with("28") {
144        return (false, false, 0);
145    }
146
147    // 42xxx — Syntax Error or Access Rule Violation → permanent
148    if code.code().starts_with("42") {
149        return (false, false, 0);
150    }
151
152    // All other SQLSTATE codes → not transient by default
153    (false, false, 0)
154}
155
156/// Classify a MySQL error by numeric code.
157/// Reference: <https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html>
158fn classify_mysql_error(err: &mysql::Error) -> Option<(bool, bool, u64)> {
159    match err {
160        mysql::Error::MySqlError(me) => {
161            match me.code {
162                // ER_LOCK_DEADLOCK
163                1213 => Some((true, false, 1_000)),
164                // ER_LOCK_WAIT_TIMEOUT
165                1205 => Some((true, false, 0)),
166                // ER_CON_COUNT_ERROR (too many connections)
167                1040 => Some((true, true, 15_000)),
168                // ER_SERVER_SHUTDOWN
169                1053 => Some((true, true, 15_000)),
170                // ER_ACCESS_DENIED_ERROR, ER_DBACCESS_DENIED_ERROR
171                1045 | 1044 => Some((false, false, 0)),
172                // ER_BAD_DB_ERROR, ER_NO_SUCH_TABLE, ER_PARSE_ERROR
173                1049 | 1146 | 1064 => Some((false, false, 0)),
174                _ => None,
175            }
176        }
177        mysql::Error::IoError(_) => Some((true, true, 0)),
178        _ => None,
179    }
180}
181
182#[cfg(test)]
183pub(crate) fn is_transient(err: &anyhow::Error) -> bool {
184    classify_error(err).0
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_is_transient_matches() {
193        assert!(is_transient(&anyhow::anyhow!("statement timed out")));
194        assert!(is_transient(&anyhow::anyhow!("connection reset")));
195    }
196
197    #[test]
198    fn test_is_transient_rejects() {
199        assert!(!is_transient(&anyhow::anyhow!("syntax error")));
200        assert!(!is_transient(&anyhow::anyhow!("permission denied")));
201        assert!(!is_transient(&anyhow::anyhow!("table not found")));
202    }
203
204    #[test]
205    fn test_classify_network_errors_need_reconnect() {
206        let cases = [
207            "connection refused",
208            "no route to host",
209            "network is unreachable",
210            "broken pipe",
211            "unexpected eof",
212            "MySQL server has gone away",
213            "lost connection to server",
214            "can't connect to mysql server",
215            "the server closed the connection",
216            "got an error reading communication packets",
217            "ssl handshake failed",
218        ];
219        for msg in cases {
220            let (transient, reconnect, _) = classify_error(&anyhow::anyhow!("{}", msg));
221            assert!(transient, "should be transient: {}", msg);
222            assert!(reconnect, "should need reconnect: {}", msg);
223        }
224    }
225
226    #[test]
227    fn test_classify_timeout_no_reconnect() {
228        let (t, r, _) = classify_error(&anyhow::anyhow!("statement timed out"));
229        assert!(t);
230        assert!(!r, "timeout should not require reconnect");
231
232        let (t, r, _) = classify_error(&anyhow::anyhow!("lock wait timeout exceeded"));
233        assert!(t);
234        assert!(!r);
235    }
236
237    #[test]
238    fn test_classify_capacity_errors_extra_delay() {
239        let (t, r, delay) = classify_error(&anyhow::anyhow!("too many connections"));
240        assert!(t);
241        assert!(r);
242        assert!(
243            delay >= 10_000,
244            "capacity errors should have extra delay, got: {}ms",
245            delay
246        );
247
248        let (t, _, delay) = classify_error(&anyhow::anyhow!("the database system is starting up"));
249        assert!(t);
250        assert!(delay >= 10_000);
251    }
252
253    #[test]
254    fn test_classify_deadlock_retryable() {
255        let (t, r, delay) = classify_error(&anyhow::anyhow!("deadlock detected"));
256        assert!(t);
257        assert!(!r, "deadlock should not require reconnect");
258        assert!(delay >= 1_000, "deadlock should have small extra delay");
259    }
260
261    #[test]
262    fn test_classify_permanent_errors() {
263        let cases = [
264            "syntax error",
265            "permission denied",
266            "relation does not exist",
267            "column not found",
268        ];
269        for msg in cases {
270            let (transient, _, _) = classify_error(&anyhow::anyhow!("{}", msg));
271            assert!(!transient, "should NOT be transient: {}", msg);
272        }
273    }
274
275    #[test]
276    fn test_classify_credential_errors_not_transient() {
277        let cases = [
278            "loading credential to sign http request",
279            "error sending request for url (http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token): dns error",
280            "invalid_grant: Token has been expired or revoked",
281            "Access Denied: no permission",
282        ];
283        for msg in cases {
284            let (transient, _, _) = classify_error(&anyhow::anyhow!("{}", msg));
285            assert!(
286                !transient,
287                "credential error should NOT be transient: {}",
288                msg
289            );
290        }
291    }
292}