use std::fmt::{self, Debug, Formatter};
use std::future::{self, Future};
use std::ops::{Deref, DerefMut};
use futures_core::future::BoxFuture;
use crate::database::Database;
use crate::error::Error;
use crate::pool::MaybePoolConnection;
use crate::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr};
pub trait TransactionManager {
type Database: Database;
fn begin(
conn: &mut <Self::Database as Database>::Connection,
statement: Option<SqlStr>,
) -> impl Future<Output = Result<(), Error>> + Send + '_;
fn commit(
conn: &mut <Self::Database as Database>::Connection,
) -> impl Future<Output = Result<(), Error>> + Send + '_;
fn rollback(
conn: &mut <Self::Database as Database>::Connection,
) -> impl Future<Output = Result<(), Error>> + Send + '_;
fn start_rollback(conn: &mut <Self::Database as Database>::Connection);
fn get_transaction_depth(conn: &<Self::Database as Database>::Connection) -> usize;
}
pub struct Transaction<'c, DB>
where
DB: Database,
{
connection: MaybePoolConnection<'c, DB>,
open: bool,
}
impl<'c, DB> Transaction<'c, DB>
where
DB: Database,
{
#[doc(hidden)]
pub fn begin(
conn: impl Into<MaybePoolConnection<'c, DB>>,
statement: Option<SqlStr>,
) -> BoxFuture<'c, Result<Self, Error>> {
let conn = conn.into();
Box::pin(async move {
let mut tx = Self {
connection: conn,
open: true,
};
DB::TransactionManager::begin(&mut tx.connection, statement).await?;
Ok(tx)
})
}
pub async fn commit(mut self) -> Result<(), Error> {
DB::TransactionManager::commit(&mut self.connection).await?;
self.open = false;
Ok(())
}
pub async fn rollback(mut self) -> Result<(), Error> {
DB::TransactionManager::rollback(&mut self.connection).await?;
self.open = false;
Ok(())
}
}
impl<DB> Debug for Transaction<'_, DB>
where
DB: Database,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Transaction").finish()
}
}
impl<DB> Deref for Transaction<'_, DB>
where
DB: Database,
{
type Target = DB::Connection;
#[inline]
fn deref(&self) -> &Self::Target {
&self.connection
}
}
impl<DB> DerefMut for Transaction<'_, DB>
where
DB: Database,
{
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.connection
}
}
impl<DB: Database> AsMut<DB::Connection> for Transaction<'_, DB> {
fn as_mut(&mut self) -> &mut DB::Connection {
&mut self.connection
}
}
impl<'t, DB: Database> crate::acquire::Acquire<'t> for &'t mut Transaction<'_, DB> {
type Database = DB;
type Connection = &'t mut <DB as Database>::Connection;
#[inline]
fn acquire(self) -> BoxFuture<'t, Result<Self::Connection, Error>> {
Box::pin(future::ready(Ok(&mut **self)))
}
#[inline]
fn begin(self) -> BoxFuture<'t, Result<Transaction<'t, DB>, Error>> {
Transaction::begin(&mut **self, None)
}
}
impl<DB> Drop for Transaction<'_, DB>
where
DB: Database,
{
fn drop(&mut self) {
if self.open {
DB::TransactionManager::start_rollback(&mut self.connection);
}
}
}
pub fn begin_ansi_transaction_sql(depth: usize) -> SqlStr {
if depth == 0 {
"BEGIN".into_sql_str()
} else {
AssertSqlSafe(format!("SAVEPOINT _sqlx_savepoint_{depth}")).into_sql_str()
}
}
pub fn commit_ansi_transaction_sql(depth: usize) -> SqlStr {
if depth == 1 {
"COMMIT".into_sql_str()
} else {
AssertSqlSafe(format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1)).into_sql_str()
}
}
pub fn rollback_ansi_transaction_sql(depth: usize) -> SqlStr {
if depth == 1 {
"ROLLBACK".into_sql_str()
} else {
AssertSqlSafe(format!(
"ROLLBACK TO SAVEPOINT _sqlx_savepoint_{}",
depth - 1
))
.into_sql_str()
}
}