use crate::connection::AssertSend;
use crate::error::{firebird_err, Error};
use crate::{Firebird, FirebirdConnection};
use sqlx_core::sql_str::SqlStr;
use sqlx_core::transaction::TransactionManager;
pub struct FirebirdTransactionManager;
impl TransactionManager for FirebirdTransactionManager {
type Database = Firebird;
fn begin(
conn: &mut FirebirdConnection,
statement: Option<SqlStr>,
) -> impl std::future::Future<Output = Result<(), Error>> + Send + '_ {
let inner = conn.inner.clone();
let depth = conn.transaction_depth;
AssertSend(async move {
if depth == 0 {
if let Some(statement) = statement {
let mut guard = inner.lock().await;
guard
.execute_batch(statement.as_str())
.await
.map_err(firebird_err)?;
}
} else {
if statement.is_some() {
return Err(Error::InvalidSavePointStatement);
}
let sql = format!("SAVEPOINT _sqlx_savepoint_{depth}");
let mut guard = inner.lock().await;
guard.execute_batch(&sql).await.map_err(firebird_err)?;
}
conn.transaction_depth += 1;
Ok(())
})
}
fn commit(
conn: &mut FirebirdConnection,
) -> impl std::future::Future<Output = Result<(), Error>> + Send + '_ {
let inner = conn.inner.clone();
let depth = conn.transaction_depth;
AssertSend(async move {
if depth > 0 {
if depth == 1 {
let guard = inner.lock().await;
guard.commit().await.map_err(firebird_err)?;
} else {
let sql = format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1);
let mut guard = inner.lock().await;
guard.execute_batch(&sql).await.map_err(firebird_err)?;
}
conn.transaction_depth = depth - 1;
}
Ok(())
})
}
fn rollback(
conn: &mut FirebirdConnection,
) -> impl std::future::Future<Output = Result<(), Error>> + Send + '_ {
let inner = conn.inner.clone();
let depth = conn.transaction_depth;
AssertSend(async move {
if depth > 0 {
if depth == 1 {
let mut guard = inner.lock().await;
guard.rollback().await.map_err(firebird_err)?;
} else {
let sql = format!("ROLLBACK TO SAVEPOINT _sqlx_savepoint_{}", depth - 1);
let mut guard = inner.lock().await;
guard.execute_batch(&sql).await.map_err(firebird_err)?;
}
conn.transaction_depth = depth - 1;
}
Ok(())
})
}
fn start_rollback(conn: &mut FirebirdConnection) {
let depth = conn.transaction_depth;
if depth > 0 {
conn.transaction_depth = depth - 1;
}
}
fn get_transaction_depth(conn: &FirebirdConnection) -> usize {
conn.transaction_depth
}
}