use diesel::connection::InstrumentationEvent;
use diesel::connection::TransactionManagerStatus;
use diesel::connection::{
InTransactionStatus, TransactionDepthChange, ValidTransactionManagerStatus,
};
use diesel::result::Error;
use diesel::QueryResult;
use std::borrow::Cow;
use std::future::Future;
use std::num::NonZeroU32;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
pub trait AsyncFunc<T, R>:
AsyncFnOnce(T) -> R + FnOnce(T) -> <Self as AsyncFunc<T, R>>::Fut
{
type Fut: Future<Output = R>;
}
impl<F, T, Fut, R> AsyncFunc<T, R> for F
where
F: AsyncFnOnce(T) -> R + FnOnce(T) -> Fut,
Fut: Future<Output = R>,
{
type Fut = Fut;
}
use crate::AsyncConnection;
pub trait TransactionManager<Conn: AsyncConnection>: Send {
type TransactionStateData;
fn begin_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
fn rollback_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
fn commit_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
#[doc(hidden)]
fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus;
fn transaction<'a, 'conn, F, R, E>(
conn: &'conn mut Conn,
callback: F,
) -> impl Future<Output = Result<R, E>> + Send + 'conn
where
for<'r> F: AsyncFnOnce(&'r mut Conn) -> Result<R, E>
+ AsyncFunc<&'r mut Conn, Result<R, E>, Fut: Send>
+ Send
+ 'a,
E: From<Error> + Send,
R: Send,
'a: 'conn,
{
async move {
let callback = callback;
Self::begin_transaction(conn).await?;
match callback(&mut *conn).await {
Ok(value) => {
Self::commit_transaction(conn).await?;
Ok(value)
}
Err(user_error) => match Self::rollback_transaction(conn).await {
Ok(()) => Err(user_error),
Err(Error::BrokenTransactionManager) => {
Err(user_error)
}
Err(rollback_error) => Err(rollback_error.into()),
},
}
}
}
#[doc(hidden)]
fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
check_broken_transaction_state(conn)
}
}
fn check_broken_transaction_state<Conn>(conn: &mut Conn) -> bool
where
Conn: AsyncConnection,
{
match Conn::TransactionManager::transaction_manager_status_mut(conn).transaction_state() {
Ok(ValidTransactionManagerStatus {
in_transaction: None,
..
}) => false,
Err(_) => true,
Ok(ValidTransactionManagerStatus {
in_transaction: Some(s),
..
}) => !s.test_transaction,
}
}
#[derive(Default, Debug)]
pub struct AnsiTransactionManager {
pub(crate) status: TransactionManagerStatus,
pub(crate) is_broken: Arc<AtomicBool>,
pub(crate) is_commit: bool,
}
impl AnsiTransactionManager {
fn get_transaction_state<Conn>(
conn: &mut Conn,
) -> QueryResult<&mut ValidTransactionManagerStatus>
where
Conn: AsyncConnection<TransactionManager = Self>,
{
conn.transaction_state().status.transaction_state()
}
pub async fn begin_transaction_sql<Conn>(conn: &mut Conn, sql: &str) -> QueryResult<()>
where
Conn: AsyncConnection<TransactionManager = Self>,
{
let is_broken = conn.transaction_state().is_broken.clone();
let state = Self::get_transaction_state(conn)?;
if let Some(_depth) = state.transaction_depth() {
return Err(Error::AlreadyInTransaction);
}
let instrumentation_depth = NonZeroU32::new(1);
conn.instrumentation()
.on_connection_event(InstrumentationEvent::begin_transaction(
instrumentation_depth.expect("We know that 1 is not zero"),
));
Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?;
Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
Ok(())
}
async fn critical_transaction_block<F>(is_broken: &AtomicBool, f: F) -> F::Output
where
F: std::future::Future,
{
let was_broken = is_broken.swap(true, Ordering::Relaxed);
debug_assert!(
!was_broken,
"Tried to execute a transaction SQL on transaction manager that was previously cancled"
);
let res = f.await;
is_broken.store(false, Ordering::Relaxed);
res
}
}
impl<Conn> TransactionManager<Conn> for AnsiTransactionManager
where
Conn: AsyncConnection<TransactionManager = Self>,
{
type TransactionStateData = Self;
async fn begin_transaction(conn: &mut Conn) -> QueryResult<()> {
let transaction_state = Self::get_transaction_state(conn)?;
let start_transaction_sql = match transaction_state.transaction_depth() {
None => Cow::from("BEGIN"),
Some(transaction_depth) => {
Cow::from(format!("SAVEPOINT diesel_savepoint_{transaction_depth}"))
}
};
let depth = transaction_state
.transaction_depth()
.and_then(|d| d.checked_add(1))
.unwrap_or(NonZeroU32::new(1).expect("It's not 0"));
conn.instrumentation()
.on_connection_event(InstrumentationEvent::begin_transaction(depth));
Self::critical_transaction_block(
&conn.transaction_state().is_broken.clone(),
conn.batch_execute(&start_transaction_sql),
)
.await?;
Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
Ok(())
}
async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()> {
let transaction_state = Self::get_transaction_state(conn)?;
let (
(rollback_sql, rolling_back_top_level),
requires_rollback_maybe_up_to_top_level_before_execute,
) = match transaction_state.in_transaction {
Some(ref in_transaction) => (
match in_transaction.transaction_depth.get() {
1 => (Cow::Borrowed("ROLLBACK"), true),
depth_gt1 => (
Cow::Owned(format!(
"ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
depth_gt1 - 1
)),
false,
),
},
in_transaction.requires_rollback_maybe_up_to_top_level,
),
None => return Err(Error::NotInTransaction),
};
let depth = transaction_state
.transaction_depth()
.expect("We know that we are in a transaction here");
conn.instrumentation()
.on_connection_event(InstrumentationEvent::rollback_transaction(depth));
let is_broken = conn.transaction_state().is_broken.clone();
match Self::critical_transaction_block(&is_broken, conn.batch_execute(&rollback_sql)).await
{
Ok(()) => {
match Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::DecreaseDepth)
{
Ok(()) => {}
Err(Error::NotInTransaction) if rolling_back_top_level => {
}
Err(e) => return Err(e),
}
Ok(())
}
Err(rollback_error) => {
let tm_status = Self::transaction_manager_status_mut(conn);
match tm_status {
TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
in_transaction:
Some(InTransactionStatus {
transaction_depth,
requires_rollback_maybe_up_to_top_level,
..
}),
..
}) if transaction_depth.get() > 1 => {
*transaction_depth = NonZeroU32::new(transaction_depth.get() - 1)
.expect("Depth was checked to be > 1");
*requires_rollback_maybe_up_to_top_level = true;
if requires_rollback_maybe_up_to_top_level_before_execute {
return Ok(());
}
}
TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
in_transaction: None,
..
}) => {
}
_ => tm_status.set_in_error(),
}
Err(rollback_error)
}
}
}
async fn commit_transaction(conn: &mut Conn) -> QueryResult<()> {
let transaction_state = Self::get_transaction_state(conn)?;
let transaction_depth = transaction_state.transaction_depth();
let (commit_sql, committing_top_level) = match transaction_depth {
None => return Err(Error::NotInTransaction),
Some(transaction_depth) if transaction_depth.get() == 1 => {
(Cow::Borrowed("COMMIT"), true)
}
Some(transaction_depth) => (
Cow::Owned(format!(
"RELEASE SAVEPOINT diesel_savepoint_{}",
transaction_depth.get() - 1
)),
false,
),
};
let depth = transaction_state
.transaction_depth()
.expect("We know that we are in a transaction here");
conn.instrumentation()
.on_connection_event(InstrumentationEvent::commit_transaction(depth));
let is_broken = {
let transaction_state = conn.transaction_state();
transaction_state.is_commit = true;
transaction_state.is_broken.clone()
};
let res =
Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await;
conn.transaction_state().is_commit = false;
match res {
Ok(()) => {
match Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::DecreaseDepth)
{
Ok(()) => {}
Err(Error::NotInTransaction) if committing_top_level => {
}
Err(e) => return Err(e),
}
Ok(())
}
Err(commit_error) => {
if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
in_transaction:
Some(InTransactionStatus {
requires_rollback_maybe_up_to_top_level: true,
..
}),
..
}) = conn.transaction_state().status
{
match Self::rollback_transaction(conn).await {
Ok(()) => {}
Err(rollback_error) => {
conn.transaction_state().status.set_in_error();
return Err(Error::RollbackErrorOnCommit {
rollback_error: Box::new(rollback_error),
commit_error: Box::new(commit_error),
});
}
}
} else {
Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::DecreaseDepth)?;
}
Err(commit_error)
}
}
}
fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus {
&mut conn.transaction_state().status
}
fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
conn.transaction_state().is_broken.load(Ordering::Relaxed)
|| check_broken_transaction_state(conn)
}
}