Skip to main content

entity_core/
transaction.rs

1// SPDX-FileCopyrightText: 2025-2026 RAprogramm <andrey.rozanov.vl@gmail.com>
2// SPDX-License-Identifier: MIT
3
4//! Transaction support for entity-derive.
5//!
6//! This module provides type-safe transaction management with automatic
7//! commit/rollback semantics. It uses a fluent builder pattern for composing
8//! multiple entity operations into a single transaction.
9//!
10//! # Overview
11//!
12//! - [`Transaction`] — Entry point for creating transactions
13//! - [`TransactionContext`] — Holds active transaction, provides repo access
14//! - [`TransactionError`] — Error wrapper for transaction operations
15//!
16//! # Example
17//!
18//! ```rust,ignore
19//! use entity_derive::prelude::*;
20//!
21//! async fn transfer(pool: &PgPool, from: Uuid, to: Uuid, amount: i64) -> Result<(), AppError> {
22//!     Transaction::new(pool)
23//!         .with_accounts()
24//!         .with_transfers()
25//!         .run(async |ctx| {
26//!             let from_acc = ctx.accounts().find_by_id(from).await?.ok_or(AppError::NotFound)?;
27//!
28//!             ctx.accounts().update(from, UpdateAccount {
29//!                 balance: Some(from_acc.balance - amount),
30//!                 ..Default::default()
31//!             }).await?;
32//!
33//!             ctx.transfers().create(CreateTransfer { from, to, amount }).await?;
34//!             Ok(())
35//!         })
36//!         .await
37//! }
38//! ```
39
40#[cfg(feature = "postgres")]
41use std::future::Future;
42use std::{error::Error as StdError, fmt};
43
44/// Transaction builder for composing multi-entity operations.
45///
46/// Use [`Transaction::new`] to create a builder, chain `.with_*()` methods
47/// to declare which entities you'll use, then call `.run()` to execute.
48///
49/// # Type Parameters
50///
51/// - `'p` — Pool lifetime
52/// - `DB` — Database pool type (e.g., `PgPool`)
53///
54/// # Example
55///
56/// ```rust,ignore
57/// Transaction::new(&pool)
58///     .with_users()
59///     .with_orders()
60///     .run(async |ctx| {
61///         let user = ctx.users().find_by_id(id).await?;
62///         ctx.orders().create(order).await?;
63///         Ok(())
64///     })
65///     .await?;
66/// ```
67pub struct Transaction<'p, DB> {
68    pool: &'p DB
69}
70
71impl<'p, DB> Transaction<'p, DB> {
72    /// Create a new transaction builder.
73    ///
74    /// # Arguments
75    ///
76    /// * `pool` — Database connection pool
77    ///
78    /// # Example
79    ///
80    /// ```rust,ignore
81    /// let tx = Transaction::new(&pool);
82    /// ```
83    pub const fn new(pool: &'p DB) -> Self {
84        Self {
85            pool
86        }
87    }
88
89    /// Get reference to the underlying pool.
90    #[must_use]
91    pub const fn pool(&self) -> &'p DB {
92        self.pool
93    }
94}
95
96/// Active transaction context with repository access.
97///
98/// This struct holds the database transaction and provides access to
99/// entity repositories via extension traits generated by the macro.
100///
101/// # Automatic Rollback
102///
103/// If dropped without explicit commit, the transaction is automatically
104/// rolled back via the underlying database transaction's Drop impl.
105///
106/// # Accessing Repositories
107///
108/// Each entity with `#[entity(transactions)]` generates an extension trait
109/// that adds an accessor method:
110///
111/// ```rust,ignore
112/// // For entity BankAccount, use:
113/// ctx.bank_accounts().find_by_id(id).await?;
114/// ctx.bank_accounts().create(dto).await?;
115/// ctx.bank_accounts().update(id, dto).await?;
116/// ```
117#[cfg(feature = "postgres")]
118pub struct TransactionContext {
119    tx: sqlx::Transaction<'static, sqlx::Postgres>
120}
121
122#[cfg(feature = "postgres")]
123impl TransactionContext {
124    /// Create a new transaction context.
125    ///
126    /// # Arguments
127    ///
128    /// * `tx` — Active database transaction
129    #[doc(hidden)]
130    #[must_use]
131    pub const fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
132        Self {
133            tx
134        }
135    }
136
137    /// Get mutable reference to the underlying transaction.
138    ///
139    /// Use this for custom queries within the transaction or
140    /// for repository adapters to execute queries.
141    pub const fn transaction(&mut self) -> &mut sqlx::Transaction<'static, sqlx::Postgres> {
142        &mut self.tx
143    }
144
145    /// Commit the transaction.
146    ///
147    /// Consumes self and commits all changes.
148    ///
149    /// # Errors
150    ///
151    /// Propagates any `sqlx::Error` from the database transaction.
152    pub async fn commit(self) -> Result<(), sqlx::Error> {
153        self.tx.commit().await
154    }
155
156    /// Rollback the transaction.
157    ///
158    /// Consumes self and rolls back all changes.
159    ///
160    /// # Errors
161    ///
162    /// Propagates any `sqlx::Error` from the database transaction.
163    pub async fn rollback(self) -> Result<(), sqlx::Error> {
164        self.tx.rollback().await
165    }
166}
167
168/// Error type for transaction operations.
169///
170/// Wraps database errors and provides context about the transaction state.
171#[derive(Debug)]
172pub enum TransactionError<E> {
173    /// Failed to begin transaction.
174    Begin(E),
175
176    /// Failed to commit transaction.
177    Commit(E),
178
179    /// Failed to rollback transaction.
180    Rollback(E),
181
182    /// Operation within transaction failed.
183    Operation(E)
184}
185
186impl<E: fmt::Display> fmt::Display for TransactionError<E> {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        match self {
189            Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
190            Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
191            Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
192            Self::Operation(e) => write!(f, "transaction operation failed: {e}")
193        }
194    }
195}
196
197impl<E: StdError + 'static> StdError for TransactionError<E> {
198    fn source(&self) -> Option<&(dyn StdError + 'static)> {
199        match self {
200            Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
201        }
202    }
203}
204
205impl<E> TransactionError<E> {
206    /// Check if this is a begin error.
207    pub const fn is_begin(&self) -> bool {
208        matches!(self, Self::Begin(_))
209    }
210
211    /// Check if this is a commit error.
212    pub const fn is_commit(&self) -> bool {
213        matches!(self, Self::Commit(_))
214    }
215
216    /// Check if this is a rollback error.
217    pub const fn is_rollback(&self) -> bool {
218        matches!(self, Self::Rollback(_))
219    }
220
221    /// Check if this is an operation error.
222    pub const fn is_operation(&self) -> bool {
223        matches!(self, Self::Operation(_))
224    }
225
226    /// Get the inner error.
227    pub fn into_inner(self) -> E {
228        match self {
229            Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
230        }
231    }
232}
233
234#[cfg(feature = "postgres")]
235impl From<TransactionError<Self>> for sqlx::Error {
236    fn from(err: TransactionError<Self>) -> Self {
237        err.into_inner()
238    }
239}
240
241/// Finalize a transaction lifecycle: commit on `Ok`, drop (rollback) on `Err`.
242///
243/// Backend-agnostic helper extracted so the commit/rollback decision can be
244/// unit-tested without a live database connection. Tests provide a mock
245/// `ctx` and a tracking `commit_fn` to assert that:
246///
247/// - `commit_fn` runs exactly once when `result` is `Ok`
248/// - `commit_fn` does **not** run when `result` is `Err`
249/// - Errors from `commit_fn` propagate via `E::from`
250///
251/// # Errors
252///
253/// Returns the closure's original error on `Err`, or the converted commit
254/// error if `commit_fn` fails on `Ok`.
255#[cfg(any(feature = "postgres", test))]
256async fn finalize_with_commit<C, T, E, CommitErr, Cf, Fut>(
257    ctx: C,
258    result: Result<T, E>,
259    commit_fn: Cf
260) -> Result<T, E>
261where
262    Cf: FnOnce(C) -> Fut,
263    Fut: core::future::Future<Output = Result<(), CommitErr>>,
264    E: From<CommitErr>
265{
266    match result {
267        Ok(value) => {
268            commit_fn(ctx).await.map_err(E::from)?;
269            Ok(value)
270        }
271        Err(e) => Err(e)
272    }
273}
274
275// PostgreSQL implementation
276#[cfg(feature = "postgres")]
277impl Transaction<'_, sqlx::PgPool> {
278    /// Execute a closure within a `PostgreSQL` transaction.
279    ///
280    /// Commits the transaction explicitly when the closure returns `Ok`.
281    /// On `Err`, the transaction context is dropped and `sqlx` rolls back
282    /// automatically via its `Drop` implementation.
283    ///
284    /// The closure receives `&mut TransactionContext` (not by value) so that
285    /// `run` retains ownership and can invoke `commit().await` on success.
286    ///
287    /// # Type Parameters
288    ///
289    /// - `F` — Async closure
290    /// - `T` — Success type
291    /// - `E` — Error type (must be convertible from `sqlx::Error`)
292    ///
293    /// # Example
294    ///
295    /// ```rust,ignore
296    /// Transaction::new(&pool)
297    ///     .with_users()
298    ///     .run(async |ctx| {
299    ///         let user = ctx.users().create(dto).await?;
300    ///         Ok(user)
301    ///     })
302    ///     .await?;
303    /// ```
304    ///
305    /// # Errors
306    ///
307    /// Propagates any error from the closure, from `begin`, or from `commit`.
308    pub async fn run<F, T, E>(self, f: F) -> Result<T, E>
309    where
310        F: AsyncFnOnce(&mut TransactionContext) -> Result<T, E>,
311        E: From<sqlx::Error>
312    {
313        let tx = self.pool.begin().await.map_err(E::from)?;
314        let mut ctx = TransactionContext::new(tx);
315        let result = f(&mut ctx).await;
316        finalize_with_commit(ctx, result, |c| c.commit()).await
317    }
318
319    /// Execute a closure within a transaction with explicit commit.
320    ///
321    /// Unlike [`run`](Self::run), this method passes `TransactionContext` by
322    /// value so the closure can call `ctx.commit().await` (or
323    /// `ctx.rollback().await`) itself. If the closure returns without
324    /// committing, the transaction is rolled back when `ctx` is dropped.
325    ///
326    /// Use this when you need conditional commit logic; otherwise prefer `run`.
327    ///
328    /// # Example
329    ///
330    /// ```rust,ignore
331    /// Transaction::new(&pool)
332    ///     .run_with_commit(|mut ctx| async move {
333    ///         let user = ctx.users().create(dto).await?;
334    ///         ctx.commit().await?;
335    ///         Ok(user)
336    ///     })
337    ///     .await?;
338    /// ```
339    ///
340    /// # Errors
341    ///
342    /// Propagates any error from the closure or database transaction.
343    pub async fn run_with_commit<F, Fut, T, E>(self, f: F) -> Result<T, E>
344    where
345        F: FnOnce(TransactionContext) -> Fut + Send,
346        Fut: Future<Output = Result<T, E>> + Send,
347        E: From<sqlx::Error>
348    {
349        let tx = self.pool.begin().await.map_err(E::from)?;
350        let ctx = TransactionContext::new(tx);
351        f(ctx).await
352    }
353}
354
355#[cfg(test)]
356#[allow(clippy::uninlined_format_args)]
357mod tests {
358    use std::error::Error;
359
360    use super::*;
361
362    #[test]
363    fn transaction_error_display_begin() {
364        let err: TransactionError<std::io::Error> =
365            TransactionError::Begin(std::io::Error::other("test"));
366        assert!(err.to_string().contains("begin"));
367        assert!(err.to_string().contains("test"));
368    }
369
370    #[test]
371    fn transaction_error_display_commit() {
372        let err: TransactionError<std::io::Error> =
373            TransactionError::Commit(std::io::Error::other("test"));
374        assert!(err.to_string().contains("commit"));
375    }
376
377    #[test]
378    fn transaction_error_display_rollback() {
379        let err: TransactionError<std::io::Error> =
380            TransactionError::Rollback(std::io::Error::other("test"));
381        assert!(err.to_string().contains("rollback"));
382    }
383
384    #[test]
385    fn transaction_error_display_operation() {
386        let err: TransactionError<std::io::Error> =
387            TransactionError::Operation(std::io::Error::other("test"));
388        assert!(err.to_string().contains("operation"));
389    }
390
391    #[test]
392    fn transaction_error_is_methods() {
393        let begin: TransactionError<&str> = TransactionError::Begin("e");
394        let commit: TransactionError<&str> = TransactionError::Commit("e");
395        let rollback: TransactionError<&str> = TransactionError::Rollback("e");
396        let operation: TransactionError<&str> = TransactionError::Operation("e");
397
398        assert!(begin.is_begin());
399        assert!(!begin.is_commit());
400        assert!(!begin.is_rollback());
401        assert!(!begin.is_operation());
402
403        assert!(!commit.is_begin());
404        assert!(commit.is_commit());
405        assert!(!commit.is_rollback());
406        assert!(!commit.is_operation());
407
408        assert!(!rollback.is_begin());
409        assert!(!rollback.is_commit());
410        assert!(rollback.is_rollback());
411        assert!(!rollback.is_operation());
412
413        assert!(!operation.is_begin());
414        assert!(!operation.is_commit());
415        assert!(!operation.is_rollback());
416        assert!(operation.is_operation());
417    }
418
419    #[test]
420    fn transaction_error_into_inner() {
421        let err: TransactionError<&str> = TransactionError::Operation("test");
422        assert_eq!(err.into_inner(), "test");
423    }
424
425    #[test]
426    fn transaction_error_into_inner_begin() {
427        let err: TransactionError<&str> = TransactionError::Begin("begin_err");
428        assert_eq!(err.into_inner(), "begin_err");
429    }
430
431    #[test]
432    fn transaction_error_into_inner_commit() {
433        let err: TransactionError<&str> = TransactionError::Commit("commit_err");
434        assert_eq!(err.into_inner(), "commit_err");
435    }
436
437    #[test]
438    fn transaction_error_into_inner_rollback() {
439        let err: TransactionError<&str> = TransactionError::Rollback("rollback_err");
440        assert_eq!(err.into_inner(), "rollback_err");
441    }
442
443    #[test]
444    fn transaction_error_source_begin() {
445        let err: TransactionError<std::io::Error> =
446            TransactionError::Begin(std::io::Error::other("src"));
447        assert!(err.source().is_some());
448    }
449
450    #[test]
451    fn transaction_error_source_commit() {
452        let err: TransactionError<std::io::Error> =
453            TransactionError::Commit(std::io::Error::other("src"));
454        assert!(err.source().is_some());
455    }
456
457    #[test]
458    fn transaction_error_source_rollback() {
459        let err: TransactionError<std::io::Error> =
460            TransactionError::Rollback(std::io::Error::other("src"));
461        assert!(err.source().is_some());
462    }
463
464    #[test]
465    fn transaction_error_source_operation() {
466        let err: TransactionError<std::io::Error> =
467            TransactionError::Operation(std::io::Error::other("src"));
468        assert!(err.source().is_some());
469    }
470
471    #[test]
472    fn transaction_builder_new() {
473        struct MockPool;
474        let pool = MockPool;
475        let tx = Transaction::new(&pool);
476        let _ = tx.pool();
477    }
478
479    #[test]
480    fn transaction_builder_pool_accessor() {
481        struct MockPool {
482            id: u32
483        }
484        let pool = MockPool {
485            id: 42
486        };
487        let tx = Transaction::new(&pool);
488        assert_eq!(tx.pool().id, 42);
489    }
490
491    #[test]
492    fn transaction_error_debug() {
493        let err: TransactionError<&str> = TransactionError::Begin("test");
494        let debug_str = format!("{:?}", err);
495        assert!(debug_str.contains("Begin"));
496        assert!(debug_str.contains("test"));
497    }
498
499    #[test]
500    fn transaction_error_into_inner_all_variants() {
501        let begin: TransactionError<String> = TransactionError::Begin("begin".to_string());
502        let commit: TransactionError<String> = TransactionError::Commit("commit".to_string());
503        let rollback: TransactionError<String> =
504            TransactionError::Rollback("rollback".to_string());
505        let operation: TransactionError<String> = TransactionError::Operation("op".to_string());
506
507        assert_eq!(begin.into_inner(), "begin");
508        assert_eq!(commit.into_inner(), "commit");
509        assert_eq!(rollback.into_inner(), "rollback");
510        assert_eq!(operation.into_inner(), "op");
511    }
512
513    #[test]
514    fn transaction_error_source_all_variants() {
515        let begin: TransactionError<std::io::Error> =
516            TransactionError::Begin(std::io::Error::other("src"));
517        let commit: TransactionError<std::io::Error> =
518            TransactionError::Commit(std::io::Error::other("src"));
519        let rollback: TransactionError<std::io::Error> =
520            TransactionError::Rollback(std::io::Error::other("src"));
521        let operation: TransactionError<std::io::Error> =
522            TransactionError::Operation(std::io::Error::other("src"));
523
524        assert!(begin.source().is_some());
525        assert!(commit.source().is_some());
526        assert!(rollback.source().is_some());
527        assert!(operation.source().is_some());
528    }
529
530    #[test]
531    fn transaction_error_display_all_variants() {
532        let begin: TransactionError<std::io::Error> =
533            TransactionError::Begin(std::io::Error::other("msg"));
534        let commit: TransactionError<std::io::Error> =
535            TransactionError::Commit(std::io::Error::other("msg"));
536        let rollback: TransactionError<std::io::Error> =
537            TransactionError::Rollback(std::io::Error::other("msg"));
538        let operation: TransactionError<std::io::Error> =
539            TransactionError::Operation(std::io::Error::other("msg"));
540
541        let begin_str = begin.to_string();
542        let commit_str = commit.to_string();
543        let rollback_str = rollback.to_string();
544        let operation_str = operation.to_string();
545
546        assert!(begin_str.contains("begin"));
547        assert!(commit_str.contains("commit"));
548        assert!(rollback_str.contains("rollback"));
549        assert!(operation_str.contains("operation"));
550    }
551
552    #[test]
553    fn transaction_error_is_all_variants() {
554        let begin: TransactionError<&str> = TransactionError::Begin("e");
555        let commit: TransactionError<&str> = TransactionError::Commit("e");
556        let rollback: TransactionError<&str> = TransactionError::Rollback("e");
557        let operation: TransactionError<&str> = TransactionError::Operation("e");
558
559        assert!(begin.is_begin());
560        assert!(commit.is_commit());
561        assert!(rollback.is_rollback());
562        assert!(operation.is_operation());
563
564        assert!(!begin.is_commit());
565        assert!(!begin.is_rollback());
566        assert!(!begin.is_operation());
567
568        assert!(!commit.is_begin());
569        assert!(!commit.is_rollback());
570        assert!(!commit.is_operation());
571
572        assert!(!rollback.is_begin());
573        assert!(!rollback.is_commit());
574        assert!(!rollback.is_operation());
575
576        assert!(!operation.is_begin());
577        assert!(!operation.is_commit());
578        assert!(!operation.is_rollback());
579    }
580
581    #[test]
582    fn transaction_builder_new_const() {
583        struct MockPool;
584        let pool = MockPool;
585        let tx = Transaction::new(&pool);
586        let _ = tx;
587    }
588
589    // Regression tests for `finalize_with_commit`.
590    //
591    // Prior to this fix, `Transaction::run` consumed `TransactionContext` and
592    // dropped it before commit was ever called, so successful runs silently
593    // rolled back. The fix is to keep ownership of `ctx` in `run` and call
594    // commit on Ok. The backend-agnostic decision lives in
595    // `finalize_with_commit`, which these tests cover end-to-end with a mock
596    // context and a tracking commit closure — no database required.
597
598    #[derive(Debug, PartialEq, Eq)]
599    struct MockCtx;
600
601    #[derive(Debug, PartialEq, Eq)]
602    struct CommitErr(&'static str);
603
604    #[derive(Debug, PartialEq, Eq)]
605    enum AppErr {
606        Closure(&'static str),
607        Commit(&'static str)
608    }
609
610    impl From<CommitErr> for AppErr {
611        fn from(e: CommitErr) -> Self {
612            Self::Commit(e.0)
613        }
614    }
615
616    #[tokio::test]
617    async fn finalize_commits_on_ok() {
618        let committed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
619        let flag = committed.clone();
620
621        let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
622            MockCtx,
623            Ok::<i32, AppErr>(42),
624            move |_ctx| {
625                let flag = flag.clone();
626                async move {
627                    flag.store(true, std::sync::atomic::Ordering::SeqCst);
628                    Ok::<(), CommitErr>(())
629                }
630            }
631        )
632        .await;
633
634        assert_eq!(result, Ok(42));
635        assert!(
636            committed.load(std::sync::atomic::Ordering::SeqCst),
637            "commit_fn must run on Ok"
638        );
639    }
640
641    #[tokio::test]
642    async fn finalize_skips_commit_on_err() {
643        let committed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
644        let flag = committed.clone();
645
646        let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
647            MockCtx,
648            Err::<i32, AppErr>(AppErr::Closure("nope")),
649            move |_ctx| {
650                let flag = flag.clone();
651                async move {
652                    flag.store(true, std::sync::atomic::Ordering::SeqCst);
653                    Ok::<(), CommitErr>(())
654                }
655            }
656        )
657        .await;
658
659        assert_eq!(result, Err(AppErr::Closure("nope")));
660        assert!(
661            !committed.load(std::sync::atomic::Ordering::SeqCst),
662            "commit_fn must NOT run on Err"
663        );
664    }
665
666    #[tokio::test]
667    async fn finalize_propagates_commit_error_on_ok() {
668        let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
669            MockCtx,
670            Ok::<i32, AppErr>(42),
671            |_ctx| async { Err::<(), CommitErr>(CommitErr("commit failed")) }
672        )
673        .await;
674
675        assert_eq!(result, Err(AppErr::Commit("commit failed")));
676    }
677
678    #[tokio::test]
679    async fn finalize_preserves_closure_value_on_ok() {
680        // Confirms the Ok payload survives the commit step (return type
681        // matches the closure's success type, not the commit_fn result).
682        let result: Result<String, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
683            MockCtx,
684            Ok::<String, AppErr>("payload".to_string()),
685            |_ctx| async { Ok::<(), CommitErr>(()) }
686        )
687        .await;
688
689        assert_eq!(result, Ok("payload".to_string()));
690    }
691
692    #[tokio::test]
693    async fn finalize_does_not_swallow_closure_error_when_commit_also_would_fail() {
694        // On Err, commit_fn is never called, so a faulty commit_fn cannot
695        // hide the closure's original error.
696        let result: Result<(), AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
697            MockCtx,
698            Err::<(), AppErr>(AppErr::Closure("original")),
699            |_ctx| async { Err::<(), CommitErr>(CommitErr("never reached")) }
700        )
701        .await;
702
703        assert_eq!(result, Err(AppErr::Closure("original")));
704    }
705}