entity_core/
transaction.rs1#[cfg(feature = "postgres")]
41use std::future::Future;
42use std::{error::Error as StdError, fmt};
43
44pub struct Transaction<'p, DB> {
68 pool: &'p DB
69}
70
71impl<'p, DB> Transaction<'p, DB> {
72 pub const fn new(pool: &'p DB) -> Self {
84 Self {
85 pool
86 }
87 }
88
89 #[must_use]
91 pub const fn pool(&self) -> &'p DB {
92 self.pool
93 }
94}
95
96#[cfg(feature = "postgres")]
118pub struct TransactionContext {
119 tx: sqlx::Transaction<'static, sqlx::Postgres>
120}
121
122#[cfg(feature = "postgres")]
123impl TransactionContext {
124 #[doc(hidden)]
130 #[must_use]
131 pub const fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
132 Self {
133 tx
134 }
135 }
136
137 pub const fn transaction(&mut self) -> &mut sqlx::Transaction<'static, sqlx::Postgres> {
142 &mut self.tx
143 }
144
145 pub async fn commit(self) -> Result<(), sqlx::Error> {
153 self.tx.commit().await
154 }
155
156 pub async fn rollback(self) -> Result<(), sqlx::Error> {
164 self.tx.rollback().await
165 }
166}
167
168#[derive(Debug)]
172pub enum TransactionError<E> {
173 Begin(E),
175
176 Commit(E),
178
179 Rollback(E),
181
182 Operation(E)
184}
185
186impl<E: fmt::Display> fmt::Display for TransactionError<E> {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 match self {
189 Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
190 Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
191 Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
192 Self::Operation(e) => write!(f, "transaction operation failed: {e}")
193 }
194 }
195}
196
197impl<E: StdError + 'static> StdError for TransactionError<E> {
198 fn source(&self) -> Option<&(dyn StdError + 'static)> {
199 match self {
200 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
201 }
202 }
203}
204
205impl<E> TransactionError<E> {
206 pub const fn is_begin(&self) -> bool {
208 matches!(self, Self::Begin(_))
209 }
210
211 pub const fn is_commit(&self) -> bool {
213 matches!(self, Self::Commit(_))
214 }
215
216 pub const fn is_rollback(&self) -> bool {
218 matches!(self, Self::Rollback(_))
219 }
220
221 pub const fn is_operation(&self) -> bool {
223 matches!(self, Self::Operation(_))
224 }
225
226 pub fn into_inner(self) -> E {
228 match self {
229 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
230 }
231 }
232}
233
234#[cfg(feature = "postgres")]
235impl From<TransactionError<Self>> for sqlx::Error {
236 fn from(err: TransactionError<Self>) -> Self {
237 err.into_inner()
238 }
239}
240
241#[cfg(feature = "postgres")]
243impl Transaction<'_, sqlx::PgPool> {
244 pub async fn run<F, Fut, T, E>(self, f: F) -> Result<T, E>
271 where
272 F: FnOnce(TransactionContext) -> Fut + Send,
273 Fut: Future<Output = Result<T, E>> + Send,
274 E: From<sqlx::Error>
275 {
276 let tx = self.pool.begin().await.map_err(E::from)?;
277 let ctx = TransactionContext::new(tx);
278
279 match f(ctx).await {
280 Ok(result) => Ok(result),
281 Err(e) => Err(e)
282 }
283 }
284
285 pub async fn run_with_commit<F, Fut, T, E>(self, f: F) -> Result<T, E>
306 where
307 F: FnOnce(TransactionContext) -> Fut + Send,
308 Fut: Future<Output = Result<T, E>> + Send,
309 E: From<sqlx::Error>
310 {
311 let tx = self.pool.begin().await.map_err(E::from)?;
312 let ctx = TransactionContext::new(tx);
313 f(ctx).await
314 }
315}
316
317#[cfg(test)]
318#[allow(clippy::uninlined_format_args)]
319mod tests {
320 use std::error::Error;
321
322 use super::*;
323
324 #[test]
325 fn transaction_error_display_begin() {
326 let err: TransactionError<std::io::Error> =
327 TransactionError::Begin(std::io::Error::other("test"));
328 assert!(err.to_string().contains("begin"));
329 assert!(err.to_string().contains("test"));
330 }
331
332 #[test]
333 fn transaction_error_display_commit() {
334 let err: TransactionError<std::io::Error> =
335 TransactionError::Commit(std::io::Error::other("test"));
336 assert!(err.to_string().contains("commit"));
337 }
338
339 #[test]
340 fn transaction_error_display_rollback() {
341 let err: TransactionError<std::io::Error> =
342 TransactionError::Rollback(std::io::Error::other("test"));
343 assert!(err.to_string().contains("rollback"));
344 }
345
346 #[test]
347 fn transaction_error_display_operation() {
348 let err: TransactionError<std::io::Error> =
349 TransactionError::Operation(std::io::Error::other("test"));
350 assert!(err.to_string().contains("operation"));
351 }
352
353 #[test]
354 fn transaction_error_is_methods() {
355 let begin: TransactionError<&str> = TransactionError::Begin("e");
356 let commit: TransactionError<&str> = TransactionError::Commit("e");
357 let rollback: TransactionError<&str> = TransactionError::Rollback("e");
358 let operation: TransactionError<&str> = TransactionError::Operation("e");
359
360 assert!(begin.is_begin());
361 assert!(!begin.is_commit());
362 assert!(!begin.is_rollback());
363 assert!(!begin.is_operation());
364
365 assert!(!commit.is_begin());
366 assert!(commit.is_commit());
367 assert!(!commit.is_rollback());
368 assert!(!commit.is_operation());
369
370 assert!(!rollback.is_begin());
371 assert!(!rollback.is_commit());
372 assert!(rollback.is_rollback());
373 assert!(!rollback.is_operation());
374
375 assert!(!operation.is_begin());
376 assert!(!operation.is_commit());
377 assert!(!operation.is_rollback());
378 assert!(operation.is_operation());
379 }
380
381 #[test]
382 fn transaction_error_into_inner() {
383 let err: TransactionError<&str> = TransactionError::Operation("test");
384 assert_eq!(err.into_inner(), "test");
385 }
386
387 #[test]
388 fn transaction_error_into_inner_begin() {
389 let err: TransactionError<&str> = TransactionError::Begin("begin_err");
390 assert_eq!(err.into_inner(), "begin_err");
391 }
392
393 #[test]
394 fn transaction_error_into_inner_commit() {
395 let err: TransactionError<&str> = TransactionError::Commit("commit_err");
396 assert_eq!(err.into_inner(), "commit_err");
397 }
398
399 #[test]
400 fn transaction_error_into_inner_rollback() {
401 let err: TransactionError<&str> = TransactionError::Rollback("rollback_err");
402 assert_eq!(err.into_inner(), "rollback_err");
403 }
404
405 #[test]
406 fn transaction_error_source_begin() {
407 let err: TransactionError<std::io::Error> =
408 TransactionError::Begin(std::io::Error::other("src"));
409 assert!(err.source().is_some());
410 }
411
412 #[test]
413 fn transaction_error_source_commit() {
414 let err: TransactionError<std::io::Error> =
415 TransactionError::Commit(std::io::Error::other("src"));
416 assert!(err.source().is_some());
417 }
418
419 #[test]
420 fn transaction_error_source_rollback() {
421 let err: TransactionError<std::io::Error> =
422 TransactionError::Rollback(std::io::Error::other("src"));
423 assert!(err.source().is_some());
424 }
425
426 #[test]
427 fn transaction_error_source_operation() {
428 let err: TransactionError<std::io::Error> =
429 TransactionError::Operation(std::io::Error::other("src"));
430 assert!(err.source().is_some());
431 }
432
433 #[test]
434 fn transaction_builder_new() {
435 struct MockPool;
436 let pool = MockPool;
437 let tx = Transaction::new(&pool);
438 let _ = tx.pool();
439 }
440
441 #[test]
442 fn transaction_builder_pool_accessor() {
443 struct MockPool {
444 id: u32
445 }
446 let pool = MockPool {
447 id: 42
448 };
449 let tx = Transaction::new(&pool);
450 assert_eq!(tx.pool().id, 42);
451 }
452
453 #[test]
454 fn transaction_error_debug() {
455 let err: TransactionError<&str> = TransactionError::Begin("test");
456 let debug_str = format!("{:?}", err);
457 assert!(debug_str.contains("Begin"));
458 assert!(debug_str.contains("test"));
459 }
460
461 #[test]
462 fn transaction_error_into_inner_all_variants() {
463 let begin: TransactionError<String> = TransactionError::Begin("begin".to_string());
464 let commit: TransactionError<String> = TransactionError::Commit("commit".to_string());
465 let rollback: TransactionError<String> =
466 TransactionError::Rollback("rollback".to_string());
467 let operation: TransactionError<String> = TransactionError::Operation("op".to_string());
468
469 assert_eq!(begin.into_inner(), "begin");
470 assert_eq!(commit.into_inner(), "commit");
471 assert_eq!(rollback.into_inner(), "rollback");
472 assert_eq!(operation.into_inner(), "op");
473 }
474
475 #[test]
476 fn transaction_error_source_all_variants() {
477 let begin: TransactionError<std::io::Error> =
478 TransactionError::Begin(std::io::Error::other("src"));
479 let commit: TransactionError<std::io::Error> =
480 TransactionError::Commit(std::io::Error::other("src"));
481 let rollback: TransactionError<std::io::Error> =
482 TransactionError::Rollback(std::io::Error::other("src"));
483 let operation: TransactionError<std::io::Error> =
484 TransactionError::Operation(std::io::Error::other("src"));
485
486 assert!(begin.source().is_some());
487 assert!(commit.source().is_some());
488 assert!(rollback.source().is_some());
489 assert!(operation.source().is_some());
490 }
491
492 #[test]
493 fn transaction_error_display_all_variants() {
494 let begin: TransactionError<std::io::Error> =
495 TransactionError::Begin(std::io::Error::other("msg"));
496 let commit: TransactionError<std::io::Error> =
497 TransactionError::Commit(std::io::Error::other("msg"));
498 let rollback: TransactionError<std::io::Error> =
499 TransactionError::Rollback(std::io::Error::other("msg"));
500 let operation: TransactionError<std::io::Error> =
501 TransactionError::Operation(std::io::Error::other("msg"));
502
503 let begin_str = begin.to_string();
504 let commit_str = commit.to_string();
505 let rollback_str = rollback.to_string();
506 let operation_str = operation.to_string();
507
508 assert!(begin_str.contains("begin"));
509 assert!(commit_str.contains("commit"));
510 assert!(rollback_str.contains("rollback"));
511 assert!(operation_str.contains("operation"));
512 }
513
514 #[test]
515 fn transaction_error_is_all_variants() {
516 let begin: TransactionError<&str> = TransactionError::Begin("e");
517 let commit: TransactionError<&str> = TransactionError::Commit("e");
518 let rollback: TransactionError<&str> = TransactionError::Rollback("e");
519 let operation: TransactionError<&str> = TransactionError::Operation("e");
520
521 assert!(begin.is_begin());
522 assert!(commit.is_commit());
523 assert!(rollback.is_rollback());
524 assert!(operation.is_operation());
525
526 assert!(!begin.is_commit());
527 assert!(!begin.is_rollback());
528 assert!(!begin.is_operation());
529
530 assert!(!commit.is_begin());
531 assert!(!commit.is_rollback());
532 assert!(!commit.is_operation());
533
534 assert!(!rollback.is_begin());
535 assert!(!rollback.is_commit());
536 assert!(!rollback.is_operation());
537
538 assert!(!operation.is_begin());
539 assert!(!operation.is_commit());
540 assert!(!operation.is_rollback());
541 }
542
543 #[test]
544 fn transaction_builder_new_const() {
545 struct MockPool;
546 let pool = MockPool;
547 let tx = Transaction::new(&pool);
548 let _ = tx;
549 }
550}