entity_core/
transaction.rs

1// SPDX-FileCopyrightText: 2025-2026 RAprogramm <andrey.rozanov.vl@gmail.com>
2// SPDX-License-Identifier: MIT
3
4//! Transaction support for entity-derive.
5//!
6//! This module provides type-safe transaction management with automatic
7//! commit/rollback semantics. It uses the builder pattern for composing
8//! multiple repositories into a single transaction context.
9//!
10//! # Overview
11//!
12//! - [`Transaction`] — Entry point for creating transactions
13//! - [`TransactionContext`] — Holds active transaction and repository adapters
14//! - [`TransactionError`] — Error wrapper for transaction operations
15//!
16//! # Example
17//!
18//! ```rust,ignore
19//! use entity_derive::prelude::*;
20//!
21//! async fn transfer(pool: &PgPool, from: Uuid, to: Uuid, amount: Decimal) -> Result<(), AppError> {
22//!     Transaction::new(pool)
23//!         .with_accounts()
24//!         .with_transfers()
25//!         .run(|mut ctx| async move {
26//!             let from_acc = ctx.accounts().find_by_id(from).await?.ok_or(AppError::NotFound)?;
27//!
28//!             ctx.accounts().update(from, UpdateAccount {
29//!                 balance: Some(from_acc.balance - amount),
30//!             }).await?;
31//!
32//!             ctx.transfers().create(CreateTransfer { from, to, amount }).await?;
33//!             Ok(())
34//!         })
35//!         .await
36//! }
37//! ```
38
39use std::{error::Error as StdError, fmt, future::Future, marker::PhantomData};
40
41/// Transaction builder for composing repositories.
42///
43/// Use [`Transaction::new`] to create a builder, then chain `.with_*()` methods
44/// to add repositories, and finally call `.run()` to execute.
45///
46/// # Type Parameters
47///
48/// - `DB` — Database type (e.g., `Postgres`)
49/// - `Repos` — Tuple of repository adapters accumulated via builder
50pub struct Transaction<'p, DB, Repos = ()> {
51    pool:   &'p DB,
52    _repos: PhantomData<Repos>
53}
54
55impl<'p, DB> Transaction<'p, DB, ()> {
56    /// Create a new transaction builder.
57    ///
58    /// # Arguments
59    ///
60    /// * `pool` — Database connection pool
61    ///
62    /// # Example
63    ///
64    /// ```rust,ignore
65    /// let tx = Transaction::new(&pool);
66    /// ```
67    pub const fn new(pool: &'p DB) -> Self {
68        Self {
69            pool,
70            _repos: PhantomData
71        }
72    }
73}
74
75impl<'p, DB, Repos> Transaction<'p, DB, Repos> {
76    /// Get reference to the underlying pool.
77    pub const fn pool(&self) -> &'p DB {
78        self.pool
79    }
80
81    /// Transform repository tuple type.
82    ///
83    /// Used internally by generated `with_*` methods.
84    #[doc(hidden)]
85    pub const fn with_repo<NewRepos>(self) -> Transaction<'p, DB, NewRepos> {
86        Transaction {
87            pool:   self.pool,
88            _repos: PhantomData
89        }
90    }
91}
92
93/// Active transaction context with repository adapters.
94///
95/// This struct holds the database transaction and provides access to
96/// repository adapters that operate within the transaction.
97///
98/// # Automatic Rollback
99///
100/// If dropped without explicit commit, the transaction is automatically
101/// rolled back via the underlying database transaction's Drop impl.
102///
103/// # Type Parameters
104///
105/// - `'t` — Transaction lifetime
106/// - `Tx` — Transaction type (e.g., `sqlx::Transaction<'t, Postgres>`)
107/// - `Repos` — Tuple of repository adapters
108pub struct TransactionContext<'t, Tx, Repos> {
109    tx:        Tx,
110    repos:     Repos,
111    _lifetime: PhantomData<&'t ()>
112}
113
114impl<'t, Tx, Repos> TransactionContext<'t, Tx, Repos> {
115    /// Create a new transaction context.
116    ///
117    /// # Arguments
118    ///
119    /// * `tx` — Active database transaction
120    /// * `repos` — Repository adapters tuple
121    #[doc(hidden)]
122    pub const fn new(tx: Tx, repos: Repos) -> Self {
123        Self {
124            tx,
125            repos,
126            _lifetime: PhantomData
127        }
128    }
129
130    /// Get mutable reference to the underlying transaction.
131    ///
132    /// Use this for custom queries within the transaction.
133    pub fn transaction(&mut self) -> &mut Tx {
134        &mut self.tx
135    }
136
137    /// Get reference to repository adapters.
138    pub const fn repos(&self) -> &Repos {
139        &self.repos
140    }
141
142    /// Get mutable reference to repository adapters.
143    pub fn repos_mut(&mut self) -> &mut Repos {
144        &mut self.repos
145    }
146}
147
148/// Error type for transaction operations.
149///
150/// Wraps database errors and provides context about the transaction state.
151#[derive(Debug)]
152pub enum TransactionError<E> {
153    /// Failed to begin transaction.
154    Begin(E),
155
156    /// Failed to commit transaction.
157    Commit(E),
158
159    /// Failed to rollback transaction.
160    Rollback(E),
161
162    /// Operation within transaction failed.
163    Operation(E)
164}
165
166impl<E: fmt::Display> fmt::Display for TransactionError<E> {
167    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168        match self {
169            Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
170            Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
171            Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
172            Self::Operation(e) => write!(f, "transaction operation failed: {e}")
173        }
174    }
175}
176
177impl<E: StdError + 'static> StdError for TransactionError<E> {
178    fn source(&self) -> Option<&(dyn StdError + 'static)> {
179        match self {
180            Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
181        }
182    }
183}
184
185impl<E> TransactionError<E> {
186    /// Check if this is a begin error.
187    pub const fn is_begin(&self) -> bool {
188        matches!(self, Self::Begin(_))
189    }
190
191    /// Check if this is a commit error.
192    pub const fn is_commit(&self) -> bool {
193        matches!(self, Self::Commit(_))
194    }
195
196    /// Check if this is a rollback error.
197    pub const fn is_rollback(&self) -> bool {
198        matches!(self, Self::Rollback(_))
199    }
200
201    /// Check if this is an operation error.
202    pub const fn is_operation(&self) -> bool {
203        matches!(self, Self::Operation(_))
204    }
205
206    /// Get the inner error.
207    pub fn into_inner(self) -> E {
208        match self {
209            Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
210        }
211    }
212}
213
214/// Trait for types that can begin a transaction.
215///
216/// Implemented for database pools to enable transaction creation.
217#[allow(async_fn_in_trait)]
218pub trait Transactional: Sized + Send + Sync {
219    /// Transaction type.
220    type Transaction<'t>: Send
221    where
222        Self: 't;
223
224    /// Error type for transaction operations.
225    type Error: StdError + Send + Sync;
226
227    /// Begin a new transaction.
228    async fn begin(&self) -> Result<Self::Transaction<'_>, Self::Error>;
229}
230
231/// Trait for transaction types that can be committed or rolled back.
232#[allow(async_fn_in_trait)]
233pub trait TransactionOps: Sized + Send {
234    /// Error type.
235    type Error: StdError + Send + Sync;
236
237    /// Commit the transaction.
238    async fn commit(self) -> Result<(), Self::Error>;
239
240    /// Rollback the transaction.
241    async fn rollback(self) -> Result<(), Self::Error>;
242}
243
244/// Trait for executing operations within a transaction.
245///
246/// This trait is implemented on [`Transaction`] with specific repository
247/// combinations, enabling type-safe execution.
248#[allow(async_fn_in_trait)]
249pub trait TransactionRunner<'p, Repos>: Sized {
250    /// Transaction type.
251    type Tx: TransactionOps;
252
253    /// Database error type.
254    type DbError: StdError + Send + Sync;
255
256    /// Execute a closure within the transaction.
257    ///
258    /// Automatically commits on `Ok`, rolls back on `Err` or panic.
259    ///
260    /// # Type Parameters
261    ///
262    /// - `F` — Closure type
263    /// - `Fut` — Future returned by closure
264    /// - `T` — Success type
265    /// - `E` — Error type (must be convertible from database error)
266    async fn run<F, Fut, T, E>(self, f: F) -> Result<T, E>
267    where
268        F: FnOnce(TransactionContext<'_, Self::Tx, Repos>) -> Fut + Send,
269        Fut: Future<Output = Result<T, E>> + Send,
270        E: From<TransactionError<Self::DbError>>;
271}
272
273// sqlx implementations (requires database for testing)
274// LCOV_EXCL_START
275#[cfg(feature = "postgres")]
276mod postgres_impl {
277    use sqlx::{PgPool, Postgres};
278
279    use super::*;
280
281    impl Transactional for PgPool {
282        type Transaction<'t> = sqlx::Transaction<'t, Postgres>;
283        type Error = sqlx::Error;
284
285        async fn begin(&self) -> Result<Self::Transaction<'_>, Self::Error> {
286            sqlx::pool::Pool::begin(self).await
287        }
288    }
289
290    impl TransactionOps for sqlx::Transaction<'_, Postgres> {
291        type Error = sqlx::Error;
292
293        async fn commit(self) -> Result<(), Self::Error> {
294            sqlx::Transaction::commit(self).await
295        }
296
297        async fn rollback(self) -> Result<(), Self::Error> {
298            sqlx::Transaction::rollback(self).await
299        }
300    }
301
302    impl<'p, Repos: Send> Transaction<'p, PgPool, Repos> {
303        /// Execute a closure within a PostgreSQL transaction.
304        ///
305        /// Automatically commits on `Ok`, rolls back on `Err` or drop.
306        ///
307        /// # Example
308        ///
309        /// ```rust,ignore
310        /// Transaction::new(&pool)
311        ///     .with_users()
312        ///     .run(|mut ctx| async move {
313        ///         ctx.users().create(dto).await
314        ///     })
315        ///     .await?;
316        /// ```
317        pub async fn run<F, Fut, T, E>(self, f: F) -> Result<T, E>
318        where
319            F: for<'t> FnOnce(
320                    TransactionContext<'t, sqlx::Transaction<'t, Postgres>, Repos>
321                ) -> Fut
322                + Send,
323            Fut: Future<Output = Result<T, E>> + Send,
324            E: From<TransactionError<sqlx::Error>>,
325            Repos: Default
326        {
327            let tx = self.pool.begin().await.map_err(TransactionError::Begin)?;
328            let ctx = TransactionContext::new(tx, Repos::default());
329
330            match f(ctx).await {
331                Ok(result) => Ok(result),
332                Err(e) => Err(e)
333            }
334        }
335    }
336}
337// LCOV_EXCL_STOP
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[test]
344    fn transaction_error_display_begin() {
345        let err: TransactionError<std::io::Error> =
346            TransactionError::Begin(std::io::Error::other("test"));
347        assert!(err.to_string().contains("begin"));
348        assert!(err.to_string().contains("test"));
349    }
350
351    #[test]
352    fn transaction_error_display_commit() {
353        let err: TransactionError<std::io::Error> =
354            TransactionError::Commit(std::io::Error::other("commit_err"));
355        assert!(err.to_string().contains("commit"));
356        assert!(err.to_string().contains("commit_err"));
357    }
358
359    #[test]
360    fn transaction_error_display_rollback() {
361        let err: TransactionError<std::io::Error> =
362            TransactionError::Rollback(std::io::Error::other("rollback_err"));
363        assert!(err.to_string().contains("rollback"));
364        assert!(err.to_string().contains("rollback_err"));
365    }
366
367    #[test]
368    fn transaction_error_display_operation() {
369        let err: TransactionError<std::io::Error> =
370            TransactionError::Operation(std::io::Error::other("op_err"));
371        assert!(err.to_string().contains("operation"));
372        assert!(err.to_string().contains("op_err"));
373    }
374
375    #[test]
376    fn transaction_error_is_methods() {
377        let begin: TransactionError<&str> = TransactionError::Begin("e");
378        let commit: TransactionError<&str> = TransactionError::Commit("e");
379        let rollback: TransactionError<&str> = TransactionError::Rollback("e");
380        let op: TransactionError<&str> = TransactionError::Operation("e");
381
382        assert!(begin.is_begin());
383        assert!(!begin.is_commit());
384        assert!(!begin.is_rollback());
385        assert!(!begin.is_operation());
386
387        assert!(commit.is_commit());
388        assert!(!commit.is_begin());
389
390        assert!(rollback.is_rollback());
391        assert!(!rollback.is_begin());
392
393        assert!(op.is_operation());
394        assert!(!op.is_begin());
395    }
396
397    #[test]
398    fn transaction_error_into_inner() {
399        let err: TransactionError<&str> = TransactionError::Operation("inner");
400        assert_eq!(err.into_inner(), "inner");
401    }
402
403    #[test]
404    fn transaction_error_into_inner_all_variants() {
405        assert_eq!(TransactionError::Begin("b").into_inner(), "b");
406        assert_eq!(TransactionError::Commit("c").into_inner(), "c");
407        assert_eq!(TransactionError::Rollback("r").into_inner(), "r");
408        assert_eq!(TransactionError::Operation("o").into_inner(), "o");
409    }
410
411    #[test]
412    fn transaction_error_source() {
413        let inner = std::io::Error::other("source_err");
414        let err: TransactionError<std::io::Error> = TransactionError::Begin(inner);
415        assert!(err.source().is_some());
416
417        let commit_err: TransactionError<std::io::Error> =
418            TransactionError::Commit(std::io::Error::other("c"));
419        assert!(commit_err.source().is_some());
420
421        let rollback_err: TransactionError<std::io::Error> =
422            TransactionError::Rollback(std::io::Error::other("r"));
423        assert!(rollback_err.source().is_some());
424
425        let op_err: TransactionError<std::io::Error> =
426            TransactionError::Operation(std::io::Error::other("o"));
427        assert!(op_err.source().is_some());
428    }
429
430    #[test]
431    fn transaction_builder_new() {
432        struct MockPool;
433        let pool = MockPool;
434        let tx: Transaction<'_, MockPool, ()> = Transaction::new(&pool);
435        let _ = tx.pool();
436    }
437
438    #[test]
439    fn transaction_builder_with_repo() {
440        struct MockPool;
441        let pool = MockPool;
442        let tx: Transaction<'_, MockPool, ()> = Transaction::new(&pool);
443        let tx2: Transaction<'_, MockPool, i32> = tx.with_repo();
444        let _ = tx2.pool();
445    }
446
447    #[test]
448    fn transaction_context_new() {
449        let tx = "mock_tx";
450        let repos = (1, 2, 3);
451        let ctx = TransactionContext::new(tx, repos);
452        assert_eq!(*ctx.repos(), (1, 2, 3));
453    }
454
455    #[test]
456    fn transaction_context_transaction() {
457        let tx = String::from("mock_tx");
458        let repos = ();
459        let mut ctx = TransactionContext::new(tx, repos);
460        assert_eq!(ctx.transaction(), "mock_tx");
461    }
462
463    #[test]
464    fn transaction_context_repos_mut() {
465        let tx = "mock_tx";
466        let repos = vec![1, 2, 3];
467        let mut ctx = TransactionContext::new(tx, repos);
468        ctx.repos_mut().push(4);
469        assert_eq!(*ctx.repos(), vec![1, 2, 3, 4]);
470    }
471}