#[cfg(feature = "postgres")]
use std::future::Future;
use std::{error::Error as StdError, fmt};
pub struct Transaction<'p, DB> {
pool: &'p DB
}
impl<'p, DB> Transaction<'p, DB> {
pub const fn new(pool: &'p DB) -> Self {
Self {
pool
}
}
#[must_use]
pub const fn pool(&self) -> &'p DB {
self.pool
}
}
#[cfg(feature = "postgres")]
pub struct TransactionContext {
tx: sqlx::Transaction<'static, sqlx::Postgres>
}
#[cfg(feature = "postgres")]
impl TransactionContext {
#[doc(hidden)]
#[must_use]
pub const fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
Self {
tx
}
}
pub const fn transaction(&mut self) -> &mut sqlx::Transaction<'static, sqlx::Postgres> {
&mut self.tx
}
pub async fn commit(self) -> Result<(), sqlx::Error> {
self.tx.commit().await
}
pub async fn rollback(self) -> Result<(), sqlx::Error> {
self.tx.rollback().await
}
}
#[derive(Debug)]
pub enum TransactionError<E> {
Begin(E),
Commit(E),
Rollback(E),
Operation(E)
}
impl<E: fmt::Display> fmt::Display for TransactionError<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
Self::Operation(e) => write!(f, "transaction operation failed: {e}")
}
}
}
impl<E: StdError + 'static> StdError for TransactionError<E> {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
}
}
}
impl<E> TransactionError<E> {
pub const fn is_begin(&self) -> bool {
matches!(self, Self::Begin(_))
}
pub const fn is_commit(&self) -> bool {
matches!(self, Self::Commit(_))
}
pub const fn is_rollback(&self) -> bool {
matches!(self, Self::Rollback(_))
}
pub const fn is_operation(&self) -> bool {
matches!(self, Self::Operation(_))
}
pub fn into_inner(self) -> E {
match self {
Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
}
}
}
#[cfg(feature = "postgres")]
impl From<TransactionError<Self>> for sqlx::Error {
fn from(err: TransactionError<Self>) -> Self {
err.into_inner()
}
}
#[cfg(any(feature = "postgres", test))]
async fn finalize_with_commit<C, T, E, CommitErr, Cf, Fut>(
ctx: C,
result: Result<T, E>,
commit_fn: Cf
) -> Result<T, E>
where
Cf: FnOnce(C) -> Fut,
Fut: core::future::Future<Output = Result<(), CommitErr>>,
E: From<CommitErr>
{
match result {
Ok(value) => {
commit_fn(ctx).await.map_err(E::from)?;
Ok(value)
}
Err(e) => Err(e)
}
}
#[cfg(feature = "postgres")]
impl Transaction<'_, sqlx::PgPool> {
pub async fn run<F, T, E>(self, f: F) -> Result<T, E>
where
F: AsyncFnOnce(&mut TransactionContext) -> Result<T, E>,
E: From<sqlx::Error>
{
let tx = self.pool.begin().await.map_err(E::from)?;
let mut ctx = TransactionContext::new(tx);
let result = f(&mut ctx).await;
finalize_with_commit(ctx, result, |c| c.commit()).await
}
pub async fn run_with_commit<F, Fut, T, E>(self, f: F) -> Result<T, E>
where
F: FnOnce(TransactionContext) -> Fut + Send,
Fut: Future<Output = Result<T, E>> + Send,
E: From<sqlx::Error>
{
let tx = self.pool.begin().await.map_err(E::from)?;
let ctx = TransactionContext::new(tx);
f(ctx).await
}
}
#[cfg(test)]
#[allow(clippy::uninlined_format_args)]
mod tests {
use std::error::Error;
use super::*;
#[test]
fn transaction_error_display_begin() {
let err: TransactionError<std::io::Error> =
TransactionError::Begin(std::io::Error::other("test"));
assert!(err.to_string().contains("begin"));
assert!(err.to_string().contains("test"));
}
#[test]
fn transaction_error_display_commit() {
let err: TransactionError<std::io::Error> =
TransactionError::Commit(std::io::Error::other("test"));
assert!(err.to_string().contains("commit"));
}
#[test]
fn transaction_error_display_rollback() {
let err: TransactionError<std::io::Error> =
TransactionError::Rollback(std::io::Error::other("test"));
assert!(err.to_string().contains("rollback"));
}
#[test]
fn transaction_error_display_operation() {
let err: TransactionError<std::io::Error> =
TransactionError::Operation(std::io::Error::other("test"));
assert!(err.to_string().contains("operation"));
}
#[test]
fn transaction_error_is_methods() {
let begin: TransactionError<&str> = TransactionError::Begin("e");
let commit: TransactionError<&str> = TransactionError::Commit("e");
let rollback: TransactionError<&str> = TransactionError::Rollback("e");
let operation: TransactionError<&str> = TransactionError::Operation("e");
assert!(begin.is_begin());
assert!(!begin.is_commit());
assert!(!begin.is_rollback());
assert!(!begin.is_operation());
assert!(!commit.is_begin());
assert!(commit.is_commit());
assert!(!commit.is_rollback());
assert!(!commit.is_operation());
assert!(!rollback.is_begin());
assert!(!rollback.is_commit());
assert!(rollback.is_rollback());
assert!(!rollback.is_operation());
assert!(!operation.is_begin());
assert!(!operation.is_commit());
assert!(!operation.is_rollback());
assert!(operation.is_operation());
}
#[test]
fn transaction_error_into_inner() {
let err: TransactionError<&str> = TransactionError::Operation("test");
assert_eq!(err.into_inner(), "test");
}
#[test]
fn transaction_error_into_inner_begin() {
let err: TransactionError<&str> = TransactionError::Begin("begin_err");
assert_eq!(err.into_inner(), "begin_err");
}
#[test]
fn transaction_error_into_inner_commit() {
let err: TransactionError<&str> = TransactionError::Commit("commit_err");
assert_eq!(err.into_inner(), "commit_err");
}
#[test]
fn transaction_error_into_inner_rollback() {
let err: TransactionError<&str> = TransactionError::Rollback("rollback_err");
assert_eq!(err.into_inner(), "rollback_err");
}
#[test]
fn transaction_error_source_begin() {
let err: TransactionError<std::io::Error> =
TransactionError::Begin(std::io::Error::other("src"));
assert!(err.source().is_some());
}
#[test]
fn transaction_error_source_commit() {
let err: TransactionError<std::io::Error> =
TransactionError::Commit(std::io::Error::other("src"));
assert!(err.source().is_some());
}
#[test]
fn transaction_error_source_rollback() {
let err: TransactionError<std::io::Error> =
TransactionError::Rollback(std::io::Error::other("src"));
assert!(err.source().is_some());
}
#[test]
fn transaction_error_source_operation() {
let err: TransactionError<std::io::Error> =
TransactionError::Operation(std::io::Error::other("src"));
assert!(err.source().is_some());
}
#[test]
fn transaction_builder_new() {
struct MockPool;
let pool = MockPool;
let tx = Transaction::new(&pool);
let _ = tx.pool();
}
#[test]
fn transaction_builder_pool_accessor() {
struct MockPool {
id: u32
}
let pool = MockPool {
id: 42
};
let tx = Transaction::new(&pool);
assert_eq!(tx.pool().id, 42);
}
#[test]
fn transaction_error_debug() {
let err: TransactionError<&str> = TransactionError::Begin("test");
let debug_str = format!("{:?}", err);
assert!(debug_str.contains("Begin"));
assert!(debug_str.contains("test"));
}
#[test]
fn transaction_error_into_inner_all_variants() {
let begin: TransactionError<String> = TransactionError::Begin("begin".to_string());
let commit: TransactionError<String> = TransactionError::Commit("commit".to_string());
let rollback: TransactionError<String> =
TransactionError::Rollback("rollback".to_string());
let operation: TransactionError<String> = TransactionError::Operation("op".to_string());
assert_eq!(begin.into_inner(), "begin");
assert_eq!(commit.into_inner(), "commit");
assert_eq!(rollback.into_inner(), "rollback");
assert_eq!(operation.into_inner(), "op");
}
#[test]
fn transaction_error_source_all_variants() {
let begin: TransactionError<std::io::Error> =
TransactionError::Begin(std::io::Error::other("src"));
let commit: TransactionError<std::io::Error> =
TransactionError::Commit(std::io::Error::other("src"));
let rollback: TransactionError<std::io::Error> =
TransactionError::Rollback(std::io::Error::other("src"));
let operation: TransactionError<std::io::Error> =
TransactionError::Operation(std::io::Error::other("src"));
assert!(begin.source().is_some());
assert!(commit.source().is_some());
assert!(rollback.source().is_some());
assert!(operation.source().is_some());
}
#[test]
fn transaction_error_display_all_variants() {
let begin: TransactionError<std::io::Error> =
TransactionError::Begin(std::io::Error::other("msg"));
let commit: TransactionError<std::io::Error> =
TransactionError::Commit(std::io::Error::other("msg"));
let rollback: TransactionError<std::io::Error> =
TransactionError::Rollback(std::io::Error::other("msg"));
let operation: TransactionError<std::io::Error> =
TransactionError::Operation(std::io::Error::other("msg"));
let begin_str = begin.to_string();
let commit_str = commit.to_string();
let rollback_str = rollback.to_string();
let operation_str = operation.to_string();
assert!(begin_str.contains("begin"));
assert!(commit_str.contains("commit"));
assert!(rollback_str.contains("rollback"));
assert!(operation_str.contains("operation"));
}
#[test]
fn transaction_error_is_all_variants() {
let begin: TransactionError<&str> = TransactionError::Begin("e");
let commit: TransactionError<&str> = TransactionError::Commit("e");
let rollback: TransactionError<&str> = TransactionError::Rollback("e");
let operation: TransactionError<&str> = TransactionError::Operation("e");
assert!(begin.is_begin());
assert!(commit.is_commit());
assert!(rollback.is_rollback());
assert!(operation.is_operation());
assert!(!begin.is_commit());
assert!(!begin.is_rollback());
assert!(!begin.is_operation());
assert!(!commit.is_begin());
assert!(!commit.is_rollback());
assert!(!commit.is_operation());
assert!(!rollback.is_begin());
assert!(!rollback.is_commit());
assert!(!rollback.is_operation());
assert!(!operation.is_begin());
assert!(!operation.is_commit());
assert!(!operation.is_rollback());
}
#[test]
fn transaction_builder_new_const() {
struct MockPool;
let pool = MockPool;
let tx = Transaction::new(&pool);
let _ = tx;
}
#[derive(Debug, PartialEq, Eq)]
struct MockCtx;
#[derive(Debug, PartialEq, Eq)]
struct CommitErr(&'static str);
#[derive(Debug, PartialEq, Eq)]
enum AppErr {
Closure(&'static str),
Commit(&'static str)
}
impl From<CommitErr> for AppErr {
fn from(e: CommitErr) -> Self {
Self::Commit(e.0)
}
}
#[tokio::test]
async fn finalize_commits_on_ok() {
let committed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let flag = committed.clone();
let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
MockCtx,
Ok::<i32, AppErr>(42),
move |_ctx| {
let flag = flag.clone();
async move {
flag.store(true, std::sync::atomic::Ordering::SeqCst);
Ok::<(), CommitErr>(())
}
}
)
.await;
assert_eq!(result, Ok(42));
assert!(
committed.load(std::sync::atomic::Ordering::SeqCst),
"commit_fn must run on Ok"
);
}
#[tokio::test]
async fn finalize_skips_commit_on_err() {
let committed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let flag = committed.clone();
let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
MockCtx,
Err::<i32, AppErr>(AppErr::Closure("nope")),
move |_ctx| {
let flag = flag.clone();
async move {
flag.store(true, std::sync::atomic::Ordering::SeqCst);
Ok::<(), CommitErr>(())
}
}
)
.await;
assert_eq!(result, Err(AppErr::Closure("nope")));
assert!(
!committed.load(std::sync::atomic::Ordering::SeqCst),
"commit_fn must NOT run on Err"
);
}
#[tokio::test]
async fn finalize_propagates_commit_error_on_ok() {
let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
MockCtx,
Ok::<i32, AppErr>(42),
|_ctx| async { Err::<(), CommitErr>(CommitErr("commit failed")) }
)
.await;
assert_eq!(result, Err(AppErr::Commit("commit failed")));
}
#[tokio::test]
async fn finalize_preserves_closure_value_on_ok() {
let result: Result<String, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
MockCtx,
Ok::<String, AppErr>("payload".to_string()),
|_ctx| async { Ok::<(), CommitErr>(()) }
)
.await;
assert_eq!(result, Ok("payload".to_string()));
}
#[tokio::test]
async fn finalize_does_not_swallow_closure_error_when_commit_also_would_fail() {
let result: Result<(), AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
MockCtx,
Err::<(), AppErr>(AppErr::Closure("original")),
|_ctx| async { Err::<(), CommitErr>(CommitErr("never reached")) }
)
.await;
assert_eq!(result, Err(AppErr::Closure("original")));
}
}