1#[cfg(feature = "postgres")]
39use std::future::Future;
40use std::{error::Error as StdError, fmt};
41
42pub struct Transaction<'p, DB> {
64 pool: &'p DB
65}
66
67impl<'p, DB> Transaction<'p, DB> {
68 pub const fn new(pool: &'p DB) -> Self {
80 Self {
81 pool
82 }
83 }
84
85 #[must_use]
87 pub const fn pool(&self) -> &'p DB {
88 self.pool
89 }
90}
91
92#[cfg(feature = "postgres")]
114pub struct TransactionContext {
115 tx: sqlx::Transaction<'static, sqlx::Postgres>
116}
117
118#[cfg(feature = "postgres")]
119impl TransactionContext {
120 #[doc(hidden)]
126 #[must_use]
127 pub const fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
128 Self {
129 tx
130 }
131 }
132
133 pub const fn transaction(&mut self) -> &mut sqlx::Transaction<'static, sqlx::Postgres> {
138 &mut self.tx
139 }
140
141 pub async fn commit(self) -> Result<(), sqlx::Error> {
149 self.tx.commit().await
150 }
151
152 pub async fn rollback(self) -> Result<(), sqlx::Error> {
160 self.tx.rollback().await
161 }
162}
163
164#[derive(Debug)]
168pub enum TransactionError<E> {
169 Begin(E),
171
172 Commit(E),
174
175 Rollback(E),
177
178 Operation(E)
180}
181
182impl<E: fmt::Display> fmt::Display for TransactionError<E> {
183 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184 match self {
185 Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
186 Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
187 Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
188 Self::Operation(e) => write!(f, "transaction operation failed: {e}")
189 }
190 }
191}
192
193impl<E: StdError + 'static> StdError for TransactionError<E> {
194 fn source(&self) -> Option<&(dyn StdError + 'static)> {
195 match self {
196 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
197 }
198 }
199}
200
201impl<E> TransactionError<E> {
202 pub const fn is_begin(&self) -> bool {
204 matches!(self, Self::Begin(_))
205 }
206
207 pub const fn is_commit(&self) -> bool {
209 matches!(self, Self::Commit(_))
210 }
211
212 pub const fn is_rollback(&self) -> bool {
214 matches!(self, Self::Rollback(_))
215 }
216
217 pub const fn is_operation(&self) -> bool {
219 matches!(self, Self::Operation(_))
220 }
221
222 pub fn into_inner(self) -> E {
224 match self {
225 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
226 }
227 }
228}
229
230#[cfg(feature = "postgres")]
231impl From<TransactionError<Self>> for sqlx::Error {
232 fn from(err: TransactionError<Self>) -> Self {
233 err.into_inner()
234 }
235}
236
237#[cfg(any(feature = "postgres", test))]
252async fn finalize_with_commit<C, T, E, CommitErr, Cf, Fut>(
253 ctx: C,
254 result: Result<T, E>,
255 commit_fn: Cf
256) -> Result<T, E>
257where
258 Cf: FnOnce(C) -> Fut,
259 Fut: core::future::Future<Output = Result<(), CommitErr>>,
260 E: From<CommitErr>
261{
262 match result {
263 Ok(value) => {
264 commit_fn(ctx).await.map_err(E::from)?;
265 Ok(value)
266 }
267 Err(e) => Err(e)
268 }
269}
270
271#[cfg(feature = "postgres")]
273impl Transaction<'_, sqlx::PgPool> {
274 #[cfg_attr(
304 feature = "tracing",
305 ::tracing::instrument(skip_all, fields(op = "tx.run"), err(Debug))
306 )]
307 pub async fn run<F, T, E>(self, f: F) -> Result<T, E>
308 where
309 F: AsyncFnOnce(&mut TransactionContext) -> Result<T, E>,
310 E: From<sqlx::Error> + core::fmt::Debug
311 {
312 let tx = self.pool.begin().await.map_err(E::from)?;
313 let mut ctx = TransactionContext::new(tx);
314 let result = f(&mut ctx).await;
315 finalize_with_commit(ctx, result, |c| c.commit()).await
316 }
317
318 #[cfg_attr(
385 feature = "tracing",
386 ::tracing::instrument(skip_all, fields(op = "tx.run_with_commit"), err(Debug))
387 )]
388 pub async fn run_with_commit<F, Fut, T, E>(self, f: F) -> Result<T, E>
389 where
390 F: FnOnce(TransactionContext) -> Fut + Send,
391 Fut: Future<Output = Result<T, E>> + Send,
392 E: From<sqlx::Error> + core::fmt::Debug
393 {
394 let tx = self.pool.begin().await.map_err(E::from)?;
395 let ctx = TransactionContext::new(tx);
396 f(ctx).await
397 }
398}
399
400#[cfg(test)]
401#[allow(clippy::uninlined_format_args)]
402mod tests {
403 use std::error::Error;
404
405 use super::*;
406
407 #[test]
408 fn transaction_error_display_begin() {
409 let err: TransactionError<std::io::Error> =
410 TransactionError::Begin(std::io::Error::other("test"));
411 assert!(err.to_string().contains("begin"));
412 assert!(err.to_string().contains("test"));
413 }
414
415 #[test]
416 fn transaction_error_display_commit() {
417 let err: TransactionError<std::io::Error> =
418 TransactionError::Commit(std::io::Error::other("test"));
419 assert!(err.to_string().contains("commit"));
420 }
421
422 #[test]
423 fn transaction_error_display_rollback() {
424 let err: TransactionError<std::io::Error> =
425 TransactionError::Rollback(std::io::Error::other("test"));
426 assert!(err.to_string().contains("rollback"));
427 }
428
429 #[test]
430 fn transaction_error_display_operation() {
431 let err: TransactionError<std::io::Error> =
432 TransactionError::Operation(std::io::Error::other("test"));
433 assert!(err.to_string().contains("operation"));
434 }
435
436 #[test]
437 fn transaction_error_is_methods() {
438 let begin: TransactionError<&str> = TransactionError::Begin("e");
439 let commit: TransactionError<&str> = TransactionError::Commit("e");
440 let rollback: TransactionError<&str> = TransactionError::Rollback("e");
441 let operation: TransactionError<&str> = TransactionError::Operation("e");
442
443 assert!(begin.is_begin());
444 assert!(!begin.is_commit());
445 assert!(!begin.is_rollback());
446 assert!(!begin.is_operation());
447
448 assert!(!commit.is_begin());
449 assert!(commit.is_commit());
450 assert!(!commit.is_rollback());
451 assert!(!commit.is_operation());
452
453 assert!(!rollback.is_begin());
454 assert!(!rollback.is_commit());
455 assert!(rollback.is_rollback());
456 assert!(!rollback.is_operation());
457
458 assert!(!operation.is_begin());
459 assert!(!operation.is_commit());
460 assert!(!operation.is_rollback());
461 assert!(operation.is_operation());
462 }
463
464 #[test]
465 fn transaction_error_into_inner() {
466 let err: TransactionError<&str> = TransactionError::Operation("test");
467 assert_eq!(err.into_inner(), "test");
468 }
469
470 #[test]
471 fn transaction_error_into_inner_begin() {
472 let err: TransactionError<&str> = TransactionError::Begin("begin_err");
473 assert_eq!(err.into_inner(), "begin_err");
474 }
475
476 #[test]
477 fn transaction_error_into_inner_commit() {
478 let err: TransactionError<&str> = TransactionError::Commit("commit_err");
479 assert_eq!(err.into_inner(), "commit_err");
480 }
481
482 #[test]
483 fn transaction_error_into_inner_rollback() {
484 let err: TransactionError<&str> = TransactionError::Rollback("rollback_err");
485 assert_eq!(err.into_inner(), "rollback_err");
486 }
487
488 #[test]
489 fn transaction_error_source_begin() {
490 let err: TransactionError<std::io::Error> =
491 TransactionError::Begin(std::io::Error::other("src"));
492 assert!(err.source().is_some());
493 }
494
495 #[test]
496 fn transaction_error_source_commit() {
497 let err: TransactionError<std::io::Error> =
498 TransactionError::Commit(std::io::Error::other("src"));
499 assert!(err.source().is_some());
500 }
501
502 #[test]
503 fn transaction_error_source_rollback() {
504 let err: TransactionError<std::io::Error> =
505 TransactionError::Rollback(std::io::Error::other("src"));
506 assert!(err.source().is_some());
507 }
508
509 #[test]
510 fn transaction_error_source_operation() {
511 let err: TransactionError<std::io::Error> =
512 TransactionError::Operation(std::io::Error::other("src"));
513 assert!(err.source().is_some());
514 }
515
516 #[test]
517 fn transaction_builder_new() {
518 struct MockPool;
519 let pool = MockPool;
520 let tx = Transaction::new(&pool);
521 let _ = tx.pool();
522 }
523
524 #[test]
525 fn transaction_builder_pool_accessor() {
526 struct MockPool {
527 id: u32
528 }
529 let pool = MockPool {
530 id: 42
531 };
532 let tx = Transaction::new(&pool);
533 assert_eq!(tx.pool().id, 42);
534 }
535
536 #[test]
537 fn transaction_error_debug() {
538 let err: TransactionError<&str> = TransactionError::Begin("test");
539 let debug_str = format!("{:?}", err);
540 assert!(debug_str.contains("Begin"));
541 assert!(debug_str.contains("test"));
542 }
543
544 #[test]
545 fn transaction_error_into_inner_all_variants() {
546 let begin: TransactionError<String> = TransactionError::Begin("begin".to_string());
547 let commit: TransactionError<String> = TransactionError::Commit("commit".to_string());
548 let rollback: TransactionError<String> =
549 TransactionError::Rollback("rollback".to_string());
550 let operation: TransactionError<String> = TransactionError::Operation("op".to_string());
551
552 assert_eq!(begin.into_inner(), "begin");
553 assert_eq!(commit.into_inner(), "commit");
554 assert_eq!(rollback.into_inner(), "rollback");
555 assert_eq!(operation.into_inner(), "op");
556 }
557
558 #[test]
559 fn transaction_error_source_all_variants() {
560 let begin: TransactionError<std::io::Error> =
561 TransactionError::Begin(std::io::Error::other("src"));
562 let commit: TransactionError<std::io::Error> =
563 TransactionError::Commit(std::io::Error::other("src"));
564 let rollback: TransactionError<std::io::Error> =
565 TransactionError::Rollback(std::io::Error::other("src"));
566 let operation: TransactionError<std::io::Error> =
567 TransactionError::Operation(std::io::Error::other("src"));
568
569 assert!(begin.source().is_some());
570 assert!(commit.source().is_some());
571 assert!(rollback.source().is_some());
572 assert!(operation.source().is_some());
573 }
574
575 #[test]
576 fn transaction_error_display_all_variants() {
577 let begin: TransactionError<std::io::Error> =
578 TransactionError::Begin(std::io::Error::other("msg"));
579 let commit: TransactionError<std::io::Error> =
580 TransactionError::Commit(std::io::Error::other("msg"));
581 let rollback: TransactionError<std::io::Error> =
582 TransactionError::Rollback(std::io::Error::other("msg"));
583 let operation: TransactionError<std::io::Error> =
584 TransactionError::Operation(std::io::Error::other("msg"));
585
586 let begin_str = begin.to_string();
587 let commit_str = commit.to_string();
588 let rollback_str = rollback.to_string();
589 let operation_str = operation.to_string();
590
591 assert!(begin_str.contains("begin"));
592 assert!(commit_str.contains("commit"));
593 assert!(rollback_str.contains("rollback"));
594 assert!(operation_str.contains("operation"));
595 }
596
597 #[test]
598 fn transaction_error_is_all_variants() {
599 let begin: TransactionError<&str> = TransactionError::Begin("e");
600 let commit: TransactionError<&str> = TransactionError::Commit("e");
601 let rollback: TransactionError<&str> = TransactionError::Rollback("e");
602 let operation: TransactionError<&str> = TransactionError::Operation("e");
603
604 assert!(begin.is_begin());
605 assert!(commit.is_commit());
606 assert!(rollback.is_rollback());
607 assert!(operation.is_operation());
608
609 assert!(!begin.is_commit());
610 assert!(!begin.is_rollback());
611 assert!(!begin.is_operation());
612
613 assert!(!commit.is_begin());
614 assert!(!commit.is_rollback());
615 assert!(!commit.is_operation());
616
617 assert!(!rollback.is_begin());
618 assert!(!rollback.is_commit());
619 assert!(!rollback.is_operation());
620
621 assert!(!operation.is_begin());
622 assert!(!operation.is_commit());
623 assert!(!operation.is_rollback());
624 }
625
626 #[test]
627 fn transaction_builder_new_const() {
628 struct MockPool;
629 let pool = MockPool;
630 let tx = Transaction::new(&pool);
631 let _ = tx;
632 }
633
634 #[derive(Debug, PartialEq, Eq)]
644 struct MockCtx;
645
646 #[derive(Debug, PartialEq, Eq)]
647 struct CommitErr(&'static str);
648
649 #[derive(Debug, PartialEq, Eq)]
650 enum AppErr {
651 Closure(&'static str),
652 Commit(&'static str)
653 }
654
655 impl From<CommitErr> for AppErr {
656 fn from(e: CommitErr) -> Self {
657 Self::Commit(e.0)
658 }
659 }
660
661 #[tokio::test]
662 async fn finalize_commits_on_ok() {
663 let committed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
664 let flag = committed.clone();
665
666 let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
667 MockCtx,
668 Ok::<i32, AppErr>(42),
669 move |_ctx| {
670 let flag = flag.clone();
671 async move {
672 flag.store(true, std::sync::atomic::Ordering::SeqCst);
673 Ok::<(), CommitErr>(())
674 }
675 }
676 )
677 .await;
678
679 assert_eq!(result, Ok(42));
680 assert!(
681 committed.load(std::sync::atomic::Ordering::SeqCst),
682 "commit_fn must run on Ok"
683 );
684 }
685
686 #[tokio::test]
687 async fn finalize_skips_commit_on_err() {
688 let committed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
689 let flag = committed.clone();
690
691 let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
692 MockCtx,
693 Err::<i32, AppErr>(AppErr::Closure("nope")),
694 move |_ctx| {
695 let flag = flag.clone();
696 async move {
697 flag.store(true, std::sync::atomic::Ordering::SeqCst);
698 Ok::<(), CommitErr>(())
699 }
700 }
701 )
702 .await;
703
704 assert_eq!(result, Err(AppErr::Closure("nope")));
705 assert!(
706 !committed.load(std::sync::atomic::Ordering::SeqCst),
707 "commit_fn must NOT run on Err"
708 );
709 }
710
711 #[tokio::test]
712 async fn finalize_propagates_commit_error_on_ok() {
713 let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
714 MockCtx,
715 Ok::<i32, AppErr>(42),
716 |_ctx| async { Err::<(), CommitErr>(CommitErr("commit failed")) }
717 )
718 .await;
719
720 assert_eq!(result, Err(AppErr::Commit("commit failed")));
721 }
722
723 #[tokio::test]
724 async fn finalize_preserves_closure_value_on_ok() {
725 let result: Result<String, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
728 MockCtx,
729 Ok::<String, AppErr>("payload".to_string()),
730 |_ctx| async { Ok::<(), CommitErr>(()) }
731 )
732 .await;
733
734 assert_eq!(result, Ok("payload".to_string()));
735 }
736
737 #[tokio::test]
738 async fn finalize_does_not_swallow_closure_error_when_commit_also_would_fail() {
739 let result: Result<(), AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
742 MockCtx,
743 Err::<(), AppErr>(AppErr::Closure("original")),
744 |_ctx| async { Err::<(), CommitErr>(CommitErr("never reached")) }
745 )
746 .await;
747
748 assert_eq!(result, Err(AppErr::Closure("original")));
749 }
750}