#[cfg(feature = "postgres")]
use std::future::Future;
use std::{error::Error as StdError, fmt};
pub struct Transaction<'p, DB> {
pool: &'p DB
}
impl<'p, DB> Transaction<'p, DB> {
pub const fn new(pool: &'p DB) -> Self {
Self {
pool
}
}
pub const fn pool(&self) -> &'p DB {
self.pool
}
}
#[cfg(feature = "postgres")]
pub struct TransactionContext {
tx: sqlx::Transaction<'static, sqlx::Postgres>
}
#[cfg(feature = "postgres")]
impl TransactionContext {
#[doc(hidden)]
pub fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
Self {
tx
}
}
pub fn transaction(&mut self) -> &mut sqlx::Transaction<'static, sqlx::Postgres> {
&mut self.tx
}
pub async fn commit(self) -> Result<(), sqlx::Error> {
self.tx.commit().await
}
pub async fn rollback(self) -> Result<(), sqlx::Error> {
self.tx.rollback().await
}
}
#[derive(Debug)]
pub enum TransactionError<E> {
Begin(E),
Commit(E),
Rollback(E),
Operation(E)
}
impl<E: fmt::Display> fmt::Display for TransactionError<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
Self::Operation(e) => write!(f, "transaction operation failed: {e}")
}
}
}
impl<E: StdError + 'static> StdError for TransactionError<E> {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
}
}
}
impl<E> TransactionError<E> {
pub const fn is_begin(&self) -> bool {
matches!(self, Self::Begin(_))
}
pub const fn is_commit(&self) -> bool {
matches!(self, Self::Commit(_))
}
pub const fn is_rollback(&self) -> bool {
matches!(self, Self::Rollback(_))
}
pub const fn is_operation(&self) -> bool {
matches!(self, Self::Operation(_))
}
pub fn into_inner(self) -> E {
match self {
Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
}
}
}
#[cfg(feature = "postgres")]
impl From<TransactionError<sqlx::Error>> for sqlx::Error {
fn from(err: TransactionError<sqlx::Error>) -> Self {
err.into_inner()
}
}
#[cfg(feature = "postgres")]
impl<'p> Transaction<'p, sqlx::PgPool> {
pub async fn run<F, Fut, T, E>(self, f: F) -> Result<T, E>
where
F: FnOnce(TransactionContext) -> Fut + Send,
Fut: Future<Output = Result<T, E>> + Send,
E: From<sqlx::Error>
{
let tx = self.pool.begin().await.map_err(E::from)?;
let ctx = TransactionContext::new(tx);
match f(ctx).await {
Ok(result) => Ok(result),
Err(e) => Err(e)
}
}
pub async fn run_with_commit<F, Fut, T, E>(self, f: F) -> Result<T, E>
where
F: FnOnce(TransactionContext) -> Fut + Send,
Fut: Future<Output = Result<T, E>> + Send,
E: From<sqlx::Error>
{
let tx = self.pool.begin().await.map_err(E::from)?;
let ctx = TransactionContext::new(tx);
f(ctx).await
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use super::*;
#[test]
fn transaction_error_display_begin() {
let err: TransactionError<std::io::Error> =
TransactionError::Begin(std::io::Error::other("test"));
assert!(err.to_string().contains("begin"));
assert!(err.to_string().contains("test"));
}
#[test]
fn transaction_error_display_commit() {
let err: TransactionError<std::io::Error> =
TransactionError::Commit(std::io::Error::other("test"));
assert!(err.to_string().contains("commit"));
}
#[test]
fn transaction_error_display_rollback() {
let err: TransactionError<std::io::Error> =
TransactionError::Rollback(std::io::Error::other("test"));
assert!(err.to_string().contains("rollback"));
}
#[test]
fn transaction_error_display_operation() {
let err: TransactionError<std::io::Error> =
TransactionError::Operation(std::io::Error::other("test"));
assert!(err.to_string().contains("operation"));
}
#[test]
fn transaction_error_is_methods() {
let begin: TransactionError<&str> = TransactionError::Begin("e");
let commit: TransactionError<&str> = TransactionError::Commit("e");
let rollback: TransactionError<&str> = TransactionError::Rollback("e");
let operation: TransactionError<&str> = TransactionError::Operation("e");
assert!(begin.is_begin());
assert!(!begin.is_commit());
assert!(!begin.is_rollback());
assert!(!begin.is_operation());
assert!(!commit.is_begin());
assert!(commit.is_commit());
assert!(!commit.is_rollback());
assert!(!commit.is_operation());
assert!(!rollback.is_begin());
assert!(!rollback.is_commit());
assert!(rollback.is_rollback());
assert!(!rollback.is_operation());
assert!(!operation.is_begin());
assert!(!operation.is_commit());
assert!(!operation.is_rollback());
assert!(operation.is_operation());
}
#[test]
fn transaction_error_into_inner() {
let err: TransactionError<&str> = TransactionError::Operation("test");
assert_eq!(err.into_inner(), "test");
}
#[test]
fn transaction_error_into_inner_begin() {
let err: TransactionError<&str> = TransactionError::Begin("begin_err");
assert_eq!(err.into_inner(), "begin_err");
}
#[test]
fn transaction_error_into_inner_commit() {
let err: TransactionError<&str> = TransactionError::Commit("commit_err");
assert_eq!(err.into_inner(), "commit_err");
}
#[test]
fn transaction_error_into_inner_rollback() {
let err: TransactionError<&str> = TransactionError::Rollback("rollback_err");
assert_eq!(err.into_inner(), "rollback_err");
}
#[test]
fn transaction_error_source_begin() {
let err: TransactionError<std::io::Error> =
TransactionError::Begin(std::io::Error::other("src"));
assert!(err.source().is_some());
}
#[test]
fn transaction_error_source_commit() {
let err: TransactionError<std::io::Error> =
TransactionError::Commit(std::io::Error::other("src"));
assert!(err.source().is_some());
}
#[test]
fn transaction_error_source_rollback() {
let err: TransactionError<std::io::Error> =
TransactionError::Rollback(std::io::Error::other("src"));
assert!(err.source().is_some());
}
#[test]
fn transaction_error_source_operation() {
let err: TransactionError<std::io::Error> =
TransactionError::Operation(std::io::Error::other("src"));
assert!(err.source().is_some());
}
#[test]
fn transaction_builder_new() {
struct MockPool;
let pool = MockPool;
let tx = Transaction::new(&pool);
let _ = tx.pool();
}
#[test]
fn transaction_builder_pool_accessor() {
struct MockPool {
id: u32
}
let pool = MockPool {
id: 42
};
let tx = Transaction::new(&pool);
assert_eq!(tx.pool().id, 42);
}
#[test]
fn transaction_error_debug() {
let err: TransactionError<&str> = TransactionError::Begin("test");
let debug_str = format!("{:?}", err);
assert!(debug_str.contains("Begin"));
assert!(debug_str.contains("test"));
}
}