rorm_db/transaction/
mod.rs1use std::error::Error as StdError;
4use std::fmt;
5use std::future::Future;
6
7use tracing::debug;
8
9use crate::internal::any::AnyTransaction;
10pub use crate::transaction::hook::TransactionHook;
11use crate::transaction::hook_closure::{ClosureHook, OnRollback, PostCommit, PreCommit};
12use crate::transaction::hook_storage::HookStorage;
13use crate::Error;
14
15mod hook;
16mod hook_closure;
17mod hook_storage;
18
19#[must_use = "A transaction needs to be committed."]
25pub struct Transaction {
26 pub(crate) sqlx: AnyTransaction,
27 hooks: Option<HookStorage>,
28}
29
30impl Transaction {
31 pub(crate) fn new(sqlx: AnyTransaction) -> Self {
32 Self { sqlx, hooks: None }
33 }
34
35 pub async fn commit(mut self) -> Result<(), TransactionError> {
37 let mut hooks = self.hooks.take();
38
39 if let Some(hooks) = hooks.as_mut() {
40 hooks.pre_commit(&mut self).await?;
41
42 if let Some(invalid_hooks) = self.hooks.as_mut() {
43 debug!("Some transaction hook added additional hooks during pre-commit. This is not supported and will be ignored.");
44
45 invalid_hooks.clear();
47 }
48 }
49
50 let result = self.sqlx.commit().await;
51
52 if let Some(hooks) = hooks.as_mut() {
53 if result.is_ok() {
54 hooks.post_commit();
55
56 hooks.clear();
58 }
59 }
60
61 result
62 .map_err(Error::SqlxError)
63 .map_err(TransactionError::Database)
64 }
65
66 pub async fn rollback(self) -> Result<(), Error> {
68 self.sqlx.rollback().await.map_err(Error::SqlxError)
69 }
70}
71
72impl Drop for HookStorage {
76 fn drop(&mut self) {
77 self.on_rollback();
79 }
80}
81
82impl Transaction {
83 pub fn hooks(&mut self) -> SimpleHooksApi<'_> {
87 SimpleHooksApi(self.hooks.get_or_insert_default())
88 }
89
90 pub fn adv_hooks(&mut self) -> AdvancedHooksApi<'_> {
94 AdvancedHooksApi(self.hooks.get_or_insert_default())
95 }
96}
97
98pub struct SimpleHooksApi<'a>(&'a mut HookStorage);
102impl SimpleHooksApi<'_> {
103 pub fn pre_commit<F>(&mut self, hook: impl FnOnce() -> F + Send + 'static) -> &mut Self
107 where
108 F: Future<Output = Result<(), TransactionError>> + Send,
109 {
110 self.0
111 .get_or_insert()
112 .push(ClosureHook::new(hook, PreCommit));
113 self
114 }
115
116 pub fn post_commit(&mut self, hook: impl FnOnce() + Send + 'static) -> &mut Self {
118 self.0
119 .get_or_insert()
120 .push(ClosureHook::new(hook, PostCommit));
121 self
122 }
123
124 pub fn on_rollback(&mut self, hook: impl FnOnce() + Send + 'static) -> &mut Self {
128 self.0
129 .get_or_insert()
130 .push(ClosureHook::new(hook, OnRollback));
131 self
132 }
133}
134
135pub struct AdvancedHooksApi<'a>(&'a mut HookStorage);
149impl AdvancedHooksApi<'_> {
150 pub fn push<T: TransactionHook>(&mut self, hook: T) {
152 self.get_all().push(hook);
153 }
154
155 pub fn get_or_insert_default<T: TransactionHook + Default>(&mut self) -> &mut T {
159 self.get_or_insert_with(T::default)
160 }
161
162 pub fn get_or_insert_with<T: TransactionHook>(&mut self, init: impl FnOnce() -> T) -> &mut T {
166 let vec = self.get_all();
167 if vec.is_empty() {
168 vec.push(init());
169 }
170 &mut vec[0]
171 }
172
173 pub fn get_all<T: TransactionHook>(&mut self) -> &mut Vec<T> {
175 self.0.get_or_insert()
176 }
177}
178
179#[derive(Debug)]
181pub enum TransactionError {
182 Database(Error),
184
185 Hook(HookError),
187}
188pub type HookError = Box<dyn StdError + Send + Sync>;
190
191impl From<Error> for TransactionError {
192 fn from(value: Error) -> Self {
193 Self::Database(value)
194 }
195}
196impl From<HookError> for TransactionError {
197 fn from(value: HookError) -> Self {
198 Self::Hook(value)
199 }
200}
201impl fmt::Display for TransactionError {
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 match self {
204 TransactionError::Database(x) => fmt::Display::fmt(x, f),
205 TransactionError::Hook(x) => fmt::Display::fmt(x, f),
206 }
207 }
208}
209impl StdError for TransactionError {
210 fn source(&self) -> Option<&(dyn StdError + 'static)> {
211 match self {
212 TransactionError::Database(x) => Some(x),
213 TransactionError::Hook(x) => Some(x.as_ref()),
214 }
215 }
216}
217
218#[must_use = "The potentially owned transaction needs to be committed."]
223pub enum TransactionGuard<'tr> {
224 Owned(Transaction),
226
227 Borrowed(&'tr mut Transaction),
229}
230
231impl TransactionGuard<'_> {
232 pub fn get_transaction(&mut self) -> &mut Transaction {
234 match self {
235 TransactionGuard::Owned(tr) => tr,
236 TransactionGuard::Borrowed(tr) => tr,
237 }
238 }
239
240 pub async fn commit(self) -> Result<(), TransactionError> {
242 if let TransactionGuard::Owned(tr) = self {
243 tr.commit().await
244 } else {
245 Ok(())
246 }
247 }
248}