Skip to main content

diesel_async/
transaction_manager.rs

1use diesel::connection::InstrumentationEvent;
2use diesel::connection::TransactionManagerStatus;
3use diesel::connection::{
4    InTransactionStatus, TransactionDepthChange, ValidTransactionManagerStatus,
5};
6use diesel::result::Error;
7use diesel::QueryResult;
8use std::borrow::Cow;
9use std::future::Future;
10use std::num::NonZeroU32;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13
14/// A helper trait to allow us asserting additional bounds on `AsyncFnOnce`
15/// Especially this lets us assert bounds on the future returned by the closure
16/// while still maintaining working type inference for the closure
17///
18/// This mostly exists for the following reasons:
19///
20/// * `AsyncFnOnce::CallOnceFuture` is not stable, so you cannot assert only with `AsyncFnOnce` that
21///   the returned future is `Send`
22/// * Only using `FnOnce(T) -> impl Future` doesn't work as you cannot use `impl Future` in that position
23/// * Using `F: FnOnce(T) -> Fut, Fut: Future<…>` doesn't work as you run into diverging lifetimes,
24///   as you essentially need to reuse the higher ranked lifetime from the closure to also restrict
25///   the lifetime of the returned future
26/// * Using a `trait Func<T>: FnOnce(T) -> <Self as Func<T>::Output> { type Output }` allows us
27///   to restrict the lifetime and bounds (like `Future` and `Send`) on the output type of the closure
28///   but then fails in type inference due to rustc bugs. You get unhelpful errors like
29///   `implementation of `FnOnce` is not general enough`, this can be side stepped by putting
30///   types on the calling side of the closure, which is not that nice API wise
31///
32/// This workaround is still not optimal as it still requires us to have this trait and show
33/// it as public bound. We somewhat try to avoid confusing users there by having an additional
34/// `AsyncFnOnce(T) -> R` bound at the calling side that hopefully will show up in rustdoc
35/// as well and hopefully guides the user to do the "right thing"
36pub trait AsyncFunc<T, R>:
37    AsyncFnOnce(T) -> R + FnOnce(T) -> <Self as AsyncFunc<T, R>>::Fut
38{
39    type Fut: Future<Output = R>;
40}
41
42impl<F, T, Fut, R> AsyncFunc<T, R> for F
43where
44    F: AsyncFnOnce(T) -> R + FnOnce(T) -> Fut,
45    Fut: Future<Output = R>,
46{
47    type Fut = Fut;
48}
49
50use crate::AsyncConnection;
51// TODO: refactor this to share more code with diesel
52
53/// Manages the internal transaction state for a connection.
54///
55/// You will not need to interact with this trait, unless you are writing an
56/// implementation of [`AsyncConnection`].
57pub trait TransactionManager<Conn: AsyncConnection>: Send {
58    /// Data stored as part of the connection implementation
59    /// to track the current transaction state of a connection
60    type TransactionStateData;
61
62    /// Begin a new transaction or savepoint
63    ///
64    /// If the transaction depth is greater than 0,
65    /// this should create a savepoint instead.
66    /// This function is expected to increment the transaction depth by 1.
67    fn begin_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
68
69    /// Rollback the inner-most transaction or savepoint
70    ///
71    /// If the transaction depth is greater than 1,
72    /// this should rollback to the most recent savepoint.
73    /// This function is expected to decrement the transaction depth by 1.
74    fn rollback_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
75
76    /// Commit the inner-most transaction or savepoint
77    ///
78    /// If the transaction depth is greater than 1,
79    /// this should release the most recent savepoint.
80    /// This function is expected to decrement the transaction depth by 1.
81    fn commit_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
82
83    /// Fetch the current transaction status as mutable
84    ///
85    /// Used to ensure that `begin_test_transaction` is not called when already
86    /// inside of a transaction, and that operations are not run in a `InError`
87    /// transaction manager.
88    #[doc(hidden)]
89    fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus;
90
91    /// Executes the given function inside of a database transaction
92    ///
93    /// Each implementation of this function needs to fulfill the documented
94    /// behaviour of [`AsyncConnection::transaction`]
95    fn transaction<'a, 'conn, F, R, E>(
96        conn: &'conn mut Conn,
97        callback: F,
98    ) -> impl Future<Output = Result<R, E>> + Send + 'conn
99    where
100        for<'r> F: AsyncFnOnce(&'r mut Conn) -> Result<R, E>
101            + AsyncFunc<&'r mut Conn, Result<R, E>, Fut: Send>
102            + Send
103            + 'a,
104        E: From<Error> + Send,
105        R: Send,
106        'a: 'conn,
107    {
108        async move {
109            let callback = callback;
110
111            Self::begin_transaction(conn).await?;
112            match callback(&mut *conn).await {
113                Ok(value) => {
114                    Self::commit_transaction(conn).await?;
115                    Ok(value)
116                }
117                Err(user_error) => match Self::rollback_transaction(conn).await {
118                    Ok(()) => Err(user_error),
119                    Err(Error::BrokenTransactionManager) => {
120                        // In this case we are probably more interested by the
121                        // original error, which likely caused this
122                        Err(user_error)
123                    }
124                    Err(rollback_error) => Err(rollback_error.into()),
125                },
126            }
127        }
128    }
129
130    /// This methods checks if the connection manager is considered to be broken
131    /// by connection pool implementations
132    ///
133    /// A connection manager is considered to be broken by default if it either
134    /// contains an open transaction (because you don't want to have connections
135    /// with open transactions in your pool) or when the transaction manager is
136    /// in an error state.
137    #[doc(hidden)]
138    fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
139        check_broken_transaction_state(conn)
140    }
141}
142
143fn check_broken_transaction_state<Conn>(conn: &mut Conn) -> bool
144where
145    Conn: AsyncConnection,
146{
147    match Conn::TransactionManager::transaction_manager_status_mut(conn).transaction_state() {
148        // all transactions are closed
149        // so we don't consider this connection broken
150        Ok(ValidTransactionManagerStatus {
151            in_transaction: None,
152            ..
153        }) => false,
154        // The transaction manager is in an error state
155        // Therefore we consider this connection broken
156        Err(_) => true,
157        // The transaction manager contains a open transaction
158        // we do consider this connection broken
159        // if that transaction was not opened by `begin_test_transaction`
160        Ok(ValidTransactionManagerStatus {
161            in_transaction: Some(s),
162            ..
163        }) => !s.test_transaction,
164    }
165}
166
167/// An implementation of `TransactionManager` which can be used for backends
168/// which use ANSI standard syntax for savepoints such as SQLite and PostgreSQL.
169#[derive(Default, Debug)]
170pub struct AnsiTransactionManager {
171    pub(crate) status: TransactionManagerStatus,
172    // this boolean flag tracks whether we are currently in the process
173    // of executing any transaction releated SQL (BEGIN, COMMIT, ROLLBACK)
174    // if we ever encounter a situation where this flag is set
175    // while the connection is returned to a pool
176    // that means the connection is broken as someone dropped the
177    // transaction future while these commands where executed
178    // and we cannot know the connection state anymore
179    //
180    // We ensure this by wrapping all calls to `.await`
181    // into `AnsiTransactionManager::critical_transaction_block`
182    // below
183    //
184    // See https://github.com/weiznich/diesel_async/issues/198 for
185    // details
186    pub(crate) is_broken: Arc<AtomicBool>,
187    // this boolean flag tracks whether we are currently in this process
188    // of trying to commit the transaction. this is useful because if we
189    // are and we get a serialization failure, we might not want to attempt
190    // a rollback up the chain.
191    pub(crate) is_commit: bool,
192}
193
194impl AnsiTransactionManager {
195    fn get_transaction_state<Conn>(
196        conn: &mut Conn,
197    ) -> QueryResult<&mut ValidTransactionManagerStatus>
198    where
199        Conn: AsyncConnection<TransactionManager = Self>,
200    {
201        conn.transaction_state().status.transaction_state()
202    }
203
204    /// Begin a transaction with custom SQL
205    ///
206    /// This is used by connections to implement more complex transaction APIs
207    /// to set things such as isolation levels.
208    /// Returns an error if already inside of a transaction.
209    pub async fn begin_transaction_sql<Conn>(conn: &mut Conn, sql: &str) -> QueryResult<()>
210    where
211        Conn: AsyncConnection<TransactionManager = Self>,
212    {
213        let is_broken = conn.transaction_state().is_broken.clone();
214        let state = Self::get_transaction_state(conn)?;
215        if let Some(_depth) = state.transaction_depth() {
216            return Err(Error::AlreadyInTransaction);
217        }
218        let instrumentation_depth = NonZeroU32::new(1);
219
220        conn.instrumentation()
221            .on_connection_event(InstrumentationEvent::begin_transaction(
222                instrumentation_depth.expect("We know that 1 is not zero"),
223            ));
224
225        // Keep remainder of this method in sync with `begin_transaction()`.
226        Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?;
227        Self::get_transaction_state(conn)?
228            .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
229        Ok(())
230    }
231
232    // This function should be used to await any connection
233    // related future in our transaction manager implementation
234    //
235    // It takes care of tracking entering and exiting executing the future
236    // which in turn is used to determine if it's safe to still use
237    // the connection in the event of a canceled transaction execution
238    async fn critical_transaction_block<F>(is_broken: &AtomicBool, f: F) -> F::Output
239    where
240        F: std::future::Future,
241    {
242        let was_broken = is_broken.swap(true, Ordering::Relaxed);
243        debug_assert!(
244            !was_broken,
245            "Tried to execute a transaction SQL on transaction manager that was previously cancled"
246        );
247        let res = f.await;
248        is_broken.store(false, Ordering::Relaxed);
249        res
250    }
251}
252
253impl<Conn> TransactionManager<Conn> for AnsiTransactionManager
254where
255    Conn: AsyncConnection<TransactionManager = Self>,
256{
257    type TransactionStateData = Self;
258
259    async fn begin_transaction(conn: &mut Conn) -> QueryResult<()> {
260        let transaction_state = Self::get_transaction_state(conn)?;
261        let start_transaction_sql = match transaction_state.transaction_depth() {
262            None => Cow::from("BEGIN"),
263            Some(transaction_depth) => {
264                Cow::from(format!("SAVEPOINT diesel_savepoint_{transaction_depth}"))
265            }
266        };
267        let depth = transaction_state
268            .transaction_depth()
269            .and_then(|d| d.checked_add(1))
270            .unwrap_or(NonZeroU32::new(1).expect("It's not 0"));
271        conn.instrumentation()
272            .on_connection_event(InstrumentationEvent::begin_transaction(depth));
273        Self::critical_transaction_block(
274            &conn.transaction_state().is_broken.clone(),
275            conn.batch_execute(&start_transaction_sql),
276        )
277        .await?;
278        Self::get_transaction_state(conn)?
279            .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
280
281        Ok(())
282    }
283
284    async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()> {
285        let transaction_state = Self::get_transaction_state(conn)?;
286
287        let (
288            (rollback_sql, rolling_back_top_level),
289            requires_rollback_maybe_up_to_top_level_before_execute,
290        ) = match transaction_state.in_transaction {
291            Some(ref in_transaction) => (
292                match in_transaction.transaction_depth.get() {
293                    1 => (Cow::Borrowed("ROLLBACK"), true),
294                    depth_gt1 => (
295                        Cow::Owned(format!(
296                            "ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
297                            depth_gt1 - 1
298                        )),
299                        false,
300                    ),
301                },
302                in_transaction.requires_rollback_maybe_up_to_top_level,
303            ),
304            None => return Err(Error::NotInTransaction),
305        };
306
307        let depth = transaction_state
308            .transaction_depth()
309            .expect("We know that we are in a transaction here");
310        conn.instrumentation()
311            .on_connection_event(InstrumentationEvent::rollback_transaction(depth));
312
313        let is_broken = conn.transaction_state().is_broken.clone();
314
315        match Self::critical_transaction_block(&is_broken, conn.batch_execute(&rollback_sql)).await
316        {
317            Ok(()) => {
318                match Self::get_transaction_state(conn)?
319                    .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
320                {
321                    Ok(()) => {}
322                    Err(Error::NotInTransaction) if rolling_back_top_level => {
323                        // Transaction exit may have already been detected by connection
324                        // implementation. It's fine.
325                    }
326                    Err(e) => return Err(e),
327                }
328                Ok(())
329            }
330            Err(rollback_error) => {
331                let tm_status = Self::transaction_manager_status_mut(conn);
332                match tm_status {
333                    TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
334                        in_transaction:
335                            Some(InTransactionStatus {
336                                transaction_depth,
337                                requires_rollback_maybe_up_to_top_level,
338                                ..
339                            }),
340                        ..
341                    }) if transaction_depth.get() > 1 => {
342                        // A savepoint failed to rollback - we may still attempt to repair
343                        // the connection by rolling back higher levels.
344
345                        // To make it easier on the user (that they don't have to really
346                        // look at actual transaction depth and can just rely on the number
347                        // of times they have called begin/commit/rollback) we still
348                        // decrement here:
349                        *transaction_depth = NonZeroU32::new(transaction_depth.get() - 1)
350                            .expect("Depth was checked to be > 1");
351                        *requires_rollback_maybe_up_to_top_level = true;
352                        if requires_rollback_maybe_up_to_top_level_before_execute {
353                            // In that case, we tolerate that savepoint releases fail
354                            // -> we should ignore errors
355                            return Ok(());
356                        }
357                    }
358                    TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
359                        in_transaction: None,
360                        ..
361                    }) => {
362                        // we would have returned `NotInTransaction` if that was already the state
363                        // before we made our call
364                        // => Transaction manager status has been fixed by the underlying connection
365                        // so we don't need to set_in_error
366                    }
367                    _ => tm_status.set_in_error(),
368                }
369                Err(rollback_error)
370            }
371        }
372    }
373
374    /// If the transaction fails to commit due to a `SerializationFailure` or a
375    /// `ReadOnlyTransaction` a rollback will be attempted. If the rollback succeeds,
376    /// the original error will be returned, otherwise the error generated by the rollback
377    /// will be returned. In the second case the connection will be considered broken
378    /// as it contains a uncommitted unabortable open transaction.
379    async fn commit_transaction(conn: &mut Conn) -> QueryResult<()> {
380        let transaction_state = Self::get_transaction_state(conn)?;
381        let transaction_depth = transaction_state.transaction_depth();
382        let (commit_sql, committing_top_level) = match transaction_depth {
383            None => return Err(Error::NotInTransaction),
384            Some(transaction_depth) if transaction_depth.get() == 1 => {
385                (Cow::Borrowed("COMMIT"), true)
386            }
387            Some(transaction_depth) => (
388                Cow::Owned(format!(
389                    "RELEASE SAVEPOINT diesel_savepoint_{}",
390                    transaction_depth.get() - 1
391                )),
392                false,
393            ),
394        };
395        let depth = transaction_state
396            .transaction_depth()
397            .expect("We know that we are in a transaction here");
398        conn.instrumentation()
399            .on_connection_event(InstrumentationEvent::commit_transaction(depth));
400
401        let is_broken = {
402            let transaction_state = conn.transaction_state();
403            transaction_state.is_commit = true;
404            transaction_state.is_broken.clone()
405        };
406
407        let res =
408            Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await;
409
410        conn.transaction_state().is_commit = false;
411
412        match res {
413            Ok(()) => {
414                match Self::get_transaction_state(conn)?
415                    .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
416                {
417                    Ok(()) => {}
418                    Err(Error::NotInTransaction) if committing_top_level => {
419                        // Transaction exit may have already been detected by connection.
420                        // It's fine
421                    }
422                    Err(e) => return Err(e),
423                }
424                Ok(())
425            }
426            Err(commit_error) => {
427                if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
428                    in_transaction:
429                        Some(InTransactionStatus {
430                            requires_rollback_maybe_up_to_top_level: true,
431                            ..
432                        }),
433                    ..
434                }) = conn.transaction_state().status
435                {
436                    // rollback_transaction handles the critical block internally on its own
437                    match Self::rollback_transaction(conn).await {
438                        Ok(()) => {}
439                        Err(rollback_error) => {
440                            conn.transaction_state().status.set_in_error();
441                            return Err(Error::RollbackErrorOnCommit {
442                                rollback_error: Box::new(rollback_error),
443                                commit_error: Box::new(commit_error),
444                            });
445                        }
446                    }
447                } else {
448                    Self::get_transaction_state(conn)?
449                        .change_transaction_depth(TransactionDepthChange::DecreaseDepth)?;
450                }
451                Err(commit_error)
452            }
453        }
454    }
455
456    fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus {
457        &mut conn.transaction_state().status
458    }
459
460    fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
461        conn.transaction_state().is_broken.load(Ordering::Relaxed)
462            || check_broken_transaction_state(conn)
463    }
464}