Skip to main content

datafusion_app/
config.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Configuration management handling
19
20use std::path::PathBuf;
21
22#[cfg(feature = "udfs-wasm")]
23use datafusion_udfs_wasm::WasmInputDataType;
24use serde::Deserialize;
25use std::collections::HashMap;
26
27#[cfg(feature = "s3")]
28use {
29    color_eyre::Result,
30    object_store::aws::{AmazonS3, AmazonS3Builder},
31};
32
33// Merges a shared config with a priority config. If a field is present in the priority config that
34// it replaces the entire field from the shared config.
35//
36// TODO: Implement full merge so that nested fields can be maintained from the shared config and
37// only selected fields are overwritten.
38pub fn merge_configs(shared: ExecutionConfig, priority: ExecutionConfig) -> ExecutionConfig {
39    // Baseline is the shared config
40    let mut merged = shared;
41
42    if let Some(object_store_config) = priority.object_store {
43        merged.object_store = Some(object_store_config)
44    }
45    if let Some(ddl_path) = priority.ddl_path {
46        merged.ddl_path = Some(ddl_path)
47    }
48    if let Some(datafusion) = priority.datafusion {
49        merged.datafusion = Some(datafusion)
50    }
51
52    if merged.benchmark_iterations != priority.benchmark_iterations {
53        merged.benchmark_iterations = priority.benchmark_iterations;
54    }
55    if merged.dedicated_executor_enabled != priority.dedicated_executor_enabled {
56        merged.dedicated_executor_enabled = priority.dedicated_executor_enabled
57    }
58    if merged.dedicated_executor_threads != priority.dedicated_executor_threads {
59        merged.dedicated_executor_threads = priority.dedicated_executor_threads
60    }
61    // if merged.iceberg != priority.iceberg {
62    //     merged.iceberg = priority.iceberg
63    // }
64
65    #[cfg(feature = "udfs-wasm")]
66    if !priority.wasm_udf.module_functions.is_empty() {
67        merged.wasm_udf = priority.wasm_udf
68    }
69
70    merged
71}
72
73#[derive(Clone, Debug, Deserialize)]
74pub struct ExecutionConfig {
75    #[serde(default)]
76    pub object_store: Option<ObjectStoreConfig>,
77    #[serde(default = "default_ddl_path")]
78    pub ddl_path: Option<PathBuf>,
79    #[serde(default = "default_benchmark_iterations")]
80    pub benchmark_iterations: usize,
81    #[serde(default)]
82    pub datafusion: Option<HashMap<String, String>>,
83    #[serde(default = "default_dedicated_executor_enabled")]
84    pub dedicated_executor_enabled: bool,
85    #[serde(default = "default_dedicated_executor_threads")]
86    pub dedicated_executor_threads: usize,
87    // #[serde(default = "default_iceberg_config")]
88    // pub iceberg: IcebergConfig,
89    #[cfg(feature = "udfs-wasm")]
90    #[serde(default = "default_wasm_udf")]
91    pub wasm_udf: WasmUdfConfig,
92    #[serde(default = "default_catalog")]
93    pub catalog: CatalogConfig,
94    #[cfg(feature = "observability")]
95    #[serde(default)]
96    pub observability: ObservabilityConfig,
97}
98
99impl Default for ExecutionConfig {
100    fn default() -> Self {
101        Self {
102            object_store: None,
103            ddl_path: default_ddl_path(),
104            benchmark_iterations: default_benchmark_iterations(),
105            datafusion: None,
106            dedicated_executor_enabled: default_dedicated_executor_enabled(),
107            dedicated_executor_threads: default_dedicated_executor_threads(),
108            // iceberg: default_iceberg_config(),
109            #[cfg(feature = "udfs-wasm")]
110            wasm_udf: default_wasm_udf(),
111            catalog: default_catalog(),
112            #[cfg(feature = "observability")]
113            observability: default_observability(),
114        }
115    }
116}
117
118fn default_ddl_path() -> Option<PathBuf> {
119    if let Some(user_dirs) = directories::UserDirs::new() {
120        let ddl_path = user_dirs
121            .home_dir()
122            .join(".config")
123            .join("dft")
124            .join("ddl.sql");
125        Some(ddl_path)
126    } else {
127        None
128    }
129}
130
131fn default_benchmark_iterations() -> usize {
132    10
133}
134
135fn default_dedicated_executor_enabled() -> bool {
136    false
137}
138
139fn default_dedicated_executor_threads() -> usize {
140    // By default we slightly over provision CPUs.  For example, if you have N CPUs available we
141    // have N CPUs for the [`DedicatedExecutor`] and 1 for the main / IO runtime.
142    //
143    // Ref: https://github.com/datafusion-contrib/datafusion-dft/pull/247#discussion_r1848270250
144    num_cpus::get()
145}
146
147// fn default_iceberg_config() -> IcebergConfig {
148//     IcebergConfig {
149//         rest_catalogs: Vec::new(),
150//     }
151// }
152
153#[cfg(feature = "udfs-wasm")]
154fn default_wasm_udf() -> WasmUdfConfig {
155    WasmUdfConfig {
156        module_functions: HashMap::new(),
157    }
158}
159
160#[cfg(feature = "s3")]
161#[derive(Clone, Debug, Deserialize)]
162pub struct S3Config {
163    bucket_name: String,
164    object_store_url: Option<String>,
165    /// Enable AWS credential chain (environment variables, ~/.aws/credentials, IAM roles).
166    /// When true, credentials are resolved via the standard AWS credential provider chain.
167    /// Static credentials in this config take precedence over environment-based credentials.
168    #[serde(default)]
169    use_credential_chain: bool,
170    aws_access_key_id: Option<String>,
171    aws_secret_access_key: Option<String>,
172    _aws_default_region: Option<String>,
173    aws_endpoint: Option<String>,
174    aws_session_token: Option<String>,
175    aws_allow_http: Option<bool>,
176}
177
178#[cfg(feature = "s3")]
179impl S3Config {
180    pub fn object_store_url(&self) -> &Option<String> {
181        &self.object_store_url
182    }
183}
184
185#[cfg(feature = "s3")]
186impl S3Config {
187    pub fn to_object_store(&self) -> Result<AmazonS3> {
188        // Choose builder based on credential chain preference
189        let mut builder = if self.use_credential_chain {
190            // Use from_env() to enable AWS credential chain
191            // This reads AWS_* environment variables and enables:
192            // - Environment variable credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY)
193            // - Web identity token authentication (AWS_WEB_IDENTITY_TOKEN_FILE for EKS/IRSA)
194            // - Container credentials (ECS via AWS_CONTAINER_CREDENTIALS_RELATIVE_URI)
195            // - EC2 instance profile via IMDSv2
196            AmazonS3Builder::from_env()
197        } else {
198            // Traditional static configuration only
199            AmazonS3Builder::new()
200        };
201
202        // Always set bucket name (required)
203        builder = builder.with_bucket_name(&self.bucket_name);
204
205        // Apply TOML-specified credentials if provided
206        // These will override environment-based credentials due to precedence
207        if let Some(access_key) = &self.aws_access_key_id {
208            builder = builder.with_access_key_id(access_key)
209        }
210        if let Some(secret) = &self.aws_secret_access_key {
211            builder = builder.with_secret_access_key(secret)
212        }
213        if let Some(endpoint) = &self.aws_endpoint {
214            builder = builder.with_endpoint(endpoint);
215        }
216        if let Some(token) = &self.aws_session_token {
217            builder = builder.with_token(token)
218        }
219        if let Some(allow_http) = &self.aws_allow_http {
220            builder = builder.with_allow_http(*allow_http)
221        }
222
223        Ok(builder.build()?)
224    }
225}
226
227#[cfg(feature = "huggingface")]
228#[derive(Clone, Debug, Deserialize)]
229pub struct HuggingFaceConfig {
230    pub repo_type: Option<String>,
231    pub repo_id: Option<String>,
232    pub revision: Option<String>,
233    pub root: Option<String>,
234    pub token: Option<String>,
235}
236
237#[derive(Clone, Debug, Deserialize)]
238pub struct ObjectStoreConfig {
239    #[cfg(feature = "s3")]
240    pub s3: Option<Vec<S3Config>>,
241    #[cfg(feature = "huggingface")]
242    pub huggingface: Option<Vec<HuggingFaceConfig>>,
243}
244
245#[derive(Clone, Debug, Deserialize, PartialEq)]
246pub struct RestCatalogConfig {
247    pub name: String,
248    pub addr: String,
249}
250
251#[derive(Clone, Debug, Deserialize, PartialEq)]
252pub struct IcebergConfig {
253    pub rest_catalogs: Vec<RestCatalogConfig>,
254}
255
256#[cfg(feature = "udfs-wasm")]
257#[derive(Clone, Debug, Deserialize, PartialEq)]
258pub struct WasmFuncDetails {
259    pub name: String,
260    pub input_types: Vec<String>,
261    pub return_type: String,
262    pub input_data_type: WasmInputDataType,
263}
264
265#[cfg(feature = "udfs-wasm")]
266#[derive(Clone, Debug, Deserialize, PartialEq)]
267pub struct WasmUdfConfig {
268    pub module_functions: HashMap<PathBuf, Vec<WasmFuncDetails>>,
269}
270
271#[cfg(feature = "flightsql")]
272#[derive(Clone, Debug)]
273pub struct FlightSQLConfig {
274    pub connection_url: String,
275    pub benchmark_iterations: usize,
276    pub auth: AuthConfig,
277    pub headers: HashMap<String, String>,
278}
279
280#[cfg(feature = "flightsql")]
281impl Default for FlightSQLConfig {
282    fn default() -> Self {
283        Self {
284            connection_url: "http://localhost:50051".to_string(),
285            benchmark_iterations: 10,
286            auth: AuthConfig::default(),
287            headers: HashMap::new(),
288        }
289    }
290}
291
292#[cfg(feature = "flightsql")]
293impl FlightSQLConfig {
294    pub fn new(
295        connection_url: String,
296        benchmark_iterations: usize,
297        auth: AuthConfig,
298        headers: HashMap<String, String>,
299    ) -> Self {
300        Self {
301            connection_url,
302            benchmark_iterations,
303            auth,
304            headers,
305        }
306    }
307}
308
309#[derive(Clone, Debug, Default, Deserialize)]
310pub struct AuthConfig {
311    pub basic_auth: Option<BasicAuth>,
312    pub bearer_token: Option<String>,
313}
314
315#[derive(Clone, Debug, Default, Deserialize)]
316pub struct BasicAuth {
317    pub username: String,
318    pub password: String,
319}
320
321#[derive(Clone, Debug, Deserialize)]
322pub struct CatalogConfig {
323    #[serde(default = "default_catalog_name")]
324    pub name: String,
325}
326
327impl Default for CatalogConfig {
328    fn default() -> Self {
329        Self {
330            name: default_catalog_name(),
331        }
332    }
333}
334
335fn default_catalog() -> CatalogConfig {
336    CatalogConfig::default()
337}
338
339fn default_catalog_name() -> String {
340    "dft".to_string()
341}
342
343#[cfg(feature = "observability")]
344#[derive(Clone, Debug, Deserialize)]
345pub struct ObservabilityConfig {
346    #[serde(default = "default_observability_schema_name")]
347    pub schema_name: String,
348    #[serde(default = "default_tokio_metrics_enabled")]
349    pub tokio_metrics_enabled: bool,
350    #[serde(default = "default_tokio_metrics_interval_secs")]
351    pub tokio_metrics_interval_secs: u64,
352}
353
354#[cfg(feature = "observability")]
355impl Default for ObservabilityConfig {
356    fn default() -> Self {
357        Self {
358            schema_name: default_observability_schema_name(),
359            tokio_metrics_enabled: default_tokio_metrics_enabled(),
360            tokio_metrics_interval_secs: default_tokio_metrics_interval_secs(),
361        }
362    }
363}
364
365#[cfg(feature = "observability")]
366fn default_observability() -> ObservabilityConfig {
367    ObservabilityConfig::default()
368}
369
370#[cfg(feature = "observability")]
371fn default_observability_schema_name() -> String {
372    "observability".to_string()
373}
374
375#[cfg(feature = "observability")]
376fn default_tokio_metrics_enabled() -> bool {
377    true
378}
379
380#[cfg(feature = "observability")]
381fn default_tokio_metrics_interval_secs() -> u64 {
382    10
383}