1use std::sync::Arc;
4
5use thiserror::Error;
6
7#[derive(Debug, Error)]
9#[non_exhaustive]
10pub enum Error {
11 #[error("connection failed: {0}")]
13 Connection(String),
14
15 #[error("connection closed")]
17 ConnectionClosed,
18
19 #[error("authentication failed: {0}")]
21 Authentication(#[from] mssql_auth::AuthError),
22
23 #[cfg(feature = "tls")]
25 #[error("TLS error: {0}")]
26 Tls(#[from] mssql_tls::TlsError),
27
28 #[cfg(not(feature = "tls"))]
30 #[error("TLS error: {0}")]
31 Tls(String),
32
33 #[error("protocol error: {0}")]
35 ProtocolError(#[from] tds_protocol::ProtocolError),
36
37 #[error("protocol error: {0}")]
39 Protocol(String),
40
41 #[error("codec error: {0}")]
43 Codec(#[from] mssql_codec::CodecError),
44
45 #[error("type error: {0}")]
47 Type(#[from] mssql_types::TypeError),
48
49 #[error("query error: {0}")]
51 Query(String),
52
53 #[error("server error {number} (severity {class}, state {state}): {message}{}", format_server_location(.server, .procedure, .line))]
55 Server {
56 number: i32,
58 class: u8,
60 state: u8,
62 message: String,
64 server: Option<String>,
66 procedure: Option<String>,
68 line: u32,
70 },
71
72 #[error("configuration error: {0}")]
74 Config(String),
75
76 #[error("TCP connection timed out connecting to {host}:{port}")]
78 ConnectTimeout {
79 host: String,
81 port: u16,
83 },
84
85 #[error("TLS handshake timed out with {host}:{port}")]
87 TlsTimeout {
88 host: String,
90 port: u16,
92 },
93
94 #[error("login timed out for {host}:{port}")]
96 LoginTimeout {
97 host: String,
99 port: u16,
101 },
102
103 #[error("command timed out")]
105 CommandTimeout,
106
107 #[error("routing required to {host}:{port}")]
109 Routing {
110 host: String,
112 port: u16,
114 },
115
116 #[error("too many redirects (max {max})")]
118 TooManyRedirects {
119 max: u8,
121 },
122
123 #[error("IO error: {0}")]
125 Io(#[source] SharedIoError),
126
127 #[error("invalid identifier: {0}")]
129 InvalidIdentifier(String),
130
131 #[error("connection pool exhausted")]
133 PoolExhausted,
134
135 #[error("query cancellation failed: {0}")]
137 Cancel(String),
138
139 #[error("query cancelled")]
141 Cancelled,
142}
143
144#[derive(Debug, Clone)]
153pub struct SharedIoError(Arc<std::io::Error>);
154
155impl std::fmt::Display for SharedIoError {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 self.0.fmt(f)
158 }
159}
160
161impl std::error::Error for SharedIoError {
162 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
163 self.0.source()
164 }
165}
166
167impl From<std::io::Error> for Error {
168 fn from(e: std::io::Error) -> Self {
169 Error::Io(SharedIoError(Arc::new(e)))
170 }
171}
172
173impl Error {
174 #[must_use]
190 pub fn is_transient(&self) -> bool {
191 match self {
192 Self::ConnectTimeout { .. }
193 | Self::TlsTimeout { .. }
194 | Self::LoginTimeout { .. }
195 | Self::CommandTimeout
196 | Self::ConnectionClosed
197 | Self::Connection(_)
198 | Self::Routing { .. }
199 | Self::PoolExhausted
200 | Self::Io(_) => true,
201 Self::Server { number, .. } => Self::is_transient_server_error(*number),
202 _ => false,
203 }
204 }
205
206 #[must_use]
232 pub fn is_transient_server_error(number: i32) -> bool {
233 matches!(
234 number,
235 1205 | -2 | 10928 | 10929 | 40197 | 40501 | 40613 | 49918 | 49919 | 49920 | 4060 | 18456 )
248 }
249
250 #[must_use]
263 pub fn is_terminal(&self) -> bool {
264 match self {
265 Self::Config(_)
266 | Self::InvalidIdentifier(_)
267 | Self::Protocol(_)
268 | Self::ProtocolError(_)
269 | Self::Tls(_)
270 | Self::Authentication(_)
271 | Self::Cancel(_) => true,
272 Self::Server { number, .. } => Self::is_terminal_server_error(*number),
273 _ => false,
274 }
275 }
276
277 #[must_use]
281 pub fn is_terminal_server_error(number: i32) -> bool {
282 matches!(
283 number,
284 102 | 207 | 208 | 547 | 2627 | 2601 )
291 }
292
293 #[must_use]
298 pub fn is_protocol_error(&self) -> bool {
299 matches!(self, Self::Protocol(_) | Self::ProtocolError(_))
300 }
301
302 #[must_use]
307 pub fn is_tls_error(&self) -> bool {
308 matches!(self, Self::Tls(_) | Self::TlsTimeout { .. })
309 }
310
311 #[must_use]
313 pub fn is_authentication_error(&self) -> bool {
314 matches!(self, Self::Authentication(_))
315 }
316
317 #[must_use]
322 pub fn is_config_error(&self) -> bool {
323 matches!(self, Self::Config(_))
324 }
325
326 #[must_use]
328 pub fn is_server_error(&self, number: i32) -> bool {
329 matches!(self, Self::Server { number: n, .. } if *n == number)
330 }
331
332 #[must_use]
340 pub fn class(&self) -> Option<u8> {
341 match self {
342 Self::Server { class, .. } => Some(*class),
343 _ => None,
344 }
345 }
346
347 #[must_use]
349 pub fn severity(&self) -> Option<u8> {
350 self.class()
351 }
352}
353
354fn format_server_location(
356 server: &Option<String>,
357 procedure: &Option<String>,
358 line: &u32,
359) -> String {
360 let mut parts = Vec::new();
361 if let Some(srv) = server {
362 if !srv.is_empty() {
363 parts.push(format!("server: {srv}"));
364 }
365 }
366 if let Some(proc) = procedure {
367 if !proc.is_empty() {
368 parts.push(format!("procedure: {proc}"));
369 }
370 }
371 if *line > 0 {
372 parts.push(format!("line: {line}"));
373 }
374 if parts.is_empty() {
375 String::new()
376 } else {
377 format!(" [{}]", parts.join(", "))
378 }
379}
380
381pub type Result<T> = std::result::Result<T, Error>;
383
384#[cfg(test)]
385#[allow(clippy::unwrap_used)]
386mod tests {
387 use super::*;
388 use std::sync::Arc;
389
390 fn make_server_error(number: i32) -> Error {
391 Error::Server {
392 number,
393 class: 16,
394 state: 1,
395 message: "Test error".to_string(),
396 server: None,
397 procedure: None,
398 line: 1,
399 }
400 }
401
402 #[test]
403 fn test_is_transient_connection_errors() {
404 assert!(
405 Error::ConnectTimeout {
406 host: "test".into(),
407 port: 1433
408 }
409 .is_transient()
410 );
411 assert!(
412 Error::TlsTimeout {
413 host: "test".into(),
414 port: 1433
415 }
416 .is_transient()
417 );
418 assert!(
419 Error::LoginTimeout {
420 host: "test".into(),
421 port: 1433
422 }
423 .is_transient()
424 );
425 assert!(Error::CommandTimeout.is_transient());
426 assert!(Error::ConnectionClosed.is_transient());
427 assert!(Error::PoolExhausted.is_transient());
428 assert!(
429 Error::Routing {
430 host: "test".into(),
431 port: 1433,
432 }
433 .is_transient()
434 );
435 }
436
437 #[test]
438 fn test_is_transient_io_error() {
439 let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "reset");
440 assert!(Error::Io(SharedIoError(Arc::new(io_err))).is_transient());
441 }
442
443 #[test]
444 fn test_is_transient_server_errors_deadlock() {
445 assert!(make_server_error(1205).is_transient());
447 }
448
449 #[test]
450 fn test_is_transient_server_errors_timeout() {
451 assert!(make_server_error(-2).is_transient());
453 }
454
455 #[test]
456 fn test_is_transient_server_errors_azure() {
457 assert!(make_server_error(10928).is_transient()); assert!(make_server_error(10929).is_transient()); assert!(make_server_error(40197).is_transient()); assert!(make_server_error(40501).is_transient()); assert!(make_server_error(40613).is_transient()); assert!(make_server_error(49918).is_transient()); assert!(make_server_error(49919).is_transient()); assert!(make_server_error(49920).is_transient()); }
467
468 #[test]
469 fn test_is_transient_server_errors_other() {
470 assert!(make_server_error(4060).is_transient()); assert!(make_server_error(18456).is_transient()); }
474
475 #[test]
476 fn test_is_not_transient() {
477 assert!(!Error::Config("bad config".into()).is_transient());
479 assert!(!Error::Query("syntax error".into()).is_transient());
480 assert!(!Error::InvalidIdentifier("bad id".into()).is_transient());
481 assert!(!make_server_error(102).is_transient()); }
483
484 #[test]
485 fn test_is_terminal_server_errors() {
486 assert!(make_server_error(102).is_terminal()); assert!(make_server_error(207).is_terminal()); assert!(make_server_error(208).is_terminal()); assert!(make_server_error(547).is_terminal()); assert!(make_server_error(2627).is_terminal()); assert!(make_server_error(2601).is_terminal()); }
494
495 #[test]
496 fn test_is_terminal_config_errors() {
497 assert!(Error::Config("bad config".into()).is_terminal());
498 assert!(Error::InvalidIdentifier("bad id".into()).is_terminal());
499 }
500
501 #[test]
502 fn test_is_not_terminal() {
503 assert!(
505 !Error::ConnectTimeout {
506 host: "test".into(),
507 port: 1433
508 }
509 .is_terminal()
510 );
511 assert!(!make_server_error(1205).is_terminal()); assert!(!make_server_error(40501).is_terminal()); }
514
515 #[test]
516 fn test_transient_server_error_static() {
517 assert!(Error::is_transient_server_error(1205));
519 assert!(Error::is_transient_server_error(40501));
520 assert!(!Error::is_transient_server_error(102));
521 }
522
523 #[test]
524 fn test_terminal_server_error_static() {
525 assert!(Error::is_terminal_server_error(102));
527 assert!(Error::is_terminal_server_error(2627));
528 assert!(!Error::is_terminal_server_error(1205));
529 }
530
531 #[test]
532 fn test_error_class() {
533 let err = make_server_error(102);
534 assert_eq!(err.class(), Some(16));
535 assert_eq!(err.severity(), Some(16));
536
537 assert_eq!(
538 Error::ConnectTimeout {
539 host: "test".into(),
540 port: 1433
541 }
542 .class(),
543 None
544 );
545 }
546
547 #[test]
548 fn test_is_server_error() {
549 let err = make_server_error(102);
550 assert!(err.is_server_error(102));
551 assert!(!err.is_server_error(103));
552
553 assert!(
554 !Error::ConnectTimeout {
555 host: "test".into(),
556 port: 1433
557 }
558 .is_server_error(102)
559 );
560 }
561}