Skip to main content

bottle_orm/
transaction.rs

1//! # Transaction Module
2//!
3//! This module provides the transaction management functionality for Bottle ORM.
4//! It allows executing multiple database operations atomically, ensuring data consistency.
5
6// ============================================================================
7// External Crate Imports
8// ============================================================================
9
10use heck::ToSnakeCase;
11use std::sync::Arc;
12use tokio::sync::Mutex;
13use futures::future::BoxFuture;
14use sqlx::any::AnyArguments;
15
16// ============================================================================
17// Internal Crate Imports
18// ============================================================================
19
20use crate::{
21    database::{Connection, Drivers, RawQuery},
22    Model, QueryBuilder,
23};
24
25// ============================================================================
26// Transaction Struct
27// ============================================================================
28
29/// A wrapper around a SQLx transaction.
30///
31/// Provides a way to execute multiple queries atomically. If any query fails,
32/// the transaction can be rolled back. If all succeed, it can be committed.
33#[derive(Debug, Clone)]
34pub struct Transaction<'a> {
35    pub(crate) tx: Arc<Mutex<Option<sqlx::Transaction<'a, sqlx::Any>>>>,
36    pub(crate) pool: sqlx::AnyPool,
37    pub(crate) driver: Drivers,
38}
39
40// Transaction is Send and Sync because it uses Arc<Mutex>.
41// This allows it to be used easily in async handlers (like Axum).
42
43// ============================================================================
44// Connection Implementation
45// ============================================================================
46
47impl Connection for Transaction<'_> {
48    fn driver(&self) -> Drivers { self.driver }
49    fn execute<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyQueryResult, sqlx::Error>> {
50        Box::pin(async move {
51            let mut guard = self.tx.lock().await;
52            if let Some(tx) = guard.as_mut() {
53                sqlx::query_with(sql, args).execute(&mut **tx).await
54            } else {
55                Err(sqlx::Error::WorkerCrashed)
56            }
57        })
58    }
59
60    fn fetch_all<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Vec<sqlx::any::AnyRow>, sqlx::Error>> {
61        Box::pin(async move {
62            let mut guard = self.tx.lock().await;
63            if let Some(tx) = guard.as_mut() {
64                sqlx::query_with(sql, args).fetch_all(&mut **tx).await
65            } else {
66                Err(sqlx::Error::WorkerCrashed)
67            }
68        })
69    }
70
71    fn fetch_one<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyRow, sqlx::Error>> {
72        Box::pin(async move {
73            let mut guard = self.tx.lock().await;
74            if let Some(tx) = guard.as_mut() {
75                sqlx::query_with(sql, args).fetch_one(&mut **tx).await
76            } else {
77                Err(sqlx::Error::WorkerCrashed)
78            }
79        })
80    }
81
82    fn fetch_optional<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Option<sqlx::any::AnyRow>, sqlx::Error>> {
83        Box::pin(async move {
84            let mut guard = self.tx.lock().await;
85            if let Some(tx) = guard.as_mut() {
86                sqlx::query_with(sql, args).fetch_optional(&mut **tx).await
87            } else {
88                Err(sqlx::Error::WorkerCrashed)
89            }
90        })
91    }
92
93    fn clone_db(&self) -> crate::Database {
94        crate::Database {
95            pool: self.pool.clone(),
96            driver: self.driver,
97        }
98    }
99}
100
101// ============================================================================
102// Transaction Implementation
103// ============================================================================
104
105impl<'a> Transaction<'a> {
106    /// Starts building a query within this transaction.
107    pub fn model<T: Model + Send + Sync + Unpin + crate::AnyImpl>(
108        &self,
109    ) -> QueryBuilder<T, Self> {
110        let active_columns = T::active_columns();
111        let mut columns: Vec<String> = Vec::with_capacity(active_columns.capacity());
112
113        for col in active_columns {
114            columns.push(col.strip_prefix("r#").unwrap_or(col).to_snake_case());
115        }
116
117        QueryBuilder::new(self.clone(), self.driver, T::table_name(), <T as Model>::columns(), columns)
118    }
119
120    /// Creates a raw SQL query builder attached to this transaction.
121    pub fn raw<'b>(&self, sql: &'b str) -> RawQuery<'b, Self> {
122        RawQuery::new(self.clone(), sql)
123    }
124
125    /// Commits the transaction.
126    pub async fn commit(self) -> Result<(), sqlx::Error> {
127        let mut guard = self.tx.lock().await;
128        if let Some(tx) = guard.take() {
129            tx.commit().await
130        } else {
131            Ok(())
132        }
133    }
134
135    /// Rolls back the transaction.
136    pub async fn rollback(self) -> Result<(), sqlx::Error> {
137        let mut guard = self.tx.lock().await;
138        if let Some(tx) = guard.take() {
139            tx.rollback().await
140        } else {
141            Ok(())
142        }
143    }
144}