Skip to main content

entity_core/
transaction.rs

1// SPDX-FileCopyrightText: 2025-2026 RAprogramm <andrey.rozanov.vl@gmail.com>
2// SPDX-License-Identifier: MIT
3
4//! Transaction support for entity-derive.
5//!
6//! This module provides type-safe transaction management with automatic
7//! commit/rollback semantics. It uses a fluent builder pattern for composing
8//! multiple entity operations into a single transaction.
9//!
10//! # Overview
11//!
12//! - [`Transaction`] — Entry point for creating transactions
13//! - [`TransactionContext`] — Holds active transaction, provides repo access
14//! - [`TransactionError`] — Error wrapper for transaction operations
15//!
16//! # Example
17//!
18//! ```rust,ignore
19//! use entity_derive::prelude::*;
20//!
21//! async fn transfer(pool: &PgPool, from: Uuid, to: Uuid, amount: i64) -> Result<(), AppError> {
22//!     Transaction::new(pool)
23//!         .with_accounts()
24//!         .with_transfers()
25//!         .run(|mut ctx| async move {
26//!             let from_acc = ctx.accounts().find_by_id(from).await?.ok_or(AppError::NotFound)?;
27//!
28//!             ctx.accounts().update(from, UpdateAccount {
29//!                 balance: Some(from_acc.balance - amount),
30//!                 ..Default::default()
31//!             }).await?;
32//!
33//!             ctx.transfers().create(CreateTransfer { from, to, amount }).await?;
34//!             Ok(())
35//!         })
36//!         .await
37//! }
38//! ```
39
40#[cfg(feature = "postgres")]
41use std::future::Future;
42use std::{error::Error as StdError, fmt};
43
44/// Transaction builder for composing multi-entity operations.
45///
46/// Use [`Transaction::new`] to create a builder, chain `.with_*()` methods
47/// to declare which entities you'll use, then call `.run()` to execute.
48///
49/// # Type Parameters
50///
51/// - `'p` — Pool lifetime
52/// - `DB` — Database pool type (e.g., `PgPool`)
53///
54/// # Example
55///
56/// ```rust,ignore
57/// Transaction::new(&pool)
58///     .with_users()
59///     .with_orders()
60///     .run(|mut ctx| async move {
61///         let user = ctx.users().find_by_id(id).await?;
62///         ctx.orders().create(order).await?;
63///         Ok(())
64///     })
65///     .await?;
66/// ```
67pub struct Transaction<'p, DB> {
68    pool: &'p DB
69}
70
71impl<'p, DB> Transaction<'p, DB> {
72    /// Create a new transaction builder.
73    ///
74    /// # Arguments
75    ///
76    /// * `pool` — Database connection pool
77    ///
78    /// # Example
79    ///
80    /// ```rust,ignore
81    /// let tx = Transaction::new(&pool);
82    /// ```
83    pub const fn new(pool: &'p DB) -> Self {
84        Self {
85            pool
86        }
87    }
88
89    /// Get reference to the underlying pool.
90    #[must_use]
91    pub const fn pool(&self) -> &'p DB {
92        self.pool
93    }
94}
95
96/// Active transaction context with repository access.
97///
98/// This struct holds the database transaction and provides access to
99/// entity repositories via extension traits generated by the macro.
100///
101/// # Automatic Rollback
102///
103/// If dropped without explicit commit, the transaction is automatically
104/// rolled back via the underlying database transaction's Drop impl.
105///
106/// # Accessing Repositories
107///
108/// Each entity with `#[entity(transactions)]` generates an extension trait
109/// that adds an accessor method:
110///
111/// ```rust,ignore
112/// // For entity BankAccount, use:
113/// ctx.bank_accounts().find_by_id(id).await?;
114/// ctx.bank_accounts().create(dto).await?;
115/// ctx.bank_accounts().update(id, dto).await?;
116/// ```
117#[cfg(feature = "postgres")]
118pub struct TransactionContext {
119    tx: sqlx::Transaction<'static, sqlx::Postgres>
120}
121
122#[cfg(feature = "postgres")]
123impl TransactionContext {
124    /// Create a new transaction context.
125    ///
126    /// # Arguments
127    ///
128    /// * `tx` — Active database transaction
129    #[doc(hidden)]
130    #[must_use]
131    pub const fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
132        Self {
133            tx
134        }
135    }
136
137    /// Get mutable reference to the underlying transaction.
138    ///
139    /// Use this for custom queries within the transaction or
140    /// for repository adapters to execute queries.
141    pub const fn transaction(&mut self) -> &mut sqlx::Transaction<'static, sqlx::Postgres> {
142        &mut self.tx
143    }
144
145    /// Commit the transaction.
146    ///
147    /// Consumes self and commits all changes.
148    ///
149    /// # Errors
150    ///
151    /// Propagates any `sqlx::Error` from the database transaction.
152    pub async fn commit(self) -> Result<(), sqlx::Error> {
153        self.tx.commit().await
154    }
155
156    /// Rollback the transaction.
157    ///
158    /// Consumes self and rolls back all changes.
159    ///
160    /// # Errors
161    ///
162    /// Propagates any `sqlx::Error` from the database transaction.
163    pub async fn rollback(self) -> Result<(), sqlx::Error> {
164        self.tx.rollback().await
165    }
166}
167
168/// Error type for transaction operations.
169///
170/// Wraps database errors and provides context about the transaction state.
171#[derive(Debug)]
172pub enum TransactionError<E> {
173    /// Failed to begin transaction.
174    Begin(E),
175
176    /// Failed to commit transaction.
177    Commit(E),
178
179    /// Failed to rollback transaction.
180    Rollback(E),
181
182    /// Operation within transaction failed.
183    Operation(E)
184}
185
186impl<E: fmt::Display> fmt::Display for TransactionError<E> {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        match self {
189            Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
190            Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
191            Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
192            Self::Operation(e) => write!(f, "transaction operation failed: {e}")
193        }
194    }
195}
196
197impl<E: StdError + 'static> StdError for TransactionError<E> {
198    fn source(&self) -> Option<&(dyn StdError + 'static)> {
199        match self {
200            Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
201        }
202    }
203}
204
205impl<E> TransactionError<E> {
206    /// Check if this is a begin error.
207    pub const fn is_begin(&self) -> bool {
208        matches!(self, Self::Begin(_))
209    }
210
211    /// Check if this is a commit error.
212    pub const fn is_commit(&self) -> bool {
213        matches!(self, Self::Commit(_))
214    }
215
216    /// Check if this is a rollback error.
217    pub const fn is_rollback(&self) -> bool {
218        matches!(self, Self::Rollback(_))
219    }
220
221    /// Check if this is an operation error.
222    pub const fn is_operation(&self) -> bool {
223        matches!(self, Self::Operation(_))
224    }
225
226    /// Get the inner error.
227    pub fn into_inner(self) -> E {
228        match self {
229            Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
230        }
231    }
232}
233
234#[cfg(feature = "postgres")]
235impl From<TransactionError<Self>> for sqlx::Error {
236    fn from(err: TransactionError<Self>) -> Self {
237        err.into_inner()
238    }
239}
240
241// PostgreSQL implementation
242#[cfg(feature = "postgres")]
243impl Transaction<'_, sqlx::PgPool> {
244    /// Execute a closure within a `PostgreSQL` transaction.
245    ///
246    /// Automatically commits on `Ok`, rolls back on `Err` or drop.
247    ///
248    /// # Type Parameters
249    ///
250    /// - `F` — Closure type
251    /// - `Fut` — Future returned by closure
252    /// - `T` — Success type
253    /// - `E` — Error type (must be convertible from `sqlx::Error`)
254    ///
255    /// # Example
256    ///
257    /// ```rust,ignore
258    /// Transaction::new(&pool)
259    ///     .with_users()
260    ///     .run(|mut ctx| async move {
261    ///         let user = ctx.users().create(dto).await?;
262    ///         Ok(user)
263    ///     })
264    ///     .await?;
265    /// ```
266    ///
267    /// # Errors
268    ///
269    /// Propagates any error from the closure or database transaction.
270    pub async fn run<F, Fut, T, E>(self, f: F) -> Result<T, E>
271    where
272        F: FnOnce(TransactionContext) -> Fut + Send,
273        Fut: Future<Output = Result<T, E>> + Send,
274        E: From<sqlx::Error>
275    {
276        let tx = self.pool.begin().await.map_err(E::from)?;
277        let ctx = TransactionContext::new(tx);
278
279        match f(ctx).await {
280            Ok(result) => Ok(result),
281            Err(e) => Err(e)
282        }
283    }
284
285    /// Execute a closure within a transaction with explicit commit.
286    ///
287    /// Unlike `run`, this method requires the closure to explicitly
288    /// commit the transaction by calling `ctx.commit()`.
289    ///
290    /// # Example
291    ///
292    /// ```rust,ignore
293    /// Transaction::new(&pool)
294    ///     .run_with_commit(|mut ctx| async move {
295    ///         let user = ctx.users().create(dto).await?;
296    ///         ctx.commit().await?;
297    ///         Ok(user)
298    ///     })
299    ///     .await?;
300    /// ```
301    ///
302    /// # Errors
303    ///
304    /// Propagates any error from the closure or database transaction.
305    pub async fn run_with_commit<F, Fut, T, E>(self, f: F) -> Result<T, E>
306    where
307        F: FnOnce(TransactionContext) -> Fut + Send,
308        Fut: Future<Output = Result<T, E>> + Send,
309        E: From<sqlx::Error>
310    {
311        let tx = self.pool.begin().await.map_err(E::from)?;
312        let ctx = TransactionContext::new(tx);
313        f(ctx).await
314    }
315}
316
317#[cfg(test)]
318#[allow(clippy::uninlined_format_args)]
319mod tests {
320    use std::error::Error;
321
322    use super::*;
323
324    #[test]
325    fn transaction_error_display_begin() {
326        let err: TransactionError<std::io::Error> =
327            TransactionError::Begin(std::io::Error::other("test"));
328        assert!(err.to_string().contains("begin"));
329        assert!(err.to_string().contains("test"));
330    }
331
332    #[test]
333    fn transaction_error_display_commit() {
334        let err: TransactionError<std::io::Error> =
335            TransactionError::Commit(std::io::Error::other("test"));
336        assert!(err.to_string().contains("commit"));
337    }
338
339    #[test]
340    fn transaction_error_display_rollback() {
341        let err: TransactionError<std::io::Error> =
342            TransactionError::Rollback(std::io::Error::other("test"));
343        assert!(err.to_string().contains("rollback"));
344    }
345
346    #[test]
347    fn transaction_error_display_operation() {
348        let err: TransactionError<std::io::Error> =
349            TransactionError::Operation(std::io::Error::other("test"));
350        assert!(err.to_string().contains("operation"));
351    }
352
353    #[test]
354    fn transaction_error_is_methods() {
355        let begin: TransactionError<&str> = TransactionError::Begin("e");
356        let commit: TransactionError<&str> = TransactionError::Commit("e");
357        let rollback: TransactionError<&str> = TransactionError::Rollback("e");
358        let operation: TransactionError<&str> = TransactionError::Operation("e");
359
360        assert!(begin.is_begin());
361        assert!(!begin.is_commit());
362        assert!(!begin.is_rollback());
363        assert!(!begin.is_operation());
364
365        assert!(!commit.is_begin());
366        assert!(commit.is_commit());
367        assert!(!commit.is_rollback());
368        assert!(!commit.is_operation());
369
370        assert!(!rollback.is_begin());
371        assert!(!rollback.is_commit());
372        assert!(rollback.is_rollback());
373        assert!(!rollback.is_operation());
374
375        assert!(!operation.is_begin());
376        assert!(!operation.is_commit());
377        assert!(!operation.is_rollback());
378        assert!(operation.is_operation());
379    }
380
381    #[test]
382    fn transaction_error_into_inner() {
383        let err: TransactionError<&str> = TransactionError::Operation("test");
384        assert_eq!(err.into_inner(), "test");
385    }
386
387    #[test]
388    fn transaction_error_into_inner_begin() {
389        let err: TransactionError<&str> = TransactionError::Begin("begin_err");
390        assert_eq!(err.into_inner(), "begin_err");
391    }
392
393    #[test]
394    fn transaction_error_into_inner_commit() {
395        let err: TransactionError<&str> = TransactionError::Commit("commit_err");
396        assert_eq!(err.into_inner(), "commit_err");
397    }
398
399    #[test]
400    fn transaction_error_into_inner_rollback() {
401        let err: TransactionError<&str> = TransactionError::Rollback("rollback_err");
402        assert_eq!(err.into_inner(), "rollback_err");
403    }
404
405    #[test]
406    fn transaction_error_source_begin() {
407        let err: TransactionError<std::io::Error> =
408            TransactionError::Begin(std::io::Error::other("src"));
409        assert!(err.source().is_some());
410    }
411
412    #[test]
413    fn transaction_error_source_commit() {
414        let err: TransactionError<std::io::Error> =
415            TransactionError::Commit(std::io::Error::other("src"));
416        assert!(err.source().is_some());
417    }
418
419    #[test]
420    fn transaction_error_source_rollback() {
421        let err: TransactionError<std::io::Error> =
422            TransactionError::Rollback(std::io::Error::other("src"));
423        assert!(err.source().is_some());
424    }
425
426    #[test]
427    fn transaction_error_source_operation() {
428        let err: TransactionError<std::io::Error> =
429            TransactionError::Operation(std::io::Error::other("src"));
430        assert!(err.source().is_some());
431    }
432
433    #[test]
434    fn transaction_builder_new() {
435        struct MockPool;
436        let pool = MockPool;
437        let tx = Transaction::new(&pool);
438        let _ = tx.pool();
439    }
440
441    #[test]
442    fn transaction_builder_pool_accessor() {
443        struct MockPool {
444            id: u32
445        }
446        let pool = MockPool {
447            id: 42
448        };
449        let tx = Transaction::new(&pool);
450        assert_eq!(tx.pool().id, 42);
451    }
452
453    #[test]
454    fn transaction_error_debug() {
455        let err: TransactionError<&str> = TransactionError::Begin("test");
456        let debug_str = format!("{:?}", err);
457        assert!(debug_str.contains("Begin"));
458        assert!(debug_str.contains("test"));
459    }
460
461    #[test]
462    fn transaction_error_into_inner_all_variants() {
463        let begin: TransactionError<String> = TransactionError::Begin("begin".to_string());
464        let commit: TransactionError<String> = TransactionError::Commit("commit".to_string());
465        let rollback: TransactionError<String> =
466            TransactionError::Rollback("rollback".to_string());
467        let operation: TransactionError<String> = TransactionError::Operation("op".to_string());
468
469        assert_eq!(begin.into_inner(), "begin");
470        assert_eq!(commit.into_inner(), "commit");
471        assert_eq!(rollback.into_inner(), "rollback");
472        assert_eq!(operation.into_inner(), "op");
473    }
474
475    #[test]
476    fn transaction_error_source_all_variants() {
477        let begin: TransactionError<std::io::Error> =
478            TransactionError::Begin(std::io::Error::other("src"));
479        let commit: TransactionError<std::io::Error> =
480            TransactionError::Commit(std::io::Error::other("src"));
481        let rollback: TransactionError<std::io::Error> =
482            TransactionError::Rollback(std::io::Error::other("src"));
483        let operation: TransactionError<std::io::Error> =
484            TransactionError::Operation(std::io::Error::other("src"));
485
486        assert!(begin.source().is_some());
487        assert!(commit.source().is_some());
488        assert!(rollback.source().is_some());
489        assert!(operation.source().is_some());
490    }
491
492    #[test]
493    fn transaction_error_display_all_variants() {
494        let begin: TransactionError<std::io::Error> =
495            TransactionError::Begin(std::io::Error::other("msg"));
496        let commit: TransactionError<std::io::Error> =
497            TransactionError::Commit(std::io::Error::other("msg"));
498        let rollback: TransactionError<std::io::Error> =
499            TransactionError::Rollback(std::io::Error::other("msg"));
500        let operation: TransactionError<std::io::Error> =
501            TransactionError::Operation(std::io::Error::other("msg"));
502
503        let begin_str = begin.to_string();
504        let commit_str = commit.to_string();
505        let rollback_str = rollback.to_string();
506        let operation_str = operation.to_string();
507
508        assert!(begin_str.contains("begin"));
509        assert!(commit_str.contains("commit"));
510        assert!(rollback_str.contains("rollback"));
511        assert!(operation_str.contains("operation"));
512    }
513
514    #[test]
515    fn transaction_error_is_all_variants() {
516        let begin: TransactionError<&str> = TransactionError::Begin("e");
517        let commit: TransactionError<&str> = TransactionError::Commit("e");
518        let rollback: TransactionError<&str> = TransactionError::Rollback("e");
519        let operation: TransactionError<&str> = TransactionError::Operation("e");
520
521        assert!(begin.is_begin());
522        assert!(commit.is_commit());
523        assert!(rollback.is_rollback());
524        assert!(operation.is_operation());
525
526        assert!(!begin.is_commit());
527        assert!(!begin.is_rollback());
528        assert!(!begin.is_operation());
529
530        assert!(!commit.is_begin());
531        assert!(!commit.is_rollback());
532        assert!(!commit.is_operation());
533
534        assert!(!rollback.is_begin());
535        assert!(!rollback.is_commit());
536        assert!(!rollback.is_operation());
537
538        assert!(!operation.is_begin());
539        assert!(!operation.is_commit());
540        assert!(!operation.is_rollback());
541    }
542
543    #[test]
544    fn transaction_builder_new_const() {
545        struct MockPool;
546        let pool = MockPool;
547        let tx = Transaction::new(&pool);
548        let _ = tx;
549    }
550}