1#[cfg(feature = "postgres")]
39use std::future::Future;
40use std::{error::Error as StdError, fmt};
41
42pub struct Transaction<'p, DB> {
64 pool: &'p DB
65}
66
67impl<'p, DB> Transaction<'p, DB> {
68 pub const fn new(pool: &'p DB) -> Self {
80 Self {
81 pool
82 }
83 }
84
85 #[must_use]
87 pub const fn pool(&self) -> &'p DB {
88 self.pool
89 }
90}
91
92#[cfg(feature = "postgres")]
114pub struct TransactionContext {
115 tx: sqlx::Transaction<'static, sqlx::Postgres>
116}
117
118#[cfg(feature = "postgres")]
119impl TransactionContext {
120 #[doc(hidden)]
126 #[must_use]
127 pub const fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
128 Self {
129 tx
130 }
131 }
132
133 pub const fn transaction(&mut self) -> &mut sqlx::Transaction<'static, sqlx::Postgres> {
138 &mut self.tx
139 }
140
141 pub async fn commit(self) -> Result<(), sqlx::Error> {
149 self.tx.commit().await
150 }
151
152 pub async fn rollback(self) -> Result<(), sqlx::Error> {
160 self.tx.rollback().await
161 }
162}
163
164#[derive(Debug)]
168pub enum TransactionError<E> {
169 Begin(E),
171
172 Commit(E),
174
175 Rollback(E),
177
178 Operation(E)
180}
181
182impl<E: fmt::Display> fmt::Display for TransactionError<E> {
183 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184 match self {
185 Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
186 Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
187 Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
188 Self::Operation(e) => write!(f, "transaction operation failed: {e}")
189 }
190 }
191}
192
193impl<E: StdError + 'static> StdError for TransactionError<E> {
194 fn source(&self) -> Option<&(dyn StdError + 'static)> {
195 match self {
196 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
197 }
198 }
199}
200
201impl<E> TransactionError<E> {
202 pub const fn is_begin(&self) -> bool {
204 matches!(self, Self::Begin(_))
205 }
206
207 pub const fn is_commit(&self) -> bool {
209 matches!(self, Self::Commit(_))
210 }
211
212 pub const fn is_rollback(&self) -> bool {
214 matches!(self, Self::Rollback(_))
215 }
216
217 pub const fn is_operation(&self) -> bool {
219 matches!(self, Self::Operation(_))
220 }
221
222 pub fn into_inner(self) -> E {
224 match self {
225 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
226 }
227 }
228}
229
230#[cfg(feature = "postgres")]
231impl From<TransactionError<Self>> for sqlx::Error {
232 fn from(err: TransactionError<Self>) -> Self {
233 err.into_inner()
234 }
235}
236
237#[cfg(any(feature = "postgres", test))]
252async fn finalize_with_commit<C, T, E, CommitErr, Cf, Fut>(
253 ctx: C,
254 result: Result<T, E>,
255 commit_fn: Cf
256) -> Result<T, E>
257where
258 Cf: FnOnce(C) -> Fut,
259 Fut: core::future::Future<Output = Result<(), CommitErr>>,
260 E: From<CommitErr>
261{
262 match result {
263 Ok(value) => {
264 commit_fn(ctx).await.map_err(E::from)?;
265 Ok(value)
266 }
267 Err(e) => Err(e)
268 }
269}
270
271#[cfg(feature = "postgres")]
273impl Transaction<'_, sqlx::PgPool> {
274 pub async fn run<F, T, E>(self, f: F) -> Result<T, E>
304 where
305 F: AsyncFnOnce(&mut TransactionContext) -> Result<T, E>,
306 E: From<sqlx::Error>
307 {
308 let tx = self.pool.begin().await.map_err(E::from)?;
309 let mut ctx = TransactionContext::new(tx);
310 let result = f(&mut ctx).await;
311 finalize_with_commit(ctx, result, |c| c.commit()).await
312 }
313
314 pub async fn run_with_commit<F, Fut, T, E>(self, f: F) -> Result<T, E>
339 where
340 F: FnOnce(TransactionContext) -> Fut + Send,
341 Fut: Future<Output = Result<T, E>> + Send,
342 E: From<sqlx::Error>
343 {
344 let tx = self.pool.begin().await.map_err(E::from)?;
345 let ctx = TransactionContext::new(tx);
346 f(ctx).await
347 }
348}
349
350#[cfg(test)]
351#[allow(clippy::uninlined_format_args)]
352mod tests {
353 use std::error::Error;
354
355 use super::*;
356
357 #[test]
358 fn transaction_error_display_begin() {
359 let err: TransactionError<std::io::Error> =
360 TransactionError::Begin(std::io::Error::other("test"));
361 assert!(err.to_string().contains("begin"));
362 assert!(err.to_string().contains("test"));
363 }
364
365 #[test]
366 fn transaction_error_display_commit() {
367 let err: TransactionError<std::io::Error> =
368 TransactionError::Commit(std::io::Error::other("test"));
369 assert!(err.to_string().contains("commit"));
370 }
371
372 #[test]
373 fn transaction_error_display_rollback() {
374 let err: TransactionError<std::io::Error> =
375 TransactionError::Rollback(std::io::Error::other("test"));
376 assert!(err.to_string().contains("rollback"));
377 }
378
379 #[test]
380 fn transaction_error_display_operation() {
381 let err: TransactionError<std::io::Error> =
382 TransactionError::Operation(std::io::Error::other("test"));
383 assert!(err.to_string().contains("operation"));
384 }
385
386 #[test]
387 fn transaction_error_is_methods() {
388 let begin: TransactionError<&str> = TransactionError::Begin("e");
389 let commit: TransactionError<&str> = TransactionError::Commit("e");
390 let rollback: TransactionError<&str> = TransactionError::Rollback("e");
391 let operation: TransactionError<&str> = TransactionError::Operation("e");
392
393 assert!(begin.is_begin());
394 assert!(!begin.is_commit());
395 assert!(!begin.is_rollback());
396 assert!(!begin.is_operation());
397
398 assert!(!commit.is_begin());
399 assert!(commit.is_commit());
400 assert!(!commit.is_rollback());
401 assert!(!commit.is_operation());
402
403 assert!(!rollback.is_begin());
404 assert!(!rollback.is_commit());
405 assert!(rollback.is_rollback());
406 assert!(!rollback.is_operation());
407
408 assert!(!operation.is_begin());
409 assert!(!operation.is_commit());
410 assert!(!operation.is_rollback());
411 assert!(operation.is_operation());
412 }
413
414 #[test]
415 fn transaction_error_into_inner() {
416 let err: TransactionError<&str> = TransactionError::Operation("test");
417 assert_eq!(err.into_inner(), "test");
418 }
419
420 #[test]
421 fn transaction_error_into_inner_begin() {
422 let err: TransactionError<&str> = TransactionError::Begin("begin_err");
423 assert_eq!(err.into_inner(), "begin_err");
424 }
425
426 #[test]
427 fn transaction_error_into_inner_commit() {
428 let err: TransactionError<&str> = TransactionError::Commit("commit_err");
429 assert_eq!(err.into_inner(), "commit_err");
430 }
431
432 #[test]
433 fn transaction_error_into_inner_rollback() {
434 let err: TransactionError<&str> = TransactionError::Rollback("rollback_err");
435 assert_eq!(err.into_inner(), "rollback_err");
436 }
437
438 #[test]
439 fn transaction_error_source_begin() {
440 let err: TransactionError<std::io::Error> =
441 TransactionError::Begin(std::io::Error::other("src"));
442 assert!(err.source().is_some());
443 }
444
445 #[test]
446 fn transaction_error_source_commit() {
447 let err: TransactionError<std::io::Error> =
448 TransactionError::Commit(std::io::Error::other("src"));
449 assert!(err.source().is_some());
450 }
451
452 #[test]
453 fn transaction_error_source_rollback() {
454 let err: TransactionError<std::io::Error> =
455 TransactionError::Rollback(std::io::Error::other("src"));
456 assert!(err.source().is_some());
457 }
458
459 #[test]
460 fn transaction_error_source_operation() {
461 let err: TransactionError<std::io::Error> =
462 TransactionError::Operation(std::io::Error::other("src"));
463 assert!(err.source().is_some());
464 }
465
466 #[test]
467 fn transaction_builder_new() {
468 struct MockPool;
469 let pool = MockPool;
470 let tx = Transaction::new(&pool);
471 let _ = tx.pool();
472 }
473
474 #[test]
475 fn transaction_builder_pool_accessor() {
476 struct MockPool {
477 id: u32
478 }
479 let pool = MockPool {
480 id: 42
481 };
482 let tx = Transaction::new(&pool);
483 assert_eq!(tx.pool().id, 42);
484 }
485
486 #[test]
487 fn transaction_error_debug() {
488 let err: TransactionError<&str> = TransactionError::Begin("test");
489 let debug_str = format!("{:?}", err);
490 assert!(debug_str.contains("Begin"));
491 assert!(debug_str.contains("test"));
492 }
493
494 #[test]
495 fn transaction_error_into_inner_all_variants() {
496 let begin: TransactionError<String> = TransactionError::Begin("begin".to_string());
497 let commit: TransactionError<String> = TransactionError::Commit("commit".to_string());
498 let rollback: TransactionError<String> =
499 TransactionError::Rollback("rollback".to_string());
500 let operation: TransactionError<String> = TransactionError::Operation("op".to_string());
501
502 assert_eq!(begin.into_inner(), "begin");
503 assert_eq!(commit.into_inner(), "commit");
504 assert_eq!(rollback.into_inner(), "rollback");
505 assert_eq!(operation.into_inner(), "op");
506 }
507
508 #[test]
509 fn transaction_error_source_all_variants() {
510 let begin: TransactionError<std::io::Error> =
511 TransactionError::Begin(std::io::Error::other("src"));
512 let commit: TransactionError<std::io::Error> =
513 TransactionError::Commit(std::io::Error::other("src"));
514 let rollback: TransactionError<std::io::Error> =
515 TransactionError::Rollback(std::io::Error::other("src"));
516 let operation: TransactionError<std::io::Error> =
517 TransactionError::Operation(std::io::Error::other("src"));
518
519 assert!(begin.source().is_some());
520 assert!(commit.source().is_some());
521 assert!(rollback.source().is_some());
522 assert!(operation.source().is_some());
523 }
524
525 #[test]
526 fn transaction_error_display_all_variants() {
527 let begin: TransactionError<std::io::Error> =
528 TransactionError::Begin(std::io::Error::other("msg"));
529 let commit: TransactionError<std::io::Error> =
530 TransactionError::Commit(std::io::Error::other("msg"));
531 let rollback: TransactionError<std::io::Error> =
532 TransactionError::Rollback(std::io::Error::other("msg"));
533 let operation: TransactionError<std::io::Error> =
534 TransactionError::Operation(std::io::Error::other("msg"));
535
536 let begin_str = begin.to_string();
537 let commit_str = commit.to_string();
538 let rollback_str = rollback.to_string();
539 let operation_str = operation.to_string();
540
541 assert!(begin_str.contains("begin"));
542 assert!(commit_str.contains("commit"));
543 assert!(rollback_str.contains("rollback"));
544 assert!(operation_str.contains("operation"));
545 }
546
547 #[test]
548 fn transaction_error_is_all_variants() {
549 let begin: TransactionError<&str> = TransactionError::Begin("e");
550 let commit: TransactionError<&str> = TransactionError::Commit("e");
551 let rollback: TransactionError<&str> = TransactionError::Rollback("e");
552 let operation: TransactionError<&str> = TransactionError::Operation("e");
553
554 assert!(begin.is_begin());
555 assert!(commit.is_commit());
556 assert!(rollback.is_rollback());
557 assert!(operation.is_operation());
558
559 assert!(!begin.is_commit());
560 assert!(!begin.is_rollback());
561 assert!(!begin.is_operation());
562
563 assert!(!commit.is_begin());
564 assert!(!commit.is_rollback());
565 assert!(!commit.is_operation());
566
567 assert!(!rollback.is_begin());
568 assert!(!rollback.is_commit());
569 assert!(!rollback.is_operation());
570
571 assert!(!operation.is_begin());
572 assert!(!operation.is_commit());
573 assert!(!operation.is_rollback());
574 }
575
576 #[test]
577 fn transaction_builder_new_const() {
578 struct MockPool;
579 let pool = MockPool;
580 let tx = Transaction::new(&pool);
581 let _ = tx;
582 }
583
584 #[derive(Debug, PartialEq, Eq)]
594 struct MockCtx;
595
596 #[derive(Debug, PartialEq, Eq)]
597 struct CommitErr(&'static str);
598
599 #[derive(Debug, PartialEq, Eq)]
600 enum AppErr {
601 Closure(&'static str),
602 Commit(&'static str)
603 }
604
605 impl From<CommitErr> for AppErr {
606 fn from(e: CommitErr) -> Self {
607 Self::Commit(e.0)
608 }
609 }
610
611 #[tokio::test]
612 async fn finalize_commits_on_ok() {
613 let committed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
614 let flag = committed.clone();
615
616 let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
617 MockCtx,
618 Ok::<i32, AppErr>(42),
619 move |_ctx| {
620 let flag = flag.clone();
621 async move {
622 flag.store(true, std::sync::atomic::Ordering::SeqCst);
623 Ok::<(), CommitErr>(())
624 }
625 }
626 )
627 .await;
628
629 assert_eq!(result, Ok(42));
630 assert!(
631 committed.load(std::sync::atomic::Ordering::SeqCst),
632 "commit_fn must run on Ok"
633 );
634 }
635
636 #[tokio::test]
637 async fn finalize_skips_commit_on_err() {
638 let committed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
639 let flag = committed.clone();
640
641 let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
642 MockCtx,
643 Err::<i32, AppErr>(AppErr::Closure("nope")),
644 move |_ctx| {
645 let flag = flag.clone();
646 async move {
647 flag.store(true, std::sync::atomic::Ordering::SeqCst);
648 Ok::<(), CommitErr>(())
649 }
650 }
651 )
652 .await;
653
654 assert_eq!(result, Err(AppErr::Closure("nope")));
655 assert!(
656 !committed.load(std::sync::atomic::Ordering::SeqCst),
657 "commit_fn must NOT run on Err"
658 );
659 }
660
661 #[tokio::test]
662 async fn finalize_propagates_commit_error_on_ok() {
663 let result: Result<i32, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
664 MockCtx,
665 Ok::<i32, AppErr>(42),
666 |_ctx| async { Err::<(), CommitErr>(CommitErr("commit failed")) }
667 )
668 .await;
669
670 assert_eq!(result, Err(AppErr::Commit("commit failed")));
671 }
672
673 #[tokio::test]
674 async fn finalize_preserves_closure_value_on_ok() {
675 let result: Result<String, AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
678 MockCtx,
679 Ok::<String, AppErr>("payload".to_string()),
680 |_ctx| async { Ok::<(), CommitErr>(()) }
681 )
682 .await;
683
684 assert_eq!(result, Ok("payload".to_string()));
685 }
686
687 #[tokio::test]
688 async fn finalize_does_not_swallow_closure_error_when_commit_also_would_fail() {
689 let result: Result<(), AppErr> = finalize_with_commit::<_, _, _, CommitErr, _, _>(
692 MockCtx,
693 Err::<(), AppErr>(AppErr::Closure("original")),
694 |_ctx| async { Err::<(), CommitErr>(CommitErr("never reached")) }
695 )
696 .await;
697
698 assert_eq!(result, Err(AppErr::Closure("original")));
699 }
700}