Skip to main content

entity_core/
transaction.rs

1// SPDX-FileCopyrightText: 2025-2026 RAprogramm <andrey.rozanov.vl@gmail.com>
2// SPDX-License-Identifier: MIT
3
4//! Transaction support for entity-derive.
5//!
6//! This module provides type-safe transaction management with automatic
7//! commit/rollback semantics. It uses a fluent builder pattern for composing
8//! multiple entity operations into a single transaction.
9//!
10//! # Overview
11//!
12//! - [`Transaction`] — Entry point for creating transactions
13//! - [`TransactionContext`] — Holds active transaction, provides repo access
14//! - [`TransactionError`] — Error wrapper for transaction operations
15//!
16//! # Example
17//!
18//! ```rust,ignore
19//! use entity_derive::prelude::*;
20//!
21//! async fn transfer(pool: &PgPool, from: Uuid, to: Uuid, amount: i64) -> Result<(), AppError> {
22//!     Transaction::new(pool)
23//!         .run(async |ctx| {
24//!             let from_acc = ctx.accounts().find_by_id(from).await?.ok_or(AppError::NotFound)?;
25//!
26//!             ctx.accounts().update(from, UpdateAccount {
27//!                 balance: Some(from_acc.balance - amount),
28//!                 ..Default::default()
29//!             }).await?;
30//!
31//!             ctx.transfers().create(CreateTransfer { from, to, amount }).await?;
32//!             Ok(())
33//!         })
34//!         .await
35//! }
36//! ```
37
38#[cfg(feature = "postgres")]
39use std::future::Future;
40use std::{error::Error as StdError, fmt};
41
42/// Transaction builder for composing multi-entity operations.
43///
44/// Use [`Transaction::new`] to create a builder, chain `.with_*()` methods
45/// to declare which entities you'll use, then call `.run()` to execute.
46///
47/// # Type Parameters
48///
49/// - `'p` — Pool lifetime
50/// - `DB` — Database pool type (e.g., `PgPool`)
51///
52/// # Example
53///
54/// ```rust,ignore
55/// Transaction::new(&pool)
56///     .run(async |ctx| {
57///         let user = ctx.users().find_by_id(id).await?;
58///         ctx.orders().create(order).await?;
59///         Ok(())
60///     })
61///     .await?;
62/// ```
63pub struct Transaction<'p, DB> {
64    pool: &'p DB
65}
66
67impl<'p, DB> Transaction<'p, DB> {
68    /// Create a new transaction builder.
69    ///
70    /// # Arguments
71    ///
72    /// * `pool` — Database connection pool
73    ///
74    /// # Example
75    ///
76    /// ```rust,ignore
77    /// let tx = Transaction::new(&pool);
78    /// ```
79    pub const fn new(pool: &'p DB) -> Self {
80        Self {
81            pool
82        }
83    }
84
85    /// Get reference to the underlying pool.
86    #[must_use]
87    pub const fn pool(&self) -> &'p DB {
88        self.pool
89    }
90}
91
92/// Active transaction context with repository access.
93///
94/// This struct holds the database transaction and provides access to
95/// entity repositories via extension traits generated by the macro.
96///
97/// # Automatic Rollback
98///
99/// If dropped without explicit commit, the transaction is automatically
100/// rolled back via the underlying database transaction's Drop impl.
101///
102/// # Accessing Repositories
103///
104/// Each entity with `#[entity(transactions)]` generates an extension trait
105/// that adds an accessor method:
106///
107/// ```rust,ignore
108/// // For entity BankAccount, use:
109/// ctx.bank_accounts().find_by_id(id).await?;
110/// ctx.bank_accounts().create(dto).await?;
111/// ctx.bank_accounts().update(id, dto).await?;
112/// ```
113#[cfg(feature = "postgres")]
114pub struct TransactionContext {
115    tx: sqlx::Transaction<'static, sqlx::Postgres>
116}
117
118#[cfg(feature = "postgres")]
119impl TransactionContext {
120    /// Create a new transaction context.
121    ///
122    /// # Arguments
123    ///
124    /// * `tx` — Active database transaction
125    #[doc(hidden)]
126    #[must_use]
127    pub const fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
128        Self {
129            tx
130        }
131    }
132
133    /// Get mutable reference to the underlying transaction.
134    ///
135    /// Use this for custom queries within the transaction or
136    /// for repository adapters to execute queries.
137    pub const fn transaction(&mut self) -> &mut sqlx::Transaction<'static, sqlx::Postgres> {
138        &mut self.tx
139    }
140
141    /// Commit the transaction.
142    ///
143    /// Consumes self and commits all changes.
144    ///
145    /// # Errors
146    ///
147    /// Propagates any `sqlx::Error` from the database transaction.
148    pub async fn commit(self) -> Result<(), sqlx::Error> {
149        self.tx.commit().await
150    }
151
152    /// Rollback the transaction.
153    ///
154    /// Consumes self and rolls back all changes.
155    ///
156    /// # Errors
157    ///
158    /// Propagates any `sqlx::Error` from the database transaction.
159    pub async fn rollback(self) -> Result<(), sqlx::Error> {
160        self.tx.rollback().await
161    }
162}
163
164/// Error type for transaction operations.
165///
166/// Wraps database errors and provides context about the transaction state.
167#[derive(Debug)]
168pub enum TransactionError<E> {
169    /// Failed to begin transaction.
170    Begin(E),
171
172    /// Failed to commit transaction.
173    Commit(E),
174
175    /// Failed to rollback transaction.
176    Rollback(E),
177
178    /// Operation within transaction failed.
179    Operation(E)
180}
181
182impl<E: fmt::Display> fmt::Display for TransactionError<E> {
183    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184        match self {
185            Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
186            Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
187            Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
188            Self::Operation(e) => write!(f, "transaction operation failed: {e}")
189        }
190    }
191}
192
193impl<E: StdError + 'static> StdError for TransactionError<E> {
194    fn source(&self) -> Option<&(dyn StdError + 'static)> {
195        match self {
196            Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
197        }
198    }
199}
200
201impl<E> TransactionError<E> {
202    /// Check if this is a begin error.
203    pub const fn is_begin(&self) -> bool {
204        matches!(self, Self::Begin(_))
205    }
206
207    /// Check if this is a commit error.
208    pub const fn is_commit(&self) -> bool {
209        matches!(self, Self::Commit(_))
210    }
211
212    /// Check if this is a rollback error.
213    pub const fn is_rollback(&self) -> bool {
214        matches!(self, Self::Rollback(_))
215    }
216
217    /// Check if this is an operation error.
218    pub const fn is_operation(&self) -> bool {
219        matches!(self, Self::Operation(_))
220    }
221
222    /// Get the inner error.
223    pub fn into_inner(self) -> E {
224        match self {
225            Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
226        }
227    }
228}
229
230#[cfg(feature = "postgres")]
231impl From<TransactionError<Self>> for sqlx::Error {
232    fn from(err: TransactionError<Self>) -> Self {
233        err.into_inner()
234    }
235}
236
237/// Finalize a transaction lifecycle: commit on `Ok`, drop (rollback) on `Err`.
238///
239/// Backend-agnostic helper extracted so the commit/rollback decision can be
240/// unit-tested without a live database connection. Tests provide a mock
241/// `ctx` and a tracking `commit_fn` to assert that:
242///
243/// - `commit_fn` runs exactly once when `result` is `Ok`
244/// - `commit_fn` does **not** run when `result` is `Err`
245/// - Errors from `commit_fn` propagate via `E::from`
246///
247/// # Errors
248///
249/// Returns the closure's original error on `Err`, or the converted commit
250/// error if `commit_fn` fails on `Ok`.
251#[cfg(any(feature = "postgres", test))]
252async fn finalize_with_commit<C, T, E, CommitErr, Cf, Fut>(
253    ctx: C,
254    result: Result<T, E>,
255    commit_fn: Cf
256) -> Result<T, E>
257where
258    Cf: FnOnce(C) -> Fut,
259    Fut: core::future::Future<Output = Result<(), CommitErr>>,
260    E: From<CommitErr>
261{
262    match result {
263        Ok(value) => {
264            commit_fn(ctx).await.map_err(E::from)?;
265            Ok(value)
266        }
267        Err(e) => Err(e)
268    }
269}
270
271// PostgreSQL implementation
272#[cfg(feature = "postgres")]
273impl Transaction<'_, sqlx::PgPool> {
274    /// Execute a closure within a `PostgreSQL` transaction.
275    ///
276    /// Commits the transaction explicitly when the closure returns `Ok`.
277    /// On `Err`, the transaction context is dropped and `sqlx` rolls back
278    /// automatically via its `Drop` implementation.
279    ///
280    /// The closure receives `&mut TransactionContext` (not by value) so that
281    /// `run` retains ownership and can invoke `commit().await` on success.
282    ///
283    /// # Type Parameters
284    ///
285    /// - `F` — Async closure
286    /// - `T` — Success type
287    /// - `E` — Error type (must be convertible from `sqlx::Error`)
288    ///
289    /// # Example
290    ///
291    /// ```rust,ignore
292    /// Transaction::new(&pool)
293    ///     .run(async |ctx| {
294    ///         let user = ctx.users().create(dto).await?;
295    ///         Ok(user)
296    ///     })
297    ///     .await?;
298    /// ```
299    ///
300    /// # Errors
301    ///
302    /// Propagates any error from the closure, from `begin`, or from `commit`.
303    pub async fn run<F, T, E>(self, f: F) -> Result<T, E>
304    where
305        F: AsyncFnOnce(&mut TransactionContext) -> Result<T, E>,
306        E: From<sqlx::Error>
307    {
308        let tx = self.pool.begin().await.map_err(E::from)?;
309        let mut ctx = TransactionContext::new(tx);
310        let result = f(&mut ctx).await;
311        finalize_with_commit(ctx, result, |c| c.commit()).await
312    }
313
314    /// Execute a closure within a transaction with explicit commit.
315    ///
316    /// Unlike [`run`](Self::run), this method passes `TransactionContext` by
317    /// value so the closure can call `ctx.commit().await` (or
318    /// `ctx.rollback().await`) itself. If the closure returns without
319    /// committing, the transaction is rolled back when `ctx` is dropped.
320    ///
321    /// Use this when you need conditional commit logic; otherwise prefer `run`.
322    ///
323    /// # Example
324    ///
325    /// ```rust,ignore
326    /// Transaction::new(&pool)
327    ///     .run_with_commit(|mut ctx| async move {
328    ///         let user = ctx.users().create(dto).await?;
329    ///         ctx.commit().await?;
330    ///         Ok(user)
331    ///     })
332    ///     .await?;
333    /// ```
334    ///
335    /// # Errors
336    ///
337    /// Propagates any error from the closure or database transaction.
338    pub async fn run_with_commit<F, Fut, T, E>(self, f: F) -> Result<T, E>
339    where
340        F: FnOnce(TransactionContext) -> Fut + Send,
341        Fut: Future<Output = Result<T, E>> + Send,
342        E: From<sqlx::Error>
343    {
344        let tx = self.pool.begin().await.map_err(E::from)?;
345        let ctx = TransactionContext::new(tx);
346        f(ctx).await
347    }
348}
349
350#[cfg(test)]
351#[allow(clippy::uninlined_format_args)]
352mod tests {
353    use std::error::Error;
354
355    use super::*;
356
357    #[test]
358    fn transaction_error_display_begin() {
359        let err: TransactionError<std::io::Error> =
360            TransactionError::Begin(std::io::Error::other("test"));
361        assert!(err.to_string().contains("begin"));
362        assert!(err.to_string().contains("test"));
363    }
364
365    #[test]
366    fn transaction_error_display_commit() {
367        let err: TransactionError<std::io::Error> =
368            TransactionError::Commit(std::io::Error::other("test"));
369        assert!(err.to_string().contains("commit"));
370    }
371
372    #[test]
373    fn transaction_error_display_rollback() {
374        let err: TransactionError<std::io::Error> =
375            TransactionError::Rollback(std::io::Error::other("test"));
376        assert!(err.to_string().contains("rollback"));
377    }
378
379    #[test]
380    fn transaction_error_display_operation() {
381        let err: TransactionError<std::io::Error> =
382            TransactionError::Operation(std::io::Error::other("test"));
383        assert!(err.to_string().contains("operation"));
384    }
385
386    #[test]
387    fn transaction_error_is_methods() {
388        let begin: TransactionError<&str> = TransactionError::Begin("e");
389        let commit: TransactionError<&str> = TransactionError::Commit("e");
390        let rollback: TransactionError<&str> = TransactionError::Rollback("e");
391        let operation: TransactionError<&str> = TransactionError::Operation("e");
392
393        assert!(begin.is_begin());
394        assert!(!begin.is_commit());
395        assert!(!begin.is_rollback());
396        assert!(!begin.is_operation());
397
398        assert!(!commit.is_begin());
399        assert!(commit.is_commit());
400        assert!(!commit.is_rollback());
401        assert!(!commit.is_operation());
402
403        assert!(!rollback.is_begin());
404        assert!(!rollback.is_commit());
405        assert!(rollback.is_rollback());
406        assert!(!rollback.is_operation());
407
408        assert!(!operation.is_begin());
409        assert!(!operation.is_commit());
410        assert!(!operation.is_rollback());
411        assert!(operation.is_operation());
412    }
413
414    #[test]
415    fn transaction_error_into_inner() {
416        let err: TransactionError<&str> = TransactionError::Operation("test");
417        assert_eq!(err.into_inner(), "test");
418    }
419
420    #[test]
421    fn transaction_error_into_inner_begin() {
422        let err: TransactionError<&str> = TransactionError::Begin("begin_err");
423        assert_eq!(err.into_inner(), "begin_err");
424    }
425
426    #[test]
427    fn transaction_error_into_inner_commit() {
428        let err: TransactionError<&str> = TransactionError::Commit("commit_err");
429        assert_eq!(err.into_inner(), "commit_err");
430    }
431
432    #[test]
433    fn transaction_error_into_inner_rollback() {
434        let err: TransactionError<&str> = TransactionError::Rollback("rollback_err");
435        assert_eq!(err.into_inner(), "rollback_err");
436    }
437
438    #[test]
439    fn transaction_error_source_begin() {
440        let err: TransactionError<std::io::Error> =
441            TransactionError::Begin(std::io::Error::other("src"));
442        assert!(err.source().is_some());
443    }
444
445    #[test]
446    fn transaction_error_source_commit() {
447        let err: TransactionError<std::io::Error> =
448            TransactionError::Commit(std::io::Error::other("src"));
449        assert!(err.source().is_some());
450    }
451
452    #[test]
453    fn transaction_error_source_rollback() {
454        let err: TransactionError<std::io::Error> =
455            TransactionError::Rollback(std::io::Error::other("src"));
456        assert!(err.source().is_some());
457    }
458
459    #[test]
460    fn transaction_error_source_operation() {
461        let err: TransactionError<std::io::Error> =
462            TransactionError::Operation(std::io::Error::other("src"));
463        assert!(err.source().is_some());
464    }
465
466    #[test]
467    fn transaction_builder_new() {
468        struct MockPool;
469        let pool = MockPool;
470        let tx = Transaction::new(&pool);
471        let _ = tx.pool();
472    }
473
474    #[test]
475    fn transaction_builder_pool_accessor() {
476        struct MockPool {
477            id: u32
478        }
479        let pool = MockPool {
480            id: 42
481        };
482        let tx = Transaction::new(&pool);
483        assert_eq!(tx.pool().id, 42);
484    }
485
486    #[test]
487    fn transaction_error_debug() {
488        let err: TransactionError<&str> = TransactionError::Begin("test");
489        let debug_str = format!("{:?}", err);
490        assert!(debug_str.contains("Begin"));
491        assert!(debug_str.contains("test"));
492    }
493
494    #[test]
495    fn transaction_error_into_inner_all_variants() {
496        let begin: TransactionError<String> = TransactionError::Begin("begin".to_string());
497        let commit: TransactionError<String> = TransactionError::Commit("commit".to_string());
498        let rollback: TransactionError<String> =
499            TransactionError::Rollback("rollback".to_string());
500        let operation: TransactionError<String> = TransactionError::Operation("op".to_string());
501
502        assert_eq!(begin.into_inner(), "begin");
503        assert_eq!(commit.into_inner(), "commit");
504        assert_eq!(rollback.into_inner(), "rollback");
505        assert_eq!(operation.into_inner(), "op");
506    }
507
508    #[test]
509    fn transaction_error_source_all_variants() {
510        let begin: TransactionError<std::io::Error> =
511            TransactionError::Begin(std::io::Error::other("src"));
512        let commit: TransactionError<std::io::Error> =
513            TransactionError::Commit(std::io::Error::other("src"));
514        let rollback: TransactionError<std::io::Error> =
515            TransactionError::Rollback(std::io::Error::other("src"));
516        let operation: TransactionError<std::io::Error> =
517            TransactionError::Operation(std::io::Error::other("src"));
518
519        assert!(begin.source().is_some());
520        assert!(commit.source().is_some());
521        assert!(rollback.source().is_some());
522        assert!(operation.source().is_some());
523    }
524
525    #[test]
526    fn transaction_error_display_all_variants() {
527        let begin: TransactionError<std::io::Error> =
528            TransactionError::Begin(std::io::Error::other("msg"));
529        let commit: TransactionError<std::io::Error> =
530            TransactionError::Commit(std::io::Error::other("msg"));
531        let rollback: TransactionError<std::io::Error> =
532            TransactionError::Rollback(std::io::Error::other("msg"));
533        let operation: TransactionError<std::io::Error> =
534            TransactionError::Operation(std::io::Error::other("msg"));
535
536        let begin_str = begin.to_string();
537        let commit_str = commit.to_string();
538        let rollback_str = rollback.to_string();
539        let operation_str = operation.to_string();
540
541        assert!(begin_str.contains("begin"));
542        assert!(commit_str.contains("commit"));
543        assert!(rollback_str.contains("rollback"));
544        assert!(operation_str.contains("operation"));
545    }
546
547    #[test]
548    fn transaction_error_is_all_variants() {
549        let begin: TransactionError<&str> = TransactionError::Begin("e");
550        let commit: TransactionError<&str> = TransactionError::Commit("e");
551        let rollback: TransactionError<&str> = TransactionError::Rollback("e");
552        let operation: TransactionError<&str> = TransactionError::Operation("e");
553
554        assert!(begin.is_begin());
555        assert!(commit.is_commit());
556        assert!(rollback.is_rollback());
557        assert!(operation.is_operation());
558
559        assert!(!begin.is_commit());
560        assert!(!begin.is_rollback());
561        assert!(!begin.is_operation());
562
563        assert!(!commit.is_begin());
564        assert!(!commit.is_rollback());
565        assert!(!commit.is_operation());
566
567        assert!(!rollback.is_begin());
568        assert!(!rollback.is_commit());
569        assert!(!rollback.is_operation());
570
571        assert!(!operation.is_begin());
572        assert!(!operation.is_commit());
573        assert!(!operation.is_rollback());
574    }
575
576    #[test]
577    fn transaction_builder_new_const() {
578        struct MockPool;
579        let pool = MockPool;
580        let tx = Transaction::new(&pool);
581        let _ = tx;
582    }
583
584    // Regression tests for `finalize_with_commit`.
585    //
586    // Prior to this fix, `Transaction::run` consumed `TransactionContext` and
587    // dropped it before commit was ever called, so successful runs silently
588    // rolled back. The fix is to keep ownership of `ctx` in `run` and call
589    // commit on Ok. The backend-agnostic decision lives in
590    // `finalize_with_commit`, which these tests cover end-to-end with a mock
591    // context and a tracking commit closure — no database required.
592
593    #[derive(Debug, PartialEq, Eq)]
594    struct MockCtx;
595
596    #[derive(Debug, PartialEq, Eq)]
597    struct CommitErr(&'static str);
598
599    #[derive(Debug, PartialEq, Eq)]
600    enum AppErr {
601        Closure(&'static str),
602        Commit(&'static str)
603    }
604
605    impl From<CommitErr> for AppErr {
606        fn from(e: CommitErr) -> Self {
607            Self::Commit(e.0)
608        }
609    }
610
611    #[tokio::test]
612    async fn finalize_commits_on_ok() {
613        let committed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
614        let flag = committed.clone();
615
616        let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
617            MockCtx,
618            Ok::<i32, AppErr>(42),
619            move |_ctx| {
620                let flag = flag.clone();
621                async move {
622                    flag.store(true, std::sync::atomic::Ordering::SeqCst);
623                    Ok::<(), CommitErr>(())
624                }
625            }
626        )
627        .await;
628
629        assert_eq!(result, Ok(42));
630        assert!(
631            committed.load(std::sync::atomic::Ordering::SeqCst),
632            "commit_fn must run on Ok"
633        );
634    }
635
636    #[tokio::test]
637    async fn finalize_skips_commit_on_err() {
638        let committed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
639        let flag = committed.clone();
640
641        let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
642            MockCtx,
643            Err::<i32, AppErr>(AppErr::Closure("nope")),
644            move |_ctx| {
645                let flag = flag.clone();
646                async move {
647                    flag.store(true, std::sync::atomic::Ordering::SeqCst);
648                    Ok::<(), CommitErr>(())
649                }
650            }
651        )
652        .await;
653
654        assert_eq!(result, Err(AppErr::Closure("nope")));
655        assert!(
656            !committed.load(std::sync::atomic::Ordering::SeqCst),
657            "commit_fn must NOT run on Err"
658        );
659    }
660
661    #[tokio::test]
662    async fn finalize_propagates_commit_error_on_ok() {
663        let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
664            MockCtx,
665            Ok::<i32, AppErr>(42),
666            |_ctx| async { Err::<(), CommitErr>(CommitErr("commit failed")) }
667        )
668        .await;
669
670        assert_eq!(result, Err(AppErr::Commit("commit failed")));
671    }
672
673    #[tokio::test]
674    async fn finalize_preserves_closure_value_on_ok() {
675        // Confirms the Ok payload survives the commit step (return type
676        // matches the closure's success type, not the commit_fn result).
677        let result: Result<String, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
678            MockCtx,
679            Ok::<String, AppErr>("payload".to_string()),
680            |_ctx| async { Ok::<(), CommitErr>(()) }
681        )
682        .await;
683
684        assert_eq!(result, Ok("payload".to_string()));
685    }
686
687    #[tokio::test]
688    async fn finalize_does_not_swallow_closure_error_when_commit_also_would_fail() {
689        // On Err, commit_fn is never called, so a faulty commit_fn cannot
690        // hide the closure's original error.
691        let result: Result<(), AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
692            MockCtx,
693            Err::<(), AppErr>(AppErr::Closure("original")),
694            |_ctx| async { Err::<(), CommitErr>(CommitErr("never reached")) }
695        )
696        .await;
697
698        assert_eq!(result, Err(AppErr::Closure("original")));
699    }
700}