use thiserror::Error;
pub type OrmResult<T> = Result<T, OrmError>;
#[derive(Debug, Error)]
pub enum OrmError {
#[error("Connection error: {0}")]
Connection(String),
#[error("Query error: {0}")]
Query(#[from] tokio_postgres::Error),
#[error("Decode error on column '{column}': {message}")]
Decode { column: String, message: String },
#[error("Serialization error: {0}")]
Serialization(String),
#[cfg(feature = "pool")]
#[error("Pool error: {0}")]
Pool(String),
#[cfg(feature = "migrate")]
#[error("Migration error: {0}")]
Migration(String),
#[error("{0}")]
Other(String),
#[error("Not found: {0}")]
NotFound(String),
#[error("Too many rows: expected {expected}, got {got}")]
TooManyRows { expected: usize, got: usize },
#[error("Unique constraint violation: {0}")]
UniqueViolation(String),
#[error("Foreign key violation: {0}")]
ForeignKeyViolation(String),
#[error("Check constraint violation: {0}")]
CheckViolation(String),
#[error("Serialization failure: {0}")]
SerializationFailure(String),
#[error("Deadlock detected: {0}")]
DeadlockDetected(String),
#[error("Validation error: {0}")]
Validation(String),
#[error("Query timeout after {0:?}")]
Timeout(std::time::Duration),
#[error("Stale record: {table} with id {id} (expected version {expected_version})")]
StaleRecord {
table: &'static str,
id: String,
expected_version: i64,
},
}
impl OrmError {
pub fn decode(column: impl Into<String>, message: impl Into<String>) -> Self {
Self::Decode {
column: column.into(),
message: message.into(),
}
}
pub fn not_found(message: impl Into<String>) -> Self {
Self::NotFound(message.into())
}
pub fn too_many_rows(expected: usize, got: usize) -> Self {
Self::TooManyRows { expected, got }
}
pub fn validation(message: impl Into<String>) -> Self {
Self::Validation(message.into())
}
pub fn stale_record(table: &'static str, id: impl ToString, expected_version: i64) -> Self {
Self::StaleRecord {
table,
id: id.to_string(),
expected_version,
}
}
pub fn is_recoverable(&self) -> bool {
matches!(
self,
Self::NotFound(_)
| Self::TooManyRows { .. }
| Self::UniqueViolation(_)
| Self::ForeignKeyViolation(_)
| Self::CheckViolation(_)
| Self::SerializationFailure(_)
| Self::DeadlockDetected(_)
| Self::StaleRecord { .. }
| Self::Timeout(_)
| Self::Validation(_)
)
}
pub fn is_retryable(&self) -> bool {
matches!(
self,
Self::SerializationFailure(_) | Self::DeadlockDetected(_)
)
}
pub fn is_unique_violation(&self) -> bool {
matches!(self, Self::UniqueViolation(_))
}
pub fn is_not_found(&self) -> bool {
matches!(self, Self::NotFound(_))
}
pub fn is_too_many_rows(&self) -> bool {
matches!(self, Self::TooManyRows { .. })
}
pub fn is_timeout(&self) -> bool {
matches!(self, Self::Timeout(_))
}
pub fn is_stale_record(&self) -> bool {
matches!(self, Self::StaleRecord { .. })
}
pub fn sqlstate(&self) -> Option<&str> {
match self {
Self::Query(e) => e.as_db_error().map(|db| db.code().code()),
Self::UniqueViolation(_) => Some("23505"),
Self::ForeignKeyViolation(_) => Some("23503"),
Self::CheckViolation(_) => Some("23514"),
Self::SerializationFailure(_) => Some("40001"),
Self::DeadlockDetected(_) => Some("40P01"),
_ => None,
}
}
pub fn from_db_error(err: tokio_postgres::Error) -> Self {
if let Some(db_err) = err.as_db_error() {
let constraint = db_err.constraint().unwrap_or("unknown");
let message = db_err.message();
match db_err.code().code() {
"23505" => return Self::UniqueViolation(format!("{constraint}: {message}")),
"23503" => {
return Self::ForeignKeyViolation(format!("{constraint}: {message}"));
}
"23514" => return Self::CheckViolation(format!("{constraint}: {message}")),
"40001" => return Self::SerializationFailure(message.to_string()),
"40P01" => return Self::DeadlockDetected(message.to_string()),
"08000" | "08003" | "08006" => {
return Self::Connection(message.to_string());
}
_ => {}
}
}
if err.is_closed() {
return Self::Connection(err.to_string());
}
Self::Query(err)
}
}
#[cfg(feature = "pool")]
impl From<deadpool_postgres::PoolError> for OrmError {
fn from(err: deadpool_postgres::PoolError) -> Self {
Self::Pool(err.to_string())
}
}
#[cfg(feature = "migrate")]
impl From<refinery::Error> for OrmError {
fn from(err: refinery::Error) -> Self {
Self::Migration(err.to_string())
}
}
pub(crate) fn pgorm_warn(msg: &str) {
#[cfg(feature = "tracing")]
tracing::warn!(target: "pgorm", "{}", msg);
#[cfg(not(feature = "tracing"))]
eprintln!("{msg}");
}