sqlx-firebirdsql 0.1.0

Firebird SQL driver for SQLx
use std::str::FromStr;
use std::sync::OnceLock;
use std::time::Duration;

use crate::connection::AssertSend;
use crate::error::firebird_err;
use crate::{Firebird, FirebirdConnectOptions, FirebirdConnection};

use sqlx_core::error::Error;
use sqlx_core::pool::{Pool, PoolOptions};

pub(crate) use sqlx_core::testing::*;

// Using a blocking `OnceLock` here because the critical sections are short.
static MASTER_POOL: OnceLock<Pool<Firebird>> = OnceLock::new();

impl TestSupport for Firebird {
    fn test_context(
        args: &TestArgs,
    ) -> impl std::future::Future<Output = Result<TestContext<Self>, Error>> + Send + '_ {
        AssertSend(test_context(args))
    }

    fn cleanup_test(
        db_name: &str,
    ) -> impl std::future::Future<Output = Result<(), Error>> + Send + '_ {
        AssertSend(cleanup_test(db_name))
    }

    fn cleanup_test_dbs(
    ) -> impl std::future::Future<Output = Result<Option<usize>, Error>> + Send + 'static {
        AssertSend(cleanup_test_dbs())
    }

    fn snapshot(
        _conn: &mut FirebirdConnection,
    ) -> impl std::future::Future<Output = Result<FixtureSnapshot<Self>, Error>> + Send + '_ {
        async move { todo!("snapshot is not yet implemented for Firebird") }
    }
}

async fn test_context(args: &TestArgs) -> Result<TestContext<Firebird>, Error> {
    let url = std::env::var("DATABASE_URL")
        .expect("DATABASE_URL must be set to run sqlx tests for Firebird");

    let master_opts =
        FirebirdConnectOptions::from_str(&url).expect("failed to parse DATABASE_URL");

    let pool = PoolOptions::new()
        // Firebird's default connection limit is modest; don't use too many.
        .max_connections(20)
        // Immediately close master connections.
        .after_release(|_conn, _| Box::pin(async move { Ok(false) }))
        .connect_lazy_with(master_opts.clone());

    let master_pool = match once_lock_try_insert_polyfill(&MASTER_POOL, pool) {
        Ok(inserted) => inserted,
        Err((existing, _pool)) => existing,
    };

    // Ensure the tracking table exists in the master database.
    {
        let conn = master_pool.acquire().await?;
        let mut inner = conn.inner.lock().await;
        inner
            .execute_batch(
                "EXECUTE BLOCK AS \
                BEGIN \
                    IF (NOT EXISTS( \
                        SELECT 1 FROM RDB$RELATIONS \
                        WHERE TRIM(RDB$RELATION_NAME) = '_SQLX_TEST_DATABASES' \
                    )) THEN \
                        EXECUTE STATEMENT \
                        'CREATE TABLE _SQLX_TEST_DATABASES ( \
                            DB_NAME VARCHAR(255) NOT NULL PRIMARY KEY, \
                            DB_PATH VARCHAR(512) NOT NULL, \
                            TEST_PATH VARCHAR(512) NOT NULL, \
                            CREATED_AT TIMESTAMP DEFAULT CURRENT_TIMESTAMP \
                        )'; \
                END",
            )
            .await
            .map_err(firebird_err)?;
        inner.commit().await.map_err(firebird_err)?;
    }

    let db_name = Firebird::db_name(args);
    let master_db = master_opts
        .database
        .as_deref()
        .ok_or_else(|| Error::Configuration("DATABASE_URL must include a database path".into()))?;
    let db_file = test_db_path(master_db, &db_name);

    // Clean up any existing test database with this name.
    let _ = do_cleanup(&master_opts, &db_name, &db_file).await;

    // Create the test database.
    // Firebird databases are files on the server; we must use create_database_url
    // because CREATE DATABASE cannot be run from a connection to another database.
    let url = test_db_url(&master_opts, &db_file);
    let new_conn = firebirust::ConnectionAsync::create_database_url(&url)
        .await
        .map_err(|e| Error::Protocol(format!("Failed to create test database: {:?}", e)))?;
    drop(new_conn);

    eprintln!("created test database {db_name}");

    // Record in the tracking table.
    {
        let conn = master_pool.acquire().await?;
        let mut inner = conn.inner.lock().await;
        let sql = format!(
            "INSERT INTO _SQLX_TEST_DATABASES (DB_NAME, DB_PATH, TEST_PATH) \
             VALUES ('{}', '{}', '{}')",
            db_name.replace('\'', "''"),
            db_file.replace('\'', "''"),
            args.test_path.replace('\'', "''"),
        );
        inner.execute_batch(&sql).await.map_err(firebird_err)?;
        inner.commit().await.map_err(firebird_err)?;
    }

    // Build connection options for the test database.
    let mut test_opts = master_opts.clone();
    test_opts.database = Some(db_file);

    Ok(TestContext {
        pool_opts: PoolOptions::new()
            // Don't allow a single test to take all the connections.
            .max_connections(5)
            // Close connections ASAP if left in the idle queue.
            .idle_timeout(Some(Duration::from_secs(1)))
            .parent(master_pool.clone()),
        connect_opts: test_opts,
        db_name,
    })
}

async fn cleanup_test(db_name: &str) -> Result<(), Error> {
    let master_opts = FirebirdConnectOptions::from_str(
        &std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"),
    )?;
    let master_db = master_opts
        .database
        .as_deref()
        .ok_or_else(|| Error::Configuration("DATABASE_URL must include a database path".into()))?;
    let db_file = test_db_path(master_db, db_name);

    do_cleanup(&master_opts, db_name, &db_file).await
}

async fn cleanup_test_dbs() -> Result<Option<usize>, Error> {
    let master_opts = FirebirdConnectOptions::from_str(
        &std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"),
    )?;

    let master_pool = MASTER_POOL.get_or_init(|| {
        PoolOptions::new()
            .max_connections(20)
            .after_release(|_conn, _| Box::pin(async move { Ok(false) }))
            .connect_lazy_with(master_opts.clone())
    });

    // Query all tracked test databases.
    let databases = {
        let conn = master_pool.acquire().await?;
        let mut inner = conn.inner.lock().await;
        let mut stmt = match inner
            .prepare("SELECT DB_NAME, DB_PATH FROM _SQLX_TEST_DATABASES")
            .await
        {
            Ok(s) => s,
            Err(_) => {
                // Table doesn't exist yet; nothing to clean up.
                return Ok(None);
            }
        };

        let result = match stmt.query(()).await {
            Ok(r) => r,
            Err(_) => return Ok(None),
        };

        let mut dbs: Vec<(String, String)> = Vec::new();
        for row in result {
            if let (Ok(name), Ok(path)) = (row.get::<String>(0), row.get::<String>(1)) {
                dbs.push((name, path));
            }
        }
        dbs
    };

    if databases.is_empty() {
        return Ok(None);
    }

    let mut deleted = 0usize;
    for (db_name, db_file) in &databases {
        match do_cleanup(&master_opts, db_name, db_file).await {
            Ok(()) => {
                deleted += 1;
            }
            Err(e) => {
                eprintln!("could not clean test database {db_name:?}: {e}");
            }
        }
    }

    Ok(Some(deleted))
}

/// Drop a test database by connecting to it and executing `DROP DATABASE`,
/// then remove the tracking entry from the master database.
///
/// In Firebird, `DROP DATABASE` drops the currently connected database and
/// deletes the underlying file. This is the only way to drop a database
/// from a client connection (unlike PostgreSQL/MySQL where you can drop
/// other databases from an admin connection).
async fn do_cleanup(
    master_opts: &FirebirdConnectOptions,
    db_name: &str,
    db_file: &str,
) -> Result<(), Error> {
    // Connect to the test database and drop it.
    let url = test_db_url(master_opts, db_file);
    if let Ok(mut conn) = firebirust::ConnectionAsync::connect_url(&url).await {
        let _ = conn.execute_batch("DROP DATABASE").await;
    }

    // Remove from tracking table in master database.
    let master_pool = MASTER_POOL
        .get()
        .expect("do_cleanup() invoked before master pool initialization");
    if let Ok(conn) = master_pool.acquire().await {
        let mut inner = conn.inner.lock().await;
        let sql = format!(
            "DELETE FROM _SQLX_TEST_DATABASES WHERE DB_NAME = '{}'",
            db_name.replace('\'', "''"),
        );
        let _ = inner.execute_batch(&sql).await;
        let _ = inner.commit().await;
    }

    Ok(())
}

/// Derive the test database file path from the master database path.
///
/// Places the test database in the same directory as the master database.
/// For example, if master is `/var/firebird/data/master.fdb`, a test named
/// `_sqlx_test_abc123` becomes `/var/firebird/data/_sqlx_test_abc123.fdb`.
fn test_db_path(master_db: &str, db_name: &str) -> String {
    let sep_pos = master_db.rfind('/').or_else(|| master_db.rfind('\\'));
    match sep_pos {
        Some(pos) => format!("{}/{}.fdb", &master_db[..pos], db_name),
        None => format!("{}.fdb", db_name),
    }
}

/// Build a Firebird connection URL for a given database file path.
fn test_db_url(opts: &FirebirdConnectOptions, db_file: &str) -> String {
    format!(
        "firebird://{}:{}@{}:{}/{}",
        opts.username, opts.password, opts.host, opts.port, db_file
    )
}

fn once_lock_try_insert_polyfill<T>(this: &OnceLock<T>, value: T) -> Result<&T, (&T, T)> {
    let mut value = Some(value);
    let res = this.get_or_init(|| value.take().unwrap());
    match value {
        None => Ok(res),
        Some(value) => Err((res, value)),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_db_path_unix() {
        assert_eq!(
            test_db_path("/var/firebird/data/master.fdb", "_sqlx_test_abc"),
            "/var/firebird/data/_sqlx_test_abc.fdb"
        );
    }

    #[test]
    fn test_db_path_no_separator() {
        assert_eq!(
            test_db_path("master.fdb", "_sqlx_test_abc"),
            "_sqlx_test_abc.fdb"
        );
    }
}