1#[cfg(feature = "postgres")]
41use std::future::Future;
42use std::{error::Error as StdError, fmt};
43
44pub struct Transaction<'p, DB> {
68 pool: &'p DB
69}
70
71impl<'p, DB> Transaction<'p, DB> {
72 pub const fn new(pool: &'p DB) -> Self {
84 Self {
85 pool
86 }
87 }
88
89 #[must_use]
91 pub const fn pool(&self) -> &'p DB {
92 self.pool
93 }
94}
95
96#[cfg(feature = "postgres")]
118pub struct TransactionContext {
119 tx: sqlx::Transaction<'static, sqlx::Postgres>
120}
121
122#[cfg(feature = "postgres")]
123impl TransactionContext {
124 #[doc(hidden)]
130 #[must_use]
131 pub const fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
132 Self {
133 tx
134 }
135 }
136
137 pub const fn transaction(&mut self) -> &mut sqlx::Transaction<'static, sqlx::Postgres> {
142 &mut self.tx
143 }
144
145 pub async fn commit(self) -> Result<(), sqlx::Error> {
153 self.tx.commit().await
154 }
155
156 pub async fn rollback(self) -> Result<(), sqlx::Error> {
164 self.tx.rollback().await
165 }
166}
167
168#[derive(Debug)]
172pub enum TransactionError<E> {
173 Begin(E),
175
176 Commit(E),
178
179 Rollback(E),
181
182 Operation(E)
184}
185
186impl<E: fmt::Display> fmt::Display for TransactionError<E> {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 match self {
189 Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
190 Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
191 Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
192 Self::Operation(e) => write!(f, "transaction operation failed: {e}")
193 }
194 }
195}
196
197impl<E: StdError + 'static> StdError for TransactionError<E> {
198 fn source(&self) -> Option<&(dyn StdError + 'static)> {
199 match self {
200 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
201 }
202 }
203}
204
205impl<E> TransactionError<E> {
206 pub const fn is_begin(&self) -> bool {
208 matches!(self, Self::Begin(_))
209 }
210
211 pub const fn is_commit(&self) -> bool {
213 matches!(self, Self::Commit(_))
214 }
215
216 pub const fn is_rollback(&self) -> bool {
218 matches!(self, Self::Rollback(_))
219 }
220
221 pub const fn is_operation(&self) -> bool {
223 matches!(self, Self::Operation(_))
224 }
225
226 pub fn into_inner(self) -> E {
228 match self {
229 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
230 }
231 }
232}
233
234#[cfg(feature = "postgres")]
235impl From<TransactionError<Self>> for sqlx::Error {
236 fn from(err: TransactionError<Self>) -> Self {
237 err.into_inner()
238 }
239}
240
241#[cfg(any(feature = "postgres", test))]
256async fn finalize_with_commit<C, T, E, CommitErr, Cf, Fut>(
257 ctx: C,
258 result: Result<T, E>,
259 commit_fn: Cf
260) -> Result<T, E>
261where
262 Cf: FnOnce(C) -> Fut,
263 Fut: core::future::Future<Output = Result<(), CommitErr>>,
264 E: From<CommitErr>
265{
266 match result {
267 Ok(value) => {
268 commit_fn(ctx).await.map_err(E::from)?;
269 Ok(value)
270 }
271 Err(e) => Err(e)
272 }
273}
274
275#[cfg(feature = "postgres")]
277impl Transaction<'_, sqlx::PgPool> {
278 pub async fn run<F, T, E>(self, f: F) -> Result<T, E>
309 where
310 F: AsyncFnOnce(&mut TransactionContext) -> Result<T, E>,
311 E: From<sqlx::Error>
312 {
313 let tx = self.pool.begin().await.map_err(E::from)?;
314 let mut ctx = TransactionContext::new(tx);
315 let result = f(&mut ctx).await;
316 finalize_with_commit(ctx, result, |c| c.commit()).await
317 }
318
319 pub async fn run_with_commit<F, Fut, T, E>(self, f: F) -> Result<T, E>
344 where
345 F: FnOnce(TransactionContext) -> Fut + Send,
346 Fut: Future<Output = Result<T, E>> + Send,
347 E: From<sqlx::Error>
348 {
349 let tx = self.pool.begin().await.map_err(E::from)?;
350 let ctx = TransactionContext::new(tx);
351 f(ctx).await
352 }
353}
354
355#[cfg(test)]
356#[allow(clippy::uninlined_format_args)]
357mod tests {
358 use std::error::Error;
359
360 use super::*;
361
362 #[test]
363 fn transaction_error_display_begin() {
364 let err: TransactionError<std::io::Error> =
365 TransactionError::Begin(std::io::Error::other("test"));
366 assert!(err.to_string().contains("begin"));
367 assert!(err.to_string().contains("test"));
368 }
369
370 #[test]
371 fn transaction_error_display_commit() {
372 let err: TransactionError<std::io::Error> =
373 TransactionError::Commit(std::io::Error::other("test"));
374 assert!(err.to_string().contains("commit"));
375 }
376
377 #[test]
378 fn transaction_error_display_rollback() {
379 let err: TransactionError<std::io::Error> =
380 TransactionError::Rollback(std::io::Error::other("test"));
381 assert!(err.to_string().contains("rollback"));
382 }
383
384 #[test]
385 fn transaction_error_display_operation() {
386 let err: TransactionError<std::io::Error> =
387 TransactionError::Operation(std::io::Error::other("test"));
388 assert!(err.to_string().contains("operation"));
389 }
390
391 #[test]
392 fn transaction_error_is_methods() {
393 let begin: TransactionError<&str> = TransactionError::Begin("e");
394 let commit: TransactionError<&str> = TransactionError::Commit("e");
395 let rollback: TransactionError<&str> = TransactionError::Rollback("e");
396 let operation: TransactionError<&str> = TransactionError::Operation("e");
397
398 assert!(begin.is_begin());
399 assert!(!begin.is_commit());
400 assert!(!begin.is_rollback());
401 assert!(!begin.is_operation());
402
403 assert!(!commit.is_begin());
404 assert!(commit.is_commit());
405 assert!(!commit.is_rollback());
406 assert!(!commit.is_operation());
407
408 assert!(!rollback.is_begin());
409 assert!(!rollback.is_commit());
410 assert!(rollback.is_rollback());
411 assert!(!rollback.is_operation());
412
413 assert!(!operation.is_begin());
414 assert!(!operation.is_commit());
415 assert!(!operation.is_rollback());
416 assert!(operation.is_operation());
417 }
418
419 #[test]
420 fn transaction_error_into_inner() {
421 let err: TransactionError<&str> = TransactionError::Operation("test");
422 assert_eq!(err.into_inner(), "test");
423 }
424
425 #[test]
426 fn transaction_error_into_inner_begin() {
427 let err: TransactionError<&str> = TransactionError::Begin("begin_err");
428 assert_eq!(err.into_inner(), "begin_err");
429 }
430
431 #[test]
432 fn transaction_error_into_inner_commit() {
433 let err: TransactionError<&str> = TransactionError::Commit("commit_err");
434 assert_eq!(err.into_inner(), "commit_err");
435 }
436
437 #[test]
438 fn transaction_error_into_inner_rollback() {
439 let err: TransactionError<&str> = TransactionError::Rollback("rollback_err");
440 assert_eq!(err.into_inner(), "rollback_err");
441 }
442
443 #[test]
444 fn transaction_error_source_begin() {
445 let err: TransactionError<std::io::Error> =
446 TransactionError::Begin(std::io::Error::other("src"));
447 assert!(err.source().is_some());
448 }
449
450 #[test]
451 fn transaction_error_source_commit() {
452 let err: TransactionError<std::io::Error> =
453 TransactionError::Commit(std::io::Error::other("src"));
454 assert!(err.source().is_some());
455 }
456
457 #[test]
458 fn transaction_error_source_rollback() {
459 let err: TransactionError<std::io::Error> =
460 TransactionError::Rollback(std::io::Error::other("src"));
461 assert!(err.source().is_some());
462 }
463
464 #[test]
465 fn transaction_error_source_operation() {
466 let err: TransactionError<std::io::Error> =
467 TransactionError::Operation(std::io::Error::other("src"));
468 assert!(err.source().is_some());
469 }
470
471 #[test]
472 fn transaction_builder_new() {
473 struct MockPool;
474 let pool = MockPool;
475 let tx = Transaction::new(&pool);
476 let _ = tx.pool();
477 }
478
479 #[test]
480 fn transaction_builder_pool_accessor() {
481 struct MockPool {
482 id: u32
483 }
484 let pool = MockPool {
485 id: 42
486 };
487 let tx = Transaction::new(&pool);
488 assert_eq!(tx.pool().id, 42);
489 }
490
491 #[test]
492 fn transaction_error_debug() {
493 let err: TransactionError<&str> = TransactionError::Begin("test");
494 let debug_str = format!("{:?}", err);
495 assert!(debug_str.contains("Begin"));
496 assert!(debug_str.contains("test"));
497 }
498
499 #[test]
500 fn transaction_error_into_inner_all_variants() {
501 let begin: TransactionError<String> = TransactionError::Begin("begin".to_string());
502 let commit: TransactionError<String> = TransactionError::Commit("commit".to_string());
503 let rollback: TransactionError<String> =
504 TransactionError::Rollback("rollback".to_string());
505 let operation: TransactionError<String> = TransactionError::Operation("op".to_string());
506
507 assert_eq!(begin.into_inner(), "begin");
508 assert_eq!(commit.into_inner(), "commit");
509 assert_eq!(rollback.into_inner(), "rollback");
510 assert_eq!(operation.into_inner(), "op");
511 }
512
513 #[test]
514 fn transaction_error_source_all_variants() {
515 let begin: TransactionError<std::io::Error> =
516 TransactionError::Begin(std::io::Error::other("src"));
517 let commit: TransactionError<std::io::Error> =
518 TransactionError::Commit(std::io::Error::other("src"));
519 let rollback: TransactionError<std::io::Error> =
520 TransactionError::Rollback(std::io::Error::other("src"));
521 let operation: TransactionError<std::io::Error> =
522 TransactionError::Operation(std::io::Error::other("src"));
523
524 assert!(begin.source().is_some());
525 assert!(commit.source().is_some());
526 assert!(rollback.source().is_some());
527 assert!(operation.source().is_some());
528 }
529
530 #[test]
531 fn transaction_error_display_all_variants() {
532 let begin: TransactionError<std::io::Error> =
533 TransactionError::Begin(std::io::Error::other("msg"));
534 let commit: TransactionError<std::io::Error> =
535 TransactionError::Commit(std::io::Error::other("msg"));
536 let rollback: TransactionError<std::io::Error> =
537 TransactionError::Rollback(std::io::Error::other("msg"));
538 let operation: TransactionError<std::io::Error> =
539 TransactionError::Operation(std::io::Error::other("msg"));
540
541 let begin_str = begin.to_string();
542 let commit_str = commit.to_string();
543 let rollback_str = rollback.to_string();
544 let operation_str = operation.to_string();
545
546 assert!(begin_str.contains("begin"));
547 assert!(commit_str.contains("commit"));
548 assert!(rollback_str.contains("rollback"));
549 assert!(operation_str.contains("operation"));
550 }
551
552 #[test]
553 fn transaction_error_is_all_variants() {
554 let begin: TransactionError<&str> = TransactionError::Begin("e");
555 let commit: TransactionError<&str> = TransactionError::Commit("e");
556 let rollback: TransactionError<&str> = TransactionError::Rollback("e");
557 let operation: TransactionError<&str> = TransactionError::Operation("e");
558
559 assert!(begin.is_begin());
560 assert!(commit.is_commit());
561 assert!(rollback.is_rollback());
562 assert!(operation.is_operation());
563
564 assert!(!begin.is_commit());
565 assert!(!begin.is_rollback());
566 assert!(!begin.is_operation());
567
568 assert!(!commit.is_begin());
569 assert!(!commit.is_rollback());
570 assert!(!commit.is_operation());
571
572 assert!(!rollback.is_begin());
573 assert!(!rollback.is_commit());
574 assert!(!rollback.is_operation());
575
576 assert!(!operation.is_begin());
577 assert!(!operation.is_commit());
578 assert!(!operation.is_rollback());
579 }
580
581 #[test]
582 fn transaction_builder_new_const() {
583 struct MockPool;
584 let pool = MockPool;
585 let tx = Transaction::new(&pool);
586 let _ = tx;
587 }
588
589 #[derive(Debug, PartialEq, Eq)]
599 struct MockCtx;
600
601 #[derive(Debug, PartialEq, Eq)]
602 struct CommitErr(&'static str);
603
604 #[derive(Debug, PartialEq, Eq)]
605 enum AppErr {
606 Closure(&'static str),
607 Commit(&'static str)
608 }
609
610 impl From<CommitErr> for AppErr {
611 fn from(e: CommitErr) -> Self {
612 Self::Commit(e.0)
613 }
614 }
615
616 #[tokio::test]
617 async fn finalize_commits_on_ok() {
618 let committed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
619 let flag = committed.clone();
620
621 let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
622 MockCtx,
623 Ok::<i32, AppErr>(42),
624 move |_ctx| {
625 let flag = flag.clone();
626 async move {
627 flag.store(true, std::sync::atomic::Ordering::SeqCst);
628 Ok::<(), CommitErr>(())
629 }
630 }
631 )
632 .await;
633
634 assert_eq!(result, Ok(42));
635 assert!(
636 committed.load(std::sync::atomic::Ordering::SeqCst),
637 "commit_fn must run on Ok"
638 );
639 }
640
641 #[tokio::test]
642 async fn finalize_skips_commit_on_err() {
643 let committed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
644 let flag = committed.clone();
645
646 let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
647 MockCtx,
648 Err::<i32, AppErr>(AppErr::Closure("nope")),
649 move |_ctx| {
650 let flag = flag.clone();
651 async move {
652 flag.store(true, std::sync::atomic::Ordering::SeqCst);
653 Ok::<(), CommitErr>(())
654 }
655 }
656 )
657 .await;
658
659 assert_eq!(result, Err(AppErr::Closure("nope")));
660 assert!(
661 !committed.load(std::sync::atomic::Ordering::SeqCst),
662 "commit_fn must NOT run on Err"
663 );
664 }
665
666 #[tokio::test]
667 async fn finalize_propagates_commit_error_on_ok() {
668 let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
669 MockCtx,
670 Ok::<i32, AppErr>(42),
671 |_ctx| async { Err::<(), CommitErr>(CommitErr("commit failed")) }
672 )
673 .await;
674
675 assert_eq!(result, Err(AppErr::Commit("commit failed")));
676 }
677
678 #[tokio::test]
679 async fn finalize_preserves_closure_value_on_ok() {
680 let result: Result<String, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
683 MockCtx,
684 Ok::<String, AppErr>("payload".to_string()),
685 |_ctx| async { Ok::<(), CommitErr>(()) }
686 )
687 .await;
688
689 assert_eq!(result, Ok("payload".to_string()));
690 }
691
692 #[tokio::test]
693 async fn finalize_does_not_swallow_closure_error_when_commit_also_would_fail() {
694 let result: Result<(), AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
697 MockCtx,
698 Err::<(), AppErr>(AppErr::Closure("original")),
699 |_ctx| async { Err::<(), CommitErr>(CommitErr("never reached")) }
700 )
701 .await;
702
703 assert_eq!(result, Err(AppErr::Closure("original")));
704 }
705}