use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum Error {
#[error("connection failed: {0}")]
Connection(String),
#[error("connection closed")]
ConnectionClosed,
#[error("authentication failed: {0}")]
Authentication(#[from] mssql_auth::AuthError),
#[cfg(feature = "tls")]
#[error("TLS error: {0}")]
Tls(#[from] mssql_tls::TlsError),
#[cfg(not(feature = "tls"))]
#[error("TLS error: {0}")]
Tls(String),
#[error("protocol error: {0}")]
ProtocolError(#[from] tds_protocol::ProtocolError),
#[error("protocol error: {0}")]
Protocol(String),
#[error("codec error: {0}")]
Codec(mssql_codec::CodecError),
#[error(
"response too large: {size} bytes exceeds the configured {limit}-byte cap; \
paginate, narrow the SELECT, or raise Config::max_response_size"
)]
ResponseTooLarge {
size: usize,
limit: usize,
},
#[error("type error: {0}")]
Type(#[from] mssql_types::TypeError),
#[error("query error: {0}")]
Query(String),
#[error("server error {number} (severity {class}, state {state}): {message}{}", format_server_location(.server, .procedure, .line))]
Server {
number: i32,
class: u8,
state: u8,
message: String,
server: Option<String>,
procedure: Option<String>,
line: u32,
},
#[error("configuration error: {0}")]
Config(String),
#[error("TCP connection timed out connecting to {host}:{port}")]
ConnectTimeout {
host: String,
port: u16,
},
#[error("TLS handshake timed out with {host}:{port}")]
TlsTimeout {
host: String,
port: u16,
},
#[error("login timed out for {host}:{port}")]
LoginTimeout {
host: String,
port: u16,
},
#[error("command timed out")]
CommandTimeout,
#[error("routing required to {host}:{port}")]
Routing {
host: String,
port: u16,
},
#[error("too many redirects (max {max})")]
TooManyRedirects {
max: u8,
},
#[error("IO error: {0}")]
Io(#[source] SharedIoError),
#[error("invalid identifier: {0}")]
InvalidIdentifier(String),
#[error("connection pool exhausted")]
PoolExhausted,
#[error("query cancellation failed: {0}")]
Cancel(String),
#[error("query cancelled")]
Cancelled,
#[error("SQL Browser resolution failed for instance '{instance}': {reason}")]
BrowserResolution {
instance: String,
reason: String,
},
#[cfg(all(windows, feature = "filestream"))]
#[error("FILESTREAM error: {0}")]
FileStream(String),
#[cfg(feature = "always-encrypted")]
#[error("encryption error: {0}")]
Encryption(String),
}
#[derive(Debug, Clone)]
pub struct SharedIoError(Arc<std::io::Error>);
impl std::fmt::Display for SharedIoError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::error::Error for SharedIoError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.0.source()
}
}
impl From<mssql_codec::CodecError> for Error {
fn from(e: mssql_codec::CodecError) -> Self {
match e {
mssql_codec::CodecError::MessageTooLarge { size, limit } => {
Self::ResponseTooLarge { size, limit }
}
other => Self::Codec(other),
}
}
}
impl From<std::io::Error> for Error {
fn from(e: std::io::Error) -> Self {
Error::Io(SharedIoError(Arc::new(e)))
}
}
#[cfg(feature = "always-encrypted")]
impl From<mssql_auth::EncryptionError> for Error {
fn from(e: mssql_auth::EncryptionError) -> Self {
Error::Encryption(e.to_string())
}
}
impl Error {
#[must_use]
pub fn is_transient(&self) -> bool {
match self {
Self::ConnectTimeout { .. }
| Self::TlsTimeout { .. }
| Self::LoginTimeout { .. }
| Self::CommandTimeout
| Self::ConnectionClosed
| Self::Connection(_)
| Self::Routing { .. }
| Self::PoolExhausted
| Self::Io(_) => true,
Self::Server { number, .. } => Self::is_transient_server_error(*number),
_ => false,
}
}
#[must_use]
pub fn is_transient_server_error(number: i32) -> bool {
matches!(
number,
1205 | -2 | 10928 | 10929 | 40197 | 40501 | 40613 | 49918 | 49919 | 49920 | 4060 | 18456 )
}
#[must_use]
pub fn is_terminal(&self) -> bool {
match self {
Self::Config(_)
| Self::InvalidIdentifier(_)
| Self::Protocol(_)
| Self::ProtocolError(_)
| Self::Tls(_)
| Self::Authentication(_)
| Self::Cancel(_) => true,
Self::Server { number, .. } => Self::is_terminal_server_error(*number),
_ => false,
}
}
#[must_use]
pub fn is_terminal_server_error(number: i32) -> bool {
matches!(
number,
102 | 207 | 208 | 547 | 2627 | 2601 )
}
#[must_use]
pub fn is_protocol_error(&self) -> bool {
matches!(self, Self::Protocol(_) | Self::ProtocolError(_))
}
#[must_use]
pub fn is_tls_error(&self) -> bool {
matches!(self, Self::Tls(_) | Self::TlsTimeout { .. })
}
#[must_use]
pub fn is_authentication_error(&self) -> bool {
matches!(self, Self::Authentication(_))
}
#[must_use]
pub fn is_config_error(&self) -> bool {
matches!(self, Self::Config(_))
}
#[must_use]
pub fn is_server_error(&self, number: i32) -> bool {
matches!(self, Self::Server { number: n, .. } if *n == number)
}
#[must_use]
pub fn class(&self) -> Option<u8> {
match self {
Self::Server { class, .. } => Some(*class),
_ => None,
}
}
#[must_use]
pub fn severity(&self) -> Option<u8> {
self.class()
}
}
fn format_server_location(
server: &Option<String>,
procedure: &Option<String>,
line: &u32,
) -> String {
let mut parts = Vec::new();
if let Some(srv) = server {
if !srv.is_empty() {
parts.push(format!("server: {srv}"));
}
}
if let Some(proc) = procedure {
if !proc.is_empty() {
parts.push(format!("procedure: {proc}"));
}
}
if *line > 0 {
parts.push(format!("line: {line}"));
}
if parts.is_empty() {
String::new()
} else {
format!(" [{}]", parts.join(", "))
}
}
pub type Result<T> = std::result::Result<T, Error>;
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::sync::Arc;
fn make_server_error(number: i32) -> Error {
Error::Server {
number,
class: 16,
state: 1,
message: "Test error".to_string(),
server: None,
procedure: None,
line: 1,
}
}
#[test]
fn test_is_transient_connection_errors() {
assert!(
Error::ConnectTimeout {
host: "test".into(),
port: 1433
}
.is_transient()
);
assert!(
Error::TlsTimeout {
host: "test".into(),
port: 1433
}
.is_transient()
);
assert!(
Error::LoginTimeout {
host: "test".into(),
port: 1433
}
.is_transient()
);
assert!(Error::CommandTimeout.is_transient());
assert!(Error::ConnectionClosed.is_transient());
assert!(Error::PoolExhausted.is_transient());
assert!(
Error::Routing {
host: "test".into(),
port: 1433,
}
.is_transient()
);
}
#[test]
fn test_is_transient_io_error() {
let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "reset");
assert!(Error::Io(SharedIoError(Arc::new(io_err))).is_transient());
}
#[test]
fn test_is_transient_server_errors_deadlock() {
assert!(make_server_error(1205).is_transient());
}
#[test]
fn test_is_transient_server_errors_timeout() {
assert!(make_server_error(-2).is_transient());
}
#[test]
fn test_is_transient_server_errors_azure() {
assert!(make_server_error(10928).is_transient()); assert!(make_server_error(10929).is_transient()); assert!(make_server_error(40197).is_transient()); assert!(make_server_error(40501).is_transient()); assert!(make_server_error(40613).is_transient()); assert!(make_server_error(49918).is_transient()); assert!(make_server_error(49919).is_transient()); assert!(make_server_error(49920).is_transient()); }
#[test]
fn test_is_transient_server_errors_other() {
assert!(make_server_error(4060).is_transient()); assert!(make_server_error(18456).is_transient()); }
#[test]
fn test_is_not_transient() {
assert!(!Error::Config("bad config".into()).is_transient());
assert!(!Error::Query("syntax error".into()).is_transient());
assert!(!Error::InvalidIdentifier("bad id".into()).is_transient());
assert!(!make_server_error(102).is_transient()); }
#[test]
fn test_is_terminal_server_errors() {
assert!(make_server_error(102).is_terminal()); assert!(make_server_error(207).is_terminal()); assert!(make_server_error(208).is_terminal()); assert!(make_server_error(547).is_terminal()); assert!(make_server_error(2627).is_terminal()); assert!(make_server_error(2601).is_terminal()); }
#[test]
fn test_is_terminal_config_errors() {
assert!(Error::Config("bad config".into()).is_terminal());
assert!(Error::InvalidIdentifier("bad id".into()).is_terminal());
}
#[test]
fn test_is_not_terminal() {
assert!(
!Error::ConnectTimeout {
host: "test".into(),
port: 1433
}
.is_terminal()
);
assert!(!make_server_error(1205).is_terminal()); assert!(!make_server_error(40501).is_terminal()); }
#[test]
fn test_transient_server_error_static() {
assert!(Error::is_transient_server_error(1205));
assert!(Error::is_transient_server_error(40501));
assert!(!Error::is_transient_server_error(102));
}
#[test]
fn test_terminal_server_error_static() {
assert!(Error::is_terminal_server_error(102));
assert!(Error::is_terminal_server_error(2627));
assert!(!Error::is_terminal_server_error(1205));
}
#[test]
fn test_error_class() {
let err = make_server_error(102);
assert_eq!(err.class(), Some(16));
assert_eq!(err.severity(), Some(16));
assert_eq!(
Error::ConnectTimeout {
host: "test".into(),
port: 1433
}
.class(),
None
);
}
#[test]
fn test_is_server_error() {
let err = make_server_error(102);
assert!(err.is_server_error(102));
assert!(!err.is_server_error(103));
assert!(
!Error::ConnectTimeout {
host: "test".into(),
port: 1433
}
.is_server_error(102)
);
}
}