Skip to main content

rorm_db/transaction/
mod.rs

1//! This module holds the definition of transactions
2
3use std::error::Error as StdError;
4use std::fmt;
5use std::future::Future;
6
7use tracing::debug;
8
9use crate::internal::any::AnyTransaction;
10pub use crate::transaction::hook::TransactionHook;
11use crate::transaction::hook_closure::{ClosureHook, OnRollback, PostCommit, PreCommit};
12use crate::transaction::hook_storage::HookStorage;
13use crate::Error;
14
15mod hook;
16mod hook_closure;
17mod hook_storage;
18
19/// Transactions can be used to provide a safe way to execute multiple SQL operations
20/// after another with a way to go back to the start without something changed in the
21/// database.
22///
23/// Can be obtained using [`Database::start_transaction`](crate::Database::start_transaction).
24#[must_use = "A transaction needs to be committed."]
25pub struct Transaction {
26    pub(crate) sqlx: AnyTransaction,
27    hooks: Option<HookStorage>,
28}
29
30impl Transaction {
31    pub(crate) fn new(sqlx: AnyTransaction) -> Self {
32        Self { sqlx, hooks: None }
33    }
34
35    /// This function commits the transaction.
36    pub async fn commit(mut self) -> Result<(), TransactionError> {
37        let mut hooks = self.hooks.take();
38
39        if let Some(hooks) = hooks.as_mut() {
40            hooks.pre_commit(&mut self).await?;
41
42            if let Some(invalid_hooks) = self.hooks.as_mut() {
43                debug!("Some transaction hook added additional hooks during pre-commit. This is not supported and will be ignored.");
44
45                // Prevent `Drop` impl from calling `on_rollback`.
46                invalid_hooks.clear();
47            }
48        }
49
50        let result = self.sqlx.commit().await;
51
52        if let Some(hooks) = hooks.as_mut() {
53            if result.is_ok() {
54                hooks.post_commit();
55
56                // Prevent `Drop` impl from calling `on_rollback`.
57                hooks.clear();
58            }
59        }
60
61        result
62            .map_err(Error::SqlxError)
63            .map_err(TransactionError::Database)
64    }
65
66    /// Use this function to abort the transaction.
67    pub async fn rollback(self) -> Result<(), Error> {
68        self.sqlx.rollback().await.map_err(Error::SqlxError)
69    }
70}
71
72// This impl should be on `Transaction` itself.
73// However, the `sqlx` field has to be consumed by ownership
74// which prevents `Transaction` from implementing `Drop`.
75impl Drop for HookStorage {
76    fn drop(&mut self) {
77        // `Transaction::commit` will clear all hooks, so this call would become a no-op.
78        self.on_rollback();
79    }
80}
81
82impl Transaction {
83    /// Accesses the simple API for adding hooks to the transaction
84    ///
85    /// If you reach the API's limits, consider [`Transaction::adv_hooks`].
86    pub fn hooks(&mut self) -> SimpleHooksApi<'_> {
87        SimpleHooksApi(self.hooks.get_or_insert_default())
88    }
89
90    /// Accesses the advanced API for adding hooks to the transaction
91    ///
92    /// If you're new to transaction hooks, consider [`Transaction::hooks`].
93    pub fn adv_hooks(&mut self) -> AdvancedHooksApi<'_> {
94        AdvancedHooksApi(self.hooks.get_or_insert_default())
95    }
96}
97
98/// Simple API for adding hooks to [`Transaction`]s
99///
100/// A hook is a closure which is called before or after a transaction has been commited.
101pub struct SimpleHooksApi<'a>(&'a mut HookStorage);
102impl SimpleHooksApi<'_> {
103    /// Adds an async closure which is run before the transaction is commited.
104    ///
105    /// Note, the transaction could still fail due to a database error or a hook error.
106    pub fn pre_commit<F>(&mut self, hook: impl FnOnce() -> F + Send + 'static) -> &mut Self
107    where
108        F: Future<Output = Result<(), TransactionError>> + Send,
109    {
110        self.0
111            .get_or_insert()
112            .push(ClosureHook::new(hook, PreCommit));
113        self
114    }
115
116    /// Adds a closure which is run before the transaction has been commited.
117    pub fn post_commit(&mut self, hook: impl FnOnce() + Send + 'static) -> &mut Self {
118        self.0
119            .get_or_insert()
120            .push(ClosureHook::new(hook, PostCommit));
121        self
122    }
123
124    /// Adds a closure which is run when the transaction is rolled back.
125    ///
126    /// It MAY be called before, during or after the actual database operation.
127    pub fn on_rollback(&mut self, hook: impl FnOnce() + Send + 'static) -> &mut Self {
128        self.0
129            .get_or_insert()
130            .push(ClosureHook::new(hook, OnRollback));
131        self
132    }
133}
134
135/// Advanced API for adding hooks to [`Transaction`]s
136///
137/// A [`TransactionHook`] is a type which is called before and after a transaction has been finished.
138///
139/// A `Transaction` can store many instances of many `TransactionHook` types.
140///
141/// This API provides convenience methods for two common patters:
142/// - [`push`](Self::push) for adding many instances (potentially of the same type)
143/// - [`get_or_insert_default`](Self::get_or_insert_default) and [`get_or_insert_with`](Self::get_or_insert_with)
144///   when you only want a single instance of your hook type but want to extend it several times.
145///
146/// If these APIs are not flexible enough, you can use [`get_all`](Self::get_all) to access the raw
147/// storage of `TransactionHook`s of a single type.
148pub struct AdvancedHooksApi<'a>(&'a mut HookStorage);
149impl AdvancedHooksApi<'_> {
150    /// Adds a hook which is called if the transaction has been finished.
151    pub fn push<T: TransactionHook>(&mut self, hook: T) {
152        self.get_all().push(hook);
153    }
154
155    /// Gets the hook of type `T`.
156    ///
157    /// Adds its [`Default`] value if no value has been added yet.
158    pub fn get_or_insert_default<T: TransactionHook + Default>(&mut self) -> &mut T {
159        self.get_or_insert_with(T::default)
160    }
161
162    /// Gets the hook of type `T`.
163    ///
164    /// Calls `init` to add a value if no value has been added yet.
165    pub fn get_or_insert_with<T: TransactionHook>(&mut self, init: impl FnOnce() -> T) -> &mut T {
166        let vec = self.get_all();
167        if vec.is_empty() {
168            vec.push(init());
169        }
170        &mut vec[0]
171    }
172
173    /// Gets all hooks of type `T`.
174    pub fn get_all<T: TransactionHook>(&mut self) -> &mut Vec<T> {
175        self.0.get_or_insert()
176    }
177}
178
179/// Error for committing a [`Transaction`]
180#[derive(Debug)]
181pub enum TransactionError {
182    /// Error returned by the database
183    Database(Error),
184
185    /// Arbitrary error returned by a hook
186    Hook(HookError),
187}
188/// Arbitrary error returned by a hook
189pub type HookError = Box<dyn StdError + Send + Sync>;
190
191impl From<Error> for TransactionError {
192    fn from(value: Error) -> Self {
193        Self::Database(value)
194    }
195}
196impl From<HookError> for TransactionError {
197    fn from(value: HookError) -> Self {
198        Self::Hook(value)
199    }
200}
201impl fmt::Display for TransactionError {
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        match self {
204            TransactionError::Database(x) => fmt::Display::fmt(x, f),
205            TransactionError::Hook(x) => fmt::Display::fmt(x, f),
206        }
207    }
208}
209impl StdError for TransactionError {
210    fn source(&self) -> Option<&(dyn StdError + 'static)> {
211        match self {
212            TransactionError::Database(x) => Some(x),
213            TransactionError::Hook(x) => Some(x.as_ref()),
214        }
215    }
216}
217
218/// Either an owned or borrowed [`Transaction`].
219///
220/// "Guarding" a piece of code which has to be run in an transaction
221/// (see [`Executor::ensure_transaction`](crate::executor::Executor::ensure_transaction))
222#[must_use = "The potentially owned transaction needs to be committed."]
223pub enum TransactionGuard<'tr> {
224    /// An owned transaction
225    Owned(Transaction),
226
227    /// A borrowed transaction
228    Borrowed(&'tr mut Transaction),
229}
230
231impl TransactionGuard<'_> {
232    /// Get a reference to the guarded transaction
233    pub fn get_transaction(&mut self) -> &mut Transaction {
234        match self {
235            TransactionGuard::Owned(tr) => tr,
236            TransactionGuard::Borrowed(tr) => tr,
237        }
238    }
239
240    /// Consume the guard, committing the potentially owned transaction.
241    pub async fn commit(self) -> Result<(), TransactionError> {
242        if let TransactionGuard::Owned(tr) = self {
243            tr.commit().await
244        } else {
245            Ok(())
246        }
247    }
248}