datafusion_dft/
config.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Configuration management handling
19
20use std::path::PathBuf;
21
22use directories::{ProjectDirs, UserDirs};
23use lazy_static::lazy_static;
24use log::info;
25use serde::Deserialize;
26
27#[cfg(feature = "s3")]
28use color_eyre::Result;
29#[cfg(feature = "s3")]
30use object_store::aws::{AmazonS3, AmazonS3Builder};
31
32lazy_static! {
33    pub static ref PROJECT_NAME: String = env!("CARGO_CRATE_NAME").to_uppercase().to_string();
34    pub static ref DATA_FOLDER: Option<PathBuf> =
35        std::env::var(format!("{}_DATA", PROJECT_NAME.clone()))
36            .ok()
37            .map(PathBuf::from);
38    pub static ref LOG_ENV: String = format!("{}_LOGLEVEL", PROJECT_NAME.clone());
39    pub static ref LOG_FILE: String = format!("{}.log", env!("CARGO_PKG_NAME"));
40}
41
42fn project_directory() -> PathBuf {
43    if let Some(user_dirs) = UserDirs::new() {
44        return user_dirs.home_dir().join(".config").join("dft");
45    };
46
47    let maybe_project_dirs = ProjectDirs::from("", "", env!("CARGO_PKG_NAME"));
48    if let Some(project_dirs) = maybe_project_dirs {
49        project_dirs.data_local_dir().to_path_buf()
50    } else {
51        panic!("No known data directory")
52    }
53}
54
55pub fn get_data_dir() -> PathBuf {
56    if let Some(data_dir) = DATA_FOLDER.clone() {
57        data_dir
58    } else {
59        project_directory()
60    }
61}
62
63#[derive(Clone, Debug, Default, Deserialize)]
64pub struct AppConfig {
65    #[serde(default = "default_execution_config")]
66    pub execution: ExecutionConfig,
67    #[serde(default = "default_display_config")]
68    pub display: DisplayConfig,
69    #[serde(default = "default_interaction_config")]
70    pub interaction: InteractionConfig,
71    #[cfg(feature = "flightsql")]
72    #[serde(default = "default_flightsql_config")]
73    pub flightsql: FlightSQLConfig,
74    #[serde(default = "default_editor_config")]
75    pub editor: EditorConfig,
76}
77
78fn default_execution_config() -> ExecutionConfig {
79    ExecutionConfig::default()
80}
81
82fn default_display_config() -> DisplayConfig {
83    DisplayConfig::default()
84}
85
86fn default_interaction_config() -> InteractionConfig {
87    InteractionConfig::default()
88}
89
90#[cfg(feature = "flightsql")]
91fn default_flightsql_config() -> FlightSQLConfig {
92    FlightSQLConfig::default()
93}
94
95#[derive(Clone, Debug, Deserialize)]
96pub struct DisplayConfig {
97    #[serde(default = "default_frame_rate")]
98    pub frame_rate: f64,
99}
100
101fn default_frame_rate() -> f64 {
102    30.0
103}
104
105impl Default for DisplayConfig {
106    fn default() -> Self {
107        Self { frame_rate: 30.0 }
108    }
109}
110
111#[cfg(feature = "s3")]
112#[derive(Clone, Debug, Deserialize)]
113pub struct S3Config {
114    bucket_name: String,
115    object_store_url: Option<String>,
116    aws_access_key_id: Option<String>,
117    aws_secret_access_key: Option<String>,
118    _aws_default_region: Option<String>,
119    aws_endpoint: Option<String>,
120    aws_session_token: Option<String>,
121    aws_allow_http: Option<bool>,
122}
123
124#[cfg(feature = "s3")]
125impl S3Config {
126    pub fn object_store_url(&self) -> &Option<String> {
127        &self.object_store_url
128    }
129}
130
131#[cfg(feature = "s3")]
132impl S3Config {
133    pub fn to_object_store(&self) -> Result<AmazonS3> {
134        let mut builder = AmazonS3Builder::new();
135        builder = builder.with_bucket_name(&self.bucket_name);
136        if let Some(access_key) = &self.aws_access_key_id {
137            builder = builder.with_access_key_id(access_key)
138        }
139        if let Some(secret) = &self.aws_secret_access_key {
140            builder = builder.with_secret_access_key(secret)
141        }
142        if let Some(endpoint) = &self.aws_endpoint {
143            builder = builder.with_endpoint(endpoint);
144        }
145        if let Some(token) = &self.aws_session_token {
146            builder = builder.with_token(token)
147        }
148        if let Some(allow_http) = &self.aws_allow_http {
149            builder = builder.with_allow_http(*allow_http)
150        }
151
152        Ok(builder.build()?)
153    }
154}
155
156#[cfg(feature = "huggingface")]
157#[derive(Clone, Debug, Deserialize)]
158pub struct HuggingFaceConfig {
159    pub repo_type: Option<String>,
160    pub repo_id: Option<String>,
161    pub revision: Option<String>,
162    pub root: Option<String>,
163    pub token: Option<String>,
164}
165
166#[derive(Clone, Debug, Deserialize)]
167pub struct ObjectStoreConfig {
168    #[cfg(feature = "s3")]
169    pub s3: Option<Vec<S3Config>>,
170    #[cfg(feature = "huggingface")]
171    pub huggingface: Option<Vec<HuggingFaceConfig>>,
172}
173
174#[derive(Clone, Debug, Deserialize)]
175pub struct ExecutionConfig {
176    pub object_store: Option<ObjectStoreConfig>,
177    #[serde(default = "default_ddl_path")]
178    pub ddl_path: Option<PathBuf>,
179    #[serde(default = "default_benchmark_iterations")]
180    pub benchmark_iterations: usize,
181    #[serde(default = "default_cli_batch_size")]
182    pub cli_batch_size: usize,
183    #[serde(default = "default_tui_batch_size")]
184    pub tui_batch_size: usize,
185    #[serde(default = "default_flightsql_server_batch_size")]
186    pub flightsql_server_batch_size: usize,
187    #[serde(default = "default_dedicated_executor_enabled")]
188    pub dedicated_executor_enabled: bool,
189    #[serde(default = "default_dedicated_executor_threads")]
190    pub dedicated_executor_threads: usize,
191    #[serde(default = "default_iceberg_config")]
192    pub iceberg: IcebergConfig,
193}
194
195fn default_ddl_path() -> Option<PathBuf> {
196    info!("Creating default ExecutionConfig");
197    if let Some(user_dirs) = directories::UserDirs::new() {
198        let ddl_path = user_dirs
199            .home_dir()
200            .join(".config")
201            .join("dft")
202            .join("ddl.sql");
203        Some(ddl_path)
204    } else {
205        None
206    }
207}
208
209fn default_benchmark_iterations() -> usize {
210    10
211}
212
213fn default_cli_batch_size() -> usize {
214    8092
215}
216
217fn default_tui_batch_size() -> usize {
218    100
219}
220
221fn default_flightsql_server_batch_size() -> usize {
222    8092
223}
224
225fn default_dedicated_executor_enabled() -> bool {
226    false
227}
228
229fn default_dedicated_executor_threads() -> usize {
230    // By default we slightly over provision CPUs.  For example, if you have N CPUs available we
231    // have N CPUs for the [`DedicatedExecutor`] and 1 for the main / IO runtime.
232    //
233    // Ref: https://github.com/datafusion-contrib/datafusion-dft/pull/247#discussion_r1848270250
234    num_cpus::get()
235}
236
237fn default_iceberg_config() -> IcebergConfig {
238    IcebergConfig {
239        rest_catalogs: Vec::new(),
240    }
241}
242
243impl Default for ExecutionConfig {
244    fn default() -> Self {
245        Self {
246            object_store: None,
247            ddl_path: default_ddl_path(),
248            benchmark_iterations: default_benchmark_iterations(),
249            cli_batch_size: default_cli_batch_size(),
250            tui_batch_size: default_tui_batch_size(),
251            flightsql_server_batch_size: default_flightsql_server_batch_size(),
252            dedicated_executor_enabled: default_dedicated_executor_enabled(),
253            dedicated_executor_threads: default_dedicated_executor_threads(),
254            iceberg: default_iceberg_config(),
255        }
256    }
257}
258
259#[derive(Clone, Debug, Deserialize)]
260pub struct RestCatalogConfig {
261    pub name: String,
262    pub addr: String,
263}
264
265#[derive(Clone, Debug, Deserialize)]
266pub struct IcebergConfig {
267    pub rest_catalogs: Vec<RestCatalogConfig>,
268}
269
270#[derive(Clone, Debug, Default, Deserialize)]
271pub struct InteractionConfig {
272    #[serde(default = "default_mouse")]
273    pub mouse: bool,
274    #[serde(default = "default_paste")]
275    pub paste: bool,
276}
277
278fn default_mouse() -> bool {
279    false
280}
281
282fn default_paste() -> bool {
283    false
284}
285
286#[cfg(feature = "flightsql")]
287#[derive(Clone, Debug, Deserialize)]
288pub struct FlightSQLConfig {
289    #[serde(default = "default_connection_url")]
290    pub connection_url: String,
291    #[serde(default = "default_benchmark_iterations")]
292    pub benchmark_iterations: usize,
293    #[cfg(feature = "experimental-flightsql-server")]
294    #[serde(default = "default_server_metrics_port")]
295    pub server_metrics_port: String,
296}
297
298#[cfg(feature = "flightsql")]
299impl Default for FlightSQLConfig {
300    fn default() -> Self {
301        Self {
302            connection_url: default_connection_url(),
303            benchmark_iterations: default_benchmark_iterations(),
304            #[cfg(feature = "experimental-flightsql-server")]
305            server_metrics_port: default_server_metrics_port(),
306        }
307    }
308}
309
310#[cfg(feature = "flightsql")]
311pub fn default_connection_url() -> String {
312    "http://localhost:50051".to_string()
313}
314
315#[cfg(feature = "experimental-flightsql-server")]
316fn default_server_metrics_port() -> String {
317    "0.0.0.0:9000".to_string()
318}
319
320#[derive(Clone, Debug, Default, Deserialize)]
321pub struct EditorConfig {
322    pub experimental_syntax_highlighting: bool,
323}
324
325fn default_editor_config() -> EditorConfig {
326    EditorConfig::default()
327}