use std::fmt;
#[derive(Debug)]
pub enum MsSqlError {
Connection(ConnectionError),
Lsn(LsnError),
PrimaryKey(PrimaryKeyError),
InvalidIdentifier(String),
Query(String),
Config(String),
Other(String),
}
impl fmt::Display for MsSqlError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Connection(e) => write!(f, "Connection error: {e}"),
Self::Lsn(e) => write!(f, "LSN error: {e}"),
Self::PrimaryKey(e) => write!(f, "Primary key error: {e}"),
Self::InvalidIdentifier(msg) => write!(f, "Invalid SQL identifier: {msg}"),
Self::Query(msg) => write!(f, "Query error: {msg}"),
Self::Config(msg) => write!(f, "Configuration error: {msg}"),
Self::Other(msg) => write!(f, "{msg}"),
}
}
}
impl std::error::Error for MsSqlError {}
#[derive(Debug)]
pub enum ConnectionError {
Failed(String),
Lost(String),
Timeout(String),
AuthenticationFailed(String),
NetworkUnreachable(String),
Refused(String),
Unhealthy {
consecutive_errors: u32,
last_error: String,
},
}
impl fmt::Display for ConnectionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Failed(msg) => write!(f, "Failed to connect: {msg}"),
Self::Lost(msg) => write!(f, "Connection lost: {msg}"),
Self::Timeout(msg) => write!(f, "Connection timed out: {msg}"),
Self::AuthenticationFailed(msg) => write!(f, "Authentication failed: {msg}"),
Self::NetworkUnreachable(msg) => write!(f, "Network unreachable: {msg}"),
Self::Refused(msg) => write!(f, "Connection refused: {msg}"),
Self::Unhealthy {
consecutive_errors,
last_error,
} => {
write!(f, "Connection unhealthy after {consecutive_errors} consecutive errors: {last_error}")
}
}
}
}
#[derive(Debug)]
pub enum LsnError {
Invalid(String),
OutOfRange(String),
ParseFailed(String),
NotAvailable(String),
}
impl fmt::Display for LsnError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Invalid(msg) => write!(f, "Invalid LSN: {msg}"),
Self::OutOfRange(msg) => write!(f, "LSN out of range: {msg}"),
Self::ParseFailed(msg) => write!(f, "Failed to parse LSN: {msg}"),
Self::NotAvailable(msg) => write!(f, "LSN not available: {msg}"),
}
}
}
#[derive(Debug)]
pub enum PrimaryKeyError {
NotConfigured { table: String },
ColumnNotFound { table: String, column: String },
AllNull { table: String, columns: Vec<String> },
}
impl fmt::Display for PrimaryKeyError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NotConfigured { table } => {
write!(
f,
"No primary key configured for table '{table}'. \
Add a 'table_keys' configuration entry to specify the primary key columns."
)
}
Self::ColumnNotFound { table, column } => {
write!(
f,
"Primary key column '{column}' not found in row for table '{table}'. \
Check that the column name in 'table_keys' matches the actual column name."
)
}
Self::AllNull { table, columns } => {
write!(
f,
"All primary key values are NULL for table '{table}' (columns: {columns:?}). \
Cannot generate a stable element ID."
)
}
}
}
}
impl MsSqlError {
pub fn is_connection_error(&self) -> bool {
matches!(self, Self::Connection(_))
}
pub fn is_recoverable_lsn_error(&self) -> bool {
matches!(
self,
Self::Lsn(LsnError::Invalid(_) | LsnError::OutOfRange(_))
)
}
pub fn from_connection_error(error: impl ToString) -> Self {
let error_str = error.to_string().to_lowercase();
if error_str.contains("timed out") || error_str.contains("timeout") {
Self::Connection(ConnectionError::Timeout(error.to_string()))
} else if error_str.contains("refused") {
Self::Connection(ConnectionError::Refused(error.to_string()))
} else if error_str.contains("unreachable") {
Self::Connection(ConnectionError::NetworkUnreachable(error.to_string()))
} else if error_str.contains("authentication") || error_str.contains("login") {
Self::Connection(ConnectionError::AuthenticationFailed(error.to_string()))
} else if error_str.contains("reset")
|| error_str.contains("broken pipe")
|| error_str.contains("closed")
|| error_str.contains("eof")
{
Self::Connection(ConnectionError::Lost(error.to_string()))
} else {
Self::Connection(ConnectionError::Failed(error.to_string()))
}
}
pub fn classify(error: &anyhow::Error) -> Option<MsSqlErrorKind> {
if let Some(mssql_err) = error.downcast_ref::<MsSqlError>() {
return Some(match mssql_err {
MsSqlError::Connection(_) => MsSqlErrorKind::Connection,
MsSqlError::Lsn(LsnError::Invalid(_) | LsnError::OutOfRange(_)) => {
MsSqlErrorKind::RecoverableLsn
}
MsSqlError::Lsn(_) => MsSqlErrorKind::Other,
MsSqlError::PrimaryKey(_) => MsSqlErrorKind::Other,
MsSqlError::InvalidIdentifier(_) => MsSqlErrorKind::Other,
MsSqlError::Query(_) => MsSqlErrorKind::Other,
MsSqlError::Config(_) => MsSqlErrorKind::Other,
MsSqlError::Other(_) => MsSqlErrorKind::Other,
});
}
let error_str = error.to_string().to_lowercase();
if error_str.contains("connection")
|| error_str.contains("broken pipe")
|| error_str.contains("reset by peer")
|| error_str.contains("timed out")
|| error_str.contains("network")
|| error_str.contains("socket")
|| error_str.contains("eof")
|| error_str.contains("closed")
|| error_str.contains("refused")
|| error_str.contains("unreachable")
{
return Some(MsSqlErrorKind::Connection);
}
if error_str.contains("lsn")
&& (error_str.contains("invalid") || error_str.contains("out of range"))
{
return Some(MsSqlErrorKind::RecoverableLsn);
}
None
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MsSqlErrorKind {
Connection,
RecoverableLsn,
Other,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_error_display() {
let err = MsSqlError::Connection(ConnectionError::Lost("connection reset".to_string()));
assert!(err.to_string().contains("Connection lost"));
assert!(err.is_connection_error());
}
#[test]
fn test_lsn_error_display() {
let err = MsSqlError::Lsn(LsnError::OutOfRange("LSN too old".to_string()));
assert!(err.to_string().contains("out of range"));
assert!(err.is_recoverable_lsn_error());
}
#[test]
fn test_primary_key_error_display() {
let err = MsSqlError::PrimaryKey(PrimaryKeyError::NotConfigured {
table: "orders".to_string(),
});
assert!(err.to_string().contains("No primary key configured"));
assert!(err.to_string().contains("orders"));
}
#[test]
fn test_classify_connection_error() {
let err = anyhow::anyhow!("connection reset by peer");
assert_eq!(MsSqlError::classify(&err), Some(MsSqlErrorKind::Connection));
let err = anyhow::anyhow!("broken pipe");
assert_eq!(MsSqlError::classify(&err), Some(MsSqlErrorKind::Connection));
let err = anyhow::anyhow!("network unreachable");
assert_eq!(MsSqlError::classify(&err), Some(MsSqlErrorKind::Connection));
}
#[test]
fn test_classify_lsn_error() {
let err = anyhow::anyhow!("The specified LSN is invalid or out of range");
assert_eq!(
MsSqlError::classify(&err),
Some(MsSqlErrorKind::RecoverableLsn)
);
}
#[test]
fn test_classify_unknown_error() {
let err = anyhow::anyhow!("some random error");
assert_eq!(MsSqlError::classify(&err), None);
}
#[test]
fn test_from_connection_error() {
let err = MsSqlError::from_connection_error("connection timed out");
assert!(matches!(
err,
MsSqlError::Connection(ConnectionError::Timeout(_))
));
let err = MsSqlError::from_connection_error("connection refused");
assert!(matches!(
err,
MsSqlError::Connection(ConnectionError::Refused(_))
));
let err = MsSqlError::from_connection_error("broken pipe");
assert!(matches!(
err,
MsSqlError::Connection(ConnectionError::Lost(_))
));
}
}