diesel_async/transaction_manager.rs
1use diesel::connection::InstrumentationEvent;
2use diesel::connection::TransactionManagerStatus;
3use diesel::connection::{
4 InTransactionStatus, TransactionDepthChange, ValidTransactionManagerStatus,
5};
6use diesel::result::Error;
7use diesel::QueryResult;
8use std::borrow::Cow;
9use std::future::Future;
10use std::num::NonZeroU32;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13
14/// A helper trait to allow us asserting additional bounds on `AsyncFnOnce`
15/// Especially this lets us assert bounds on the future returned by the closure
16/// while still maintaining working type inference for the closure
17///
18/// This mostly exists for the following reasons:
19///
20/// * `AsyncFnOnce::CallOnceFuture` is not stable, so you cannot assert only with `AsyncFnOnce` that
21/// the returned future is `Send`
22/// * Only using `FnOnce(T) -> impl Future` doesn't work as you cannot use `impl Future` in that position
23/// * Using `F: FnOnce(T) -> Fut, Fut: Future<…>` doesn't work as you run into diverging lifetimes,
24/// as you essentially need to reuse the higher ranked lifetime from the closure to also restrict
25/// the lifetime of the returned future
26/// * Using a `trait Func<T>: FnOnce(T) -> <Self as Func<T>::Output> { type Output }` allows us
27/// to restrict the lifetime and bounds (like `Future` and `Send`) on the output type of the closure
28/// but then fails in type inference due to rustc bugs. You get unhelpful errors like
29/// `implementation of `FnOnce` is not general enough`, this can be side stepped by putting
30/// types on the calling side of the closure, which is not that nice API wise
31///
32/// This workaround is still not optimal as it still requires us to have this trait and show
33/// it as public bound. We somewhat try to avoid confusing users there by having an additional
34/// `AsyncFnOnce(T) -> R` bound at the calling side that hopefully will show up in rustdoc
35/// as well and hopefully guides the user to do the "right thing"
36pub trait AsyncFunc<T, R>:
37 AsyncFnOnce(T) -> R + FnOnce(T) -> <Self as AsyncFunc<T, R>>::Fut
38{
39 type Fut: Future<Output = R>;
40}
41
42impl<F, T, Fut, R> AsyncFunc<T, R> for F
43where
44 F: AsyncFnOnce(T) -> R + FnOnce(T) -> Fut,
45 Fut: Future<Output = R>,
46{
47 type Fut = Fut;
48}
49
50use crate::AsyncConnection;
51// TODO: refactor this to share more code with diesel
52
53/// Manages the internal transaction state for a connection.
54///
55/// You will not need to interact with this trait, unless you are writing an
56/// implementation of [`AsyncConnection`].
57pub trait TransactionManager<Conn: AsyncConnection>: Send {
58 /// Data stored as part of the connection implementation
59 /// to track the current transaction state of a connection
60 type TransactionStateData;
61
62 /// Begin a new transaction or savepoint
63 ///
64 /// If the transaction depth is greater than 0,
65 /// this should create a savepoint instead.
66 /// This function is expected to increment the transaction depth by 1.
67 fn begin_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
68
69 /// Rollback the inner-most transaction or savepoint
70 ///
71 /// If the transaction depth is greater than 1,
72 /// this should rollback to the most recent savepoint.
73 /// This function is expected to decrement the transaction depth by 1.
74 fn rollback_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
75
76 /// Commit the inner-most transaction or savepoint
77 ///
78 /// If the transaction depth is greater than 1,
79 /// this should release the most recent savepoint.
80 /// This function is expected to decrement the transaction depth by 1.
81 fn commit_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
82
83 /// Fetch the current transaction status as mutable
84 ///
85 /// Used to ensure that `begin_test_transaction` is not called when already
86 /// inside of a transaction, and that operations are not run in a `InError`
87 /// transaction manager.
88 #[doc(hidden)]
89 fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus;
90
91 /// Executes the given function inside of a database transaction
92 ///
93 /// Each implementation of this function needs to fulfill the documented
94 /// behaviour of [`AsyncConnection::transaction`]
95 fn transaction<'a, 'conn, F, R, E>(
96 conn: &'conn mut Conn,
97 callback: F,
98 ) -> impl Future<Output = Result<R, E>> + Send + 'conn
99 where
100 for<'r> F: AsyncFnOnce(&'r mut Conn) -> Result<R, E>
101 + AsyncFunc<&'r mut Conn, Result<R, E>, Fut: Send>
102 + Send
103 + 'a,
104 E: From<Error> + Send,
105 R: Send,
106 'a: 'conn,
107 {
108 async move {
109 let callback = callback;
110
111 Self::begin_transaction(conn).await?;
112 match callback(&mut *conn).await {
113 Ok(value) => {
114 Self::commit_transaction(conn).await?;
115 Ok(value)
116 }
117 Err(user_error) => match Self::rollback_transaction(conn).await {
118 Ok(()) => Err(user_error),
119 Err(Error::BrokenTransactionManager) => {
120 // In this case we are probably more interested by the
121 // original error, which likely caused this
122 Err(user_error)
123 }
124 Err(rollback_error) => Err(rollback_error.into()),
125 },
126 }
127 }
128 }
129
130 /// This methods checks if the connection manager is considered to be broken
131 /// by connection pool implementations
132 ///
133 /// A connection manager is considered to be broken by default if it either
134 /// contains an open transaction (because you don't want to have connections
135 /// with open transactions in your pool) or when the transaction manager is
136 /// in an error state.
137 #[doc(hidden)]
138 fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
139 check_broken_transaction_state(conn)
140 }
141}
142
143fn check_broken_transaction_state<Conn>(conn: &mut Conn) -> bool
144where
145 Conn: AsyncConnection,
146{
147 match Conn::TransactionManager::transaction_manager_status_mut(conn).transaction_state() {
148 // all transactions are closed
149 // so we don't consider this connection broken
150 Ok(ValidTransactionManagerStatus {
151 in_transaction: None,
152 ..
153 }) => false,
154 // The transaction manager is in an error state
155 // Therefore we consider this connection broken
156 Err(_) => true,
157 // The transaction manager contains a open transaction
158 // we do consider this connection broken
159 // if that transaction was not opened by `begin_test_transaction`
160 Ok(ValidTransactionManagerStatus {
161 in_transaction: Some(s),
162 ..
163 }) => !s.test_transaction,
164 }
165}
166
167/// An implementation of `TransactionManager` which can be used for backends
168/// which use ANSI standard syntax for savepoints such as SQLite and PostgreSQL.
169#[derive(Default, Debug)]
170pub struct AnsiTransactionManager {
171 pub(crate) status: TransactionManagerStatus,
172 // this boolean flag tracks whether we are currently in the process
173 // of executing any transaction releated SQL (BEGIN, COMMIT, ROLLBACK)
174 // if we ever encounter a situation where this flag is set
175 // while the connection is returned to a pool
176 // that means the connection is broken as someone dropped the
177 // transaction future while these commands where executed
178 // and we cannot know the connection state anymore
179 //
180 // We ensure this by wrapping all calls to `.await`
181 // into `AnsiTransactionManager::critical_transaction_block`
182 // below
183 //
184 // See https://github.com/weiznich/diesel_async/issues/198 for
185 // details
186 pub(crate) is_broken: Arc<AtomicBool>,
187 // this boolean flag tracks whether we are currently in this process
188 // of trying to commit the transaction. this is useful because if we
189 // are and we get a serialization failure, we might not want to attempt
190 // a rollback up the chain.
191 pub(crate) is_commit: bool,
192}
193
194impl AnsiTransactionManager {
195 fn get_transaction_state<Conn>(
196 conn: &mut Conn,
197 ) -> QueryResult<&mut ValidTransactionManagerStatus>
198 where
199 Conn: AsyncConnection<TransactionManager = Self>,
200 {
201 conn.transaction_state().status.transaction_state()
202 }
203
204 /// Begin a transaction with custom SQL
205 ///
206 /// This is used by connections to implement more complex transaction APIs
207 /// to set things such as isolation levels.
208 /// Returns an error if already inside of a transaction.
209 pub async fn begin_transaction_sql<Conn>(conn: &mut Conn, sql: &str) -> QueryResult<()>
210 where
211 Conn: AsyncConnection<TransactionManager = Self>,
212 {
213 let is_broken = conn.transaction_state().is_broken.clone();
214 let state = Self::get_transaction_state(conn)?;
215 if let Some(_depth) = state.transaction_depth() {
216 return Err(Error::AlreadyInTransaction);
217 }
218 let instrumentation_depth = NonZeroU32::new(1);
219
220 conn.instrumentation()
221 .on_connection_event(InstrumentationEvent::begin_transaction(
222 instrumentation_depth.expect("We know that 1 is not zero"),
223 ));
224
225 // Keep remainder of this method in sync with `begin_transaction()`.
226 Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?;
227 Self::get_transaction_state(conn)?
228 .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
229 Ok(())
230 }
231
232 // This function should be used to await any connection
233 // related future in our transaction manager implementation
234 //
235 // It takes care of tracking entering and exiting executing the future
236 // which in turn is used to determine if it's safe to still use
237 // the connection in the event of a canceled transaction execution
238 async fn critical_transaction_block<F>(is_broken: &AtomicBool, f: F) -> F::Output
239 where
240 F: std::future::Future,
241 {
242 let was_broken = is_broken.swap(true, Ordering::Relaxed);
243 debug_assert!(
244 !was_broken,
245 "Tried to execute a transaction SQL on transaction manager that was previously cancled"
246 );
247 let res = f.await;
248 is_broken.store(false, Ordering::Relaxed);
249 res
250 }
251}
252
253impl<Conn> TransactionManager<Conn> for AnsiTransactionManager
254where
255 Conn: AsyncConnection<TransactionManager = Self>,
256{
257 type TransactionStateData = Self;
258
259 async fn begin_transaction(conn: &mut Conn) -> QueryResult<()> {
260 let transaction_state = Self::get_transaction_state(conn)?;
261 let start_transaction_sql = match transaction_state.transaction_depth() {
262 None => Cow::from("BEGIN"),
263 Some(transaction_depth) => {
264 Cow::from(format!("SAVEPOINT diesel_savepoint_{transaction_depth}"))
265 }
266 };
267 let depth = transaction_state
268 .transaction_depth()
269 .and_then(|d| d.checked_add(1))
270 .unwrap_or(NonZeroU32::new(1).expect("It's not 0"));
271 conn.instrumentation()
272 .on_connection_event(InstrumentationEvent::begin_transaction(depth));
273 Self::critical_transaction_block(
274 &conn.transaction_state().is_broken.clone(),
275 conn.batch_execute(&start_transaction_sql),
276 )
277 .await?;
278 Self::get_transaction_state(conn)?
279 .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
280
281 Ok(())
282 }
283
284 async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()> {
285 let transaction_state = Self::get_transaction_state(conn)?;
286
287 let (
288 (rollback_sql, rolling_back_top_level),
289 requires_rollback_maybe_up_to_top_level_before_execute,
290 ) = match transaction_state.in_transaction {
291 Some(ref in_transaction) => (
292 match in_transaction.transaction_depth.get() {
293 1 => (Cow::Borrowed("ROLLBACK"), true),
294 depth_gt1 => (
295 Cow::Owned(format!(
296 "ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
297 depth_gt1 - 1
298 )),
299 false,
300 ),
301 },
302 in_transaction.requires_rollback_maybe_up_to_top_level,
303 ),
304 None => return Err(Error::NotInTransaction),
305 };
306
307 let depth = transaction_state
308 .transaction_depth()
309 .expect("We know that we are in a transaction here");
310 conn.instrumentation()
311 .on_connection_event(InstrumentationEvent::rollback_transaction(depth));
312
313 let is_broken = conn.transaction_state().is_broken.clone();
314
315 match Self::critical_transaction_block(&is_broken, conn.batch_execute(&rollback_sql)).await
316 {
317 Ok(()) => {
318 match Self::get_transaction_state(conn)?
319 .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
320 {
321 Ok(()) => {}
322 Err(Error::NotInTransaction) if rolling_back_top_level => {
323 // Transaction exit may have already been detected by connection
324 // implementation. It's fine.
325 }
326 Err(e) => return Err(e),
327 }
328 Ok(())
329 }
330 Err(rollback_error) => {
331 let tm_status = Self::transaction_manager_status_mut(conn);
332 match tm_status {
333 TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
334 in_transaction:
335 Some(InTransactionStatus {
336 transaction_depth,
337 requires_rollback_maybe_up_to_top_level,
338 ..
339 }),
340 ..
341 }) if transaction_depth.get() > 1 => {
342 // A savepoint failed to rollback - we may still attempt to repair
343 // the connection by rolling back higher levels.
344
345 // To make it easier on the user (that they don't have to really
346 // look at actual transaction depth and can just rely on the number
347 // of times they have called begin/commit/rollback) we still
348 // decrement here:
349 *transaction_depth = NonZeroU32::new(transaction_depth.get() - 1)
350 .expect("Depth was checked to be > 1");
351 *requires_rollback_maybe_up_to_top_level = true;
352 if requires_rollback_maybe_up_to_top_level_before_execute {
353 // In that case, we tolerate that savepoint releases fail
354 // -> we should ignore errors
355 return Ok(());
356 }
357 }
358 TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
359 in_transaction: None,
360 ..
361 }) => {
362 // we would have returned `NotInTransaction` if that was already the state
363 // before we made our call
364 // => Transaction manager status has been fixed by the underlying connection
365 // so we don't need to set_in_error
366 }
367 _ => tm_status.set_in_error(),
368 }
369 Err(rollback_error)
370 }
371 }
372 }
373
374 /// If the transaction fails to commit due to a `SerializationFailure` or a
375 /// `ReadOnlyTransaction` a rollback will be attempted. If the rollback succeeds,
376 /// the original error will be returned, otherwise the error generated by the rollback
377 /// will be returned. In the second case the connection will be considered broken
378 /// as it contains a uncommitted unabortable open transaction.
379 async fn commit_transaction(conn: &mut Conn) -> QueryResult<()> {
380 let transaction_state = Self::get_transaction_state(conn)?;
381 let transaction_depth = transaction_state.transaction_depth();
382 let (commit_sql, committing_top_level) = match transaction_depth {
383 None => return Err(Error::NotInTransaction),
384 Some(transaction_depth) if transaction_depth.get() == 1 => {
385 (Cow::Borrowed("COMMIT"), true)
386 }
387 Some(transaction_depth) => (
388 Cow::Owned(format!(
389 "RELEASE SAVEPOINT diesel_savepoint_{}",
390 transaction_depth.get() - 1
391 )),
392 false,
393 ),
394 };
395 let depth = transaction_state
396 .transaction_depth()
397 .expect("We know that we are in a transaction here");
398 conn.instrumentation()
399 .on_connection_event(InstrumentationEvent::commit_transaction(depth));
400
401 let is_broken = {
402 let transaction_state = conn.transaction_state();
403 transaction_state.is_commit = true;
404 transaction_state.is_broken.clone()
405 };
406
407 let res =
408 Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await;
409
410 conn.transaction_state().is_commit = false;
411
412 match res {
413 Ok(()) => {
414 match Self::get_transaction_state(conn)?
415 .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
416 {
417 Ok(()) => {}
418 Err(Error::NotInTransaction) if committing_top_level => {
419 // Transaction exit may have already been detected by connection.
420 // It's fine
421 }
422 Err(e) => return Err(e),
423 }
424 Ok(())
425 }
426 Err(commit_error) => {
427 if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
428 in_transaction:
429 Some(InTransactionStatus {
430 requires_rollback_maybe_up_to_top_level: true,
431 ..
432 }),
433 ..
434 }) = conn.transaction_state().status
435 {
436 // rollback_transaction handles the critical block internally on its own
437 match Self::rollback_transaction(conn).await {
438 Ok(()) => {}
439 Err(rollback_error) => {
440 conn.transaction_state().status.set_in_error();
441 return Err(Error::RollbackErrorOnCommit {
442 rollback_error: Box::new(rollback_error),
443 commit_error: Box::new(commit_error),
444 });
445 }
446 }
447 } else {
448 Self::get_transaction_state(conn)?
449 .change_transaction_depth(TransactionDepthChange::DecreaseDepth)?;
450 }
451 Err(commit_error)
452 }
453 }
454 }
455
456 fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus {
457 &mut conn.transaction_state().status
458 }
459
460 fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
461 conn.transaction_state().is_broken.load(Ordering::Relaxed)
462 || check_broken_transaction_state(conn)
463 }
464}