1use std::sync::Arc;
4
5use thiserror::Error;
6
7#[derive(Debug, Error)]
9#[non_exhaustive]
10pub enum Error {
11 #[error("connection failed: {0}")]
13 Connection(String),
14
15 #[error("connection closed")]
17 ConnectionClosed,
18
19 #[error("authentication failed: {0}")]
21 Authentication(#[from] mssql_auth::AuthError),
22
23 #[error("TLS error: {0}")]
25 Tls(String),
26
27 #[error("protocol error: {0}")]
29 Protocol(String),
30
31 #[error("codec error: {0}")]
33 Codec(#[from] mssql_codec::CodecError),
34
35 #[error("type error: {0}")]
37 Type(#[from] mssql_types::TypeError),
38
39 #[error("query error: {0}")]
41 Query(String),
42
43 #[error("server error {number}: {message}")]
45 Server {
46 number: i32,
48 class: u8,
50 state: u8,
52 message: String,
54 server: Option<String>,
56 procedure: Option<String>,
58 line: u32,
60 },
61
62 #[error("transaction error: {0}")]
64 Transaction(String),
65
66 #[error("configuration error: {0}")]
68 Config(String),
69
70 #[error("connection timed out")]
72 ConnectTimeout,
73
74 #[error("TLS handshake timed out")]
76 TlsTimeout,
77
78 #[error("connection timed out")]
80 ConnectionTimeout,
81
82 #[error("command timed out")]
84 CommandTimeout,
85
86 #[error("routing required to {host}:{port}")]
88 Routing {
89 host: String,
91 port: u16,
93 },
94
95 #[error("too many redirects (max {max})")]
97 TooManyRedirects {
98 max: u8,
100 },
101
102 #[error("IO error: {0}")]
104 Io(Arc<std::io::Error>),
105
106 #[error("invalid identifier: {0}")]
108 InvalidIdentifier(String),
109
110 #[error("connection pool exhausted")]
112 PoolExhausted,
113
114 #[error("query cancellation failed: {0}")]
116 Cancel(String),
117
118 #[error("query cancelled")]
120 Cancelled,
121}
122
123impl From<mssql_tls::TlsError> for Error {
124 fn from(e: mssql_tls::TlsError) -> Self {
125 Error::Tls(e.to_string())
126 }
127}
128
129impl From<tds_protocol::ProtocolError> for Error {
130 fn from(e: tds_protocol::ProtocolError) -> Self {
131 Error::Protocol(e.to_string())
132 }
133}
134
135impl From<std::io::Error> for Error {
136 fn from(e: std::io::Error) -> Self {
137 Error::Io(Arc::new(e))
138 }
139}
140
141impl Error {
142 #[must_use]
158 pub fn is_transient(&self) -> bool {
159 match self {
160 Self::ConnectTimeout
161 | Self::TlsTimeout
162 | Self::ConnectionTimeout
163 | Self::CommandTimeout
164 | Self::ConnectionClosed
165 | Self::Routing { .. }
166 | Self::PoolExhausted
167 | Self::Io(_) => true,
168 Self::Server { number, .. } => Self::is_transient_server_error(*number),
169 _ => false,
170 }
171 }
172
173 #[must_use]
177 pub fn is_transient_server_error(number: i32) -> bool {
178 matches!(
179 number,
180 1205 | -2 | 10928 | 10929 | 40197 | 40501 | 40613 | 49918 | 49919 | 49920 | 4060 | 18456 )
193 }
194
195 #[must_use]
208 pub fn is_terminal(&self) -> bool {
209 match self {
210 Self::Config(_) | Self::InvalidIdentifier(_) => true,
211 Self::Server { number, .. } => Self::is_terminal_server_error(*number),
212 _ => false,
213 }
214 }
215
216 #[must_use]
220 pub fn is_terminal_server_error(number: i32) -> bool {
221 matches!(
222 number,
223 102 | 207 | 208 | 547 | 2627 | 2601 )
230 }
231
232 #[must_use]
237 pub fn is_protocol_error(&self) -> bool {
238 matches!(self, Self::Protocol(_))
239 }
240
241 #[must_use]
243 pub fn is_server_error(&self, number: i32) -> bool {
244 matches!(self, Self::Server { number: n, .. } if *n == number)
245 }
246
247 #[must_use]
255 pub fn class(&self) -> Option<u8> {
256 match self {
257 Self::Server { class, .. } => Some(*class),
258 _ => None,
259 }
260 }
261
262 #[must_use]
264 pub fn severity(&self) -> Option<u8> {
265 self.class()
266 }
267}
268
269pub type Result<T> = std::result::Result<T, Error>;
271
272#[cfg(test)]
273#[allow(clippy::unwrap_used)]
274mod tests {
275 use super::*;
276 use std::sync::Arc;
277
278 fn make_server_error(number: i32) -> Error {
279 Error::Server {
280 number,
281 class: 16,
282 state: 1,
283 message: "Test error".to_string(),
284 server: None,
285 procedure: None,
286 line: 1,
287 }
288 }
289
290 #[test]
291 fn test_is_transient_connection_errors() {
292 assert!(Error::ConnectionTimeout.is_transient());
293 assert!(Error::CommandTimeout.is_transient());
294 assert!(Error::ConnectionClosed.is_transient());
295 assert!(Error::PoolExhausted.is_transient());
296 assert!(
297 Error::Routing {
298 host: "test".into(),
299 port: 1433,
300 }
301 .is_transient()
302 );
303 }
304
305 #[test]
306 fn test_is_transient_io_error() {
307 let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "reset");
308 assert!(Error::Io(Arc::new(io_err)).is_transient());
309 }
310
311 #[test]
312 fn test_is_transient_server_errors_deadlock() {
313 assert!(make_server_error(1205).is_transient());
315 }
316
317 #[test]
318 fn test_is_transient_server_errors_timeout() {
319 assert!(make_server_error(-2).is_transient());
321 }
322
323 #[test]
324 fn test_is_transient_server_errors_azure() {
325 assert!(make_server_error(10928).is_transient()); assert!(make_server_error(10929).is_transient()); assert!(make_server_error(40197).is_transient()); assert!(make_server_error(40501).is_transient()); assert!(make_server_error(40613).is_transient()); assert!(make_server_error(49918).is_transient()); assert!(make_server_error(49919).is_transient()); assert!(make_server_error(49920).is_transient()); }
335
336 #[test]
337 fn test_is_transient_server_errors_other() {
338 assert!(make_server_error(4060).is_transient()); assert!(make_server_error(18456).is_transient()); }
342
343 #[test]
344 fn test_is_not_transient() {
345 assert!(!Error::Config("bad config".into()).is_transient());
347 assert!(!Error::Query("syntax error".into()).is_transient());
348 assert!(!Error::InvalidIdentifier("bad id".into()).is_transient());
349 assert!(!make_server_error(102).is_transient()); }
351
352 #[test]
353 fn test_is_terminal_server_errors() {
354 assert!(make_server_error(102).is_terminal()); assert!(make_server_error(207).is_terminal()); assert!(make_server_error(208).is_terminal()); assert!(make_server_error(547).is_terminal()); assert!(make_server_error(2627).is_terminal()); assert!(make_server_error(2601).is_terminal()); }
362
363 #[test]
364 fn test_is_terminal_config_errors() {
365 assert!(Error::Config("bad config".into()).is_terminal());
366 assert!(Error::InvalidIdentifier("bad id".into()).is_terminal());
367 }
368
369 #[test]
370 fn test_is_not_terminal() {
371 assert!(!Error::ConnectionTimeout.is_terminal());
373 assert!(!make_server_error(1205).is_terminal()); assert!(!make_server_error(40501).is_terminal()); }
376
377 #[test]
378 fn test_transient_server_error_static() {
379 assert!(Error::is_transient_server_error(1205));
381 assert!(Error::is_transient_server_error(40501));
382 assert!(!Error::is_transient_server_error(102));
383 }
384
385 #[test]
386 fn test_terminal_server_error_static() {
387 assert!(Error::is_terminal_server_error(102));
389 assert!(Error::is_terminal_server_error(2627));
390 assert!(!Error::is_terminal_server_error(1205));
391 }
392
393 #[test]
394 fn test_error_class() {
395 let err = make_server_error(102);
396 assert_eq!(err.class(), Some(16));
397 assert_eq!(err.severity(), Some(16));
398
399 assert_eq!(Error::ConnectionTimeout.class(), None);
400 }
401
402 #[test]
403 fn test_is_server_error() {
404 let err = make_server_error(102);
405 assert!(err.is_server_error(102));
406 assert!(!err.is_server_error(103));
407
408 assert!(!Error::ConnectionTimeout.is_server_error(102));
409 }
410}