use tokio::runtime::Handle;
use crate::middleware::SqlMiddlewareDbError;
use super::core::{Idle, InTx, SKIP_DROP_ROLLBACK};
use super::{TursoConnection, TursoManager};
impl TursoConnection<Idle> {
pub async fn begin(mut self) -> Result<TursoConnection<InTx>, SqlMiddlewareDbError> {
begin_from_conn(self.take_conn()?).await
}
}
impl TursoConnection<InTx> {
pub async fn commit(mut self) -> Result<TursoConnection<Idle>, SqlMiddlewareDbError> {
let conn = self.take_conn()?;
match conn.execute_batch("COMMIT").await {
Ok(()) => Ok(TursoConnection {
conn: Some(conn),
needs_rollback: false,
_state: std::marker::PhantomData,
}),
Err(e) => {
let _ = conn.execute_batch("ROLLBACK").await;
self.conn = Some(conn);
Err(SqlMiddlewareDbError::ExecutionError(format!(
"turso commit error: {e}"
)))
}
}
}
pub async fn rollback(mut self) -> Result<TursoConnection<Idle>, SqlMiddlewareDbError> {
let conn = self.take_conn()?;
match conn.execute_batch("ROLLBACK").await {
Ok(()) => Ok(TursoConnection {
conn: Some(conn),
needs_rollback: false,
_state: std::marker::PhantomData,
}),
Err(e) => {
self.conn = Some(conn);
Err(SqlMiddlewareDbError::ExecutionError(format!(
"turso rollback error: {e}"
)))
}
}
}
}
pub(crate) async fn begin_from_conn(
conn: bb8::PooledConnection<'static, TursoManager>,
) -> Result<TursoConnection<InTx>, SqlMiddlewareDbError> {
conn.execute_batch("BEGIN")
.await
.map_err(|e| SqlMiddlewareDbError::ExecutionError(format!("turso begin error: {e}")))?;
Ok(TursoConnection {
conn: Some(conn),
needs_rollback: true,
_state: std::marker::PhantomData,
})
}
fn skip_drop_rollback() -> bool {
SKIP_DROP_ROLLBACK.load(std::sync::atomic::Ordering::Relaxed)
}
#[doc(hidden)]
pub fn set_skip_drop_rollback_for_tests(skip: bool) {
SKIP_DROP_ROLLBACK.store(skip, std::sync::atomic::Ordering::Relaxed);
}
impl<State> Drop for TursoConnection<State> {
fn drop(&mut self) {
if self.needs_rollback
&& !skip_drop_rollback()
&& let Some(conn) = self.conn.take()
&& let Ok(handle) = Handle::try_current()
{
handle.spawn(async move {
let _ = conn.execute_batch("ROLLBACK").await;
});
}
}
}