use prax_query::QueryError;
use thiserror::Error;
pub type PgResult<T> = Result<T, PgError>;
#[derive(Error, Debug)]
pub enum PgError {
#[error("pool error: {0}")]
Pool(#[from] deadpool_postgres::PoolError),
#[error("postgres error: {0}")]
Postgres(#[from] tokio_postgres::Error),
#[error("configuration error: {0}")]
Config(String),
#[error("connection error: {0}")]
Connection(String),
#[error("query error: {0}")]
Query(String),
#[error("deserialization error: {0}")]
Deserialization(String),
#[error("type conversion error: {0}")]
TypeConversion(String),
#[error("operation timed out after {0}ms")]
Timeout(u64),
#[error("internal error: {0}")]
Internal(String),
}
impl PgError {
pub fn config(message: impl Into<String>) -> Self {
Self::Config(message.into())
}
pub fn connection(message: impl Into<String>) -> Self {
Self::Connection(message.into())
}
pub fn query(message: impl Into<String>) -> Self {
Self::Query(message.into())
}
pub fn deserialization(message: impl Into<String>) -> Self {
Self::Deserialization(message.into())
}
pub fn type_conversion(message: impl Into<String>) -> Self {
Self::TypeConversion(message.into())
}
pub fn is_connection_error(&self) -> bool {
matches!(self, Self::Pool(_) | Self::Connection(_))
}
pub fn is_timeout(&self) -> bool {
matches!(self, Self::Timeout(_))
}
}
impl From<PgError> for QueryError {
fn from(err: PgError) -> Self {
match err {
PgError::Pool(e) => QueryError::connection(e.to_string()),
PgError::Postgres(e) => {
let code = e.code();
if let Some(code) = code {
let code_str = code.code();
if code_str == "23505" {
return QueryError::constraint_violation("", e.to_string());
}
if code_str == "23503" {
return QueryError::constraint_violation("", e.to_string());
}
if code_str == "23502" {
return QueryError::invalid_input("", e.to_string());
}
}
QueryError::database(e.to_string())
}
PgError::Config(msg) => QueryError::connection(msg),
PgError::Connection(msg) => QueryError::connection(msg),
PgError::Query(msg) => QueryError::database(msg),
PgError::Deserialization(msg) => QueryError::serialization(msg),
PgError::TypeConversion(msg) => QueryError::serialization(msg),
PgError::Timeout(ms) => QueryError::timeout(ms),
PgError::Internal(msg) => QueryError::internal(msg),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_creation() {
let err = PgError::config("invalid URL");
assert!(matches!(err, PgError::Config(_)));
let err = PgError::connection("connection refused");
assert!(err.is_connection_error());
let err = PgError::Timeout(5000);
assert!(err.is_timeout());
}
#[test]
fn test_into_query_error() {
let pg_err = PgError::Timeout(1000);
let query_err: QueryError = pg_err.into();
assert!(query_err.is_timeout());
}
}