sqlx-firebirdsql 0.1.0

Firebird SQL driver for SQLx
use crate::connection::AssertSend;
use crate::error::{firebird_err, Error};
use crate::{Firebird, FirebirdConnection};

use sqlx_core::sql_str::SqlStr;
use sqlx_core::transaction::TransactionManager;

/// Implementation of [`TransactionManager`] for Firebird.
pub struct FirebirdTransactionManager;

impl TransactionManager for FirebirdTransactionManager {
    type Database = Firebird;

    fn begin(
        conn: &mut FirebirdConnection,
        statement: Option<SqlStr>,
    ) -> impl std::future::Future<Output = Result<(), Error>> + Send + '_ {
        let inner = conn.inner.clone();
        let depth = conn.transaction_depth;
        AssertSend(async move {
            if depth == 0 {
                if let Some(statement) = statement {
                    let mut guard = inner.lock().await;
                    guard
                        .execute_batch(statement.as_str())
                        .await
                        .map_err(firebird_err)?;
                }
                // Firebird auto-starts transactions; just track depth.
            } else {
                if statement.is_some() {
                    return Err(Error::InvalidSavePointStatement);
                }
                let sql = format!("SAVEPOINT _sqlx_savepoint_{depth}");
                let mut guard = inner.lock().await;
                guard.execute_batch(&sql).await.map_err(firebird_err)?;
            }

            conn.transaction_depth += 1;
            Ok(())
        })
    }

    fn commit(
        conn: &mut FirebirdConnection,
    ) -> impl std::future::Future<Output = Result<(), Error>> + Send + '_ {
        let inner = conn.inner.clone();
        let depth = conn.transaction_depth;
        AssertSend(async move {
            if depth > 0 {
                if depth == 1 {
                    let guard = inner.lock().await;
                    guard.commit().await.map_err(firebird_err)?;
                } else {
                    let sql = format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1);
                    let mut guard = inner.lock().await;
                    guard.execute_batch(&sql).await.map_err(firebird_err)?;
                }
                conn.transaction_depth = depth - 1;
            }
            Ok(())
        })
    }

    fn rollback(
        conn: &mut FirebirdConnection,
    ) -> impl std::future::Future<Output = Result<(), Error>> + Send + '_ {
        let inner = conn.inner.clone();
        let depth = conn.transaction_depth;
        AssertSend(async move {
            if depth > 0 {
                if depth == 1 {
                    let mut guard = inner.lock().await;
                    guard.rollback().await.map_err(firebird_err)?;
                } else {
                    let sql = format!("ROLLBACK TO SAVEPOINT _sqlx_savepoint_{}", depth - 1);
                    let mut guard = inner.lock().await;
                    guard.execute_batch(&sql).await.map_err(firebird_err)?;
                }
                conn.transaction_depth = depth - 1;
            }
            Ok(())
        })
    }

    fn start_rollback(conn: &mut FirebirdConnection) {
        let depth = conn.transaction_depth;
        if depth > 0 {
            conn.transaction_depth = depth - 1;
        }
    }

    fn get_transaction_depth(conn: &FirebirdConnection) -> usize {
        conn.transaction_depth
    }
}