1#[cfg(feature = "postgres")]
39use std::future::Future;
40use std::{error::Error as StdError, fmt};
41
42pub struct Transaction<'p, DB> {
64 pool: &'p DB
65}
66
67impl<'p, DB> Transaction<'p, DB> {
68 pub const fn new(pool: &'p DB) -> Self {
80 Self {
81 pool
82 }
83 }
84
85 #[must_use]
87 pub const fn pool(&self) -> &'p DB {
88 self.pool
89 }
90}
91
92#[cfg(feature = "postgres")]
114pub struct TransactionContext {
115 tx: sqlx::Transaction<'static, sqlx::Postgres>
116}
117
118#[cfg(feature = "postgres")]
119impl TransactionContext {
120 #[doc(hidden)]
126 #[must_use]
127 pub const fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
128 Self {
129 tx
130 }
131 }
132
133 pub const fn transaction(&mut self) -> &mut sqlx::Transaction<'static, sqlx::Postgres> {
138 &mut self.tx
139 }
140
141 pub async fn commit(self) -> Result<(), sqlx::Error> {
149 self.tx.commit().await
150 }
151
152 pub async fn rollback(self) -> Result<(), sqlx::Error> {
160 self.tx.rollback().await
161 }
162}
163
164#[derive(Debug)]
168pub enum TransactionError<E> {
169 Begin(E),
171
172 Commit(E),
174
175 Rollback(E),
177
178 Operation(E)
180}
181
182impl<E: fmt::Display> fmt::Display for TransactionError<E> {
183 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184 match self {
185 Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
186 Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
187 Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
188 Self::Operation(e) => write!(f, "transaction operation failed: {e}")
189 }
190 }
191}
192
193impl<E: StdError + 'static> StdError for TransactionError<E> {
194 fn source(&self) -> Option<&(dyn StdError + 'static)> {
195 match self {
196 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
197 }
198 }
199}
200
201impl<E> TransactionError<E> {
202 pub const fn is_begin(&self) -> bool {
204 matches!(self, Self::Begin(_))
205 }
206
207 pub const fn is_commit(&self) -> bool {
209 matches!(self, Self::Commit(_))
210 }
211
212 pub const fn is_rollback(&self) -> bool {
214 matches!(self, Self::Rollback(_))
215 }
216
217 pub const fn is_operation(&self) -> bool {
219 matches!(self, Self::Operation(_))
220 }
221
222 pub fn into_inner(self) -> E {
224 match self {
225 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
226 }
227 }
228}
229
230#[cfg(feature = "postgres")]
231impl From<TransactionError<Self>> for sqlx::Error {
232 fn from(err: TransactionError<Self>) -> Self {
233 err.into_inner()
234 }
235}
236
237#[cfg(any(feature = "postgres", test))]
252async fn finalize_with_commit<C, T, E, CommitErr, Cf, Fut>(
253 ctx: C,
254 result: Result<T, E>,
255 commit_fn: Cf
256) -> Result<T, E>
257where
258 Cf: FnOnce(C) -> Fut,
259 Fut: core::future::Future<Output = Result<(), CommitErr>>,
260 E: From<CommitErr>
261{
262 match result {
263 Ok(value) => {
264 commit_fn(ctx).await.map_err(E::from)?;
265 Ok(value)
266 }
267 Err(e) => Err(e)
268 }
269}
270
271#[cfg(feature = "postgres")]
273impl Transaction<'_, sqlx::PgPool> {
274 pub async fn run<F, T, E>(self, f: F) -> Result<T, E>
304 where
305 F: AsyncFnOnce(&mut TransactionContext) -> Result<T, E>,
306 E: From<sqlx::Error>
307 {
308 let tx = self.pool.begin().await.map_err(E::from)?;
309 let mut ctx = TransactionContext::new(tx);
310 let result = f(&mut ctx).await;
311 finalize_with_commit(ctx, result, |c| c.commit()).await
312 }
313
314 pub async fn run_with_commit<F, Fut, T, E>(self, f: F) -> Result<T, E>
381 where
382 F: FnOnce(TransactionContext) -> Fut + Send,
383 Fut: Future<Output = Result<T, E>> + Send,
384 E: From<sqlx::Error>
385 {
386 let tx = self.pool.begin().await.map_err(E::from)?;
387 let ctx = TransactionContext::new(tx);
388 f(ctx).await
389 }
390}
391
392#[cfg(test)]
393#[allow(clippy::uninlined_format_args)]
394mod tests {
395 use std::error::Error;
396
397 use super::*;
398
399 #[test]
400 fn transaction_error_display_begin() {
401 let err: TransactionError<std::io::Error> =
402 TransactionError::Begin(std::io::Error::other("test"));
403 assert!(err.to_string().contains("begin"));
404 assert!(err.to_string().contains("test"));
405 }
406
407 #[test]
408 fn transaction_error_display_commit() {
409 let err: TransactionError<std::io::Error> =
410 TransactionError::Commit(std::io::Error::other("test"));
411 assert!(err.to_string().contains("commit"));
412 }
413
414 #[test]
415 fn transaction_error_display_rollback() {
416 let err: TransactionError<std::io::Error> =
417 TransactionError::Rollback(std::io::Error::other("test"));
418 assert!(err.to_string().contains("rollback"));
419 }
420
421 #[test]
422 fn transaction_error_display_operation() {
423 let err: TransactionError<std::io::Error> =
424 TransactionError::Operation(std::io::Error::other("test"));
425 assert!(err.to_string().contains("operation"));
426 }
427
428 #[test]
429 fn transaction_error_is_methods() {
430 let begin: TransactionError<&str> = TransactionError::Begin("e");
431 let commit: TransactionError<&str> = TransactionError::Commit("e");
432 let rollback: TransactionError<&str> = TransactionError::Rollback("e");
433 let operation: TransactionError<&str> = TransactionError::Operation("e");
434
435 assert!(begin.is_begin());
436 assert!(!begin.is_commit());
437 assert!(!begin.is_rollback());
438 assert!(!begin.is_operation());
439
440 assert!(!commit.is_begin());
441 assert!(commit.is_commit());
442 assert!(!commit.is_rollback());
443 assert!(!commit.is_operation());
444
445 assert!(!rollback.is_begin());
446 assert!(!rollback.is_commit());
447 assert!(rollback.is_rollback());
448 assert!(!rollback.is_operation());
449
450 assert!(!operation.is_begin());
451 assert!(!operation.is_commit());
452 assert!(!operation.is_rollback());
453 assert!(operation.is_operation());
454 }
455
456 #[test]
457 fn transaction_error_into_inner() {
458 let err: TransactionError<&str> = TransactionError::Operation("test");
459 assert_eq!(err.into_inner(), "test");
460 }
461
462 #[test]
463 fn transaction_error_into_inner_begin() {
464 let err: TransactionError<&str> = TransactionError::Begin("begin_err");
465 assert_eq!(err.into_inner(), "begin_err");
466 }
467
468 #[test]
469 fn transaction_error_into_inner_commit() {
470 let err: TransactionError<&str> = TransactionError::Commit("commit_err");
471 assert_eq!(err.into_inner(), "commit_err");
472 }
473
474 #[test]
475 fn transaction_error_into_inner_rollback() {
476 let err: TransactionError<&str> = TransactionError::Rollback("rollback_err");
477 assert_eq!(err.into_inner(), "rollback_err");
478 }
479
480 #[test]
481 fn transaction_error_source_begin() {
482 let err: TransactionError<std::io::Error> =
483 TransactionError::Begin(std::io::Error::other("src"));
484 assert!(err.source().is_some());
485 }
486
487 #[test]
488 fn transaction_error_source_commit() {
489 let err: TransactionError<std::io::Error> =
490 TransactionError::Commit(std::io::Error::other("src"));
491 assert!(err.source().is_some());
492 }
493
494 #[test]
495 fn transaction_error_source_rollback() {
496 let err: TransactionError<std::io::Error> =
497 TransactionError::Rollback(std::io::Error::other("src"));
498 assert!(err.source().is_some());
499 }
500
501 #[test]
502 fn transaction_error_source_operation() {
503 let err: TransactionError<std::io::Error> =
504 TransactionError::Operation(std::io::Error::other("src"));
505 assert!(err.source().is_some());
506 }
507
508 #[test]
509 fn transaction_builder_new() {
510 struct MockPool;
511 let pool = MockPool;
512 let tx = Transaction::new(&pool);
513 let _ = tx.pool();
514 }
515
516 #[test]
517 fn transaction_builder_pool_accessor() {
518 struct MockPool {
519 id: u32
520 }
521 let pool = MockPool {
522 id: 42
523 };
524 let tx = Transaction::new(&pool);
525 assert_eq!(tx.pool().id, 42);
526 }
527
528 #[test]
529 fn transaction_error_debug() {
530 let err: TransactionError<&str> = TransactionError::Begin("test");
531 let debug_str = format!("{:?}", err);
532 assert!(debug_str.contains("Begin"));
533 assert!(debug_str.contains("test"));
534 }
535
536 #[test]
537 fn transaction_error_into_inner_all_variants() {
538 let begin: TransactionError<String> = TransactionError::Begin("begin".to_string());
539 let commit: TransactionError<String> = TransactionError::Commit("commit".to_string());
540 let rollback: TransactionError<String> =
541 TransactionError::Rollback("rollback".to_string());
542 let operation: TransactionError<String> = TransactionError::Operation("op".to_string());
543
544 assert_eq!(begin.into_inner(), "begin");
545 assert_eq!(commit.into_inner(), "commit");
546 assert_eq!(rollback.into_inner(), "rollback");
547 assert_eq!(operation.into_inner(), "op");
548 }
549
550 #[test]
551 fn transaction_error_source_all_variants() {
552 let begin: TransactionError<std::io::Error> =
553 TransactionError::Begin(std::io::Error::other("src"));
554 let commit: TransactionError<std::io::Error> =
555 TransactionError::Commit(std::io::Error::other("src"));
556 let rollback: TransactionError<std::io::Error> =
557 TransactionError::Rollback(std::io::Error::other("src"));
558 let operation: TransactionError<std::io::Error> =
559 TransactionError::Operation(std::io::Error::other("src"));
560
561 assert!(begin.source().is_some());
562 assert!(commit.source().is_some());
563 assert!(rollback.source().is_some());
564 assert!(operation.source().is_some());
565 }
566
567 #[test]
568 fn transaction_error_display_all_variants() {
569 let begin: TransactionError<std::io::Error> =
570 TransactionError::Begin(std::io::Error::other("msg"));
571 let commit: TransactionError<std::io::Error> =
572 TransactionError::Commit(std::io::Error::other("msg"));
573 let rollback: TransactionError<std::io::Error> =
574 TransactionError::Rollback(std::io::Error::other("msg"));
575 let operation: TransactionError<std::io::Error> =
576 TransactionError::Operation(std::io::Error::other("msg"));
577
578 let begin_str = begin.to_string();
579 let commit_str = commit.to_string();
580 let rollback_str = rollback.to_string();
581 let operation_str = operation.to_string();
582
583 assert!(begin_str.contains("begin"));
584 assert!(commit_str.contains("commit"));
585 assert!(rollback_str.contains("rollback"));
586 assert!(operation_str.contains("operation"));
587 }
588
589 #[test]
590 fn transaction_error_is_all_variants() {
591 let begin: TransactionError<&str> = TransactionError::Begin("e");
592 let commit: TransactionError<&str> = TransactionError::Commit("e");
593 let rollback: TransactionError<&str> = TransactionError::Rollback("e");
594 let operation: TransactionError<&str> = TransactionError::Operation("e");
595
596 assert!(begin.is_begin());
597 assert!(commit.is_commit());
598 assert!(rollback.is_rollback());
599 assert!(operation.is_operation());
600
601 assert!(!begin.is_commit());
602 assert!(!begin.is_rollback());
603 assert!(!begin.is_operation());
604
605 assert!(!commit.is_begin());
606 assert!(!commit.is_rollback());
607 assert!(!commit.is_operation());
608
609 assert!(!rollback.is_begin());
610 assert!(!rollback.is_commit());
611 assert!(!rollback.is_operation());
612
613 assert!(!operation.is_begin());
614 assert!(!operation.is_commit());
615 assert!(!operation.is_rollback());
616 }
617
618 #[test]
619 fn transaction_builder_new_const() {
620 struct MockPool;
621 let pool = MockPool;
622 let tx = Transaction::new(&pool);
623 let _ = tx;
624 }
625
626 #[derive(Debug, PartialEq, Eq)]
636 struct MockCtx;
637
638 #[derive(Debug, PartialEq, Eq)]
639 struct CommitErr(&'static str);
640
641 #[derive(Debug, PartialEq, Eq)]
642 enum AppErr {
643 Closure(&'static str),
644 Commit(&'static str)
645 }
646
647 impl From<CommitErr> for AppErr {
648 fn from(e: CommitErr) -> Self {
649 Self::Commit(e.0)
650 }
651 }
652
653 #[tokio::test]
654 async fn finalize_commits_on_ok() {
655 let committed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
656 let flag = committed.clone();
657
658 let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
659 MockCtx,
660 Ok::<i32, AppErr>(42),
661 move |_ctx| {
662 let flag = flag.clone();
663 async move {
664 flag.store(true, std::sync::atomic::Ordering::SeqCst);
665 Ok::<(), CommitErr>(())
666 }
667 }
668 )
669 .await;
670
671 assert_eq!(result, Ok(42));
672 assert!(
673 committed.load(std::sync::atomic::Ordering::SeqCst),
674 "commit_fn must run on Ok"
675 );
676 }
677
678 #[tokio::test]
679 async fn finalize_skips_commit_on_err() {
680 let committed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
681 let flag = committed.clone();
682
683 let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
684 MockCtx,
685 Err::<i32, AppErr>(AppErr::Closure("nope")),
686 move |_ctx| {
687 let flag = flag.clone();
688 async move {
689 flag.store(true, std::sync::atomic::Ordering::SeqCst);
690 Ok::<(), CommitErr>(())
691 }
692 }
693 )
694 .await;
695
696 assert_eq!(result, Err(AppErr::Closure("nope")));
697 assert!(
698 !committed.load(std::sync::atomic::Ordering::SeqCst),
699 "commit_fn must NOT run on Err"
700 );
701 }
702
703 #[tokio::test]
704 async fn finalize_propagates_commit_error_on_ok() {
705 let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
706 MockCtx,
707 Ok::<i32, AppErr>(42),
708 |_ctx| async { Err::<(), CommitErr>(CommitErr("commit failed")) }
709 )
710 .await;
711
712 assert_eq!(result, Err(AppErr::Commit("commit failed")));
713 }
714
715 #[tokio::test]
716 async fn finalize_preserves_closure_value_on_ok() {
717 let result: Result<String, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
720 MockCtx,
721 Ok::<String, AppErr>("payload".to_string()),
722 |_ctx| async { Ok::<(), CommitErr>(()) }
723 )
724 .await;
725
726 assert_eq!(result, Ok("payload".to_string()));
727 }
728
729 #[tokio::test]
730 async fn finalize_does_not_swallow_closure_error_when_commit_also_would_fail() {
731 let result: Result<(), AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
734 MockCtx,
735 Err::<(), AppErr>(AppErr::Closure("original")),
736 |_ctx| async { Err::<(), CommitErr>(CommitErr("never reached")) }
737 )
738 .await;
739
740 assert_eq!(result, Err(AppErr::Closure("original")));
741 }
742}