entity_core/
transaction.rs

1// SPDX-FileCopyrightText: 2025-2026 RAprogramm <andrey.rozanov.vl@gmail.com>
2// SPDX-License-Identifier: MIT
3
4//! Transaction support for entity-derive.
5//!
6//! This module provides type-safe transaction management with automatic
7//! commit/rollback semantics. It uses a fluent builder pattern for composing
8//! multiple entity operations into a single transaction.
9//!
10//! # Overview
11//!
12//! - [`Transaction`] — Entry point for creating transactions
13//! - [`TransactionContext`] — Holds active transaction, provides repo access
14//! - [`TransactionError`] — Error wrapper for transaction operations
15//!
16//! # Example
17//!
18//! ```rust,ignore
19//! use entity_derive::prelude::*;
20//!
21//! async fn transfer(pool: &PgPool, from: Uuid, to: Uuid, amount: i64) -> Result<(), AppError> {
22//!     Transaction::new(pool)
23//!         .with_accounts()
24//!         .with_transfers()
25//!         .run(|mut ctx| async move {
26//!             let from_acc = ctx.accounts().find_by_id(from).await?.ok_or(AppError::NotFound)?;
27//!
28//!             ctx.accounts().update(from, UpdateAccount {
29//!                 balance: Some(from_acc.balance - amount),
30//!                 ..Default::default()
31//!             }).await?;
32//!
33//!             ctx.transfers().create(CreateTransfer { from, to, amount }).await?;
34//!             Ok(())
35//!         })
36//!         .await
37//! }
38//! ```
39
40#[cfg(feature = "postgres")]
41use std::future::Future;
42use std::{error::Error as StdError, fmt};
43
44/// Transaction builder for composing multi-entity operations.
45///
46/// Use [`Transaction::new`] to create a builder, chain `.with_*()` methods
47/// to declare which entities you'll use, then call `.run()` to execute.
48///
49/// # Type Parameters
50///
51/// - `'p` — Pool lifetime
52/// - `DB` — Database pool type (e.g., `PgPool`)
53///
54/// # Example
55///
56/// ```rust,ignore
57/// Transaction::new(&pool)
58///     .with_users()
59///     .with_orders()
60///     .run(|mut ctx| async move {
61///         let user = ctx.users().find_by_id(id).await?;
62///         ctx.orders().create(order).await?;
63///         Ok(())
64///     })
65///     .await?;
66/// ```
67pub struct Transaction<'p, DB> {
68    pool: &'p DB
69}
70
71impl<'p, DB> Transaction<'p, DB> {
72    /// Create a new transaction builder.
73    ///
74    /// # Arguments
75    ///
76    /// * `pool` — Database connection pool
77    ///
78    /// # Example
79    ///
80    /// ```rust,ignore
81    /// let tx = Transaction::new(&pool);
82    /// ```
83    pub const fn new(pool: &'p DB) -> Self {
84        Self {
85            pool
86        }
87    }
88
89    /// Get reference to the underlying pool.
90    pub const fn pool(&self) -> &'p DB {
91        self.pool
92    }
93}
94
95/// Active transaction context with repository access.
96///
97/// This struct holds the database transaction and provides access to
98/// entity repositories via extension traits generated by the macro.
99///
100/// # Automatic Rollback
101///
102/// If dropped without explicit commit, the transaction is automatically
103/// rolled back via the underlying database transaction's Drop impl.
104///
105/// # Accessing Repositories
106///
107/// Each entity with `#[entity(transactions)]` generates an extension trait
108/// that adds an accessor method:
109///
110/// ```rust,ignore
111/// // For entity BankAccount, use:
112/// ctx.bank_accounts().find_by_id(id).await?;
113/// ctx.bank_accounts().create(dto).await?;
114/// ctx.bank_accounts().update(id, dto).await?;
115/// ```
116#[cfg(feature = "postgres")]
117pub struct TransactionContext {
118    tx: sqlx::Transaction<'static, sqlx::Postgres>
119}
120
121#[cfg(feature = "postgres")]
122impl TransactionContext {
123    /// Create a new transaction context.
124    ///
125    /// # Arguments
126    ///
127    /// * `tx` — Active database transaction
128    #[doc(hidden)]
129    pub fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
130        Self {
131            tx
132        }
133    }
134
135    /// Get mutable reference to the underlying transaction.
136    ///
137    /// Use this for custom queries within the transaction or
138    /// for repository adapters to execute queries.
139    pub fn transaction(&mut self) -> &mut sqlx::Transaction<'static, sqlx::Postgres> {
140        &mut self.tx
141    }
142
143    /// Commit the transaction.
144    ///
145    /// Consumes self and commits all changes.
146    pub async fn commit(self) -> Result<(), sqlx::Error> {
147        self.tx.commit().await
148    }
149
150    /// Rollback the transaction.
151    ///
152    /// Consumes self and rolls back all changes.
153    pub async fn rollback(self) -> Result<(), sqlx::Error> {
154        self.tx.rollback().await
155    }
156}
157
158/// Error type for transaction operations.
159///
160/// Wraps database errors and provides context about the transaction state.
161#[derive(Debug)]
162pub enum TransactionError<E> {
163    /// Failed to begin transaction.
164    Begin(E),
165
166    /// Failed to commit transaction.
167    Commit(E),
168
169    /// Failed to rollback transaction.
170    Rollback(E),
171
172    /// Operation within transaction failed.
173    Operation(E)
174}
175
176impl<E: fmt::Display> fmt::Display for TransactionError<E> {
177    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178        match self {
179            Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
180            Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
181            Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
182            Self::Operation(e) => write!(f, "transaction operation failed: {e}")
183        }
184    }
185}
186
187impl<E: StdError + 'static> StdError for TransactionError<E> {
188    fn source(&self) -> Option<&(dyn StdError + 'static)> {
189        match self {
190            Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
191        }
192    }
193}
194
195impl<E> TransactionError<E> {
196    /// Check if this is a begin error.
197    pub const fn is_begin(&self) -> bool {
198        matches!(self, Self::Begin(_))
199    }
200
201    /// Check if this is a commit error.
202    pub const fn is_commit(&self) -> bool {
203        matches!(self, Self::Commit(_))
204    }
205
206    /// Check if this is a rollback error.
207    pub const fn is_rollback(&self) -> bool {
208        matches!(self, Self::Rollback(_))
209    }
210
211    /// Check if this is an operation error.
212    pub const fn is_operation(&self) -> bool {
213        matches!(self, Self::Operation(_))
214    }
215
216    /// Get the inner error.
217    pub fn into_inner(self) -> E {
218        match self {
219            Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
220        }
221    }
222}
223
224#[cfg(feature = "postgres")]
225impl From<TransactionError<sqlx::Error>> for sqlx::Error {
226    fn from(err: TransactionError<sqlx::Error>) -> Self {
227        err.into_inner()
228    }
229}
230
231// PostgreSQL implementation
232#[cfg(feature = "postgres")]
233impl<'p> Transaction<'p, sqlx::PgPool> {
234    /// Execute a closure within a PostgreSQL transaction.
235    ///
236    /// Automatically commits on `Ok`, rolls back on `Err` or drop.
237    ///
238    /// # Type Parameters
239    ///
240    /// - `F` — Closure type
241    /// - `Fut` — Future returned by closure
242    /// - `T` — Success type
243    /// - `E` — Error type (must be convertible from sqlx::Error)
244    ///
245    /// # Example
246    ///
247    /// ```rust,ignore
248    /// Transaction::new(&pool)
249    ///     .with_users()
250    ///     .run(|mut ctx| async move {
251    ///         let user = ctx.users().create(dto).await?;
252    ///         Ok(user)
253    ///     })
254    ///     .await?;
255    /// ```
256    pub async fn run<F, Fut, T, E>(self, f: F) -> Result<T, E>
257    where
258        F: FnOnce(TransactionContext) -> Fut + Send,
259        Fut: Future<Output = Result<T, E>> + Send,
260        E: From<sqlx::Error>
261    {
262        let tx = self.pool.begin().await.map_err(E::from)?;
263        let ctx = TransactionContext::new(tx);
264
265        match f(ctx).await {
266            Ok(result) => Ok(result),
267            Err(e) => Err(e)
268        }
269    }
270
271    /// Execute a closure within a transaction with explicit commit.
272    ///
273    /// Unlike `run`, this method requires the closure to explicitly
274    /// commit the transaction by calling `ctx.commit()`.
275    ///
276    /// # Example
277    ///
278    /// ```rust,ignore
279    /// Transaction::new(&pool)
280    ///     .run_with_commit(|mut ctx| async move {
281    ///         let user = ctx.users().create(dto).await?;
282    ///         ctx.commit().await?;
283    ///         Ok(user)
284    ///     })
285    ///     .await?;
286    /// ```
287    pub async fn run_with_commit<F, Fut, T, E>(self, f: F) -> Result<T, E>
288    where
289        F: FnOnce(TransactionContext) -> Fut + Send,
290        Fut: Future<Output = Result<T, E>> + Send,
291        E: From<sqlx::Error>
292    {
293        let tx = self.pool.begin().await.map_err(E::from)?;
294        let ctx = TransactionContext::new(tx);
295        f(ctx).await
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use std::error::Error;
302
303    use super::*;
304
305    #[test]
306    fn transaction_error_display_begin() {
307        let err: TransactionError<std::io::Error> =
308            TransactionError::Begin(std::io::Error::other("test"));
309        assert!(err.to_string().contains("begin"));
310        assert!(err.to_string().contains("test"));
311    }
312
313    #[test]
314    fn transaction_error_display_commit() {
315        let err: TransactionError<std::io::Error> =
316            TransactionError::Commit(std::io::Error::other("test"));
317        assert!(err.to_string().contains("commit"));
318    }
319
320    #[test]
321    fn transaction_error_display_rollback() {
322        let err: TransactionError<std::io::Error> =
323            TransactionError::Rollback(std::io::Error::other("test"));
324        assert!(err.to_string().contains("rollback"));
325    }
326
327    #[test]
328    fn transaction_error_display_operation() {
329        let err: TransactionError<std::io::Error> =
330            TransactionError::Operation(std::io::Error::other("test"));
331        assert!(err.to_string().contains("operation"));
332    }
333
334    #[test]
335    fn transaction_error_is_methods() {
336        let begin: TransactionError<&str> = TransactionError::Begin("e");
337        let commit: TransactionError<&str> = TransactionError::Commit("e");
338        let rollback: TransactionError<&str> = TransactionError::Rollback("e");
339        let operation: TransactionError<&str> = TransactionError::Operation("e");
340
341        assert!(begin.is_begin());
342        assert!(!begin.is_commit());
343        assert!(!begin.is_rollback());
344        assert!(!begin.is_operation());
345
346        assert!(!commit.is_begin());
347        assert!(commit.is_commit());
348        assert!(!commit.is_rollback());
349        assert!(!commit.is_operation());
350
351        assert!(!rollback.is_begin());
352        assert!(!rollback.is_commit());
353        assert!(rollback.is_rollback());
354        assert!(!rollback.is_operation());
355
356        assert!(!operation.is_begin());
357        assert!(!operation.is_commit());
358        assert!(!operation.is_rollback());
359        assert!(operation.is_operation());
360    }
361
362    #[test]
363    fn transaction_error_into_inner() {
364        let err: TransactionError<&str> = TransactionError::Operation("test");
365        assert_eq!(err.into_inner(), "test");
366    }
367
368    #[test]
369    fn transaction_error_into_inner_begin() {
370        let err: TransactionError<&str> = TransactionError::Begin("begin_err");
371        assert_eq!(err.into_inner(), "begin_err");
372    }
373
374    #[test]
375    fn transaction_error_into_inner_commit() {
376        let err: TransactionError<&str> = TransactionError::Commit("commit_err");
377        assert_eq!(err.into_inner(), "commit_err");
378    }
379
380    #[test]
381    fn transaction_error_into_inner_rollback() {
382        let err: TransactionError<&str> = TransactionError::Rollback("rollback_err");
383        assert_eq!(err.into_inner(), "rollback_err");
384    }
385
386    #[test]
387    fn transaction_error_source_begin() {
388        let err: TransactionError<std::io::Error> =
389            TransactionError::Begin(std::io::Error::other("src"));
390        assert!(err.source().is_some());
391    }
392
393    #[test]
394    fn transaction_error_source_commit() {
395        let err: TransactionError<std::io::Error> =
396            TransactionError::Commit(std::io::Error::other("src"));
397        assert!(err.source().is_some());
398    }
399
400    #[test]
401    fn transaction_error_source_rollback() {
402        let err: TransactionError<std::io::Error> =
403            TransactionError::Rollback(std::io::Error::other("src"));
404        assert!(err.source().is_some());
405    }
406
407    #[test]
408    fn transaction_error_source_operation() {
409        let err: TransactionError<std::io::Error> =
410            TransactionError::Operation(std::io::Error::other("src"));
411        assert!(err.source().is_some());
412    }
413
414    #[test]
415    fn transaction_builder_new() {
416        struct MockPool;
417        let pool = MockPool;
418        let tx = Transaction::new(&pool);
419        let _ = tx.pool();
420    }
421
422    #[test]
423    fn transaction_builder_pool_accessor() {
424        struct MockPool {
425            id: u32
426        }
427        let pool = MockPool {
428            id: 42
429        };
430        let tx = Transaction::new(&pool);
431        assert_eq!(tx.pool().id, 42);
432    }
433
434    #[test]
435    fn transaction_error_debug() {
436        let err: TransactionError<&str> = TransactionError::Begin("test");
437        let debug_str = format!("{:?}", err);
438        assert!(debug_str.contains("Begin"));
439        assert!(debug_str.contains("test"));
440    }
441}