use std::error::Error as StdError;
use std::fmt;
use std::future::Future;
use tracing::debug;
use crate::internal::any::AnyTransaction;
pub use crate::transaction::hook::TransactionHook;
use crate::transaction::hook_closure::{ClosureHook, OnRollback, PostCommit, PreCommit};
use crate::transaction::hook_storage::HookStorage;
use crate::Error;
mod hook;
mod hook_closure;
mod hook_storage;
#[must_use = "A transaction needs to be committed."]
pub struct Transaction {
pub(crate) sqlx: AnyTransaction,
hooks: Option<HookStorage>,
}
impl Transaction {
pub(crate) fn new(sqlx: AnyTransaction) -> Self {
Self { sqlx, hooks: None }
}
pub async fn commit(mut self) -> Result<(), TransactionError> {
let mut hooks = self.hooks.take();
if let Some(hooks) = hooks.as_mut() {
hooks.pre_commit(&mut self).await?;
if let Some(invalid_hooks) = self.hooks.as_mut() {
debug!("Some transaction hook added additional hooks during pre-commit. This is not supported and will be ignored.");
invalid_hooks.clear();
}
}
let result = self.sqlx.commit().await;
if let Some(hooks) = hooks.as_mut() {
if result.is_ok() {
hooks.post_commit();
hooks.clear();
}
}
result
.map_err(Error::SqlxError)
.map_err(TransactionError::Database)
}
pub async fn rollback(self) -> Result<(), Error> {
self.sqlx.rollback().await.map_err(Error::SqlxError)
}
}
impl Drop for HookStorage {
fn drop(&mut self) {
self.on_rollback();
}
}
impl Transaction {
pub fn hooks(&mut self) -> SimpleHooksApi<'_> {
SimpleHooksApi(self.hooks.get_or_insert_default())
}
pub fn adv_hooks(&mut self) -> AdvancedHooksApi<'_> {
AdvancedHooksApi(self.hooks.get_or_insert_default())
}
}
pub struct SimpleHooksApi<'a>(&'a mut HookStorage);
impl SimpleHooksApi<'_> {
pub fn pre_commit<F>(&mut self, hook: impl FnOnce() -> F + Send + 'static) -> &mut Self
where
F: Future<Output = Result<(), TransactionError>> + Send,
{
self.0
.get_or_insert()
.push(ClosureHook::new(hook, PreCommit));
self
}
pub fn post_commit(&mut self, hook: impl FnOnce() + Send + 'static) -> &mut Self {
self.0
.get_or_insert()
.push(ClosureHook::new(hook, PostCommit));
self
}
pub fn on_rollback(&mut self, hook: impl FnOnce() + Send + 'static) -> &mut Self {
self.0
.get_or_insert()
.push(ClosureHook::new(hook, OnRollback));
self
}
}
pub struct AdvancedHooksApi<'a>(&'a mut HookStorage);
impl AdvancedHooksApi<'_> {
pub fn push<T: TransactionHook>(&mut self, hook: T) {
self.get_all().push(hook);
}
pub fn get_or_insert_default<T: TransactionHook + Default>(&mut self) -> &mut T {
self.get_or_insert_with(T::default)
}
pub fn get_or_insert_with<T: TransactionHook>(&mut self, init: impl FnOnce() -> T) -> &mut T {
let vec = self.get_all();
if vec.is_empty() {
vec.push(init());
}
&mut vec[0]
}
pub fn get_all<T: TransactionHook>(&mut self) -> &mut Vec<T> {
self.0.get_or_insert()
}
}
#[derive(Debug)]
pub enum TransactionError {
Database(Error),
Hook(HookError),
}
pub type HookError = Box<dyn StdError + Send + Sync>;
impl From<Error> for TransactionError {
fn from(value: Error) -> Self {
Self::Database(value)
}
}
impl From<HookError> for TransactionError {
fn from(value: HookError) -> Self {
Self::Hook(value)
}
}
impl fmt::Display for TransactionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TransactionError::Database(x) => fmt::Display::fmt(x, f),
TransactionError::Hook(x) => fmt::Display::fmt(x, f),
}
}
}
#[must_use = "The potentially owned transaction needs to be committed."]
pub enum TransactionGuard<'tr> {
Owned(Transaction),
Borrowed(&'tr mut Transaction),
}
impl TransactionGuard<'_> {
pub fn get_transaction(&mut self) -> &mut Transaction {
match self {
TransactionGuard::Owned(tr) => tr,
TransactionGuard::Borrowed(tr) => tr,
}
}
pub async fn commit(self) -> Result<(), TransactionError> {
if let TransactionGuard::Owned(tr) = self {
tr.commit().await
} else {
Ok(())
}
}
}