entity_core/
transaction.rs1#[cfg(feature = "postgres")]
41use std::future::Future;
42use std::{error::Error as StdError, fmt};
43
44pub struct Transaction<'p, DB> {
68 pool: &'p DB
69}
70
71impl<'p, DB> Transaction<'p, DB> {
72 pub const fn new(pool: &'p DB) -> Self {
84 Self {
85 pool
86 }
87 }
88
89 pub const fn pool(&self) -> &'p DB {
91 self.pool
92 }
93}
94
95#[cfg(feature = "postgres")]
117pub struct TransactionContext {
118 tx: sqlx::Transaction<'static, sqlx::Postgres>
119}
120
121#[cfg(feature = "postgres")]
122impl TransactionContext {
123 #[doc(hidden)]
129 pub fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
130 Self {
131 tx
132 }
133 }
134
135 pub fn transaction(&mut self) -> &mut sqlx::Transaction<'static, sqlx::Postgres> {
140 &mut self.tx
141 }
142
143 pub async fn commit(self) -> Result<(), sqlx::Error> {
147 self.tx.commit().await
148 }
149
150 pub async fn rollback(self) -> Result<(), sqlx::Error> {
154 self.tx.rollback().await
155 }
156}
157
158#[derive(Debug)]
162pub enum TransactionError<E> {
163 Begin(E),
165
166 Commit(E),
168
169 Rollback(E),
171
172 Operation(E)
174}
175
176impl<E: fmt::Display> fmt::Display for TransactionError<E> {
177 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178 match self {
179 Self::Begin(e) => write!(f, "failed to begin transaction: {e}"),
180 Self::Commit(e) => write!(f, "failed to commit transaction: {e}"),
181 Self::Rollback(e) => write!(f, "failed to rollback transaction: {e}"),
182 Self::Operation(e) => write!(f, "transaction operation failed: {e}")
183 }
184 }
185}
186
187impl<E: StdError + 'static> StdError for TransactionError<E> {
188 fn source(&self) -> Option<&(dyn StdError + 'static)> {
189 match self {
190 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => Some(e)
191 }
192 }
193}
194
195impl<E> TransactionError<E> {
196 pub const fn is_begin(&self) -> bool {
198 matches!(self, Self::Begin(_))
199 }
200
201 pub const fn is_commit(&self) -> bool {
203 matches!(self, Self::Commit(_))
204 }
205
206 pub const fn is_rollback(&self) -> bool {
208 matches!(self, Self::Rollback(_))
209 }
210
211 pub const fn is_operation(&self) -> bool {
213 matches!(self, Self::Operation(_))
214 }
215
216 pub fn into_inner(self) -> E {
218 match self {
219 Self::Begin(e) | Self::Commit(e) | Self::Rollback(e) | Self::Operation(e) => e
220 }
221 }
222}
223
224#[cfg(feature = "postgres")]
225impl From<TransactionError<sqlx::Error>> for sqlx::Error {
226 fn from(err: TransactionError<sqlx::Error>) -> Self {
227 err.into_inner()
228 }
229}
230
231#[cfg(feature = "postgres")]
233impl<'p> Transaction<'p, sqlx::PgPool> {
234 pub async fn run<F, Fut, T, E>(self, f: F) -> Result<T, E>
257 where
258 F: FnOnce(TransactionContext) -> Fut + Send,
259 Fut: Future<Output = Result<T, E>> + Send,
260 E: From<sqlx::Error>
261 {
262 let tx = self.pool.begin().await.map_err(E::from)?;
263 let ctx = TransactionContext::new(tx);
264
265 match f(ctx).await {
266 Ok(result) => Ok(result),
267 Err(e) => Err(e)
268 }
269 }
270
271 pub async fn run_with_commit<F, Fut, T, E>(self, f: F) -> Result<T, E>
288 where
289 F: FnOnce(TransactionContext) -> Fut + Send,
290 Fut: Future<Output = Result<T, E>> + Send,
291 E: From<sqlx::Error>
292 {
293 let tx = self.pool.begin().await.map_err(E::from)?;
294 let ctx = TransactionContext::new(tx);
295 f(ctx).await
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use std::error::Error;
302
303 use super::*;
304
305 #[test]
306 fn transaction_error_display_begin() {
307 let err: TransactionError<std::io::Error> =
308 TransactionError::Begin(std::io::Error::other("test"));
309 assert!(err.to_string().contains("begin"));
310 assert!(err.to_string().contains("test"));
311 }
312
313 #[test]
314 fn transaction_error_display_commit() {
315 let err: TransactionError<std::io::Error> =
316 TransactionError::Commit(std::io::Error::other("test"));
317 assert!(err.to_string().contains("commit"));
318 }
319
320 #[test]
321 fn transaction_error_display_rollback() {
322 let err: TransactionError<std::io::Error> =
323 TransactionError::Rollback(std::io::Error::other("test"));
324 assert!(err.to_string().contains("rollback"));
325 }
326
327 #[test]
328 fn transaction_error_display_operation() {
329 let err: TransactionError<std::io::Error> =
330 TransactionError::Operation(std::io::Error::other("test"));
331 assert!(err.to_string().contains("operation"));
332 }
333
334 #[test]
335 fn transaction_error_is_methods() {
336 let begin: TransactionError<&str> = TransactionError::Begin("e");
337 let commit: TransactionError<&str> = TransactionError::Commit("e");
338 let rollback: TransactionError<&str> = TransactionError::Rollback("e");
339 let operation: TransactionError<&str> = TransactionError::Operation("e");
340
341 assert!(begin.is_begin());
342 assert!(!begin.is_commit());
343 assert!(!begin.is_rollback());
344 assert!(!begin.is_operation());
345
346 assert!(!commit.is_begin());
347 assert!(commit.is_commit());
348 assert!(!commit.is_rollback());
349 assert!(!commit.is_operation());
350
351 assert!(!rollback.is_begin());
352 assert!(!rollback.is_commit());
353 assert!(rollback.is_rollback());
354 assert!(!rollback.is_operation());
355
356 assert!(!operation.is_begin());
357 assert!(!operation.is_commit());
358 assert!(!operation.is_rollback());
359 assert!(operation.is_operation());
360 }
361
362 #[test]
363 fn transaction_error_into_inner() {
364 let err: TransactionError<&str> = TransactionError::Operation("test");
365 assert_eq!(err.into_inner(), "test");
366 }
367
368 #[test]
369 fn transaction_error_into_inner_begin() {
370 let err: TransactionError<&str> = TransactionError::Begin("begin_err");
371 assert_eq!(err.into_inner(), "begin_err");
372 }
373
374 #[test]
375 fn transaction_error_into_inner_commit() {
376 let err: TransactionError<&str> = TransactionError::Commit("commit_err");
377 assert_eq!(err.into_inner(), "commit_err");
378 }
379
380 #[test]
381 fn transaction_error_into_inner_rollback() {
382 let err: TransactionError<&str> = TransactionError::Rollback("rollback_err");
383 assert_eq!(err.into_inner(), "rollback_err");
384 }
385
386 #[test]
387 fn transaction_error_source_begin() {
388 let err: TransactionError<std::io::Error> =
389 TransactionError::Begin(std::io::Error::other("src"));
390 assert!(err.source().is_some());
391 }
392
393 #[test]
394 fn transaction_error_source_commit() {
395 let err: TransactionError<std::io::Error> =
396 TransactionError::Commit(std::io::Error::other("src"));
397 assert!(err.source().is_some());
398 }
399
400 #[test]
401 fn transaction_error_source_rollback() {
402 let err: TransactionError<std::io::Error> =
403 TransactionError::Rollback(std::io::Error::other("src"));
404 assert!(err.source().is_some());
405 }
406
407 #[test]
408 fn transaction_error_source_operation() {
409 let err: TransactionError<std::io::Error> =
410 TransactionError::Operation(std::io::Error::other("src"));
411 assert!(err.source().is_some());
412 }
413
414 #[test]
415 fn transaction_builder_new() {
416 struct MockPool;
417 let pool = MockPool;
418 let tx = Transaction::new(&pool);
419 let _ = tx.pool();
420 }
421
422 #[test]
423 fn transaction_builder_pool_accessor() {
424 struct MockPool {
425 id: u32
426 }
427 let pool = MockPool {
428 id: 42
429 };
430 let tx = Transaction::new(&pool);
431 assert_eq!(tx.pool().id, 42);
432 }
433
434 #[test]
435 fn transaction_error_debug() {
436 let err: TransactionError<&str> = TransactionError::Begin("test");
437 let debug_str = format!("{:?}", err);
438 assert!(debug_str.contains("Begin"));
439 assert!(debug_str.contains("test"));
440 }
441}