use thiserror::Error;
#[derive(Error, Debug)]
pub enum DbError {
#[error("Configuration error: {0}")]
Config(String),
#[error("Connection error: {0}")]
Connection(String),
#[error("Query error: {0}")]
Query(String),
#[error("Transaction error: {0}")]
Transaction(String),
#[error("Lock error: {0}")]
Lock(String),
#[error("SQL injection detected: {0}")]
SqlInjection(String),
#[error("Invalid table name: {0}")]
InvalidTableName(String),
#[error("Invalid field name: {0}")]
InvalidFieldName(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Parse error: {0}")]
Parse(String),
#[error("Pool error: {0}")]
Pool(String),
#[error("Timeout error: {0}")]
Timeout(String),
#[error("Not found: {0}")]
NotFound(String),
#[error("Database not initialized")]
NotInitialized,
#[error("Unknown error: {0}")]
Unknown(String),
}
pub type DbResult<T> = Result<T, DbError>;
impl From<String> for DbError {
fn from(s: String) -> Self {
DbError::Unknown(s)
}
}
impl From<&str> for DbError {
fn from(s: &str) -> Self {
DbError::Unknown(s.to_string())
}
}
impl<T> From<std::sync::PoisonError<T>> for DbError {
fn from(e: std::sync::PoisonError<T>) -> Self {
DbError::Lock(e.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::error::Error;
use std::sync::{Arc, Mutex};
#[test]
fn display_messages_for_all_variants() {
assert_eq!(
DbError::Config("bad cfg".to_string()).to_string(),
"Configuration error: bad cfg"
);
assert_eq!(
DbError::Connection("conn down".to_string()).to_string(),
"Connection error: conn down"
);
assert_eq!(
DbError::Query("bad sql".to_string()).to_string(),
"Query error: bad sql"
);
assert_eq!(
DbError::Transaction("rollback".to_string()).to_string(),
"Transaction error: rollback"
);
assert_eq!(
DbError::Lock("mutex".to_string()).to_string(),
"Lock error: mutex"
);
assert_eq!(
DbError::SqlInjection("DROP TABLE".to_string()).to_string(),
"SQL injection detected: DROP TABLE"
);
assert_eq!(
DbError::InvalidTableName("bad table".to_string()).to_string(),
"Invalid table name: bad table"
);
assert_eq!(
DbError::InvalidFieldName("bad field".to_string()).to_string(),
"Invalid field name: bad field"
);
let io_err = std::io::Error::other("disk fail");
assert_eq!(DbError::Io(io_err).to_string(), "IO error: disk fail");
assert_eq!(
DbError::Parse("invalid json".to_string()).to_string(),
"Parse error: invalid json"
);
assert_eq!(
DbError::Pool("pool exhausted".to_string()).to_string(),
"Pool error: pool exhausted"
);
assert_eq!(
DbError::Timeout("30s".to_string()).to_string(),
"Timeout error: 30s"
);
assert_eq!(
DbError::NotFound("record".to_string()).to_string(),
"Not found: record"
);
assert_eq!(
DbError::NotInitialized.to_string(),
"Database not initialized"
);
assert_eq!(
DbError::Unknown("other".to_string()).to_string(),
"Unknown error: other"
);
}
#[test]
fn from_string_maps_to_unknown() {
let err: DbError = String::from("string failure").into();
match err {
DbError::Unknown(msg) => assert_eq!(msg, "string failure"),
_ => panic!("expected DbError::Unknown"),
}
}
#[test]
fn from_str_maps_to_unknown() {
let err: DbError = "str failure".into();
match err {
DbError::Unknown(msg) => assert_eq!(msg, "str failure"),
_ => panic!("expected DbError::Unknown"),
}
}
#[test]
fn from_poison_error_maps_to_lock() {
let mutex = Arc::new(Mutex::new(1));
let mutex_in_thread = Arc::clone(&mutex);
let handle = std::thread::spawn(move || {
let _guard = mutex_in_thread.lock().expect("lock in worker");
panic!("poison mutex");
});
let _ = handle.join();
let poison_err = mutex.lock().expect_err("lock should be poisoned");
let err: DbError = poison_err.into();
match err {
DbError::Lock(msg) => assert!(
msg.contains("poison"),
"unexpected lock error message: {msg}"
),
_ => panic!("expected DbError::Lock"),
}
}
#[test]
fn dbresult_alias_ok_and_err() {
assert_eq!(7_i32, 7);
let err = DbError::NotFound("id=1".to_string());
assert_eq!(err.to_string(), "Not found: id=1");
}
#[test]
fn error_trait_behavior() {
fn assert_is_error<E: Error>(_err: &E) {}
let query_err = DbError::Query("bad query".to_string());
assert_is_error(&query_err);
assert_eq!(query_err.to_string(), "Query error: bad query");
assert!(query_err.source().is_none());
let io_err = DbError::Io(std::io::Error::other("io root cause"));
assert_eq!(io_err.to_string(), "IO error: io root cause");
assert_eq!(
io_err.source().expect("io source").to_string(),
"io root cause"
);
}
}