entity_core/
transaction.rs1#[cfg(feature = "postgres")]
41use std::future::Future;
42use std::{error::Error as StdError, fmt};
43
44pub struct Transaction<'p, DB> {
68 pool: &'p DB
69}
70
71impl<'p, DB> Transaction<'p, DB> {
72 pub const fn new(pool: &'p DB) -> Self {
84 Self {
85 pool
86 }
87 }
88
89 #[must_use]
91 pub const fn pool(&self) -> &'p DB {
92 self.pool
93 }
94}
95
96#[cfg(feature = "postgres")]
118pub struct TransactionContext {
119 tx: sqlx::Transaction<'static, sqlx::Postgres>
120}
121
122#[cfg(feature = "postgres")]
123impl TransactionContext {
124 #[doc(hidden)]
130 #[must_use]
131 pub const fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
132 Self {
133 tx
134 }
135 }
136
137 pub const fn transaction(&mut self) -> &mut sqlx::Transaction<'static, sqlx::Postgres> {
142 &mut self.tx
143 }
144
145 pub async fn commit(self) -> Result<(), sqlx::Error> {
153 self.tx.commit().await
154 }
155
156 pub async fn rollback(self) -> Result<(), sqlx::Error> {
164 self.tx.rollback().await
165 }
166}
167
168#[derive(Debug)]
172pub enum TransactionError<E> {
173 Begin(E),
175
176 Commit(E),
178
179 Rollback(E),
181
182 Operation(E)
184}
185
186impl<E: fmt::Display> fmt::Display for TransactionError<E> {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 match self {
189 Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
190 Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
191 Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
192 Self::Operation(e) => write!(f, "transaction operation failed: {e}")
193 }
194 }
195}
196
197impl<E: StdError + 'static> StdError for TransactionError<E> {
198 fn source(&self) -> Option<&(dyn StdError + 'static)> {
199 match self {
200 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
201 }
202 }
203}
204
205impl<E> TransactionError<E> {
206 pub const fn is_begin(&self) -> bool {
208 matches!(self, Self::Begin(_))
209 }
210
211 pub const fn is_commit(&self) -> bool {
213 matches!(self, Self::Commit(_))
214 }
215
216 pub const fn is_rollback(&self) -> bool {
218 matches!(self, Self::Rollback(_))
219 }
220
221 pub const fn is_operation(&self) -> bool {
223 matches!(self, Self::Operation(_))
224 }
225
226 pub fn into_inner(self) -> E {
228 match self {
229 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
230 }
231 }
232}
233
234#[cfg(feature = "postgres")]
235impl From<TransactionError<Self>> for sqlx::Error {
236 fn from(err: TransactionError<Self>) -> Self {
237 err.into_inner()
238 }
239}
240
241#[cfg(feature = "postgres")]
243impl Transaction<'_, sqlx::PgPool> {
244 pub async fn run<F, T, E>(self, f: F) -> Result<T, E>
275 where
276 F: AsyncFnOnce(&mut TransactionContext) -> Result<T, E>,
277 E: From<sqlx::Error>
278 {
279 let tx = self.pool.begin().await.map_err(E::from)?;
280 let mut ctx = TransactionContext::new(tx);
281
282 match f(&mut ctx).await {
283 Ok(result) => {
284 ctx.commit().await.map_err(E::from)?;
285 Ok(result)
286 }
287 Err(e) => Err(e)
288 }
289 }
290
291 pub async fn run_with_commit<F, Fut, T, E>(self, f: F) -> Result<T, E>
316 where
317 F: FnOnce(TransactionContext) -> Fut + Send,
318 Fut: Future<Output = Result<T, E>> + Send,
319 E: From<sqlx::Error>
320 {
321 let tx = self.pool.begin().await.map_err(E::from)?;
322 let ctx = TransactionContext::new(tx);
323 f(ctx).await
324 }
325}
326
327#[cfg(test)]
328#[allow(clippy::uninlined_format_args)]
329mod tests {
330 use std::error::Error;
331
332 use super::*;
333
334 #[test]
335 fn transaction_error_display_begin() {
336 let err: TransactionError<std::io::Error> =
337 TransactionError::Begin(std::io::Error::other("test"));
338 assert!(err.to_string().contains("begin"));
339 assert!(err.to_string().contains("test"));
340 }
341
342 #[test]
343 fn transaction_error_display_commit() {
344 let err: TransactionError<std::io::Error> =
345 TransactionError::Commit(std::io::Error::other("test"));
346 assert!(err.to_string().contains("commit"));
347 }
348
349 #[test]
350 fn transaction_error_display_rollback() {
351 let err: TransactionError<std::io::Error> =
352 TransactionError::Rollback(std::io::Error::other("test"));
353 assert!(err.to_string().contains("rollback"));
354 }
355
356 #[test]
357 fn transaction_error_display_operation() {
358 let err: TransactionError<std::io::Error> =
359 TransactionError::Operation(std::io::Error::other("test"));
360 assert!(err.to_string().contains("operation"));
361 }
362
363 #[test]
364 fn transaction_error_is_methods() {
365 let begin: TransactionError<&str> = TransactionError::Begin("e");
366 let commit: TransactionError<&str> = TransactionError::Commit("e");
367 let rollback: TransactionError<&str> = TransactionError::Rollback("e");
368 let operation: TransactionError<&str> = TransactionError::Operation("e");
369
370 assert!(begin.is_begin());
371 assert!(!begin.is_commit());
372 assert!(!begin.is_rollback());
373 assert!(!begin.is_operation());
374
375 assert!(!commit.is_begin());
376 assert!(commit.is_commit());
377 assert!(!commit.is_rollback());
378 assert!(!commit.is_operation());
379
380 assert!(!rollback.is_begin());
381 assert!(!rollback.is_commit());
382 assert!(rollback.is_rollback());
383 assert!(!rollback.is_operation());
384
385 assert!(!operation.is_begin());
386 assert!(!operation.is_commit());
387 assert!(!operation.is_rollback());
388 assert!(operation.is_operation());
389 }
390
391 #[test]
392 fn transaction_error_into_inner() {
393 let err: TransactionError<&str> = TransactionError::Operation("test");
394 assert_eq!(err.into_inner(), "test");
395 }
396
397 #[test]
398 fn transaction_error_into_inner_begin() {
399 let err: TransactionError<&str> = TransactionError::Begin("begin_err");
400 assert_eq!(err.into_inner(), "begin_err");
401 }
402
403 #[test]
404 fn transaction_error_into_inner_commit() {
405 let err: TransactionError<&str> = TransactionError::Commit("commit_err");
406 assert_eq!(err.into_inner(), "commit_err");
407 }
408
409 #[test]
410 fn transaction_error_into_inner_rollback() {
411 let err: TransactionError<&str> = TransactionError::Rollback("rollback_err");
412 assert_eq!(err.into_inner(), "rollback_err");
413 }
414
415 #[test]
416 fn transaction_error_source_begin() {
417 let err: TransactionError<std::io::Error> =
418 TransactionError::Begin(std::io::Error::other("src"));
419 assert!(err.source().is_some());
420 }
421
422 #[test]
423 fn transaction_error_source_commit() {
424 let err: TransactionError<std::io::Error> =
425 TransactionError::Commit(std::io::Error::other("src"));
426 assert!(err.source().is_some());
427 }
428
429 #[test]
430 fn transaction_error_source_rollback() {
431 let err: TransactionError<std::io::Error> =
432 TransactionError::Rollback(std::io::Error::other("src"));
433 assert!(err.source().is_some());
434 }
435
436 #[test]
437 fn transaction_error_source_operation() {
438 let err: TransactionError<std::io::Error> =
439 TransactionError::Operation(std::io::Error::other("src"));
440 assert!(err.source().is_some());
441 }
442
443 #[test]
444 fn transaction_builder_new() {
445 struct MockPool;
446 let pool = MockPool;
447 let tx = Transaction::new(&pool);
448 let _ = tx.pool();
449 }
450
451 #[test]
452 fn transaction_builder_pool_accessor() {
453 struct MockPool {
454 id: u32
455 }
456 let pool = MockPool {
457 id: 42
458 };
459 let tx = Transaction::new(&pool);
460 assert_eq!(tx.pool().id, 42);
461 }
462
463 #[test]
464 fn transaction_error_debug() {
465 let err: TransactionError<&str> = TransactionError::Begin("test");
466 let debug_str = format!("{:?}", err);
467 assert!(debug_str.contains("Begin"));
468 assert!(debug_str.contains("test"));
469 }
470
471 #[test]
472 fn transaction_error_into_inner_all_variants() {
473 let begin: TransactionError<String> = TransactionError::Begin("begin".to_string());
474 let commit: TransactionError<String> = TransactionError::Commit("commit".to_string());
475 let rollback: TransactionError<String> =
476 TransactionError::Rollback("rollback".to_string());
477 let operation: TransactionError<String> = TransactionError::Operation("op".to_string());
478
479 assert_eq!(begin.into_inner(), "begin");
480 assert_eq!(commit.into_inner(), "commit");
481 assert_eq!(rollback.into_inner(), "rollback");
482 assert_eq!(operation.into_inner(), "op");
483 }
484
485 #[test]
486 fn transaction_error_source_all_variants() {
487 let begin: TransactionError<std::io::Error> =
488 TransactionError::Begin(std::io::Error::other("src"));
489 let commit: TransactionError<std::io::Error> =
490 TransactionError::Commit(std::io::Error::other("src"));
491 let rollback: TransactionError<std::io::Error> =
492 TransactionError::Rollback(std::io::Error::other("src"));
493 let operation: TransactionError<std::io::Error> =
494 TransactionError::Operation(std::io::Error::other("src"));
495
496 assert!(begin.source().is_some());
497 assert!(commit.source().is_some());
498 assert!(rollback.source().is_some());
499 assert!(operation.source().is_some());
500 }
501
502 #[test]
503 fn transaction_error_display_all_variants() {
504 let begin: TransactionError<std::io::Error> =
505 TransactionError::Begin(std::io::Error::other("msg"));
506 let commit: TransactionError<std::io::Error> =
507 TransactionError::Commit(std::io::Error::other("msg"));
508 let rollback: TransactionError<std::io::Error> =
509 TransactionError::Rollback(std::io::Error::other("msg"));
510 let operation: TransactionError<std::io::Error> =
511 TransactionError::Operation(std::io::Error::other("msg"));
512
513 let begin_str = begin.to_string();
514 let commit_str = commit.to_string();
515 let rollback_str = rollback.to_string();
516 let operation_str = operation.to_string();
517
518 assert!(begin_str.contains("begin"));
519 assert!(commit_str.contains("commit"));
520 assert!(rollback_str.contains("rollback"));
521 assert!(operation_str.contains("operation"));
522 }
523
524 #[test]
525 fn transaction_error_is_all_variants() {
526 let begin: TransactionError<&str> = TransactionError::Begin("e");
527 let commit: TransactionError<&str> = TransactionError::Commit("e");
528 let rollback: TransactionError<&str> = TransactionError::Rollback("e");
529 let operation: TransactionError<&str> = TransactionError::Operation("e");
530
531 assert!(begin.is_begin());
532 assert!(commit.is_commit());
533 assert!(rollback.is_rollback());
534 assert!(operation.is_operation());
535
536 assert!(!begin.is_commit());
537 assert!(!begin.is_rollback());
538 assert!(!begin.is_operation());
539
540 assert!(!commit.is_begin());
541 assert!(!commit.is_rollback());
542 assert!(!commit.is_operation());
543
544 assert!(!rollback.is_begin());
545 assert!(!rollback.is_commit());
546 assert!(!rollback.is_operation());
547
548 assert!(!operation.is_begin());
549 assert!(!operation.is_commit());
550 assert!(!operation.is_rollback());
551 }
552
553 #[test]
554 fn transaction_builder_new_const() {
555 struct MockPool;
556 let pool = MockPool;
557 let tx = Transaction::new(&pool);
558 let _ = tx;
559 }
560
561 #[cfg(feature = "postgres")]
571 #[allow(dead_code, clippy::no_effect_underscore_binding)]
572 fn _run_signature_accepts_mut_ref(pool: &sqlx::PgPool) {
573 let _fut = Transaction::new(pool)
574 .run(async |_ctx: &mut TransactionContext| Ok::<(), sqlx::Error>(()));
575 }
576}