use crate::databases::DatabaseConnection;
use anyhow::Result;
pub struct TransactionManager {
db: Box<dyn DatabaseConnection>,
in_transaction: bool,
savepoints: Vec<String>,
}
impl TransactionManager {
pub fn new(db: Box<dyn DatabaseConnection>) -> Self {
Self {
db,
in_transaction: false,
savepoints: Vec::new(),
}
}
pub async fn begin(&mut self) -> Result<()> {
if self.in_transaction {
return Err(anyhow::anyhow!("Transaction already in progress"));
}
self.db.begin_transaction().await?;
self.in_transaction = true;
Ok(())
}
pub async fn commit(&mut self) -> Result<()> {
if !self.in_transaction {
return Err(anyhow::anyhow!("No transaction in progress"));
}
self.db.commit_transaction().await?;
self.in_transaction = false;
self.savepoints.clear();
Ok(())
}
pub async fn rollback(&mut self) -> Result<()> {
if !self.in_transaction {
return Err(anyhow::anyhow!("No transaction in progress"));
}
self.db.rollback_transaction().await?;
self.in_transaction = false;
self.savepoints.clear();
Ok(())
}
pub async fn savepoint(&mut self, name: &str) -> Result<()> {
if !self.in_transaction {
return Err(anyhow::anyhow!("No transaction in progress"));
}
let sql = format!("SAVEPOINT {}", name);
self.db.execute(&sql).await?;
self.savepoints.push(name.to_string());
Ok(())
}
pub async fn rollback_to_savepoint(&mut self, name: &str) -> Result<()> {
if !self.in_transaction {
return Err(anyhow::anyhow!("No transaction in progress"));
}
let sql = format!("ROLLBACK TO SAVEPOINT {}", name);
self.db.execute(&sql).await?;
if let Some(pos) = self.savepoints.iter().position(|s| s == name) {
self.savepoints.truncate(pos + 1);
}
Ok(())
}
pub async fn release_savepoint(&mut self, name: &str) -> Result<()> {
if !self.in_transaction {
return Err(anyhow::anyhow!("No transaction in progress"));
}
let sql = format!("RELEASE SAVEPOINT {}", name);
self.db.execute(&sql).await?;
if let Some(pos) = self.savepoints.iter().position(|s| s == name) {
self.savepoints.remove(pos);
}
Ok(())
}
pub fn is_in_transaction(&self) -> bool {
self.in_transaction
}
pub fn get_savepoints(&self) -> &[String] {
&self.savepoints
}
}
pub struct TransactionExecutor;
impl TransactionExecutor {
pub async fn execute<F, Fut>(db: &mut TransactionManager, operation: F) -> Result<()>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<()>>,
{
db.begin().await?;
match operation().await {
Ok(()) => {
db.commit().await?;
Ok(())
}
Err(e) => {
let _ = db.rollback().await;
Err(e)
}
}
}
pub async fn execute_batch<F, Fut>(
db: &mut TransactionManager,
operations: Vec<F>,
) -> Result<Vec<Result<()>>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<()>>,
{
db.begin().await?;
let mut results = Vec::with_capacity(operations.len());
for (i, operation) in operations.into_iter().enumerate() {
let savepoint_name = format!("sp_{}", i);
db.savepoint(&savepoint_name).await?;
match operation().await {
Ok(()) => {
results.push(Ok(()));
db.release_savepoint(&savepoint_name).await?;
}
Err(e) => {
results.push(Err(e));
db.rollback_to_savepoint(&savepoint_name).await?;
}
}
}
db.commit().await?;
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{create_connection, DatabaseType};
#[tokio::test]
#[ignore]
async fn test_transaction_basic() {
let conn = create_connection(DatabaseType::SQLite, "sqlite::memory:").await.unwrap();
let mut tx = TransactionManager::new(conn);
let result = tx.begin().await;
println!("Begin result: {:?}", result);
if result.is_err() {
println!("Transaction not supported, skipping...");
return;
}
assert!(tx.is_in_transaction());
let result = tx.commit().await;
println!("Commit result: {:?}", result);
if result.is_ok() {
assert!(!tx.is_in_transaction());
}
}
#[tokio::test]
#[ignore]
async fn test_transaction_rollback() {
let conn = create_connection(DatabaseType::SQLite, "sqlite::memory:").await.unwrap();
let mut tx = TransactionManager::new(conn);
let result = tx.begin().await;
println!("Begin result: {:?}", result);
if result.is_err() {
println!("Transaction not supported, skipping...");
return;
}
let result = tx.rollback().await;
println!("Rollback result: {:?}", result);
if result.is_ok() {
assert!(!tx.is_in_transaction());
}
}
#[tokio::test]
#[ignore]
async fn test_savepoint() {
let conn = create_connection(DatabaseType::SQLite, "sqlite::memory:").await.unwrap();
let mut tx = TransactionManager::new(conn);
let result = tx.begin().await;
println!("Begin result: {:?}", result);
if result.is_err() {
println!("Transaction not supported, skipping...");
return;
}
let result = tx.savepoint("sp1").await;
println!("Savepoint result: {:?}", result);
if result.is_ok() {
assert_eq!(tx.get_savepoints().len(), 1);
}
let result = tx.rollback_to_savepoint("sp1").await;
println!("Rollback to savepoint result: {:?}", result);
let result = tx.commit().await;
println!("Commit result: {:?}", result);
}
}