1use std::{error::Error as StdError, fmt, future::Future, marker::PhantomData};
40
41pub struct Transaction<'p, DB, Repos = ()> {
51 pool: &'p DB,
52 _repos: PhantomData<Repos>
53}
54
55impl<'p, DB> Transaction<'p, DB, ()> {
56 pub const fn new(pool: &'p DB) -> Self {
68 Self {
69 pool,
70 _repos: PhantomData
71 }
72 }
73}
74
75impl<'p, DB, Repos> Transaction<'p, DB, Repos> {
76 pub const fn pool(&self) -> &'p DB {
78 self.pool
79 }
80
81 #[doc(hidden)]
85 pub const fn with_repo<NewRepos>(self) -> Transaction<'p, DB, NewRepos> {
86 Transaction {
87 pool: self.pool,
88 _repos: PhantomData
89 }
90 }
91}
92
93pub struct TransactionContext<'t, Tx, Repos> {
109 tx: Tx,
110 repos: Repos,
111 _lifetime: PhantomData<&'t ()>
112}
113
114impl<'t, Tx, Repos> TransactionContext<'t, Tx, Repos> {
115 #[doc(hidden)]
122 pub const fn new(tx: Tx, repos: Repos) -> Self {
123 Self {
124 tx,
125 repos,
126 _lifetime: PhantomData
127 }
128 }
129
130 pub fn transaction(&mut self) -> &mut Tx {
134 &mut self.tx
135 }
136
137 pub const fn repos(&self) -> &Repos {
139 &self.repos
140 }
141
142 pub fn repos_mut(&mut self) -> &mut Repos {
144 &mut self.repos
145 }
146}
147
148#[derive(Debug)]
152pub enum TransactionError<E> {
153 Begin(E),
155
156 Commit(E),
158
159 Rollback(E),
161
162 Operation(E)
164}
165
166impl<E: fmt::Display> fmt::Display for TransactionError<E> {
167 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168 match self {
169 Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
170 Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
171 Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
172 Self::Operation(e) => write!(f, "transaction operation failed: {e}")
173 }
174 }
175}
176
177impl<E: StdError + 'static> StdError for TransactionError<E> {
178 fn source(&self) -> Option<&(dyn StdError + 'static)> {
179 match self {
180 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
181 }
182 }
183}
184
185impl<E> TransactionError<E> {
186 pub const fn is_begin(&self) -> bool {
188 matches!(self, Self::Begin(_))
189 }
190
191 pub const fn is_commit(&self) -> bool {
193 matches!(self, Self::Commit(_))
194 }
195
196 pub const fn is_rollback(&self) -> bool {
198 matches!(self, Self::Rollback(_))
199 }
200
201 pub const fn is_operation(&self) -> bool {
203 matches!(self, Self::Operation(_))
204 }
205
206 pub fn into_inner(self) -> E {
208 match self {
209 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
210 }
211 }
212}
213
214#[allow(async_fn_in_trait)]
218pub trait Transactional: Sized + Send + Sync {
219 type Transaction<'t>: Send
221 where
222 Self: 't;
223
224 type Error: StdError + Send + Sync;
226
227 async fn begin(&self) -> Result<Self::Transaction<'_>, Self::Error>;
229}
230
231#[allow(async_fn_in_trait)]
233pub trait TransactionOps: Sized + Send {
234 type Error: StdError + Send + Sync;
236
237 async fn commit(self) -> Result<(), Self::Error>;
239
240 async fn rollback(self) -> Result<(), Self::Error>;
242}
243
244#[allow(async_fn_in_trait)]
249pub trait TransactionRunner<'p, Repos>: Sized {
250 type Tx: TransactionOps;
252
253 type DbError: StdError + Send + Sync;
255
256 async fn run<F, Fut, T, E>(self, f: F) -> Result<T, E>
267 where
268 F: FnOnce(TransactionContext<'_, Self::Tx, Repos>) -> Fut + Send,
269 Fut: Future<Output = Result<T, E>> + Send,
270 E: From<TransactionError<Self::DbError>>;
271}
272
273#[cfg(feature = "postgres")]
276mod postgres_impl {
277 use sqlx::{PgPool, Postgres};
278
279 use super::*;
280
281 impl Transactional for PgPool {
282 type Transaction<'t> = sqlx::Transaction<'t, Postgres>;
283 type Error = sqlx::Error;
284
285 async fn begin(&self) -> Result<Self::Transaction<'_>, Self::Error> {
286 sqlx::pool::Pool::begin(self).await
287 }
288 }
289
290 impl TransactionOps for sqlx::Transaction<'_, Postgres> {
291 type Error = sqlx::Error;
292
293 async fn commit(self) -> Result<(), Self::Error> {
294 sqlx::Transaction::commit(self).await
295 }
296
297 async fn rollback(self) -> Result<(), Self::Error> {
298 sqlx::Transaction::rollback(self).await
299 }
300 }
301
302 impl<'p, Repos: Send> Transaction<'p, PgPool, Repos> {
303 pub async fn run<F, Fut, T, E>(self, f: F) -> Result<T, E>
318 where
319 F: for<'t> FnOnce(
320 TransactionContext<'t, sqlx::Transaction<'t, Postgres>, Repos>
321 ) -> Fut
322 + Send,
323 Fut: Future<Output = Result<T, E>> + Send,
324 E: From<TransactionError<sqlx::Error>>,
325 Repos: Default
326 {
327 let tx = self.pool.begin().await.map_err(TransactionError::Begin)?;
328 let ctx = TransactionContext::new(tx, Repos::default());
329
330 match f(ctx).await {
331 Ok(result) => Ok(result),
332 Err(e) => Err(e)
333 }
334 }
335 }
336}
337#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn transaction_error_display_begin() {
345 let err: TransactionError<std::io::Error> =
346 TransactionError::Begin(std::io::Error::other("test"));
347 assert!(err.to_string().contains("begin"));
348 assert!(err.to_string().contains("test"));
349 }
350
351 #[test]
352 fn transaction_error_display_commit() {
353 let err: TransactionError<std::io::Error> =
354 TransactionError::Commit(std::io::Error::other("commit_err"));
355 assert!(err.to_string().contains("commit"));
356 assert!(err.to_string().contains("commit_err"));
357 }
358
359 #[test]
360 fn transaction_error_display_rollback() {
361 let err: TransactionError<std::io::Error> =
362 TransactionError::Rollback(std::io::Error::other("rollback_err"));
363 assert!(err.to_string().contains("rollback"));
364 assert!(err.to_string().contains("rollback_err"));
365 }
366
367 #[test]
368 fn transaction_error_display_operation() {
369 let err: TransactionError<std::io::Error> =
370 TransactionError::Operation(std::io::Error::other("op_err"));
371 assert!(err.to_string().contains("operation"));
372 assert!(err.to_string().contains("op_err"));
373 }
374
375 #[test]
376 fn transaction_error_is_methods() {
377 let begin: TransactionError<&str> = TransactionError::Begin("e");
378 let commit: TransactionError<&str> = TransactionError::Commit("e");
379 let rollback: TransactionError<&str> = TransactionError::Rollback("e");
380 let op: TransactionError<&str> = TransactionError::Operation("e");
381
382 assert!(begin.is_begin());
383 assert!(!begin.is_commit());
384 assert!(!begin.is_rollback());
385 assert!(!begin.is_operation());
386
387 assert!(commit.is_commit());
388 assert!(!commit.is_begin());
389
390 assert!(rollback.is_rollback());
391 assert!(!rollback.is_begin());
392
393 assert!(op.is_operation());
394 assert!(!op.is_begin());
395 }
396
397 #[test]
398 fn transaction_error_into_inner() {
399 let err: TransactionError<&str> = TransactionError::Operation("inner");
400 assert_eq!(err.into_inner(), "inner");
401 }
402
403 #[test]
404 fn transaction_error_into_inner_all_variants() {
405 assert_eq!(TransactionError::Begin("b").into_inner(), "b");
406 assert_eq!(TransactionError::Commit("c").into_inner(), "c");
407 assert_eq!(TransactionError::Rollback("r").into_inner(), "r");
408 assert_eq!(TransactionError::Operation("o").into_inner(), "o");
409 }
410
411 #[test]
412 fn transaction_error_source() {
413 let inner = std::io::Error::other("source_err");
414 let err: TransactionError<std::io::Error> = TransactionError::Begin(inner);
415 assert!(err.source().is_some());
416
417 let commit_err: TransactionError<std::io::Error> =
418 TransactionError::Commit(std::io::Error::other("c"));
419 assert!(commit_err.source().is_some());
420
421 let rollback_err: TransactionError<std::io::Error> =
422 TransactionError::Rollback(std::io::Error::other("r"));
423 assert!(rollback_err.source().is_some());
424
425 let op_err: TransactionError<std::io::Error> =
426 TransactionError::Operation(std::io::Error::other("o"));
427 assert!(op_err.source().is_some());
428 }
429
430 #[test]
431 fn transaction_builder_new() {
432 struct MockPool;
433 let pool = MockPool;
434 let tx: Transaction<'_, MockPool, ()> = Transaction::new(&pool);
435 let _ = tx.pool();
436 }
437
438 #[test]
439 fn transaction_builder_with_repo() {
440 struct MockPool;
441 let pool = MockPool;
442 let tx: Transaction<'_, MockPool, ()> = Transaction::new(&pool);
443 let tx2: Transaction<'_, MockPool, i32> = tx.with_repo();
444 let _ = tx2.pool();
445 }
446
447 #[test]
448 fn transaction_context_new() {
449 let tx = "mock_tx";
450 let repos = (1, 2, 3);
451 let ctx = TransactionContext::new(tx, repos);
452 assert_eq!(*ctx.repos(), (1, 2, 3));
453 }
454
455 #[test]
456 fn transaction_context_transaction() {
457 let tx = String::from("mock_tx");
458 let repos = ();
459 let mut ctx = TransactionContext::new(tx, repos);
460 assert_eq!(ctx.transaction(), "mock_tx");
461 }
462
463 #[test]
464 fn transaction_context_repos_mut() {
465 let tx = "mock_tx";
466 let repos = vec![1, 2, 3];
467 let mut ctx = TransactionContext::new(tx, repos);
468 ctx.repos_mut().push(4);
469 assert_eq!(*ctx.repos(), vec![1, 2, 3, 4]);
470 }
471}