use std::sync::Arc;
use std::sync::atomic::Ordering;
use tokio::runtime::Handle;
use tokio::task::block_in_place;
use crate::middleware::SqlMiddlewareDbError;
use super::SqliteTypedConnection;
use super::core::{SKIP_DROP_ROLLBACK, begin_from_conn, run_blocking};
use crate::sqlite::connection::{rollback_with_busy_retries, rollback_with_busy_retries_blocking};
use crate::sqlite::config::SharedSqliteConnection;
impl SqliteTypedConnection<super::core::Idle> {
pub async fn begin(
mut self,
) -> Result<SqliteTypedConnection<super::core::InTx>, SqlMiddlewareDbError> {
begin_from_conn(self.take_conn()?).await
}
}
impl SqliteTypedConnection<super::core::InTx> {
pub async fn commit(
mut self,
) -> Result<SqliteTypedConnection<super::core::Idle>, SqlMiddlewareDbError> {
let conn_handle = self.conn_handle()?;
let commit_result = run_blocking(Arc::clone(&conn_handle), |guard| {
guard
.execute_batch("COMMIT")
.map_err(SqlMiddlewareDbError::SqliteError)
})
.await;
match commit_result {
Ok(()) => {
let conn = self.take_conn()?;
Ok(SqliteTypedConnection {
conn: Some(conn),
needs_rollback: false,
_state: std::marker::PhantomData,
})
}
Err(err) => {
if rollback_with_busy_retries(&conn_handle).await.is_err() {
conn_handle.mark_broken();
}
Err(err)
}
}
}
pub async fn rollback(
mut self,
) -> Result<SqliteTypedConnection<super::core::Idle>, SqlMiddlewareDbError> {
let conn_handle = self.conn_handle()?;
let rollback_result = rollback_with_busy_retries(&conn_handle).await;
match rollback_result {
Ok(()) => {
let conn = self.take_conn()?;
Ok(SqliteTypedConnection {
conn: Some(conn),
needs_rollback: false,
_state: std::marker::PhantomData,
})
}
Err(err) => {
conn_handle.mark_broken();
Err(err)
}
}
}
}
impl<State> Drop for SqliteTypedConnection<State> {
fn drop(&mut self) {
if self.needs_rollback
&& !skip_drop_rollback()
&& let Some(conn) = self.conn.take()
{
let conn_handle: SharedSqliteConnection = Arc::clone(&*conn);
let rollback = || rollback_with_busy_retries_blocking(&conn_handle);
let result = if Handle::try_current().is_ok() {
block_in_place(rollback)
} else {
rollback()
};
if result.is_err() {
conn_handle.mark_broken();
}
}
}
}
fn skip_drop_rollback() -> bool {
SKIP_DROP_ROLLBACK.load(Ordering::Relaxed)
}
#[doc(hidden)]
pub fn set_skip_drop_rollback_for_tests(skip: bool) {
SKIP_DROP_ROLLBACK.store(skip, Ordering::Relaxed);
}