geni 1.3.0

A standalone database CLI migration tool
Documentation
use anyhow::Result;
use chrono::Utc;
use log::info;
use std::env;
use std::fs;
use std::fs::File;
use std::io::Write;
use std::path::Path;
use tempfile::TempDir;

use geni::config::Database;
use geni::database_drivers;
use geni::migrate::{down, up};

use testcontainers::core::wait::LogWaitStrategy;
use testcontainers::core::{IntoContainerPort, WaitFor};
use testcontainers::runners::AsyncRunner;
use testcontainers::{GenericImage, ImageExt};

fn generate_test_migrations(migration_path: &str) -> Result<()> {
    let file_endings = vec!["up", "down"];
    let test_queries = [
        (
            "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL); CREATE TABLE computers (id INTEGER PRIMARY KEY, name TEXT NOT NULL); ;",
            "DROP TABLE users;",
        ),
        (
            "CREATE TABLE users2 (id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
            "DROP TABLE users2;",
        ),
        (
            "CREATE TABLE users3 (id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
            "DROP TABLE users3;",
        ),
        (
            "CREATE TABLE users4 (id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
            "DROP TABLE users4;",
        ),
        (
            "CREATE TABLE users5 (id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
            "DROP TABLE users5;",
        ),
        (
            "CREATE TABLE users6 (id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
            "DROP TABLE users6;",
        ),
    ];

    for (index, t) in test_queries.iter().enumerate() {
        for f in &file_endings {
            let timestamp = Utc::now().timestamp() + index as i64;
            let filename = format!("{migration_path}/{timestamp}_{index}_test.{f}.sql");
            let path = std::path::Path::new(filename.as_str());

            if let Some(parent) = path.parent() {
                fs::create_dir_all(parent)?;
            }

            let mut file = File::create(path)?;
            match *f {
                "up" => file.write_all(t.0.as_bytes())?,
                "down" => file.write_all(t.1.as_bytes())?,
                _ => {}
            }
            info!("Generated {}", filename)
        }
    }

    Ok(())
}

async fn test_migrate(database: Database, db_url: &str, migrations_table: &str) -> Result<()> {
    let url = db_url.to_string();
    let tmp_dir = TempDir::new()?;
    let migration_folder_string = tmp_dir.path().to_str().unwrap().to_string();
    let database_wait_timeout = 30;
    let database_schema_file = env::var("DATABASE_SCHEMA_FILE").unwrap_or("schema.sql".to_string());

    generate_test_migrations(&migration_folder_string)?;

    let mut create_client = database_drivers::new(
        url.clone(),
        None,
        migrations_table.to_string(),
        migration_folder_string.clone(),
        database_schema_file.clone(),
        Some(database_wait_timeout),
        false,
    )
    .await
    .unwrap();

    match database {
        Database::Postgres | Database::MySQL | Database::MariaDB => {
            create_client.create_database().await.unwrap();
        }
        _ => {}
    };

    let mut client = database_drivers::new(
        url.clone(),
        None,
        migrations_table.to_string(),
        migration_folder_string.clone(),
        database_schema_file.clone(),
        Some(database_wait_timeout),
        true,
    )
    .await
    .unwrap();

    let u = up(
        url.clone(),
        None,
        migrations_table.to_string(),
        migration_folder_string.clone(),
        database_schema_file.clone(),
        Some(database_wait_timeout),
        true,
    )
    .await;
    assert!(u.is_ok());
    assert_eq!(
        client
            .get_or_create_schema_migrations()
            .await
            .unwrap()
            .len(),
        6,
    );

    let d = down(
        url.clone(),
        None,
        migrations_table.to_string(),
        migration_folder_string.clone(),
        database_schema_file.clone(),
        Some(database_wait_timeout),
        false,
        &1,
    )
    .await;
    assert!(d.is_ok());
    assert_eq!(
        client
            .get_or_create_schema_migrations()
            .await
            .unwrap()
            .len(),
        5
    );

    let d = down(
        url,
        None,
        migrations_table.to_string(),
        migration_folder_string.clone(),
        database_schema_file.clone(),
        Some(database_wait_timeout),
        false,
        &3,
    )
    .await;
    assert!(d.is_ok());
    assert_eq!(
        client
            .get_or_create_schema_migrations()
            .await
            .unwrap()
            .len(),
        2
    );

    let schema_dump_file = format!("{}/{}", migration_folder_string, database_schema_file);
    assert!(Path::new(&schema_dump_file).exists());

    Ok(())
}

#[tokio::test]
async fn test_migrate_postgres() -> Result<()> {
    let container = GenericImage::new("postgres", "18.0")
        .with_exposed_port(5432.tcp())
        .with_wait_for(WaitFor::message_on_stdout(
            "database system is ready to accept connections",
        ))
        .with_env_var("POSTGRES_DB", "development")
        .with_env_var("POSTGRES_USER", "postgres")
        .with_env_var("POSTGRES_PASSWORD", "mysecretpassword")
        .start()
        .await
        .expect("Failed to start postgres");
    let host_port = container.get_host_port_ipv4(5432).await?;
    let url = format!(
        "postgres://postgres:mysecretpassword@localhost:{}/app?sslmode=disable",
        host_port
    );

    env::set_var("DATABASE_SCHEMA_FILE", "postgres_schema.sql");
    test_migrate(Database::Postgres, &url, "schema_migrations").await?;

    drop(container);
    Ok(())
}

#[tokio::test]
async fn test_migrate_postgres_schema_qualified() -> Result<()> {
    let container = GenericImage::new("postgres", "18.0")
        .with_exposed_port(5432.tcp())
        .with_wait_for(WaitFor::message_on_stdout(
            "database system is ready to accept connections",
        ))
        .with_env_var("POSTGRES_DB", "development")
        .with_env_var("POSTGRES_USER", "postgres")
        .with_env_var("POSTGRES_PASSWORD", "mysecretpassword")
        .start()
        .await
        .expect("Failed to start postgres");
    let host_port = container.get_host_port_ipv4(5432).await?;
    let url = format!(
        "postgres://postgres:mysecretpassword@localhost:{}/app?sslmode=disable",
        host_port
    );

    env::set_var("DATABASE_SCHEMA_FILE", "postgres_schema.sql");
    test_migrate(Database::Postgres, &url, "migrations.migrations").await?;

    drop(container);
    Ok(())
}

#[tokio::test]
async fn test_migrate_mysql() -> Result<()> {
    let container = GenericImage::new("mysql", "latest")
        .with_exposed_port(3306.tcp())
        .with_wait_for(WaitFor::Log(
            LogWaitStrategy::stdout_or_stderr("ready for connections").with_times(2),
        ))
        .with_env_var("MYSQL_ROOT_PASSWORD", "password")
        .with_env_var("MYSQL_DATABASE", "development")
        .start()
        .await
        .expect("Failed to start mysql");
    let host_port = container.get_host_port_ipv4(3306).await?;
    let url = format!("mysql://root:password@localhost:{}/app", host_port);

    env::set_var("DATABASE_SCHEMA_FILE", "mysql_schema.sql");
    test_migrate(Database::MySQL, &url, "schema_migrations").await?;

    drop(container);
    Ok(())
}

#[tokio::test]
async fn test_migrate_maria() -> Result<()> {
    let container = GenericImage::new("mariadb", "11.1.3")
        .with_exposed_port(3306.tcp())
        .with_wait_for(WaitFor::Log(
            LogWaitStrategy::stdout_or_stderr("ready for connections").with_times(2),
        ))
        .with_env_var("MARIADB_ROOT_PASSWORD", "password")
        .with_env_var("MARIADB_DATABASE", "development")
        .start()
        .await
        .expect("Failed to start mariadb");
    let host_port = container.get_host_port_ipv4(3306).await?;
    let url = format!("mariadb://root:password@localhost:{}/app", host_port);

    env::set_var("DATABASE_SCHEMA_FILE", "maria_schema.sql");
    test_migrate(Database::MariaDB, &url, "schema_migrations").await?;

    drop(container);
    Ok(())
}

#[tokio::test]
async fn test_migrate_libsql() -> Result<()> {
    let container = GenericImage::new("ghcr.io/tursodatabase/libsql-server", "latest")
        .with_exposed_port(8080.tcp())
        .with_wait_for(WaitFor::message_on_either_std(
            "listening for incoming user HTTP connection",
        ))
        .start()
        .await
        .expect("Failed to start libsql");
    let host_port = container.get_host_port_ipv4(8080).await?;
    let url = format!("http://localhost:{}", host_port);

    env::set_var("DATABASE_SCHEMA_FILE", "libsql_schema.sql");
    test_migrate(Database::LibSQL, &url, "schema_migrations").await?;

    drop(container);
    Ok(())
}

#[tokio::test]
async fn test_migrate_sqlite() {
    env::set_var("DATABASE_SCHEMA_FILE", "sqlite_schema.sql");
    let tmp_dir = TempDir::new().unwrap();
    let migration_folder_string = tmp_dir.path().to_str().unwrap().to_string();
    let filename = format!("{}/test.sqlite", migration_folder_string);
    let path = std::path::Path::new(&filename);
    File::create(path).unwrap();
    let url = format!("sqlite://{}", path.to_str().unwrap());
    test_migrate(Database::SQLite, &url, "schema_migrations")
        .await
        .unwrap();
}

#[tokio::test]
async fn test_migrate_failure() {
    env::set_var("DATABASE_SCHEMA_FILE", "sqlite_schema.sql");
    let tmp_dir = TempDir::new().unwrap();
    let migration_folder_string = tmp_dir.path().to_str().unwrap().to_string();
    let filename = format!("{}/test.sqlite", migration_folder_string);
    let path = std::path::Path::new(&filename);
    File::create(path).unwrap();
    let url = format!("sqlite://{}", path.to_str().unwrap());

    let file_endings = vec!["up", "down"];
    let test_queries = [(
        r#"
            CREATE EXTENSION IF NOT EXISTS "uuid-ossp";

            CREATE OR REPLACE FUNCTION uuid_generate_v7()
            RETURN uuid
            AS $$
            SELECT encode(
                set_bit(
                set_bit(
                    overlay(uuid_send(gen_random_uuid())
                            placing substring(int8send(floor(extract(epoch from clock_timestamp()) * 1000)::bigint) from 3)
                            from 1 for 6
                    ),
                    52, 1
                ),
                53, 1
                ),
                'hex')::uuid;
            $$
            LANGUAGE SQL
            VOLATILE;

            CREATE TABLE tokens (
                id UUID NOT NULL DEFAULT uuid_generate_v7 ()
                token TEXT NOT NULL UNIQUE,
                app_id UUID NOT NULL,
                app_slug TEXT NOT NULL,
                expires_at TIMESTAMP,
                updated_at TIMESTAMP WITH TIME ZONE,
                created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
            );

            CREATE TABLE permissions (
                id UUID NOT NULL DEFAULT uuid_generate_v7 ()
                token UUID NOT NULL,
                updated_at TIMESTAMP WITH TIME ZONE,
                created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
            );
        "#,
        "",
    )];

    for (index, t) in test_queries.iter().enumerate() {
        for f in &file_endings {
            let timestamp = Utc::now().timestamp() + index as i64;
            let filename = format!("{migration_folder_string}/{timestamp}_{index}_test.{f}.sql");
            let path = std::path::Path::new(filename.as_str());

            if let Some(parent) = path.parent() {
                fs::create_dir_all(parent).unwrap()
            }

            let mut file = File::create(path).unwrap();
            match *f {
                "up" => file.write_all(t.0.as_bytes()).unwrap(),
                "down" => file.write_all(t.1.as_bytes()).unwrap(),
                _ => {}
            }
            info!("Generated {}", filename)
        }
    }

    let database_wait_timeout = 30;
    let database_schema_file = "sqlite_schema.sql".to_string();

    let u = up(
        url.clone(),
        None,
        "schema_migrations".to_string(),
        migration_folder_string.clone(),
        database_schema_file.clone(),
        Some(database_wait_timeout),
        true,
    )
    .await;
    assert!(u.is_err());
}