diesel_async/
transaction_manager.rs1use diesel::connection::InstrumentationEvent;
2use diesel::connection::TransactionManagerStatus;
3use diesel::connection::{
4 InTransactionStatus, TransactionDepthChange, ValidTransactionManagerStatus,
5};
6use diesel::result::Error;
7use diesel::QueryResult;
8use scoped_futures::ScopedBoxFuture;
9use std::borrow::Cow;
10use std::future::Future;
11use std::num::NonZeroU32;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14
15use crate::AsyncConnection;
16pub trait TransactionManager<Conn: AsyncConnection>: Send {
23 type TransactionStateData;
26
27 fn begin_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
33
34 fn rollback_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
40
41 fn commit_transaction(conn: &mut Conn) -> impl Future<Output = QueryResult<()>> + Send;
47
48 #[doc(hidden)]
54 fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus;
55
56 fn transaction<'a, 'conn, F, R, E>(
61 conn: &'conn mut Conn,
62 callback: F,
63 ) -> impl Future<Output = Result<R, E>> + Send + 'conn
64 where
65 F: for<'r> FnOnce(&'r mut Conn) -> ScopedBoxFuture<'a, 'r, Result<R, E>> + Send + 'a,
66 E: From<Error> + Send,
67 R: Send,
68 'a: 'conn,
69 {
70 async move {
71 let callback = callback;
72
73 Self::begin_transaction(conn).await?;
74 match callback(&mut *conn).await {
75 Ok(value) => {
76 Self::commit_transaction(conn).await?;
77 Ok(value)
78 }
79 Err(user_error) => match Self::rollback_transaction(conn).await {
80 Ok(()) => Err(user_error),
81 Err(Error::BrokenTransactionManager) => {
82 Err(user_error)
85 }
86 Err(rollback_error) => Err(rollback_error.into()),
87 },
88 }
89 }
90 }
91
92 #[doc(hidden)]
100 fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
101 check_broken_transaction_state(conn)
102 }
103}
104
105fn check_broken_transaction_state<Conn>(conn: &mut Conn) -> bool
106where
107 Conn: AsyncConnection,
108{
109 match Conn::TransactionManager::transaction_manager_status_mut(conn).transaction_state() {
110 Ok(ValidTransactionManagerStatus {
113 in_transaction: None,
114 ..
115 }) => false,
116 Err(_) => true,
119 Ok(ValidTransactionManagerStatus {
123 in_transaction: Some(s),
124 ..
125 }) => !s.test_transaction,
126 }
127}
128
129#[derive(Default, Debug)]
132pub struct AnsiTransactionManager {
133 pub(crate) status: TransactionManagerStatus,
134 pub(crate) is_broken: Arc<AtomicBool>,
149 pub(crate) is_commit: bool,
154}
155
156impl AnsiTransactionManager {
157 fn get_transaction_state<Conn>(
158 conn: &mut Conn,
159 ) -> QueryResult<&mut ValidTransactionManagerStatus>
160 where
161 Conn: AsyncConnection<TransactionManager = Self>,
162 {
163 conn.transaction_state().status.transaction_state()
164 }
165
166 pub async fn begin_transaction_sql<Conn>(conn: &mut Conn, sql: &str) -> QueryResult<()>
172 where
173 Conn: AsyncConnection<TransactionManager = Self>,
174 {
175 let is_broken = conn.transaction_state().is_broken.clone();
176 let state = Self::get_transaction_state(conn)?;
177 if let Some(_depth) = state.transaction_depth() {
178 return Err(Error::AlreadyInTransaction);
179 }
180 let instrumentation_depth = NonZeroU32::new(1);
181
182 conn.instrumentation()
183 .on_connection_event(InstrumentationEvent::begin_transaction(
184 instrumentation_depth.expect("We know that 1 is not zero"),
185 ));
186
187 Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?;
189 Self::get_transaction_state(conn)?
190 .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
191 Ok(())
192 }
193
194 async fn critical_transaction_block<F>(is_broken: &AtomicBool, f: F) -> F::Output
201 where
202 F: std::future::Future,
203 {
204 let was_broken = is_broken.swap(true, Ordering::Relaxed);
205 debug_assert!(
206 !was_broken,
207 "Tried to execute a transaction SQL on transaction manager that was previously cancled"
208 );
209 let res = f.await;
210 is_broken.store(false, Ordering::Relaxed);
211 res
212 }
213}
214
215impl<Conn> TransactionManager<Conn> for AnsiTransactionManager
216where
217 Conn: AsyncConnection<TransactionManager = Self>,
218{
219 type TransactionStateData = Self;
220
221 async fn begin_transaction(conn: &mut Conn) -> QueryResult<()> {
222 let transaction_state = Self::get_transaction_state(conn)?;
223 let start_transaction_sql = match transaction_state.transaction_depth() {
224 None => Cow::from("BEGIN"),
225 Some(transaction_depth) => {
226 Cow::from(format!("SAVEPOINT diesel_savepoint_{transaction_depth}"))
227 }
228 };
229 let depth = transaction_state
230 .transaction_depth()
231 .and_then(|d| d.checked_add(1))
232 .unwrap_or(NonZeroU32::new(1).expect("It's not 0"));
233 conn.instrumentation()
234 .on_connection_event(InstrumentationEvent::begin_transaction(depth));
235 Self::critical_transaction_block(
236 &conn.transaction_state().is_broken.clone(),
237 conn.batch_execute(&start_transaction_sql),
238 )
239 .await?;
240 Self::get_transaction_state(conn)?
241 .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
242
243 Ok(())
244 }
245
246 async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()> {
247 let transaction_state = Self::get_transaction_state(conn)?;
248
249 let (
250 (rollback_sql, rolling_back_top_level),
251 requires_rollback_maybe_up_to_top_level_before_execute,
252 ) = match transaction_state.in_transaction {
253 Some(ref in_transaction) => (
254 match in_transaction.transaction_depth.get() {
255 1 => (Cow::Borrowed("ROLLBACK"), true),
256 depth_gt1 => (
257 Cow::Owned(format!(
258 "ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
259 depth_gt1 - 1
260 )),
261 false,
262 ),
263 },
264 in_transaction.requires_rollback_maybe_up_to_top_level,
265 ),
266 None => return Err(Error::NotInTransaction),
267 };
268
269 let depth = transaction_state
270 .transaction_depth()
271 .expect("We know that we are in a transaction here");
272 conn.instrumentation()
273 .on_connection_event(InstrumentationEvent::rollback_transaction(depth));
274
275 let is_broken = conn.transaction_state().is_broken.clone();
276
277 match Self::critical_transaction_block(&is_broken, conn.batch_execute(&rollback_sql)).await
278 {
279 Ok(()) => {
280 match Self::get_transaction_state(conn)?
281 .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
282 {
283 Ok(()) => {}
284 Err(Error::NotInTransaction) if rolling_back_top_level => {
285 }
288 Err(e) => return Err(e),
289 }
290 Ok(())
291 }
292 Err(rollback_error) => {
293 let tm_status = Self::transaction_manager_status_mut(conn);
294 match tm_status {
295 TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
296 in_transaction:
297 Some(InTransactionStatus {
298 transaction_depth,
299 requires_rollback_maybe_up_to_top_level,
300 ..
301 }),
302 ..
303 }) if transaction_depth.get() > 1 => {
304 *transaction_depth = NonZeroU32::new(transaction_depth.get() - 1)
312 .expect("Depth was checked to be > 1");
313 *requires_rollback_maybe_up_to_top_level = true;
314 if requires_rollback_maybe_up_to_top_level_before_execute {
315 return Ok(());
318 }
319 }
320 TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
321 in_transaction: None,
322 ..
323 }) => {
324 }
329 _ => tm_status.set_in_error(),
330 }
331 Err(rollback_error)
332 }
333 }
334 }
335
336 async fn commit_transaction(conn: &mut Conn) -> QueryResult<()> {
342 let transaction_state = Self::get_transaction_state(conn)?;
343 let transaction_depth = transaction_state.transaction_depth();
344 let (commit_sql, committing_top_level) = match transaction_depth {
345 None => return Err(Error::NotInTransaction),
346 Some(transaction_depth) if transaction_depth.get() == 1 => {
347 (Cow::Borrowed("COMMIT"), true)
348 }
349 Some(transaction_depth) => (
350 Cow::Owned(format!(
351 "RELEASE SAVEPOINT diesel_savepoint_{}",
352 transaction_depth.get() - 1
353 )),
354 false,
355 ),
356 };
357 let depth = transaction_state
358 .transaction_depth()
359 .expect("We know that we are in a transaction here");
360 conn.instrumentation()
361 .on_connection_event(InstrumentationEvent::commit_transaction(depth));
362
363 let is_broken = {
364 let transaction_state = conn.transaction_state();
365 transaction_state.is_commit = true;
366 transaction_state.is_broken.clone()
367 };
368
369 let res =
370 Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await;
371
372 conn.transaction_state().is_commit = false;
373
374 match res {
375 Ok(()) => {
376 match Self::get_transaction_state(conn)?
377 .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
378 {
379 Ok(()) => {}
380 Err(Error::NotInTransaction) if committing_top_level => {
381 }
384 Err(e) => return Err(e),
385 }
386 Ok(())
387 }
388 Err(commit_error) => {
389 if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
390 in_transaction:
391 Some(InTransactionStatus {
392 requires_rollback_maybe_up_to_top_level: true,
393 ..
394 }),
395 ..
396 }) = conn.transaction_state().status
397 {
398 match Self::rollback_transaction(conn).await {
400 Ok(()) => {}
401 Err(rollback_error) => {
402 conn.transaction_state().status.set_in_error();
403 return Err(Error::RollbackErrorOnCommit {
404 rollback_error: Box::new(rollback_error),
405 commit_error: Box::new(commit_error),
406 });
407 }
408 }
409 } else {
410 Self::get_transaction_state(conn)?
411 .change_transaction_depth(TransactionDepthChange::DecreaseDepth)?;
412 }
413 Err(commit_error)
414 }
415 }
416 }
417
418 fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus {
419 &mut conn.transaction_state().status
420 }
421
422 fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
423 conn.transaction_state().is_broken.load(Ordering::Relaxed)
424 || check_broken_transaction_state(conn)
425 }
426}