Skip to main content

sql_middleware/sqlite/
transaction.rs

1use std::sync::Arc;
2
3use crate::adapters::params::convert_params;
4use crate::middleware::{ConversionMode, ResultSet, RowValues, SqlMiddlewareDbError};
5use crate::pool::MiddlewarePoolConnection;
6use crate::tx_outcome::TxOutcome;
7
8use super::connection::SqliteConnection;
9use super::params::Params;
10
11use std::sync::atomic::{AtomicBool, Ordering};
12
13static REWRAP_ON_ROLLBACK_FAILURE: AtomicBool = AtomicBool::new(false);
14
15#[doc(hidden)]
16pub fn set_rewrap_on_rollback_failure_for_tests(rewrap: bool) {
17    REWRAP_ON_ROLLBACK_FAILURE.store(rewrap, Ordering::Relaxed);
18}
19
20fn rewrap_on_rollback_failure_for_tests() -> bool {
21    REWRAP_ON_ROLLBACK_FAILURE.load(Ordering::Relaxed)
22}
23
24/// Transaction handle that owns the `SQLite` connection until completion.
25pub struct Tx<'a> {
26    conn: Option<SqliteConnection>,
27    conn_slot: &'a mut MiddlewarePoolConnection,
28}
29
30/// Prepared statement tied to a `SQLite` transaction.
31pub struct Prepared {
32    sql: Arc<String>,
33}
34
35/// Begin a transaction, temporarily taking ownership of the pooled `SQLite` connection
36/// until commit/rollback (or drop) returns it to the wrapper.
37///
38/// # Errors
39/// Returns `SqlMiddlewareDbError` if the transaction cannot be started.
40pub async fn begin_transaction(
41    conn_slot: &mut MiddlewarePoolConnection,
42) -> Result<Tx<'_>, SqlMiddlewareDbError> {
43    #[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
44    let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot else {
45        return Err(SqlMiddlewareDbError::Unimplemented(
46            "begin_transaction is only available for SQLite connections".into(),
47        ));
48    };
49    #[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
50    let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot;
51
52    let mut conn = conn.take().ok_or_else(|| {
53        SqlMiddlewareDbError::ExecutionError(
54            "SQLite connection already taken from pool wrapper".into(),
55        )
56    })?;
57    conn.begin().await?;
58    Ok(Tx {
59        conn: Some(conn),
60        conn_slot,
61    })
62}
63
64impl Tx<'_> {
65    fn conn_mut(&mut self) -> Result<&mut SqliteConnection, SqlMiddlewareDbError> {
66        self.conn.as_mut().ok_or_else(|| {
67            SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
68        })
69    }
70
71    /// Prepare a statement within this transaction.
72    ///
73    /// # Errors
74    /// Returns `SqlMiddlewareDbError` if the transaction has already completed.
75    pub fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
76        if self.conn.is_none() {
77            return Err(SqlMiddlewareDbError::ExecutionError(
78                "SQLite transaction already completed".into(),
79            ));
80        }
81        Ok(Prepared {
82            sql: Arc::new(sql.to_owned()),
83        })
84    }
85
86    /// Execute a prepared statement as DML within this transaction.
87    ///
88    /// # Errors
89    /// Returns `SqlMiddlewareDbError` if parameter conversion or execution fails.
90    pub async fn execute_prepared(
91        &mut self,
92        prepared: &Prepared,
93        params: &[RowValues],
94    ) -> Result<usize, SqlMiddlewareDbError> {
95        let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
96        let conn = self.conn_mut()?;
97        conn.execute_dml_in_tx(prepared.sql.as_ref(), &converted.0)
98            .await
99    }
100
101    /// Execute a prepared statement as a query within this transaction.
102    ///
103    /// # Errors
104    /// Returns `SqlMiddlewareDbError` if parameter conversion or execution fails.
105    pub async fn query_prepared(
106        &mut self,
107        prepared: &Prepared,
108        params: &[RowValues],
109    ) -> Result<ResultSet, SqlMiddlewareDbError> {
110        let converted = convert_params::<Params>(params, ConversionMode::Query)?;
111        let conn = self.conn_mut()?;
112        conn.execute_select_in_tx(
113            prepared.sql.as_ref(),
114            &converted.0,
115            super::query::build_result_set,
116        )
117        .await
118    }
119
120    /// Execute a batch inside the open transaction.
121    ///
122    /// # Errors
123    /// Returns `SqlMiddlewareDbError` if executing the batch fails.
124    pub async fn execute_batch(&mut self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
125        let conn = self.conn_mut()?;
126        conn.execute_batch_in_tx(sql).await
127    }
128
129    /// Commit the transaction and rewrap the pooled connection.
130    ///
131    /// # Errors
132    /// Returns `SqlMiddlewareDbError` if committing the transaction fails.
133    pub async fn commit(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
134        let mut conn = self.conn.take().ok_or_else(|| {
135            SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
136        })?;
137        match conn.commit().await {
138            Ok(()) => {
139                self.rewrap(conn);
140                Ok(TxOutcome::without_restored_connection())
141            }
142            Err(err) => {
143                let handle = conn.conn_handle();
144                let rollback_result = super::connection::rollback_with_busy_retries(&handle).await;
145                if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
146                    conn.in_transaction = false;
147                    self.rewrap(conn);
148                }
149                if rollback_result.is_err() && !rewrap_on_rollback_failure_for_tests() {
150                    handle.mark_broken();
151                }
152                Err(err)
153            }
154        }
155    }
156
157    /// Roll back the transaction and rewrap the pooled connection.
158    ///
159    /// # Errors
160    /// Returns `SqlMiddlewareDbError` if rolling back fails.
161    pub async fn rollback(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
162        let mut conn = self.conn.take().ok_or_else(|| {
163            SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
164        })?;
165        let handle = conn.conn_handle();
166        match super::connection::rollback_with_busy_retries(&handle).await {
167            Ok(()) => {
168                conn.in_transaction = false;
169                self.rewrap(conn);
170                Ok(TxOutcome::without_restored_connection())
171            }
172            Err(err) => {
173                if rewrap_on_rollback_failure_for_tests() {
174                    conn.in_transaction = false;
175                    self.rewrap(conn);
176                }
177                if !rewrap_on_rollback_failure_for_tests() {
178                    handle.mark_broken();
179                }
180                Err(err)
181            }
182        }
183    }
184
185    fn rewrap(&mut self, conn: SqliteConnection) {
186        #[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
187        let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot else {
188            return;
189        };
190        #[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
191        let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot;
192        debug_assert!(slot.is_none(), "sqlite conn slot should be empty during tx");
193        *slot = Some(conn);
194    }
195}
196
197impl Drop for Tx<'_> {
198    /// Rolls back on drop to avoid leaking open transactions; the rollback is best-effort and
199    /// `SQLite` may report "no transaction is active" if the transaction was already completed
200    /// by user code (e.g., via `execute_batch_in_tx`). Such errors are ignored because the goal
201    /// is simply to leave the connection in a clean state before returning it to the pool.
202    fn drop(&mut self) {
203        if let Some(mut conn) = self.conn.take() {
204            let handle = conn.conn_handle();
205            let rollback_result = super::connection::rollback_with_busy_retries_blocking(&handle);
206            if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
207                conn.in_transaction = false;
208                self.rewrap(conn);
209            } else {
210                // Mark broken so the pool will drop and replace this connection instead of
211                // handing out one that might still be mid-transaction.
212                handle.mark_broken();
213            }
214        }
215    }
216}