#[cfg(feature = "diesel")]
use diesel::{
r2d2::PoolError,
result::{DatabaseErrorInformation, DatabaseErrorKind, Error},
};
#[cfg(feature = "tokio-postgres")]
use tokio_postgres::error::{Error as PgError, SqlState};
use crate::{custom::StatusCode, Problem};
#[track_caller]
fn sql_error<M>(msg: M) -> Problem
where
M: std::error::Error + Send + Sync + 'static,
{
Problem::from_status(StatusCode::INTERNAL_SERVER_ERROR)
.with_detail("An unexpected error occurred")
.with_cause(msg)
}
#[cfg(feature = "diesel")]
impl From<PoolError> for Problem {
#[track_caller]
fn from(err: PoolError) -> Self {
sql_error(err)
}
}
#[cfg(feature = "diesel")]
impl From<Error> for Problem {
#[track_caller]
fn from(err: Error) -> Self {
match err {
Error::DatabaseError(kind, info) => match kind {
DatabaseErrorKind::UniqueViolation => sql_error(UniqueViolation(info.into())),
DatabaseErrorKind::ForeignKeyViolation => {
sql_error(ForeignKeyViolation(info.into()))
}
DatabaseErrorKind::__Unknown if info.constraint_name().is_some() => {
sql_error(CheckViolation(info.into()))
}
_ => sql_error(Error::DatabaseError(kind, info)),
},
Error::NotFound => sql_error(NoRowsFound),
err => sql_error(err),
}
}
}
#[cfg(feature = "tokio-postgres")]
impl From<PgError> for Problem {
#[track_caller]
fn from(err: PgError) -> Self {
if let Some(db_err) = err.as_db_error() {
match db_err.code().clone() {
SqlState::UNIQUE_VIOLATION => sql_error(UniqueViolation(db_err.into())),
SqlState::FOREIGN_KEY_VIOLATION => sql_error(ForeignKeyViolation(db_err.into())),
SqlState::CHECK_VIOLATION => sql_error(CheckViolation(db_err.into())),
_ => sql_error(err),
}
} else {
sql_error(err)
}
}
}
pub struct NoRowsFound;
impl std::fmt::Debug for NoRowsFound {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("NoRowsFound")
}
}
impl std::fmt::Display for NoRowsFound {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("No rows were found where one was expected")
}
}
impl std::error::Error for NoRowsFound {}
pub struct UniqueViolation(DbErrorInfo);
impl UniqueViolation {
pub fn constraint_name(&self) -> Option<&str> {
self.0.constraint_name.as_deref()
}
pub fn message(&self) -> &str {
&self.0.message
}
pub fn table_name(&self) -> Option<&str> {
self.0.table_name.as_deref()
}
}
impl std::fmt::Debug for UniqueViolation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UniqueViolation")
.field("table_name", &self.table_name())
.field("constraint_name", &self.constraint_name())
.field("message", &self.message())
.finish()
}
}
impl std::fmt::Display for UniqueViolation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.message())
}
}
impl std::error::Error for UniqueViolation {}
pub struct ForeignKeyViolation(DbErrorInfo);
impl ForeignKeyViolation {
pub fn constraint_name(&self) -> Option<&str> {
self.0.constraint_name.as_deref()
}
pub fn message(&self) -> &str {
&self.0.message
}
pub fn table_name(&self) -> Option<&str> {
self.0.table_name.as_deref()
}
}
impl std::fmt::Debug for ForeignKeyViolation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ForeignKeyViolation")
.field("table_name", &self.table_name())
.field("constraint_name", &self.constraint_name())
.field("message", &self.message())
.finish()
}
}
impl std::fmt::Display for ForeignKeyViolation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.message())
}
}
impl std::error::Error for ForeignKeyViolation {}
pub struct CheckViolation(DbErrorInfo);
impl CheckViolation {
pub fn constraint_name(&self) -> Option<&str> {
self.0.constraint_name.as_deref()
}
pub fn message(&self) -> &str {
&self.0.message
}
pub fn table_name(&self) -> Option<&str> {
self.0.table_name.as_deref()
}
}
impl std::fmt::Debug for CheckViolation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CheckViolation")
.field("table_name", &self.table_name())
.field("constraint_name", &self.constraint_name())
.field("message", &self.message())
.finish()
}
}
impl std::fmt::Display for CheckViolation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.message())
}
}
impl std::error::Error for CheckViolation {}
struct DbErrorInfo {
constraint_name: Option<String>,
message: String,
table_name: Option<String>,
}
#[cfg(feature = "diesel")]
impl From<Box<dyn DatabaseErrorInformation + Send + Sync>> for DbErrorInfo {
fn from(info: Box<dyn DatabaseErrorInformation + Send + Sync>) -> Self {
Self {
constraint_name: info.constraint_name().map(String::from),
message: info.message().to_string(),
table_name: info.table_name().map(String::from),
}
}
}
#[cfg(feature = "tokio-postgres")]
impl From<&'_ tokio_postgres::error::DbError> for DbErrorInfo {
fn from(err: &'_ tokio_postgres::error::DbError) -> Self {
Self {
constraint_name: err.constraint().map(String::from),
message: err.message().to_string(),
table_name: err.table().map(String::from),
}
}
}