use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)]
pub enum Error {
#[error("Connection error: {0}")]
Connection(String),
#[error("Query error: {code} - {message}")]
Query { code: String, message: String },
#[error("Authentication error: {0}")]
Auth(String),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("QUIC error: {0}")]
Quic(String),
#[error("TLS error: {0}")]
Tls(String),
#[error("Invalid DSN: {0}")]
InvalidDsn(String),
#[error("Type error: {0}")]
Type(String),
#[error("Operation timed out")]
Timeout,
#[error("Pool error: {0}")]
Pool(String),
#[error("Validation error: {0}")]
Validation(String),
#[error("{0}")]
Other(String),
}
impl Error {
pub fn connection<S: Into<String>>(msg: S) -> Self {
Error::Connection(msg.into())
}
pub fn query<S: Into<String>>(msg: S) -> Self {
Error::Query {
code: "QUERY_ERROR".to_string(),
message: msg.into(),
}
}
pub fn protocol<S: Into<String>>(msg: S) -> Self {
Error::Connection(format!("Protocol error: {}", msg.into()))
}
pub fn transaction<S: Into<String>>(msg: S) -> Self {
Error::Connection(format!("Transaction error: {}", msg.into()))
}
pub fn timeout() -> Self {
Error::Timeout
}
pub fn auth<S: Into<String>>(msg: S) -> Self {
Error::Auth(msg.into())
}
pub fn quic<S: Into<String>>(msg: S) -> Self {
Error::Quic(msg.into())
}
pub fn tls<S: Into<String>>(msg: S) -> Self {
Error::Tls(msg.into())
}
pub fn type_error<S: Into<String>>(msg: S) -> Self {
Error::Type(msg.into())
}
pub fn pool<S: Into<String>>(msg: S) -> Self {
Error::Pool(msg.into())
}
pub fn validation<S: Into<String>>(msg: S) -> Self {
Error::Validation(msg.into())
}
pub fn invalid_dsn<S: Into<String>>(msg: S) -> Self {
Error::InvalidDsn(msg.into())
}
pub fn is_retryable(&self) -> bool {
match self {
Error::Connection(_) => true,
Error::Timeout => true,
Error::Quic(_) => true,
Error::Query { code, .. } => {
code == "40001" || code == "40P01" || code == "40502"
}
Error::Pool(_) => true,
_ => false,
}
}
pub fn code(&self) -> Option<&str> {
match self {
Error::Query { code, .. } => Some(code),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
#[test]
fn test_error_display_connection() {
let err = Error::Connection("connection refused".to_string());
assert_eq!(err.to_string(), "Connection error: connection refused");
}
#[test]
fn test_error_display_query() {
let err = Error::Query {
code: "42000".to_string(),
message: "syntax error".to_string(),
};
assert_eq!(err.to_string(), "Query error: 42000 - syntax error");
}
#[test]
fn test_error_display_auth() {
let err = Error::Auth("invalid credentials".to_string());
assert_eq!(err.to_string(), "Authentication error: invalid credentials");
}
#[test]
fn test_error_display_io() {
let io_err = io::Error::new(io::ErrorKind::NotFound, "file not found");
let err = Error::Io(io_err);
assert!(err.to_string().starts_with("I/O error:"));
}
#[test]
fn test_error_display_json() {
let json_err: serde_json::Error = serde_json::from_str::<i32>("invalid").unwrap_err();
let err = Error::Json(json_err);
assert!(err.to_string().starts_with("JSON error:"));
}
#[test]
fn test_error_display_quic() {
let err = Error::Quic("connection reset".to_string());
assert_eq!(err.to_string(), "QUIC error: connection reset");
}
#[test]
fn test_error_display_tls() {
let err = Error::Tls("certificate expired".to_string());
assert_eq!(err.to_string(), "TLS error: certificate expired");
}
#[test]
fn test_error_display_invalid_dsn() {
let err = Error::InvalidDsn("missing host".to_string());
assert_eq!(err.to_string(), "Invalid DSN: missing host");
}
#[test]
fn test_error_display_type() {
let err = Error::Type("cannot convert int to string".to_string());
assert_eq!(err.to_string(), "Type error: cannot convert int to string");
}
#[test]
fn test_error_display_timeout() {
let err = Error::Timeout;
assert_eq!(err.to_string(), "Operation timed out");
}
#[test]
fn test_error_display_pool() {
let err = Error::Pool("pool exhausted".to_string());
assert_eq!(err.to_string(), "Pool error: pool exhausted");
}
#[test]
fn test_error_display_other() {
let err = Error::Other("unknown error".to_string());
assert_eq!(err.to_string(), "unknown error");
}
#[test]
fn test_error_from_io() {
let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "refused");
let err: Error = io_err.into();
assert!(matches!(err, Error::Io(_)));
}
#[test]
fn test_error_from_json() {
let json_err: serde_json::Error = serde_json::from_str::<i32>("not_a_number").unwrap_err();
let err: Error = json_err.into();
assert!(matches!(err, Error::Json(_)));
}
#[test]
fn test_error_helper_connection() {
let err = Error::connection("test connection error");
assert!(matches!(err, Error::Connection(msg) if msg == "test connection error"));
}
#[test]
fn test_error_helper_query() {
let err = Error::query("test query error");
assert!(matches!(err, Error::Query { code, message }
if code == "QUERY_ERROR" && message == "test query error"));
}
#[test]
fn test_error_helper_protocol() {
let err = Error::protocol("invalid frame");
assert!(matches!(err, Error::Connection(msg) if msg.contains("Protocol error")));
}
#[test]
fn test_error_helper_transaction() {
let err = Error::transaction("rollback failed");
assert!(matches!(err, Error::Connection(msg) if msg.contains("Transaction error")));
}
#[test]
fn test_error_helper_timeout() {
let err = Error::timeout();
assert!(matches!(err, Error::Timeout));
}
#[test]
fn test_error_helper_auth() {
let err = Error::auth("bad token");
assert!(matches!(err, Error::Auth(msg) if msg == "bad token"));
}
#[test]
fn test_error_helper_quic() {
let err = Error::quic("stream closed");
assert!(matches!(err, Error::Quic(msg) if msg == "stream closed"));
}
#[test]
fn test_error_helper_tls() {
let err = Error::tls("handshake failed");
assert!(matches!(err, Error::Tls(msg) if msg == "handshake failed"));
}
#[test]
fn test_error_helper_type_error() {
let err = Error::type_error("invalid cast");
assert!(matches!(err, Error::Type(msg) if msg == "invalid cast"));
}
#[test]
fn test_error_helper_pool() {
let err = Error::pool("no connections available");
assert!(matches!(err, Error::Pool(msg) if msg == "no connections available"));
}
#[test]
fn test_error_is_retryable_connection() {
let err = Error::Connection("network error".to_string());
assert!(err.is_retryable());
}
#[test]
fn test_error_is_retryable_timeout() {
let err = Error::Timeout;
assert!(err.is_retryable());
}
#[test]
fn test_error_is_retryable_quic() {
let err = Error::Quic("reset".to_string());
assert!(err.is_retryable());
}
#[test]
fn test_error_is_retryable_pool() {
let err = Error::Pool("exhausted".to_string());
assert!(err.is_retryable());
}
#[test]
fn test_error_is_retryable_serialization_failure() {
let err = Error::Query {
code: "40001".to_string(),
message: "serialization failure".to_string(),
};
assert!(err.is_retryable());
}
#[test]
fn test_error_is_retryable_deadlock() {
let err = Error::Query {
code: "40P01".to_string(),
message: "deadlock detected".to_string(),
};
assert!(err.is_retryable());
}
#[test]
fn test_error_is_retryable_transaction_deadlock() {
let err = Error::Query {
code: "40502".to_string(),
message: "transaction deadlock".to_string(),
};
assert!(err.is_retryable());
}
#[test]
fn test_error_not_retryable_syntax() {
let err = Error::Query {
code: "42000".to_string(),
message: "syntax error".to_string(),
};
assert!(!err.is_retryable());
}
#[test]
fn test_error_not_retryable_auth() {
let err = Error::Auth("invalid".to_string());
assert!(!err.is_retryable());
}
#[test]
fn test_error_not_retryable_tls() {
let err = Error::Tls("cert error".to_string());
assert!(!err.is_retryable());
}
#[test]
fn test_error_not_retryable_dsn() {
let err = Error::InvalidDsn("bad format".to_string());
assert!(!err.is_retryable());
}
#[test]
fn test_error_not_retryable_type() {
let err = Error::Type("cast failed".to_string());
assert!(!err.is_retryable());
}
#[test]
fn test_error_code_query() {
let err = Error::Query {
code: "42000".to_string(),
message: "syntax error".to_string(),
};
assert_eq!(err.code(), Some("42000"));
}
#[test]
fn test_error_code_non_query() {
let err = Error::Connection("test".to_string());
assert_eq!(err.code(), None);
}
#[test]
fn test_result_type_alias() {
fn returns_result() -> Result<i32> {
Ok(42)
}
assert_eq!(returns_result().unwrap(), 42);
}
#[test]
fn test_result_type_alias_error() {
fn returns_error() -> Result<i32> {
Err(Error::Other("test".to_string()))
}
assert!(returns_error().is_err());
}
#[test]
fn test_error_string_conversion() {
let err = Error::connection("test");
assert!(matches!(err, Error::Connection(_)));
let err = Error::connection(String::from("test"));
assert!(matches!(err, Error::Connection(_)));
}
#[test]
fn test_error_debug() {
let err = Error::Connection("test".to_string());
let debug_str = format!("{:?}", err);
assert!(debug_str.contains("Connection"));
assert!(debug_str.contains("test"));
}
}