use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SqlxErrorKind {
None,
Database,
UniqueConstraint,
ForeignKeyConstraint,
CheckConstraint,
NullConstraint,
Deadlock,
SerializationFailure,
Io,
Tls,
Protocol,
RowNotFound,
TypeNotFound,
ColumnIndexOutOfBounds,
ColumnNotFound,
ColumnDecode,
Decode,
PoolTimedOut,
PoolClosed,
WorkerCrashed,
Configuration,
}
impl SqlxErrorKind {
pub fn from_sqlx(e: &sqlx::Error) -> Self {
match e {
sqlx::Error::Database(db_err) => {
if db_err.is_unique_violation() {
SqlxErrorKind::UniqueConstraint
} else if db_err.is_foreign_key_violation() {
SqlxErrorKind::ForeignKeyConstraint
} else if db_err.is_check_violation() {
SqlxErrorKind::CheckConstraint
} else {
let state = db_err.code();
let state = state.as_deref().unwrap_or("");
let msg = db_err.message();
if state == "23502"
|| msg.contains("NOT NULL constraint")
|| msg.contains("not-null constraint")
|| msg.contains("cannot be null")
{
SqlxErrorKind::NullConstraint
} else if state == "40P01" || msg.to_ascii_lowercase().contains("deadlock") {
SqlxErrorKind::Deadlock
} else if state == "40001"
|| msg.to_ascii_lowercase().contains("serialization failure")
|| msg.to_ascii_lowercase().contains("could not serialize")
{
SqlxErrorKind::SerializationFailure
} else {
SqlxErrorKind::Database
}
}
}
sqlx::Error::Io(_) => SqlxErrorKind::Io,
sqlx::Error::Tls(_) => SqlxErrorKind::Tls,
sqlx::Error::Protocol(_) => SqlxErrorKind::Protocol,
sqlx::Error::RowNotFound => SqlxErrorKind::RowNotFound,
sqlx::Error::TypeNotFound { .. } => SqlxErrorKind::TypeNotFound,
sqlx::Error::ColumnIndexOutOfBounds { .. } => SqlxErrorKind::ColumnIndexOutOfBounds,
sqlx::Error::ColumnNotFound(_) => SqlxErrorKind::ColumnNotFound,
sqlx::Error::ColumnDecode { .. } => SqlxErrorKind::ColumnDecode,
sqlx::Error::Decode(_) => SqlxErrorKind::Decode,
sqlx::Error::PoolTimedOut => SqlxErrorKind::PoolTimedOut,
sqlx::Error::PoolClosed => SqlxErrorKind::PoolClosed,
sqlx::Error::WorkerCrashed => SqlxErrorKind::WorkerCrashed,
sqlx::Error::Configuration(_) => SqlxErrorKind::Configuration,
#[allow(unreachable_patterns)]
_ => SqlxErrorKind::None,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ConnectorError {
Database(SqlxErrorKind, String),
Connection(SqlxErrorKind, String),
RowDecode(SqlxErrorKind, String),
Core(nautilus_core::Error),
}
impl ConnectorError {
pub fn database(e: sqlx::Error, context: &str) -> Self {
ConnectorError::Database(SqlxErrorKind::from_sqlx(&e), format!("{}: {}", context, e))
}
pub fn connection(e: sqlx::Error, context: &str) -> Self {
ConnectorError::Connection(SqlxErrorKind::from_sqlx(&e), format!("{}: {}", context, e))
}
pub fn row_decode(e: sqlx::Error, context: &str) -> Self {
ConnectorError::RowDecode(SqlxErrorKind::from_sqlx(&e), format!("{}: {}", context, e))
}
pub fn database_msg(msg: impl Into<String>) -> Self {
ConnectorError::Database(SqlxErrorKind::None, msg.into())
}
pub fn connection_msg(msg: impl Into<String>) -> Self {
ConnectorError::Connection(SqlxErrorKind::None, msg.into())
}
pub fn row_decode_msg(msg: impl Into<String>) -> Self {
ConnectorError::RowDecode(SqlxErrorKind::None, msg.into())
}
pub fn sqlx_kind(&self) -> SqlxErrorKind {
match self {
ConnectorError::Database(k, _)
| ConnectorError::Connection(k, _)
| ConnectorError::RowDecode(k, _) => *k,
ConnectorError::Core(_) => SqlxErrorKind::None,
}
}
}
impl fmt::Display for ConnectorError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ConnectorError::Database(_, msg) => write!(f, "Database error: {}", msg),
ConnectorError::Connection(_, msg) => write!(f, "Connection error: {}", msg),
ConnectorError::RowDecode(_, msg) => write!(f, "Row decode error: {}", msg),
ConnectorError::Core(e) => write!(f, "Core error: {}", e),
}
}
}
impl std::error::Error for ConnectorError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
ConnectorError::Core(e) => Some(e),
_ => None,
}
}
}
impl From<nautilus_core::Error> for ConnectorError {
fn from(e: nautilus_core::Error) -> Self {
ConnectorError::Core(e)
}
}
pub type Result<T> = std::result::Result<T, ConnectorError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_display() {
assert_eq!(
ConnectorError::database_msg("query failed").to_string(),
"Database error: query failed"
);
assert_eq!(
ConnectorError::connection_msg("refused").to_string(),
"Connection error: refused"
);
assert_eq!(
ConnectorError::row_decode_msg("invalid bool").to_string(),
"Row decode error: invalid bool"
);
}
#[test]
fn test_sqlx_kind() {
let err = ConnectorError::database_msg("test");
assert_eq!(err.sqlx_kind(), SqlxErrorKind::None);
let err = ConnectorError::Database(SqlxErrorKind::PoolTimedOut, "timeout".to_string());
assert_eq!(err.sqlx_kind(), SqlxErrorKind::PoolTimedOut);
}
#[test]
fn test_from_core_error() {
let core_err = nautilus_core::Error::InvalidQuery("bad query".to_string());
let conn_err = ConnectorError::from(core_err.clone());
assert_eq!(conn_err, ConnectorError::Core(core_err));
assert!(conn_err.to_string().contains("bad query"));
}
}