use std::sync::Arc;
use crate::adapters::params::convert_params;
use crate::middleware::{ConversionMode, ResultSet, RowValues, SqlMiddlewareDbError};
use crate::pool::MiddlewarePoolConnection;
use crate::tx_outcome::TxOutcome;
use super::connection::SqliteConnection;
use super::params::Params;
use std::sync::atomic::{AtomicBool, Ordering};
static REWRAP_ON_ROLLBACK_FAILURE: AtomicBool = AtomicBool::new(false);
#[doc(hidden)]
pub fn set_rewrap_on_rollback_failure_for_tests(rewrap: bool) {
REWRAP_ON_ROLLBACK_FAILURE.store(rewrap, Ordering::Relaxed);
}
fn rewrap_on_rollback_failure_for_tests() -> bool {
REWRAP_ON_ROLLBACK_FAILURE.load(Ordering::Relaxed)
}
pub struct Tx<'a> {
conn: Option<SqliteConnection>,
conn_slot: &'a mut MiddlewarePoolConnection,
}
pub struct Prepared {
sql: Arc<String>,
}
pub async fn begin_transaction(
conn_slot: &mut MiddlewarePoolConnection,
) -> Result<Tx<'_>, SqlMiddlewareDbError> {
#[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot else {
return Err(SqlMiddlewareDbError::Unimplemented(
"begin_transaction is only available for SQLite connections".into(),
));
};
#[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot;
let mut conn = conn.take().ok_or_else(|| {
SqlMiddlewareDbError::ExecutionError(
"SQLite connection already taken from pool wrapper".into(),
)
})?;
conn.begin().await?;
Ok(Tx {
conn: Some(conn),
conn_slot,
})
}
impl Tx<'_> {
fn conn_mut(&mut self) -> Result<&mut SqliteConnection, SqlMiddlewareDbError> {
self.conn.as_mut().ok_or_else(|| {
SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
})
}
pub fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
if self.conn.is_none() {
return Err(SqlMiddlewareDbError::ExecutionError(
"SQLite transaction already completed".into(),
));
}
Ok(Prepared {
sql: Arc::new(sql.to_owned()),
})
}
pub async fn execute_prepared(
&mut self,
prepared: &Prepared,
params: &[RowValues],
) -> Result<usize, SqlMiddlewareDbError> {
let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
let conn = self.conn_mut()?;
conn.execute_dml_in_tx(prepared.sql.as_ref(), &converted.0)
.await
}
pub async fn query_prepared(
&mut self,
prepared: &Prepared,
params: &[RowValues],
) -> Result<ResultSet, SqlMiddlewareDbError> {
let converted = convert_params::<Params>(params, ConversionMode::Query)?;
let conn = self.conn_mut()?;
conn.execute_select_in_tx(
prepared.sql.as_ref(),
&converted.0,
super::query::build_result_set,
)
.await
}
pub async fn execute_batch(&mut self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
let conn = self.conn_mut()?;
conn.execute_batch_in_tx(sql).await
}
pub async fn commit(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
let mut conn = self.conn.take().ok_or_else(|| {
SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
})?;
match conn.commit().await {
Ok(()) => {
self.rewrap(conn);
Ok(TxOutcome::without_restored_connection())
}
Err(err) => {
let handle = conn.conn_handle();
let rollback_result =
super::connection::rollback_with_busy_retries(&handle).await;
if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
conn.in_transaction = false;
self.rewrap(conn);
}
if rollback_result.is_err() && !rewrap_on_rollback_failure_for_tests() {
handle.mark_broken();
}
Err(err)
}
}
}
pub async fn rollback(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
let mut conn = self.conn.take().ok_or_else(|| {
SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
})?;
let handle = conn.conn_handle();
match super::connection::rollback_with_busy_retries(&handle).await {
Ok(()) => {
conn.in_transaction = false;
self.rewrap(conn);
Ok(TxOutcome::without_restored_connection())
}
Err(err) => {
if rewrap_on_rollback_failure_for_tests() {
conn.in_transaction = false;
self.rewrap(conn);
}
if !rewrap_on_rollback_failure_for_tests() {
handle.mark_broken();
}
Err(err)
}
}
}
fn rewrap(&mut self, conn: SqliteConnection) {
#[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot else {
return;
};
#[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot;
debug_assert!(slot.is_none(), "sqlite conn slot should be empty during tx");
*slot = Some(conn);
}
}
impl Drop for Tx<'_> {
fn drop(&mut self) {
if let Some(mut conn) = self.conn.take() {
let handle = conn.conn_handle();
let rollback_result =
super::connection::rollback_with_busy_retries_blocking(&handle);
if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
conn.in_transaction = false;
self.rewrap(conn);
} else {
handle.mark_broken();
}
}
}
}