Skip to main content

entity_core/
transaction.rs

1// SPDX-FileCopyrightText: 2025-2026 RAprogramm <andrey.rozanov.vl@gmail.com>
2// SPDX-License-Identifier: MIT
3
4//! Transaction support for entity-derive.
5//!
6//! This module provides type-safe transaction management with automatic
7//! commit/rollback semantics. It uses a fluent builder pattern for composing
8//! multiple entity operations into a single transaction.
9//!
10//! # Overview
11//!
12//! - [`Transaction`] — Entry point for creating transactions
13//! - [`TransactionContext`] — Holds active transaction, provides repo access
14//! - [`TransactionError`] — Error wrapper for transaction operations
15//!
16//! # Example
17//!
18//! ```rust,ignore
19//! use entity_derive::prelude::*;
20//!
21//! async fn transfer(pool: &PgPool, from: Uuid, to: Uuid, amount: i64) -> Result<(), AppError> {
22//!     Transaction::new(pool)
23//!         .with_accounts()
24//!         .with_transfers()
25//!         .run(async |ctx| {
26//!             let from_acc = ctx.accounts().find_by_id(from).await?.ok_or(AppError::NotFound)?;
27//!
28//!             ctx.accounts().update(from, UpdateAccount {
29//!                 balance: Some(from_acc.balance - amount),
30//!                 ..Default::default()
31//!             }).await?;
32//!
33//!             ctx.transfers().create(CreateTransfer { from, to, amount }).await?;
34//!             Ok(())
35//!         })
36//!         .await
37//! }
38//! ```
39
40#[cfg(feature = "postgres")]
41use std::future::Future;
42use std::{error::Error as StdError, fmt};
43
44/// Transaction builder for composing multi-entity operations.
45///
46/// Use [`Transaction::new`] to create a builder, chain `.with_*()` methods
47/// to declare which entities you'll use, then call `.run()` to execute.
48///
49/// # Type Parameters
50///
51/// - `'p` — Pool lifetime
52/// - `DB` — Database pool type (e.g., `PgPool`)
53///
54/// # Example
55///
56/// ```rust,ignore
57/// Transaction::new(&pool)
58///     .with_users()
59///     .with_orders()
60///     .run(async |ctx| {
61///         let user = ctx.users().find_by_id(id).await?;
62///         ctx.orders().create(order).await?;
63///         Ok(())
64///     })
65///     .await?;
66/// ```
67pub struct Transaction<'p, DB> {
68    pool: &'p DB
69}
70
71impl<'p, DB> Transaction<'p, DB> {
72    /// Create a new transaction builder.
73    ///
74    /// # Arguments
75    ///
76    /// * `pool` — Database connection pool
77    ///
78    /// # Example
79    ///
80    /// ```rust,ignore
81    /// let tx = Transaction::new(&pool);
82    /// ```
83    pub const fn new(pool: &'p DB) -> Self {
84        Self {
85            pool
86        }
87    }
88
89    /// Get reference to the underlying pool.
90    #[must_use]
91    pub const fn pool(&self) -> &'p DB {
92        self.pool
93    }
94}
95
96/// Active transaction context with repository access.
97///
98/// This struct holds the database transaction and provides access to
99/// entity repositories via extension traits generated by the macro.
100///
101/// # Automatic Rollback
102///
103/// If dropped without explicit commit, the transaction is automatically
104/// rolled back via the underlying database transaction's Drop impl.
105///
106/// # Accessing Repositories
107///
108/// Each entity with `#[entity(transactions)]` generates an extension trait
109/// that adds an accessor method:
110///
111/// ```rust,ignore
112/// // For entity BankAccount, use:
113/// ctx.bank_accounts().find_by_id(id).await?;
114/// ctx.bank_accounts().create(dto).await?;
115/// ctx.bank_accounts().update(id, dto).await?;
116/// ```
117#[cfg(feature = "postgres")]
118pub struct TransactionContext {
119    tx: sqlx::Transaction<'static, sqlx::Postgres>
120}
121
122#[cfg(feature = "postgres")]
123impl TransactionContext {
124    /// Create a new transaction context.
125    ///
126    /// # Arguments
127    ///
128    /// * `tx` — Active database transaction
129    #[doc(hidden)]
130    #[must_use]
131    pub const fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
132        Self {
133            tx
134        }
135    }
136
137    /// Get mutable reference to the underlying transaction.
138    ///
139    /// Use this for custom queries within the transaction or
140    /// for repository adapters to execute queries.
141    pub const fn transaction(&mut self) -> &mut sqlx::Transaction<'static, sqlx::Postgres> {
142        &mut self.tx
143    }
144
145    /// Commit the transaction.
146    ///
147    /// Consumes self and commits all changes.
148    ///
149    /// # Errors
150    ///
151    /// Propagates any `sqlx::Error` from the database transaction.
152    pub async fn commit(self) -> Result<(), sqlx::Error> {
153        self.tx.commit().await
154    }
155
156    /// Rollback the transaction.
157    ///
158    /// Consumes self and rolls back all changes.
159    ///
160    /// # Errors
161    ///
162    /// Propagates any `sqlx::Error` from the database transaction.
163    pub async fn rollback(self) -> Result<(), sqlx::Error> {
164        self.tx.rollback().await
165    }
166}
167
168/// Error type for transaction operations.
169///
170/// Wraps database errors and provides context about the transaction state.
171#[derive(Debug)]
172pub enum TransactionError<E> {
173    /// Failed to begin transaction.
174    Begin(E),
175
176    /// Failed to commit transaction.
177    Commit(E),
178
179    /// Failed to rollback transaction.
180    Rollback(E),
181
182    /// Operation within transaction failed.
183    Operation(E)
184}
185
186impl<E: fmt::Display> fmt::Display for TransactionError<E> {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        match self {
189            Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
190            Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
191            Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
192            Self::Operation(e) => write!(f, "transaction operation failed: {e}")
193        }
194    }
195}
196
197impl<E: StdError + 'static> StdError for TransactionError<E> {
198    fn source(&self) -> Option<&(dyn StdError + 'static)> {
199        match self {
200            Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
201        }
202    }
203}
204
205impl<E> TransactionError<E> {
206    /// Check if this is a begin error.
207    pub const fn is_begin(&self) -> bool {
208        matches!(self, Self::Begin(_))
209    }
210
211    /// Check if this is a commit error.
212    pub const fn is_commit(&self) -> bool {
213        matches!(self, Self::Commit(_))
214    }
215
216    /// Check if this is a rollback error.
217    pub const fn is_rollback(&self) -> bool {
218        matches!(self, Self::Rollback(_))
219    }
220
221    /// Check if this is an operation error.
222    pub const fn is_operation(&self) -> bool {
223        matches!(self, Self::Operation(_))
224    }
225
226    /// Get the inner error.
227    pub fn into_inner(self) -> E {
228        match self {
229            Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
230        }
231    }
232}
233
234#[cfg(feature = "postgres")]
235impl From<TransactionError<Self>> for sqlx::Error {
236    fn from(err: TransactionError<Self>) -> Self {
237        err.into_inner()
238    }
239}
240
241// PostgreSQL implementation
242#[cfg(feature = "postgres")]
243impl Transaction<'_, sqlx::PgPool> {
244    /// Execute a closure within a `PostgreSQL` transaction.
245    ///
246    /// Commits the transaction explicitly when the closure returns `Ok`.
247    /// On `Err`, the transaction context is dropped and `sqlx` rolls back
248    /// automatically via its `Drop` implementation.
249    ///
250    /// The closure receives `&mut TransactionContext` (not by value) so that
251    /// `run` retains ownership and can invoke `commit().await` on success.
252    ///
253    /// # Type Parameters
254    ///
255    /// - `F` — Async closure
256    /// - `T` — Success type
257    /// - `E` — Error type (must be convertible from `sqlx::Error`)
258    ///
259    /// # Example
260    ///
261    /// ```rust,ignore
262    /// Transaction::new(&pool)
263    ///     .with_users()
264    ///     .run(async |ctx| {
265    ///         let user = ctx.users().create(dto).await?;
266    ///         Ok(user)
267    ///     })
268    ///     .await?;
269    /// ```
270    ///
271    /// # Errors
272    ///
273    /// Propagates any error from the closure, from `begin`, or from `commit`.
274    pub async fn run<F, T, E>(self, f: F) -> Result<T, E>
275    where
276        F: AsyncFnOnce(&mut TransactionContext) -> Result<T, E>,
277        E: From<sqlx::Error>
278    {
279        let tx = self.pool.begin().await.map_err(E::from)?;
280        let mut ctx = TransactionContext::new(tx);
281
282        match f(&mut ctx).await {
283            Ok(result) => {
284                ctx.commit().await.map_err(E::from)?;
285                Ok(result)
286            }
287            Err(e) => Err(e)
288        }
289    }
290
291    /// Execute a closure within a transaction with explicit commit.
292    ///
293    /// Unlike [`run`](Self::run), this method passes `TransactionContext` by
294    /// value so the closure can call `ctx.commit().await` (or
295    /// `ctx.rollback().await`) itself. If the closure returns without
296    /// committing, the transaction is rolled back when `ctx` is dropped.
297    ///
298    /// Use this when you need conditional commit logic; otherwise prefer `run`.
299    ///
300    /// # Example
301    ///
302    /// ```rust,ignore
303    /// Transaction::new(&pool)
304    ///     .run_with_commit(|mut ctx| async move {
305    ///         let user = ctx.users().create(dto).await?;
306    ///         ctx.commit().await?;
307    ///         Ok(user)
308    ///     })
309    ///     .await?;
310    /// ```
311    ///
312    /// # Errors
313    ///
314    /// Propagates any error from the closure or database transaction.
315    pub async fn run_with_commit<F, Fut, T, E>(self, f: F) -> Result<T, E>
316    where
317        F: FnOnce(TransactionContext) -> Fut + Send,
318        Fut: Future<Output = Result<T, E>> + Send,
319        E: From<sqlx::Error>
320    {
321        let tx = self.pool.begin().await.map_err(E::from)?;
322        let ctx = TransactionContext::new(tx);
323        f(ctx).await
324    }
325}
326
327#[cfg(test)]
328#[allow(clippy::uninlined_format_args)]
329mod tests {
330    use std::error::Error;
331
332    use super::*;
333
334    #[test]
335    fn transaction_error_display_begin() {
336        let err: TransactionError<std::io::Error> =
337            TransactionError::Begin(std::io::Error::other("test"));
338        assert!(err.to_string().contains("begin"));
339        assert!(err.to_string().contains("test"));
340    }
341
342    #[test]
343    fn transaction_error_display_commit() {
344        let err: TransactionError<std::io::Error> =
345            TransactionError::Commit(std::io::Error::other("test"));
346        assert!(err.to_string().contains("commit"));
347    }
348
349    #[test]
350    fn transaction_error_display_rollback() {
351        let err: TransactionError<std::io::Error> =
352            TransactionError::Rollback(std::io::Error::other("test"));
353        assert!(err.to_string().contains("rollback"));
354    }
355
356    #[test]
357    fn transaction_error_display_operation() {
358        let err: TransactionError<std::io::Error> =
359            TransactionError::Operation(std::io::Error::other("test"));
360        assert!(err.to_string().contains("operation"));
361    }
362
363    #[test]
364    fn transaction_error_is_methods() {
365        let begin: TransactionError<&str> = TransactionError::Begin("e");
366        let commit: TransactionError<&str> = TransactionError::Commit("e");
367        let rollback: TransactionError<&str> = TransactionError::Rollback("e");
368        let operation: TransactionError<&str> = TransactionError::Operation("e");
369
370        assert!(begin.is_begin());
371        assert!(!begin.is_commit());
372        assert!(!begin.is_rollback());
373        assert!(!begin.is_operation());
374
375        assert!(!commit.is_begin());
376        assert!(commit.is_commit());
377        assert!(!commit.is_rollback());
378        assert!(!commit.is_operation());
379
380        assert!(!rollback.is_begin());
381        assert!(!rollback.is_commit());
382        assert!(rollback.is_rollback());
383        assert!(!rollback.is_operation());
384
385        assert!(!operation.is_begin());
386        assert!(!operation.is_commit());
387        assert!(!operation.is_rollback());
388        assert!(operation.is_operation());
389    }
390
391    #[test]
392    fn transaction_error_into_inner() {
393        let err: TransactionError<&str> = TransactionError::Operation("test");
394        assert_eq!(err.into_inner(), "test");
395    }
396
397    #[test]
398    fn transaction_error_into_inner_begin() {
399        let err: TransactionError<&str> = TransactionError::Begin("begin_err");
400        assert_eq!(err.into_inner(), "begin_err");
401    }
402
403    #[test]
404    fn transaction_error_into_inner_commit() {
405        let err: TransactionError<&str> = TransactionError::Commit("commit_err");
406        assert_eq!(err.into_inner(), "commit_err");
407    }
408
409    #[test]
410    fn transaction_error_into_inner_rollback() {
411        let err: TransactionError<&str> = TransactionError::Rollback("rollback_err");
412        assert_eq!(err.into_inner(), "rollback_err");
413    }
414
415    #[test]
416    fn transaction_error_source_begin() {
417        let err: TransactionError<std::io::Error> =
418            TransactionError::Begin(std::io::Error::other("src"));
419        assert!(err.source().is_some());
420    }
421
422    #[test]
423    fn transaction_error_source_commit() {
424        let err: TransactionError<std::io::Error> =
425            TransactionError::Commit(std::io::Error::other("src"));
426        assert!(err.source().is_some());
427    }
428
429    #[test]
430    fn transaction_error_source_rollback() {
431        let err: TransactionError<std::io::Error> =
432            TransactionError::Rollback(std::io::Error::other("src"));
433        assert!(err.source().is_some());
434    }
435
436    #[test]
437    fn transaction_error_source_operation() {
438        let err: TransactionError<std::io::Error> =
439            TransactionError::Operation(std::io::Error::other("src"));
440        assert!(err.source().is_some());
441    }
442
443    #[test]
444    fn transaction_builder_new() {
445        struct MockPool;
446        let pool = MockPool;
447        let tx = Transaction::new(&pool);
448        let _ = tx.pool();
449    }
450
451    #[test]
452    fn transaction_builder_pool_accessor() {
453        struct MockPool {
454            id: u32
455        }
456        let pool = MockPool {
457            id: 42
458        };
459        let tx = Transaction::new(&pool);
460        assert_eq!(tx.pool().id, 42);
461    }
462
463    #[test]
464    fn transaction_error_debug() {
465        let err: TransactionError<&str> = TransactionError::Begin("test");
466        let debug_str = format!("{:?}", err);
467        assert!(debug_str.contains("Begin"));
468        assert!(debug_str.contains("test"));
469    }
470
471    #[test]
472    fn transaction_error_into_inner_all_variants() {
473        let begin: TransactionError<String> = TransactionError::Begin("begin".to_string());
474        let commit: TransactionError<String> = TransactionError::Commit("commit".to_string());
475        let rollback: TransactionError<String> =
476            TransactionError::Rollback("rollback".to_string());
477        let operation: TransactionError<String> = TransactionError::Operation("op".to_string());
478
479        assert_eq!(begin.into_inner(), "begin");
480        assert_eq!(commit.into_inner(), "commit");
481        assert_eq!(rollback.into_inner(), "rollback");
482        assert_eq!(operation.into_inner(), "op");
483    }
484
485    #[test]
486    fn transaction_error_source_all_variants() {
487        let begin: TransactionError<std::io::Error> =
488            TransactionError::Begin(std::io::Error::other("src"));
489        let commit: TransactionError<std::io::Error> =
490            TransactionError::Commit(std::io::Error::other("src"));
491        let rollback: TransactionError<std::io::Error> =
492            TransactionError::Rollback(std::io::Error::other("src"));
493        let operation: TransactionError<std::io::Error> =
494            TransactionError::Operation(std::io::Error::other("src"));
495
496        assert!(begin.source().is_some());
497        assert!(commit.source().is_some());
498        assert!(rollback.source().is_some());
499        assert!(operation.source().is_some());
500    }
501
502    #[test]
503    fn transaction_error_display_all_variants() {
504        let begin: TransactionError<std::io::Error> =
505            TransactionError::Begin(std::io::Error::other("msg"));
506        let commit: TransactionError<std::io::Error> =
507            TransactionError::Commit(std::io::Error::other("msg"));
508        let rollback: TransactionError<std::io::Error> =
509            TransactionError::Rollback(std::io::Error::other("msg"));
510        let operation: TransactionError<std::io::Error> =
511            TransactionError::Operation(std::io::Error::other("msg"));
512
513        let begin_str = begin.to_string();
514        let commit_str = commit.to_string();
515        let rollback_str = rollback.to_string();
516        let operation_str = operation.to_string();
517
518        assert!(begin_str.contains("begin"));
519        assert!(commit_str.contains("commit"));
520        assert!(rollback_str.contains("rollback"));
521        assert!(operation_str.contains("operation"));
522    }
523
524    #[test]
525    fn transaction_error_is_all_variants() {
526        let begin: TransactionError<&str> = TransactionError::Begin("e");
527        let commit: TransactionError<&str> = TransactionError::Commit("e");
528        let rollback: TransactionError<&str> = TransactionError::Rollback("e");
529        let operation: TransactionError<&str> = TransactionError::Operation("e");
530
531        assert!(begin.is_begin());
532        assert!(commit.is_commit());
533        assert!(rollback.is_rollback());
534        assert!(operation.is_operation());
535
536        assert!(!begin.is_commit());
537        assert!(!begin.is_rollback());
538        assert!(!begin.is_operation());
539
540        assert!(!commit.is_begin());
541        assert!(!commit.is_rollback());
542        assert!(!commit.is_operation());
543
544        assert!(!rollback.is_begin());
545        assert!(!rollback.is_commit());
546        assert!(!rollback.is_operation());
547
548        assert!(!operation.is_begin());
549        assert!(!operation.is_commit());
550        assert!(!operation.is_rollback());
551    }
552
553    #[test]
554    fn transaction_builder_new_const() {
555        struct MockPool;
556        let pool = MockPool;
557        let tx = Transaction::new(&pool);
558        let _ = tx;
559    }
560
561    // Compile-time regression guard for the `run()` signature.
562    //
563    // Earlier versions accepted `FnOnce(TransactionContext) -> Fut` and
564    // dropped the context on Ok, which silently rolled back the transaction.
565    // The fix is to take `&mut TransactionContext` so `run` keeps ownership
566    // and can call `commit().await` explicitly on Ok.
567    //
568    // This test does NOT execute (no real pool) — it only checks that the
569    // signature accepts an async closure receiving `&mut TransactionContext`.
570    #[cfg(feature = "postgres")]
571    #[allow(dead_code, clippy::no_effect_underscore_binding)]
572    fn _run_signature_accepts_mut_ref(pool: &sqlx::PgPool) {
573        let _fut = Transaction::new(pool)
574            .run(async |_ctx: &mut TransactionContext| Ok::<(), sqlx::Error>(()));
575    }
576}