#[cfg(feature = "flightsql")]
use std::collections::HashMap;
use std::{io::Write, path::PathBuf};
use tempfile::{tempdir, TempDir};
#[derive(Debug)]
pub struct TestConfig {
#[allow(dead_code)]
dir: TempDir,
pub path: PathBuf,
}
#[derive(Debug, Default)]
pub struct TestConfigBuilder {
config_text: String,
}
impl TestConfigBuilder {
pub fn build(self, name: &str) -> TestConfig {
let tempdir = tempdir().unwrap();
let path = tempdir.path().join(name);
let mut file = std::fs::File::create(path.clone()).unwrap();
file.write_all(self.config_text.as_bytes()).unwrap();
file.flush().unwrap();
TestConfig { dir: tempdir, path }
}
pub fn with_ddl_path(&mut self, app: &str, ddl_path: PathBuf) -> &mut Self {
self.config_text.push_str(&format!("[{app}.execution]\n"));
let param = format!("ddl_path = '{}'\n", ddl_path.display());
self.config_text.push_str(¶m);
self
}
#[cfg(feature = "s3")]
#[allow(clippy::too_many_arguments)]
pub fn with_s3_object_store(
&mut self,
app: &str,
store: &str,
bucket_name: &str,
object_store_url: &str,
endpoint: &str,
access_key: &str,
secret_key: &str,
allow_http: bool,
) -> &mut Self {
self.config_text
.push_str(&format!("[{app}.execution.object_store]\n"));
self.config_text
.push_str(&format!("[[{app}.execution.object_store.{}]]\n", store));
self.config_text
.push_str(&format!("bucket_name = '{}'\n", bucket_name));
self.config_text
.push_str(&format!("object_store_url = '{}'\n", object_store_url));
self.config_text
.push_str(&format!("aws_endpoint = '{}'\n", endpoint));
self.config_text
.push_str(&format!("aws_access_key_id = '{}'\n", access_key));
self.config_text
.push_str(&format!("aws_secret_access_key = '{}'\n", secret_key));
self.config_text
.push_str(&format!("aws_allow_http = {}\n", allow_http));
self
}
#[cfg(feature = "s3")]
pub fn with_s3_credential_chain(
&mut self,
app: &str,
store: &str,
bucket_name: &str,
object_store_url: &str,
endpoint: Option<&str>,
allow_http: bool,
) -> &mut Self {
self.config_text
.push_str(&format!("[{app}.execution.object_store]\n"));
self.config_text
.push_str(&format!("[[{app}.execution.object_store.{}]]\n", store));
self.config_text
.push_str(&format!("bucket_name = '{}'\n", bucket_name));
self.config_text
.push_str(&format!("object_store_url = '{}'\n", object_store_url));
self.config_text.push_str("use_credential_chain = true\n");
if let Some(endpoint) = endpoint {
self.config_text
.push_str(&format!("aws_endpoint = '{}'\n", endpoint));
}
self.config_text
.push_str(&format!("aws_allow_http = {}\n", allow_http));
self
}
#[cfg(feature = "s3")]
#[allow(clippy::too_many_arguments)]
pub fn with_s3_credential_chain_and_static_override(
&mut self,
app: &str,
store: &str,
bucket_name: &str,
object_store_url: &str,
endpoint: &str,
access_key: &str,
secret_key: &str,
allow_http: bool,
) -> &mut Self {
self.config_text
.push_str(&format!("[{app}.execution.object_store]\n"));
self.config_text
.push_str(&format!("[[{app}.execution.object_store.{}]]\n", store));
self.config_text
.push_str(&format!("bucket_name = '{}'\n", bucket_name));
self.config_text
.push_str(&format!("object_store_url = '{}'\n", object_store_url));
self.config_text.push_str("use_credential_chain = true\n");
self.config_text
.push_str(&format!("aws_endpoint = '{}'\n", endpoint));
self.config_text
.push_str(&format!("aws_access_key_id = '{}'\n", access_key));
self.config_text
.push_str(&format!("aws_secret_access_key = '{}'\n", secret_key));
self.config_text
.push_str(&format!("aws_allow_http = {}\n", allow_http));
self
}
pub fn with_benchmark_iterations(&mut self, app: &str, iterations: u64) -> &mut Self {
self.config_text.push_str(&format!(
"[{app}.execution]\nbenchmark_iterations = {}\n",
iterations
));
self
}
#[cfg(feature = "flightsql")]
pub fn with_flightsql_benchmark_iterations(&mut self, iterations: u64) -> &mut Self {
self.config_text.push_str(&format!(
"[flightsql_client]\nbenchmark_iterations = {}\n",
iterations
));
self
}
#[cfg(feature = "huggingface")]
pub fn with_huggingface(
&mut self,
repo_type: &str,
repo_id: &str,
revision: &str,
) -> &mut Self {
self.config_text
.push_str("[[execution.object_store.huggingface]]\n");
self.config_text
.push_str(&format!("repo_type = '{repo_type}'\n"));
self.config_text
.push_str(&format!("repo_id = '{repo_id}'\n"));
self.config_text
.push_str(&format!("revision = '{revision}'\n"));
self
}
#[cfg(feature = "udfs-wasm")]
pub fn with_udfs_wasm(
&mut self,
module_path: &str,
function_name: &str,
input_data_type: &str,
input_types: &[&str],
return_type: &str,
) -> &mut Self {
self.config_text.push_str("[shared.wasm_udf]\n");
self.config_text.push_str("module_functions = { ");
self.config_text
.push_str(&format!("\"{}\" = [", module_path));
self.config_text.push_str("{ ");
self.config_text
.push_str(&format!("name = \"{}\", ", function_name));
self.config_text
.push_str(&format!("input_data_type = \"{}\", ", input_data_type));
self.config_text.push_str("input_types = [");
for (i, ty) in input_types.iter().enumerate() {
self.config_text.push_str(&format!("\"{}\"", ty));
if i < input_types.len() - 1 {
self.config_text.push_str(", ");
}
}
self.config_text.push_str("], ");
self.config_text
.push_str(&format!("return_type = \"{}\"", return_type));
self.config_text.push_str(" }");
self.config_text.push(']');
self.config_text.push_str(" }");
self.config_text.push('\n');
self
}
#[cfg(feature = "flightsql")]
pub fn with_client_auth(
&mut self,
client_bearer: Option<String>,
client_basic_username: Option<String>,
client_basic_password: Option<String>,
) -> &mut Self {
self.config_text.push_str("[flightsql_client.auth]\n");
if let Some(client_bearer) = client_bearer {
self.config_text
.push_str(&format!("client_bearer_token = {client_bearer}\n"));
}
if let Some(client_basic_username) = client_basic_username {
self.config_text.push_str(&format!(
"client_basic_auth.username = {client_basic_username}\n"
));
}
if let Some(client_basic_password) = client_basic_password {
self.config_text.push_str(&format!(
"client_basic_auth.password = {client_basic_password}\n"
));
}
self
}
#[cfg(feature = "flightsql")]
pub fn with_client_headers(&mut self, headers: Option<HashMap<String, String>>) -> &mut Self {
self.config_text.push_str("[flightsql_client.headers]\n");
if let Some(headers) = &headers {
for (name, value) in headers {
self.config_text.push_str(&format!("{name} = {value}\n"));
}
}
self
}
#[allow(dead_code)]
#[cfg(feature = "flightsql")]
pub fn with_server_auth(
&mut self,
server_bearer: Option<String>,
server_basic_username: Option<String>,
server_basic_password: Option<String>,
) -> &mut Self {
self.config_text.push_str("[flightsql_server.auth]\n");
if let Some(server_bearer) = server_bearer {
self.config_text
.push_str(&format!("server_bearer_token = {server_bearer}\n"));
}
if let Some(server_basic_username) = server_basic_username {
self.config_text.push_str(&format!(
"server_basic_username = {server_basic_username}\n"
));
}
if let Some(server_basic_password) = server_basic_password {
self.config_text.push_str(&format!(
"server_basic_password = {server_basic_password}\n"
));
}
self
}
pub fn with_db_path(&mut self, path: &str) -> &mut Self {
self.config_text.push_str("[db]\n");
self.config_text.push_str(&format!("path = \"{path}\"\n"));
self
}
}