clickhouse_testing/
lib.rs1use clickhouse::Row;
2use dotenvy::dotenv;
3use serde::Deserialize;
4use std::fs::read_dir;
5use std::io::ErrorKind;
6use std::path::PathBuf;
7use std::{env, fs, io};
8
9pub type Client = clickhouse::Client;
10pub use clickhouse_testing_macros::test;
11
12#[derive(Debug)]
13pub enum Error {
14 Io(std::io::Error),
15 Env(std::env::VarError),
16 Clickhouse(clickhouse::error::Error),
17 Migration(String),
18}
19
20impl From<std::io::Error> for Error {
21 fn from(e: std::io::Error) -> Self {
22 Error::Io(e)
23 }
24}
25
26impl From<std::env::VarError> for Error {
27 fn from(e: std::env::VarError) -> Self {
28 Error::Env(e)
29 }
30}
31
32impl From<clickhouse::error::Error> for Error {
33 fn from(e: clickhouse::error::Error) -> Self {
34 Error::Clickhouse(e)
35 }
36}
37
38pub async fn init_test(module_path: &str, test_name: &str) -> Result<Client, Error> {
39 _ = dotenv();
40
41 let config = read_clickhouse_config();
42 let client = create_client(&config);
43 let databases = get_dbs_list(&client).await?;
44 let db_name = next_db_version(&databases, module_path, test_name);
45
46 create_database(&client, &db_name).await?;
47
48 let test_client = client.with_database(db_name);
49 apply_migrations(&test_client).await?;
50
51 Ok(test_client)
52}
53
54pub async fn cleanup_test(client: &Client) -> Result<(), Error> {
55 let current_db = get_current_db(client).await?;
56 drop_db(client, ¤t_db).await?;
57
58 Ok(())
59}
60
61async fn apply_migrations(client: &Client) -> Result<(), Error> {
62 let migrations_path = env::var("MIGRATIONS_DIR")?;
63 let project_root = get_project_root()?;
64
65 let mut sql_files: Vec<_> = read_dir(project_root.join(migrations_path))?
66 .filter_map(|entry| entry.ok())
67 .map(|entry| entry.path())
68 .filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("sql"))
69 .collect();
70
71 sql_files.sort();
72
73 for file in sql_files {
74 let script = fs::read_to_string(&file)?;
75
76 let script_parts: Vec<&str> = script.split(';').filter(|s| !s.trim().is_empty()).collect();
78
79 for script_part in script_parts {
80 client.query(script_part).execute().await?;
81 }
82 }
83
84 Ok(())
85}
86
87fn get_project_root() -> Result<PathBuf, Error> {
88 let path = env::current_dir()?;
89
90 for ancestor_path in path.ancestors() {
91 let has_cargo = read_dir(ancestor_path)?.any(|p| p.unwrap().file_name() == "Cargo.lock");
92 if has_cargo {
93 return Ok(PathBuf::from(ancestor_path));
94 }
95 }
96
97 Err(io::Error::new(ErrorKind::NotFound, "Cargo.lock not found").into())
98}
99
100fn create_client(config: &ClickhouseConfig) -> Client {
101 Client::default()
102 .with_url(&config.url)
103 .with_database(&config.db)
104 .with_user(&config.user)
105 .with_password(&config.password)
106}
107
108fn read_clickhouse_config() -> ClickhouseConfig {
109 ClickhouseConfig {
110 url: env::var("CLICKHOUSE_URL").unwrap_or("http://localhost:8123".into()),
111 db: env::var("CLICKHOUSE_DB").unwrap_or("default".into()),
112 user: env::var("CLICKHOUSE_USER").unwrap_or("default".into()),
113 password: env::var("CLICKHOUSE_PASSWORD").unwrap_or("".into()),
114 }
115}
116
117async fn get_dbs_list(client: &Client) -> Result<Vec<Database>, Error> {
118 let databases = client
119 .query("SELECT name FROM system.databases")
120 .fetch_all::<Database>()
121 .await?;
122
123 Ok(databases)
124}
125
126async fn create_database(client: &Client, db_name: &str) -> Result<(), Error> {
127 let query = format!("CREATE DATABASE IF NOT EXISTS {}", db_name);
128 client.query(&query).execute().await?;
129
130 Ok(())
131}
132
133async fn get_current_db(client: &Client) -> Result<Database, Error> {
134 let database = client
135 .query("SELECT currentDatabase() AS name")
136 .fetch_one::<Database>()
137 .await?;
138
139 Ok(database)
140}
141
142async fn drop_db(client: &Client, database: &Database) -> Result<(), Error> {
143 client
144 .query(&format!("DROP DATABASE {}", database.name))
145 .execute()
146 .await?;
147
148 Ok(())
149}
150
151fn next_db_version(tests_dbs: &[Database], module_path: &str, test_name: &str) -> String {
152 let current_test_db = format!("test_db_{module_path}_{test_name}_");
153
154 let db_version = tests_dbs
155 .iter()
156 .filter_map(|db| {
157 db.name
158 .strip_prefix(¤t_test_db)?
159 .parse::<usize>()
160 .ok()
161 })
162 .max()
163 .map(|v| v + 1)
164 .unwrap_or(1);
165
166 format!("{current_test_db}{db_version}")
167}
168
169#[derive(Debug, Deserialize, Row)]
170struct Database {
171 name: String,
172}
173
174#[derive(Debug)]
175struct ClickhouseConfig {
176 url: String,
177 db: String,
178 user: String,
179 password: String,
180}