clickhouse_testing/
lib.rs

1use clickhouse::Row;
2use dotenvy::dotenv;
3use serde::Deserialize;
4use std::fs::read_dir;
5use std::io::ErrorKind;
6use std::path::PathBuf;
7use std::{env, fs, io};
8
9pub type Client = clickhouse::Client;
10pub use clickhouse_testing_macros::test;
11
12#[derive(Debug)]
13pub enum Error {
14    Io(std::io::Error),
15    Env(std::env::VarError),
16    Clickhouse(clickhouse::error::Error),
17    Migration(String),
18}
19
20impl From<std::io::Error> for Error {
21    fn from(e: std::io::Error) -> Self {
22        Error::Io(e)
23    }
24}
25
26impl From<std::env::VarError> for Error {
27    fn from(e: std::env::VarError) -> Self {
28        Error::Env(e)
29    }
30}
31
32impl From<clickhouse::error::Error> for Error {
33    fn from(e: clickhouse::error::Error) -> Self {
34        Error::Clickhouse(e)
35    }
36}
37
38pub async fn init_test(module_path: &str, test_name: &str) -> Result<Client, Error> {
39    _ = dotenv();
40
41    let config = read_clickhouse_config();
42    let client = create_client(&config);
43    let databases = get_dbs_list(&client).await?;
44    let db_name = next_db_version(&databases, module_path, test_name);
45
46    create_database(&client, &db_name).await?;
47
48    let test_client = client.with_database(db_name);
49    apply_migrations(&test_client).await?;
50
51    Ok(test_client)
52}
53
54pub async fn cleanup_test(client: &Client) -> Result<(), Error> {
55    let current_db = get_current_db(client).await?;
56    drop_db(client, &current_db).await?;
57
58    Ok(())
59}
60
61async fn apply_migrations(client: &Client) -> Result<(), Error> {
62    let migrations_path = env::var("MIGRATIONS_DIR")?;
63    let project_root = get_project_root()?;
64
65    let mut sql_files: Vec<_> = read_dir(project_root.join(migrations_path))?
66        .filter_map(|entry| entry.ok())
67        .map(|entry| entry.path())
68        .filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("sql"))
69        .collect();
70
71    sql_files.sort();
72
73    for file in sql_files {
74        let script = fs::read_to_string(&file)?;
75
76        // For avoiding "Multi-statements are not allowed" error
77        let script_parts: Vec<&str> = script.split(';').filter(|s| !s.trim().is_empty()).collect();
78
79        for script_part in script_parts {
80            client.query(script_part).execute().await?;
81        }
82    }
83
84    Ok(())
85}
86
87fn get_project_root() -> Result<PathBuf, Error> {
88    let path = env::current_dir()?;
89
90    for ancestor_path in path.ancestors() {
91        let has_cargo = read_dir(ancestor_path)?.any(|p| p.unwrap().file_name() == "Cargo.lock");
92        if has_cargo {
93            return Ok(PathBuf::from(ancestor_path));
94        }
95    }
96
97    Err(io::Error::new(ErrorKind::NotFound, "Cargo.lock not found").into())
98}
99
100fn create_client(config: &ClickhouseConfig) -> Client {
101    Client::default()
102        .with_url(&config.url)
103        .with_database(&config.db)
104        .with_user(&config.user)
105        .with_password(&config.password)
106}
107
108fn read_clickhouse_config() -> ClickhouseConfig {
109    ClickhouseConfig {
110        url: env::var("CLICKHOUSE_URL").unwrap_or("http://localhost:8123".into()),
111        db: env::var("CLICKHOUSE_DB").unwrap_or("default".into()),
112        user: env::var("CLICKHOUSE_USER").unwrap_or("default".into()),
113        password: env::var("CLICKHOUSE_PASSWORD").unwrap_or("".into()),
114    }
115}
116
117async fn get_dbs_list(client: &Client) -> Result<Vec<Database>, Error> {
118    let databases = client
119        .query("SELECT name FROM system.databases")
120        .fetch_all::<Database>()
121        .await?;
122
123    Ok(databases)
124}
125
126async fn create_database(client: &Client, db_name: &str) -> Result<(), Error> {
127    let query = format!("CREATE DATABASE IF NOT EXISTS {}", db_name);
128    client.query(&query).execute().await?;
129
130    Ok(())
131}
132
133async fn get_current_db(client: &Client) -> Result<Database, Error> {
134    let database = client
135        .query("SELECT currentDatabase() AS name")
136        .fetch_one::<Database>()
137        .await?;
138
139    Ok(database)
140}
141
142async fn drop_db(client: &Client, database: &Database) -> Result<(), Error> {
143    client
144        .query(&format!("DROP DATABASE {}", database.name))
145        .execute()
146        .await?;
147
148    Ok(())
149}
150
151fn next_db_version(tests_dbs: &[Database], module_path: &str, test_name: &str) -> String {
152    let current_test_db = format!("test_db_{module_path}_{test_name}_");
153
154    let db_version = tests_dbs
155        .iter()
156        .filter_map(|db| {
157            db.name
158                .strip_prefix(&current_test_db)?
159                .parse::<usize>()
160                .ok()
161        })
162        .max()
163        .map(|v| v + 1)
164        .unwrap_or(1);
165
166    format!("{current_test_db}{db_version}")
167}
168
169#[derive(Debug, Deserialize, Row)]
170struct Database {
171    name: String,
172}
173
174#[derive(Debug)]
175struct ClickhouseConfig {
176    url: String,
177    db: String,
178    user: String,
179    password: String,
180}