1use std::path::PathBuf;
21
22#[cfg(feature = "udfs-wasm")]
23use datafusion_udfs_wasm::WasmInputDataType;
24use serde::Deserialize;
25use std::collections::HashMap;
26
27#[cfg(feature = "s3")]
28use {
29 color_eyre::Result,
30 object_store::aws::{AmazonS3, AmazonS3Builder},
31};
32
33pub fn merge_configs(shared: ExecutionConfig, priority: ExecutionConfig) -> ExecutionConfig {
39 let mut merged = shared;
41
42 if let Some(object_store_config) = priority.object_store {
43 merged.object_store = Some(object_store_config)
44 }
45 if let Some(ddl_path) = priority.ddl_path {
46 merged.ddl_path = Some(ddl_path)
47 }
48 if let Some(datafusion) = priority.datafusion {
49 merged.datafusion = Some(datafusion)
50 }
51
52 if merged.benchmark_iterations != priority.benchmark_iterations {
53 merged.benchmark_iterations = priority.benchmark_iterations;
54 }
55 if merged.dedicated_executor_enabled != priority.dedicated_executor_enabled {
56 merged.dedicated_executor_enabled = priority.dedicated_executor_enabled
57 }
58 if merged.dedicated_executor_threads != priority.dedicated_executor_threads {
59 merged.dedicated_executor_threads = priority.dedicated_executor_threads
60 }
61 #[cfg(feature = "udfs-wasm")]
66 if !priority.wasm_udf.module_functions.is_empty() {
67 merged.wasm_udf = priority.wasm_udf
68 }
69
70 merged
71}
72
73#[derive(Clone, Debug, Deserialize)]
74pub struct ExecutionConfig {
75 #[serde(default)]
76 pub object_store: Option<ObjectStoreConfig>,
77 #[serde(default = "default_ddl_path")]
78 pub ddl_path: Option<PathBuf>,
79 #[serde(default = "default_benchmark_iterations")]
80 pub benchmark_iterations: usize,
81 #[serde(default)]
82 pub datafusion: Option<HashMap<String, String>>,
83 #[serde(default = "default_dedicated_executor_enabled")]
84 pub dedicated_executor_enabled: bool,
85 #[serde(default = "default_dedicated_executor_threads")]
86 pub dedicated_executor_threads: usize,
87 #[cfg(feature = "udfs-wasm")]
90 #[serde(default = "default_wasm_udf")]
91 pub wasm_udf: WasmUdfConfig,
92 #[serde(default = "default_catalog")]
93 pub catalog: CatalogConfig,
94 #[cfg(feature = "observability")]
95 #[serde(default)]
96 pub observability: ObservabilityConfig,
97}
98
99impl Default for ExecutionConfig {
100 fn default() -> Self {
101 Self {
102 object_store: None,
103 ddl_path: default_ddl_path(),
104 benchmark_iterations: default_benchmark_iterations(),
105 datafusion: None,
106 dedicated_executor_enabled: default_dedicated_executor_enabled(),
107 dedicated_executor_threads: default_dedicated_executor_threads(),
108 #[cfg(feature = "udfs-wasm")]
110 wasm_udf: default_wasm_udf(),
111 catalog: default_catalog(),
112 #[cfg(feature = "observability")]
113 observability: default_observability(),
114 }
115 }
116}
117
118fn default_ddl_path() -> Option<PathBuf> {
119 if let Some(user_dirs) = directories::UserDirs::new() {
120 let ddl_path = user_dirs
121 .home_dir()
122 .join(".config")
123 .join("dft")
124 .join("ddl.sql");
125 Some(ddl_path)
126 } else {
127 None
128 }
129}
130
131fn default_benchmark_iterations() -> usize {
132 10
133}
134
135fn default_dedicated_executor_enabled() -> bool {
136 false
137}
138
139fn default_dedicated_executor_threads() -> usize {
140 num_cpus::get()
145}
146
147#[cfg(feature = "udfs-wasm")]
154fn default_wasm_udf() -> WasmUdfConfig {
155 WasmUdfConfig {
156 module_functions: HashMap::new(),
157 }
158}
159
160#[cfg(feature = "s3")]
161#[derive(Clone, Debug, Deserialize)]
162pub struct S3Config {
163 bucket_name: String,
164 object_store_url: Option<String>,
165 #[serde(default)]
169 use_credential_chain: bool,
170 aws_access_key_id: Option<String>,
171 aws_secret_access_key: Option<String>,
172 _aws_default_region: Option<String>,
173 aws_endpoint: Option<String>,
174 aws_session_token: Option<String>,
175 aws_allow_http: Option<bool>,
176}
177
178#[cfg(feature = "s3")]
179impl S3Config {
180 pub fn object_store_url(&self) -> &Option<String> {
181 &self.object_store_url
182 }
183}
184
185#[cfg(feature = "s3")]
186impl S3Config {
187 pub fn to_object_store(&self) -> Result<AmazonS3> {
188 let mut builder = if self.use_credential_chain {
190 AmazonS3Builder::from_env()
197 } else {
198 AmazonS3Builder::new()
200 };
201
202 builder = builder.with_bucket_name(&self.bucket_name);
204
205 if let Some(access_key) = &self.aws_access_key_id {
208 builder = builder.with_access_key_id(access_key)
209 }
210 if let Some(secret) = &self.aws_secret_access_key {
211 builder = builder.with_secret_access_key(secret)
212 }
213 if let Some(endpoint) = &self.aws_endpoint {
214 builder = builder.with_endpoint(endpoint);
215 }
216 if let Some(token) = &self.aws_session_token {
217 builder = builder.with_token(token)
218 }
219 if let Some(allow_http) = &self.aws_allow_http {
220 builder = builder.with_allow_http(*allow_http)
221 }
222
223 Ok(builder.build()?)
224 }
225}
226
227#[cfg(feature = "huggingface")]
228#[derive(Clone, Debug, Deserialize)]
229pub struct HuggingFaceConfig {
230 pub repo_type: Option<String>,
231 pub repo_id: Option<String>,
232 pub revision: Option<String>,
233 pub root: Option<String>,
234 pub token: Option<String>,
235}
236
237#[derive(Clone, Debug, Deserialize)]
238pub struct ObjectStoreConfig {
239 #[cfg(feature = "s3")]
240 pub s3: Option<Vec<S3Config>>,
241 #[cfg(feature = "huggingface")]
242 pub huggingface: Option<Vec<HuggingFaceConfig>>,
243}
244
245#[derive(Clone, Debug, Deserialize, PartialEq)]
246pub struct RestCatalogConfig {
247 pub name: String,
248 pub addr: String,
249}
250
251#[derive(Clone, Debug, Deserialize, PartialEq)]
252pub struct IcebergConfig {
253 pub rest_catalogs: Vec<RestCatalogConfig>,
254}
255
256#[cfg(feature = "udfs-wasm")]
257#[derive(Clone, Debug, Deserialize, PartialEq)]
258pub struct WasmFuncDetails {
259 pub name: String,
260 pub input_types: Vec<String>,
261 pub return_type: String,
262 pub input_data_type: WasmInputDataType,
263}
264
265#[cfg(feature = "udfs-wasm")]
266#[derive(Clone, Debug, Deserialize, PartialEq)]
267pub struct WasmUdfConfig {
268 pub module_functions: HashMap<PathBuf, Vec<WasmFuncDetails>>,
269}
270
271#[cfg(feature = "flightsql")]
272#[derive(Clone, Debug)]
273pub struct FlightSQLConfig {
274 pub connection_url: String,
275 pub benchmark_iterations: usize,
276 pub auth: AuthConfig,
277 pub headers: HashMap<String, String>,
278}
279
280#[cfg(feature = "flightsql")]
281impl Default for FlightSQLConfig {
282 fn default() -> Self {
283 Self {
284 connection_url: "http://localhost:50051".to_string(),
285 benchmark_iterations: 10,
286 auth: AuthConfig::default(),
287 headers: HashMap::new(),
288 }
289 }
290}
291
292#[cfg(feature = "flightsql")]
293impl FlightSQLConfig {
294 pub fn new(
295 connection_url: String,
296 benchmark_iterations: usize,
297 auth: AuthConfig,
298 headers: HashMap<String, String>,
299 ) -> Self {
300 Self {
301 connection_url,
302 benchmark_iterations,
303 auth,
304 headers,
305 }
306 }
307}
308
309#[derive(Clone, Debug, Default, Deserialize)]
310pub struct AuthConfig {
311 pub basic_auth: Option<BasicAuth>,
312 pub bearer_token: Option<String>,
313}
314
315#[derive(Clone, Debug, Default, Deserialize)]
316pub struct BasicAuth {
317 pub username: String,
318 pub password: String,
319}
320
321#[derive(Clone, Debug, Deserialize)]
322pub struct CatalogConfig {
323 #[serde(default = "default_catalog_name")]
324 pub name: String,
325}
326
327impl Default for CatalogConfig {
328 fn default() -> Self {
329 Self {
330 name: default_catalog_name(),
331 }
332 }
333}
334
335fn default_catalog() -> CatalogConfig {
336 CatalogConfig::default()
337}
338
339fn default_catalog_name() -> String {
340 "dft".to_string()
341}
342
343#[cfg(feature = "observability")]
344#[derive(Clone, Debug, Deserialize)]
345pub struct ObservabilityConfig {
346 #[serde(default = "default_observability_schema_name")]
347 pub schema_name: String,
348 #[serde(default = "default_tokio_metrics_enabled")]
349 pub tokio_metrics_enabled: bool,
350 #[serde(default = "default_tokio_metrics_interval_secs")]
351 pub tokio_metrics_interval_secs: u64,
352}
353
354#[cfg(feature = "observability")]
355impl Default for ObservabilityConfig {
356 fn default() -> Self {
357 Self {
358 schema_name: default_observability_schema_name(),
359 tokio_metrics_enabled: default_tokio_metrics_enabled(),
360 tokio_metrics_interval_secs: default_tokio_metrics_interval_secs(),
361 }
362 }
363}
364
365#[cfg(feature = "observability")]
366fn default_observability() -> ObservabilityConfig {
367 ObservabilityConfig::default()
368}
369
370#[cfg(feature = "observability")]
371fn default_observability_schema_name() -> String {
372 "observability".to_string()
373}
374
375#[cfg(feature = "observability")]
376fn default_tokio_metrics_enabled() -> bool {
377 true
378}
379
380#[cfg(feature = "observability")]
381fn default_tokio_metrics_interval_secs() -> u64 {
382 10
383}