diesel_async/
transaction_manager.rs

1use diesel::connection::InstrumentationEvent;
2use diesel::connection::TransactionManagerStatus;
3use diesel::connection::{
4    InTransactionStatus, TransactionDepthChange, ValidTransactionManagerStatus,
5};
6use diesel::result::Error;
7use diesel::QueryResult;
8use scoped_futures::ScopedBoxFuture;
9use std::borrow::Cow;
10use std::future::Future;
11use std::num::NonZeroU32;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14
15use crate::AsyncConnection;
16// TODO: refactor this to share more code with diesel
17
18/// Manages the internal transaction state for a connection.
19///
20/// You will not need to interact with this trait, unless you are writing an
21/// implementation of [`AsyncConnection`].
22pub trait TransactionManager<Conn: AsyncConnection>: Send {
23    /// Data stored as part of the connection implementation
24    /// to track the current transaction state of a connection
25    type TransactionStateData;
26
27    /// Begin a new transaction or savepoint
28    ///
29    /// If the transaction depth is greater than 0,
30    /// this should create a savepoint instead.
31    /// This function is expected to increment the transaction depth by 1.
32    fn begin_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
33
34    /// Rollback the inner-most transaction or savepoint
35    ///
36    /// If the transaction depth is greater than 1,
37    /// this should rollback to the most recent savepoint.
38    /// This function is expected to decrement the transaction depth by 1.
39    fn rollback_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
40
41    /// Commit the inner-most transaction or savepoint
42    ///
43    /// If the transaction depth is greater than 1,
44    /// this should release the most recent savepoint.
45    /// This function is expected to decrement the transaction depth by 1.
46    fn commit_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
47
48    /// Fetch the current transaction status as mutable
49    ///
50    /// Used to ensure that `begin_test_transaction` is not called when already
51    /// inside of a transaction, and that operations are not run in a `InError`
52    /// transaction manager.
53    #[doc(hidden)]
54    fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus;
55
56    /// Executes the given function inside of a database transaction
57    ///
58    /// Each implementation of this function needs to fulfill the documented
59    /// behaviour of [`AsyncConnection::transaction`]
60    fn transaction<'a, 'conn, F, R, E>(
61        conn: &'conn mut Conn,
62        callback: F,
63    ) -> impl Future<Output = Result<R, E>> + Send + 'conn
64    where
65        F: for<'r> FnOnce(&'r mut Conn) -> ScopedBoxFuture<'a, 'r, Result<R, E>> + Send + 'a,
66        E: From<Error> + Send,
67        R: Send,
68        'a: 'conn,
69    {
70        async move {
71            let callback = callback;
72
73            Self::begin_transaction(conn).await?;
74            match callback(&mut *conn).await {
75                Ok(value) => {
76                    Self::commit_transaction(conn).await?;
77                    Ok(value)
78                }
79                Err(user_error) => match Self::rollback_transaction(conn).await {
80                    Ok(()) => Err(user_error),
81                    Err(Error::BrokenTransactionManager) => {
82                        // In this case we are probably more interested by the
83                        // original error, which likely caused this
84                        Err(user_error)
85                    }
86                    Err(rollback_error) => Err(rollback_error.into()),
87                },
88            }
89        }
90    }
91
92    /// This methods checks if the connection manager is considered to be broken
93    /// by connection pool implementations
94    ///
95    /// A connection manager is considered to be broken by default if it either
96    /// contains an open transaction (because you don't want to have connections
97    /// with open transactions in your pool) or when the transaction manager is
98    /// in an error state.
99    #[doc(hidden)]
100    fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
101        check_broken_transaction_state(conn)
102    }
103}
104
105fn check_broken_transaction_state<Conn>(conn: &mut Conn) -> bool
106where
107    Conn: AsyncConnection,
108{
109    match Conn::TransactionManager::transaction_manager_status_mut(conn).transaction_state() {
110        // all transactions are closed
111        // so we don't consider this connection broken
112        Ok(ValidTransactionManagerStatus {
113            in_transaction: None,
114            ..
115        }) => false,
116        // The transaction manager is in an error state
117        // Therefore we consider this connection broken
118        Err(_) => true,
119        // The transaction manager contains a open transaction
120        // we do consider this connection broken
121        // if that transaction was not opened by `begin_test_transaction`
122        Ok(ValidTransactionManagerStatus {
123            in_transaction: Some(s),
124            ..
125        }) => !s.test_transaction,
126    }
127}
128
129/// An implementation of `TransactionManager` which can be used for backends
130/// which use ANSI standard syntax for savepoints such as SQLite and PostgreSQL.
131#[derive(Default, Debug)]
132pub struct AnsiTransactionManager {
133    pub(crate) status: TransactionManagerStatus,
134    // this boolean flag tracks whether we are currently in the process
135    // of executing any transaction releated SQL (BEGIN, COMMIT, ROLLBACK)
136    // if we ever encounter a situation where this flag is set
137    // while the connection is returned to a pool
138    // that means the connection is broken as someone dropped the
139    // transaction future while these commands where executed
140    // and we cannot know the connection state anymore
141    //
142    // We ensure this by wrapping all calls to `.await`
143    // into `AnsiTransactionManager::critical_transaction_block`
144    // below
145    //
146    // See https://github.com/weiznich/diesel_async/issues/198 for
147    // details
148    pub(crate) is_broken: Arc<AtomicBool>,
149    // this boolean flag tracks whether we are currently in this process
150    // of trying to commit the transaction. this is useful because if we
151    // are and we get a serialization failure, we might not want to attempt
152    // a rollback up the chain.
153    pub(crate) is_commit: bool,
154}
155
156impl AnsiTransactionManager {
157    fn get_transaction_state<Conn>(
158        conn: &mut Conn,
159    ) -> QueryResult<&mut ValidTransactionManagerStatus>
160    where
161        Conn: AsyncConnection<TransactionManager = Self>,
162    {
163        conn.transaction_state().status.transaction_state()
164    }
165
166    /// Begin a transaction with custom SQL
167    ///
168    /// This is used by connections to implement more complex transaction APIs
169    /// to set things such as isolation levels.
170    /// Returns an error if already inside of a transaction.
171    pub async fn begin_transaction_sql<Conn>(conn: &mut Conn, sql: &str) -> QueryResult<()>
172    where
173        Conn: AsyncConnection<TransactionManager = Self>,
174    {
175        let is_broken = conn.transaction_state().is_broken.clone();
176        let state = Self::get_transaction_state(conn)?;
177        if let Some(_depth) = state.transaction_depth() {
178            return Err(Error::AlreadyInTransaction);
179        }
180        let instrumentation_depth = NonZeroU32::new(1);
181
182        conn.instrumentation()
183            .on_connection_event(InstrumentationEvent::begin_transaction(
184                instrumentation_depth.expect("We know that 1 is not zero"),
185            ));
186
187        // Keep remainder of this method in sync with `begin_transaction()`.
188        Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?;
189        Self::get_transaction_state(conn)?
190            .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
191        Ok(())
192    }
193
194    // This function should be used to await any connection
195    // related future in our transaction manager implementation
196    //
197    // It takes care of tracking entering and exiting executing the future
198    // which in turn is used to determine if it's safe to still use
199    // the connection in the event of a canceled transaction execution
200    async fn critical_transaction_block<F>(is_broken: &AtomicBool, f: F) -> F::Output
201    where
202        F: std::future::Future,
203    {
204        let was_broken = is_broken.swap(true, Ordering::Relaxed);
205        debug_assert!(
206            !was_broken,
207            "Tried to execute a transaction SQL on transaction manager that was previously cancled"
208        );
209        let res = f.await;
210        is_broken.store(false, Ordering::Relaxed);
211        res
212    }
213}
214
215impl<Conn> TransactionManager<Conn> for AnsiTransactionManager
216where
217    Conn: AsyncConnection<TransactionManager = Self>,
218{
219    type TransactionStateData = Self;
220
221    async fn begin_transaction(conn: &mut Conn) -> QueryResult<()> {
222        let transaction_state = Self::get_transaction_state(conn)?;
223        let start_transaction_sql = match transaction_state.transaction_depth() {
224            None => Cow::from("BEGIN"),
225            Some(transaction_depth) => {
226                Cow::from(format!("SAVEPOINT diesel_savepoint_{transaction_depth}"))
227            }
228        };
229        let depth = transaction_state
230            .transaction_depth()
231            .and_then(|d| d.checked_add(1))
232            .unwrap_or(NonZeroU32::new(1).expect("It's not 0"));
233        conn.instrumentation()
234            .on_connection_event(InstrumentationEvent::begin_transaction(depth));
235        Self::critical_transaction_block(
236            &conn.transaction_state().is_broken.clone(),
237            conn.batch_execute(&start_transaction_sql),
238        )
239        .await?;
240        Self::get_transaction_state(conn)?
241            .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
242
243        Ok(())
244    }
245
246    async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()> {
247        let transaction_state = Self::get_transaction_state(conn)?;
248
249        let (
250            (rollback_sql, rolling_back_top_level),
251            requires_rollback_maybe_up_to_top_level_before_execute,
252        ) = match transaction_state.in_transaction {
253            Some(ref in_transaction) => (
254                match in_transaction.transaction_depth.get() {
255                    1 => (Cow::Borrowed("ROLLBACK"), true),
256                    depth_gt1 => (
257                        Cow::Owned(format!(
258                            "ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
259                            depth_gt1 - 1
260                        )),
261                        false,
262                    ),
263                },
264                in_transaction.requires_rollback_maybe_up_to_top_level,
265            ),
266            None => return Err(Error::NotInTransaction),
267        };
268
269        let depth = transaction_state
270            .transaction_depth()
271            .expect("We know that we are in a transaction here");
272        conn.instrumentation()
273            .on_connection_event(InstrumentationEvent::rollback_transaction(depth));
274
275        let is_broken = conn.transaction_state().is_broken.clone();
276
277        match Self::critical_transaction_block(&is_broken, conn.batch_execute(&rollback_sql)).await
278        {
279            Ok(()) => {
280                match Self::get_transaction_state(conn)?
281                    .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
282                {
283                    Ok(()) => {}
284                    Err(Error::NotInTransaction) if rolling_back_top_level => {
285                        // Transaction exit may have already been detected by connection
286                        // implementation. It's fine.
287                    }
288                    Err(e) => return Err(e),
289                }
290                Ok(())
291            }
292            Err(rollback_error) => {
293                let tm_status = Self::transaction_manager_status_mut(conn);
294                match tm_status {
295                    TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
296                        in_transaction:
297                            Some(InTransactionStatus {
298                                transaction_depth,
299                                requires_rollback_maybe_up_to_top_level,
300                                ..
301                            }),
302                        ..
303                    }) if transaction_depth.get() > 1 => {
304                        // A savepoint failed to rollback - we may still attempt to repair
305                        // the connection by rolling back higher levels.
306
307                        // To make it easier on the user (that they don't have to really
308                        // look at actual transaction depth and can just rely on the number
309                        // of times they have called begin/commit/rollback) we still
310                        // decrement here:
311                        *transaction_depth = NonZeroU32::new(transaction_depth.get() - 1)
312                            .expect("Depth was checked to be > 1");
313                        *requires_rollback_maybe_up_to_top_level = true;
314                        if requires_rollback_maybe_up_to_top_level_before_execute {
315                            // In that case, we tolerate that savepoint releases fail
316                            // -> we should ignore errors
317                            return Ok(());
318                        }
319                    }
320                    TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
321                        in_transaction: None,
322                        ..
323                    }) => {
324                        // we would have returned `NotInTransaction` if that was already the state
325                        // before we made our call
326                        // => Transaction manager status has been fixed by the underlying connection
327                        // so we don't need to set_in_error
328                    }
329                    _ => tm_status.set_in_error(),
330                }
331                Err(rollback_error)
332            }
333        }
334    }
335
336    /// If the transaction fails to commit due to a `SerializationFailure` or a
337    /// `ReadOnlyTransaction` a rollback will be attempted. If the rollback succeeds,
338    /// the original error will be returned, otherwise the error generated by the rollback
339    /// will be returned. In the second case the connection will be considered broken
340    /// as it contains a uncommitted unabortable open transaction.
341    async fn commit_transaction(conn: &mut Conn) -> QueryResult<()> {
342        let transaction_state = Self::get_transaction_state(conn)?;
343        let transaction_depth = transaction_state.transaction_depth();
344        let (commit_sql, committing_top_level) = match transaction_depth {
345            None => return Err(Error::NotInTransaction),
346            Some(transaction_depth) if transaction_depth.get() == 1 => {
347                (Cow::Borrowed("COMMIT"), true)
348            }
349            Some(transaction_depth) => (
350                Cow::Owned(format!(
351                    "RELEASE SAVEPOINT diesel_savepoint_{}",
352                    transaction_depth.get() - 1
353                )),
354                false,
355            ),
356        };
357        let depth = transaction_state
358            .transaction_depth()
359            .expect("We know that we are in a transaction here");
360        conn.instrumentation()
361            .on_connection_event(InstrumentationEvent::commit_transaction(depth));
362
363        let is_broken = {
364            let transaction_state = conn.transaction_state();
365            transaction_state.is_commit = true;
366            transaction_state.is_broken.clone()
367        };
368
369        let res =
370            Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await;
371
372        conn.transaction_state().is_commit = false;
373
374        match res {
375            Ok(()) => {
376                match Self::get_transaction_state(conn)?
377                    .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
378                {
379                    Ok(()) => {}
380                    Err(Error::NotInTransaction) if committing_top_level => {
381                        // Transaction exit may have already been detected by connection.
382                        // It's fine
383                    }
384                    Err(e) => return Err(e),
385                }
386                Ok(())
387            }
388            Err(commit_error) => {
389                if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
390                    in_transaction:
391                        Some(InTransactionStatus {
392                            requires_rollback_maybe_up_to_top_level: true,
393                            ..
394                        }),
395                    ..
396                }) = conn.transaction_state().status
397                {
398                    // rollback_transaction handles the critical block internally on its own
399                    match Self::rollback_transaction(conn).await {
400                        Ok(()) => {}
401                        Err(rollback_error) => {
402                            conn.transaction_state().status.set_in_error();
403                            return Err(Error::RollbackErrorOnCommit {
404                                rollback_error: Box::new(rollback_error),
405                                commit_error: Box::new(commit_error),
406                            });
407                        }
408                    }
409                } else {
410                    Self::get_transaction_state(conn)?
411                        .change_transaction_depth(TransactionDepthChange::DecreaseDepth)?;
412                }
413                Err(commit_error)
414            }
415        }
416    }
417
418    fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus {
419        &mut conn.transaction_state().status
420    }
421
422    fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
423        conn.transaction_state().is_broken.load(Ordering::Relaxed)
424            || check_broken_transaction_state(conn)
425    }
426}