clickhouse-testing 0.1.3

A crate that simplifies writing integration tests for ClickHouse, inspired by sqlx.
Documentation
use clickhouse::Row;
use dotenvy::dotenv;
use serde::Deserialize;
use std::fs::read_dir;
use std::io::ErrorKind;
use std::path::PathBuf;
use std::{env, fs, io};

pub type Client = clickhouse::Client;
pub use clickhouse_testing_macros::test;

#[derive(Debug)]
pub enum Error {
    Io(std::io::Error),
    Env(std::env::VarError),
    Clickhouse(clickhouse::error::Error),
    Migration(String),
}

impl From<std::io::Error> for Error {
    fn from(e: std::io::Error) -> Self {
        Error::Io(e)
    }
}

impl From<std::env::VarError> for Error {
    fn from(e: std::env::VarError) -> Self {
        Error::Env(e)
    }
}

impl From<clickhouse::error::Error> for Error {
    fn from(e: clickhouse::error::Error) -> Self {
        Error::Clickhouse(e)
    }
}

pub async fn init_test(module_path: &str, test_name: &str) -> Result<Client, Error> {
    _ = dotenv();

    let config = read_clickhouse_config();
    let client = create_client(&config);
    let databases = get_dbs_list(&client).await?;
    let db_name = next_db_version(&databases, module_path, test_name);

    create_database(&client, &db_name).await?;

    let test_client = client.with_database(db_name);
    apply_migrations(&test_client).await?;

    Ok(test_client)
}

pub async fn cleanup_test(client: &Client) -> Result<(), Error> {
    let current_db = get_current_db(client).await?;
    drop_db(client, &current_db).await?;

    Ok(())
}

async fn apply_migrations(client: &Client) -> Result<(), Error> {
    let migrations_path = env::var("MIGRATIONS_DIR")?;
    let project_root = get_project_root()?;

    let mut sql_files: Vec<_> = read_dir(project_root.join(migrations_path))?
        .filter_map(|entry| entry.ok())
        .map(|entry| entry.path())
        .filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("sql"))
        .collect();

    sql_files.sort();

    for file in sql_files {
        let script = fs::read_to_string(&file)?;

        // For avoiding "Multi-statements are not allowed" error
        let script_parts: Vec<&str> = script.split(';').filter(|s| !s.trim().is_empty()).collect();

        for script_part in script_parts {
            client.query(script_part).execute().await?;
        }
    }

    Ok(())
}

fn get_project_root() -> Result<PathBuf, Error> {
    let path = env::current_dir()?;

    for ancestor_path in path.ancestors() {
        let has_cargo = read_dir(ancestor_path)?.any(|p| p.unwrap().file_name() == "Cargo.lock");
        if has_cargo {
            return Ok(PathBuf::from(ancestor_path));
        }
    }

    Err(io::Error::new(ErrorKind::NotFound, "Cargo.lock not found").into())
}

fn create_client(config: &ClickhouseConfig) -> Client {
    Client::default()
        .with_url(&config.url)
        .with_database(&config.db)
        .with_user(&config.user)
        .with_password(&config.password)
}

fn read_clickhouse_config() -> ClickhouseConfig {
    ClickhouseConfig {
        url: env::var("CLICKHOUSE_URL").unwrap_or("http://localhost:8123".into()),
        db: env::var("CLICKHOUSE_DB").unwrap_or("default".into()),
        user: env::var("CLICKHOUSE_USER").unwrap_or("default".into()),
        password: env::var("CLICKHOUSE_PASSWORD").unwrap_or("".into()),
    }
}

async fn get_dbs_list(client: &Client) -> Result<Vec<Database>, Error> {
    let databases = client
        .query("SELECT name FROM system.databases")
        .fetch_all::<Database>()
        .await?;

    Ok(databases)
}

async fn create_database(client: &Client, db_name: &str) -> Result<(), Error> {
    let query = format!("CREATE DATABASE IF NOT EXISTS {}", db_name);
    client.query(&query).execute().await?;

    Ok(())
}

async fn get_current_db(client: &Client) -> Result<Database, Error> {
    let database = client
        .query("SELECT currentDatabase() AS name")
        .fetch_one::<Database>()
        .await?;

    Ok(database)
}

async fn drop_db(client: &Client, database: &Database) -> Result<(), Error> {
    client
        .query(&format!("DROP DATABASE {}", database.name))
        .execute()
        .await?;

    Ok(())
}

fn sanitize_name(name: &str) -> String {
    name.replace("::", "_")
        .replace("-", "_")
        .replace(":", "_")
        .trim_matches('_')
        .to_lowercase()
}

fn next_db_version(tests_dbs: &[Database], module_path: &str, test_name: &str) -> String {
    let sanitized_module = sanitize_name(module_path);
    let sanitized_test = sanitize_name(test_name);
    let current_test_db = format!("test_db_{sanitized_module}_{sanitized_test}_");

    let db_version = tests_dbs
        .iter()
        .filter_map(|db| {
            db.name
                .strip_prefix(&current_test_db)?
                .parse::<usize>()
                .ok()
        })
        .max()
        .map(|v| v + 1)
        .unwrap_or(1);

    format!("{current_test_db}{db_version}")
}
#[derive(Debug, Deserialize, Row)]
struct Database {
    name: String,
}

#[derive(Debug)]
struct ClickhouseConfig {
    url: String,
    db: String,
    user: String,
    password: String,
}