use crate::error::{OrmError, OrmResult};
use std::sync::atomic::{AtomicU64, Ordering};
use tokio_postgres::Row;
use tokio_postgres::Statement;
use tokio_postgres::types::ToSql;
static SAVEPOINT_COUNTER: AtomicU64 = AtomicU64::new(0);
#[macro_export]
macro_rules! transaction {
($client:expr, $tx:ident, $body:block) => {{
let mut $tx = ($client)
.transaction()
.await
.map_err($crate::OrmError::from_db_error)?;
let __pgorm_tx_body_result = async { $body }.await;
match __pgorm_tx_body_result {
Ok(value) => {
$tx.commit()
.await
.map_err($crate::OrmError::from_db_error)?;
Ok(value)
}
Err(error) => match $tx.rollback().await {
Ok(()) => Err(error),
Err(rollback_err) => Err($crate::OrmError::Other(format!(
"{error} (rollback failed: {rollback_err})"
))),
},
}
}};
}
#[macro_export]
macro_rules! savepoint {
($tx:expr, $name:expr, $sp:ident, $body:block) => {{
let mut $sp = ($tx)
.savepoint($name)
.await
.map_err($crate::OrmError::from_db_error)?;
let __pgorm_sp_body_result = async { $body }.await;
match __pgorm_sp_body_result {
Ok(value) => {
$sp.commit()
.await
.map_err($crate::OrmError::from_db_error)?;
Ok(value)
}
Err(error) => match $sp.rollback().await {
Ok(()) => Err(error),
Err(rollback_err) => Err($crate::OrmError::Other(format!(
"{error} (savepoint rollback failed: {rollback_err})"
))),
},
}
}};
($tx:expr, $sp:ident, $body:block) => {{
let __pgorm_sp_name = $crate::__next_savepoint_name();
$crate::savepoint!($tx, &__pgorm_sp_name, $sp, $body)
}};
}
#[macro_export]
macro_rules! nested_transaction {
($tx:expr, $inner:ident, $body:block) => {{
let __pgorm_sp_name = $crate::__next_savepoint_name();
$crate::savepoint!($tx, &__pgorm_sp_name, $inner, $body)
}};
}
#[doc(hidden)]
pub fn __next_savepoint_name() -> String {
let n = SAVEPOINT_COUNTER.fetch_add(1, Ordering::Relaxed);
format!("pgorm_sp_{n}")
}
pub struct Savepoint<'a> {
inner: Option<tokio_postgres::Transaction<'a>>,
name: String,
}
impl<'a> Savepoint<'a> {
fn new(inner: tokio_postgres::Transaction<'a>, name: String) -> Self {
Self {
inner: Some(inner),
name,
}
}
pub fn name(&self) -> &str {
&self.name
}
pub async fn release(mut self) -> OrmResult<()> {
if let Some(tx) = self.inner.take() {
tx.commit().await.map_err(OrmError::from_db_error)?;
}
Ok(())
}
pub async fn rollback(mut self) -> OrmResult<()> {
if let Some(tx) = self.inner.take() {
tx.rollback().await.map_err(OrmError::from_db_error)?;
}
Ok(())
}
}
impl Drop for Savepoint<'_> {
fn drop(&mut self) {
if self.inner.is_some() {
#[cfg(feature = "tracing")]
tracing::warn!(
"Savepoint '{}' dropped without explicit release or rollback",
self.name,
);
}
}
}
impl crate::GenericClient for Savepoint<'_> {
async fn query(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<Vec<Row>> {
let tx = self.inner.as_ref().ok_or_else(|| {
OrmError::Other("savepoint already consumed".to_string())
})?;
crate::GenericClient::query(tx, sql, params).await
}
async fn query_one(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<Row> {
let tx = self.inner.as_ref().ok_or_else(|| {
OrmError::Other("savepoint already consumed".to_string())
})?;
crate::GenericClient::query_one(tx, sql, params).await
}
async fn query_opt(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Option<Row>> {
let tx = self.inner.as_ref().ok_or_else(|| {
OrmError::Other("savepoint already consumed".to_string())
})?;
crate::GenericClient::query_opt(tx, sql, params).await
}
async fn execute(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<u64> {
let tx = self.inner.as_ref().ok_or_else(|| {
OrmError::Other("savepoint already consumed".to_string())
})?;
crate::GenericClient::execute(tx, sql, params).await
}
fn cancel_token(&self) -> Option<tokio_postgres::CancelToken> {
self.inner.as_ref().and_then(|tx| crate::GenericClient::cancel_token(tx))
}
fn supports_prepared_statements(&self) -> bool {
self.inner
.as_ref()
.map_or(false, |tx| crate::GenericClient::supports_prepared_statements(tx))
}
async fn prepare_statement(&self, sql: &str) -> OrmResult<Statement> {
let tx = self.inner.as_ref().ok_or_else(|| {
OrmError::Other("savepoint already consumed".to_string())
})?;
crate::GenericClient::prepare_statement(tx, sql).await
}
async fn query_prepared(
&self,
stmt: &Statement,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Vec<Row>> {
let tx = self.inner.as_ref().ok_or_else(|| {
OrmError::Other("savepoint already consumed".to_string())
})?;
crate::GenericClient::query_prepared(tx, stmt, params).await
}
async fn execute_prepared(
&self,
stmt: &Statement,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<u64> {
let tx = self.inner.as_ref().ok_or_else(|| {
OrmError::Other("savepoint already consumed".to_string())
})?;
crate::GenericClient::execute_prepared(tx, stmt, params).await
}
}
pub trait TransactionExt {
fn pgorm_savepoint(
&mut self,
name: &str,
) -> impl std::future::Future<Output = OrmResult<Savepoint<'_>>> + Send;
fn pgorm_savepoint_anon(
&mut self,
) -> impl std::future::Future<Output = OrmResult<Savepoint<'_>>> + Send;
}
impl TransactionExt for tokio_postgres::Transaction<'_> {
async fn pgorm_savepoint(&mut self, name: &str) -> OrmResult<Savepoint<'_>> {
let inner = self
.savepoint(name)
.await
.map_err(OrmError::from_db_error)?;
Ok(Savepoint::new(inner, name.to_string()))
}
async fn pgorm_savepoint_anon(&mut self) -> OrmResult<Savepoint<'_>> {
let name = __next_savepoint_name();
let inner = self
.savepoint(&name)
.await
.map_err(OrmError::from_db_error)?;
Ok(Savepoint::new(inner, name))
}
}
#[cfg(feature = "pool")]
impl TransactionExt for deadpool_postgres::Transaction<'_> {
async fn pgorm_savepoint(&mut self, name: &str) -> OrmResult<Savepoint<'_>> {
let inner_tx: &mut tokio_postgres::Transaction<'_> =
std::ops::DerefMut::deref_mut(self);
let inner = inner_tx
.savepoint(name)
.await
.map_err(OrmError::from_db_error)?;
Ok(Savepoint::new(inner, name.to_string()))
}
async fn pgorm_savepoint_anon(&mut self) -> OrmResult<Savepoint<'_>> {
let name = __next_savepoint_name();
let inner_tx: &mut tokio_postgres::Transaction<'_> =
std::ops::DerefMut::deref_mut(self);
let inner = inner_tx
.savepoint(&name)
.await
.map_err(OrmError::from_db_error)?;
Ok(Savepoint::new(inner, name))
}
}