Skip to main content

datafusion_dft/
config.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Configuration management handling
19
20use std::path::PathBuf;
21
22#[cfg(any(feature = "flightsql", feature = "http"))]
23use std::net::{IpAddr, Ipv4Addr, SocketAddr};
24
25use datafusion_app::config::ExecutionConfig;
26use directories::{ProjectDirs, UserDirs};
27use lazy_static::lazy_static;
28use log::{debug, error};
29use serde::Deserialize;
30
31#[cfg(any(feature = "flightsql", feature = "http"))]
32use datafusion_app::config::AuthConfig;
33#[cfg(feature = "flightsql")]
34use std::collections::HashMap;
35use url::Url;
36
37lazy_static! {
38    pub static ref PROJECT_NAME: String = env!("CARGO_CRATE_NAME").to_uppercase().to_string();
39    pub static ref DATA_FOLDER: Option<PathBuf> =
40        std::env::var(format!("{}_DATA", PROJECT_NAME.clone()))
41            .ok()
42            .map(PathBuf::from);
43    pub static ref LOG_ENV: String = format!("{}_LOGLEVEL", PROJECT_NAME.clone());
44    pub static ref LOG_FILE: String = format!("{}.log", env!("CARGO_PKG_NAME"));
45}
46
47fn project_directory() -> PathBuf {
48    if let Some(user_dirs) = UserDirs::new() {
49        return user_dirs.home_dir().join(".config").join("dft");
50    };
51
52    let maybe_project_dirs = ProjectDirs::from("", "", env!("CARGO_PKG_NAME"));
53    if let Some(project_dirs) = maybe_project_dirs {
54        project_dirs.data_local_dir().to_path_buf()
55    } else {
56        panic!("No known data directory")
57    }
58}
59
60pub fn get_data_dir() -> PathBuf {
61    if let Some(data_dir) = DATA_FOLDER.clone() {
62        data_dir
63    } else {
64        project_directory()
65    }
66}
67
68#[derive(Clone, Debug, Default, Deserialize)]
69pub struct CliConfig {
70    #[serde(default = "default_execution_config")]
71    pub execution: ExecutionConfig,
72}
73
74#[cfg(feature = "tui")]
75#[derive(Clone, Debug, Default, Deserialize)]
76pub struct TuiConfig {
77    #[serde(default = "default_execution_config")]
78    pub execution: ExecutionConfig,
79    #[serde(default = "default_display_config")]
80    pub display: DisplayConfig,
81    #[serde(default = "default_interaction_config")]
82    pub interaction: InteractionConfig,
83    #[serde(default = "default_editor_config")]
84    pub editor: EditorConfig,
85}
86
87#[cfg(feature = "flightsql")]
88#[derive(Clone, Debug, Deserialize)]
89pub struct FlightSQLServerConfig {
90    #[serde(default = "default_execution_config")]
91    pub execution: ExecutionConfig,
92    #[serde(default = "default_connection_url")]
93    pub connection_url: String,
94    #[serde(default = "default_server_metrics_addr")]
95    pub server_metrics_addr: SocketAddr,
96    #[serde(default = "default_auth_config")]
97    pub auth: AuthConfig,
98}
99
100#[cfg(feature = "flightsql")]
101impl Default for FlightSQLServerConfig {
102    fn default() -> Self {
103        Self {
104            execution: default_execution_config(),
105            connection_url: default_connection_url(),
106            server_metrics_addr: default_server_metrics_addr(),
107            auth: default_auth_config(),
108        }
109    }
110}
111
112#[cfg(feature = "flightsql")]
113#[derive(Clone, Debug, Deserialize)]
114pub struct FlightSQLClientConfig {
115    #[serde(default = "default_connection_url")]
116    pub connection_url: String,
117    #[serde(default = "default_benchmark_iterations")]
118    pub benchmark_iterations: usize,
119    #[serde(default = "default_auth_config")]
120    pub auth: AuthConfig,
121    #[serde(default = "default_headers")]
122    pub headers: HashMap<String, String>,
123    #[serde(default)]
124    pub headers_file: Option<PathBuf>,
125}
126
127#[cfg(feature = "flightsql")]
128impl Default for FlightSQLClientConfig {
129    fn default() -> Self {
130        Self {
131            connection_url: default_connection_url(),
132            benchmark_iterations: default_benchmark_iterations(),
133            auth: default_auth_config(),
134            headers: default_headers(),
135            headers_file: None,
136        }
137    }
138}
139
140#[cfg(feature = "http")]
141#[derive(Clone, Debug, Deserialize)]
142pub struct HttpServerConfig {
143    #[serde(default = "default_execution_config")]
144    pub execution: ExecutionConfig,
145    #[serde(default = "default_connection_url")]
146    pub connection_url: String,
147    #[serde(default = "default_server_metrics_addr")]
148    pub server_metrics_addr: SocketAddr,
149    #[serde(default = "default_auth_config")]
150    pub auth: AuthConfig,
151    #[serde(default = "default_timeout_seconds")]
152    pub timeout_seconds: u64,
153    #[serde(default = "default_result_limit")]
154    pub result_limit: usize,
155}
156
157#[cfg(feature = "http")]
158impl Default for HttpServerConfig {
159    fn default() -> Self {
160        Self {
161            execution: default_execution_config(),
162            connection_url: default_connection_url(),
163            server_metrics_addr: default_server_metrics_addr(),
164            auth: default_auth_config(),
165            timeout_seconds: default_timeout_seconds(),
166            result_limit: default_result_limit(),
167        }
168    }
169}
170
171#[derive(Clone, Debug, Default, Deserialize)]
172pub struct AppConfig {
173    #[serde(default)]
174    pub shared: ExecutionConfig,
175    #[serde(default)]
176    pub cli: CliConfig,
177    #[cfg(feature = "tui")]
178    #[serde(default)]
179    pub tui: TuiConfig,
180    #[cfg(feature = "flightsql")]
181    #[serde(default)]
182    pub flightsql_client: FlightSQLClientConfig,
183    #[cfg(feature = "flightsql")]
184    #[serde(default)]
185    pub flightsql_server: FlightSQLServerConfig,
186    #[cfg(feature = "http")]
187    #[serde(default)]
188    pub http_server: HttpServerConfig,
189    #[serde(default = "default_db_config")]
190    pub db: DbConfig,
191}
192
193fn default_execution_config() -> ExecutionConfig {
194    ExecutionConfig::default()
195}
196
197#[cfg(feature = "tui")]
198fn default_display_config() -> DisplayConfig {
199    DisplayConfig::default()
200}
201
202#[cfg(feature = "tui")]
203fn default_interaction_config() -> InteractionConfig {
204    InteractionConfig::default()
205}
206
207#[derive(Debug, Clone, Deserialize)]
208pub struct DbConfig {
209    #[serde(default = "default_db_path")]
210    pub path: Url,
211}
212
213impl Default for DbConfig {
214    fn default() -> Self {
215        default_db_config()
216    }
217}
218
219fn default_db_config() -> DbConfig {
220    DbConfig {
221        path: default_db_path(),
222    }
223}
224
225#[allow(unused)]
226fn default_db_path() -> Url {
227    let base = directories::BaseDirs::new().expect("Base directories should be available");
228    let path = base
229        .data_dir()
230        .to_path_buf()
231        .join("dft/")
232        .to_str()
233        .unwrap()
234        .to_string();
235    let with_schema = format!("file://{path}");
236    Url::parse(&with_schema).unwrap()
237}
238
239#[cfg(feature = "tui")]
240#[derive(Clone, Debug, Deserialize)]
241pub struct DisplayConfig {
242    #[serde(default = "default_frame_rate")]
243    pub frame_rate: f64,
244}
245
246#[cfg(feature = "tui")]
247fn default_frame_rate() -> f64 {
248    30.0
249}
250
251#[cfg(feature = "tui")]
252impl Default for DisplayConfig {
253    fn default() -> Self {
254        Self { frame_rate: 30.0 }
255    }
256}
257
258#[cfg(feature = "flightsql")]
259fn default_benchmark_iterations() -> usize {
260    10
261}
262
263#[cfg(feature = "tui")]
264#[derive(Clone, Debug, Default, Deserialize)]
265pub struct InteractionConfig {
266    #[serde(default = "default_mouse")]
267    pub mouse: bool,
268    #[serde(default = "default_paste")]
269    pub paste: bool,
270}
271
272#[cfg(feature = "tui")]
273fn default_mouse() -> bool {
274    false
275}
276
277#[cfg(feature = "tui")]
278fn default_paste() -> bool {
279    false
280}
281
282#[cfg(any(feature = "flightsql", feature = "http"))]
283pub fn default_connection_url() -> String {
284    "http://localhost:50051".to_string()
285}
286
287#[cfg(feature = "flightsql")]
288pub fn default_headers() -> HashMap<String, String> {
289    HashMap::new()
290}
291
292#[cfg(any(feature = "flightsql", feature = "http"))]
293fn default_server_metrics_addr() -> SocketAddr {
294    SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9000)
295}
296
297#[cfg(feature = "tui")]
298#[derive(Clone, Debug, Default, Deserialize)]
299pub struct EditorConfig {
300    pub experimental_syntax_highlighting: bool,
301}
302
303#[cfg(feature = "tui")]
304fn default_editor_config() -> EditorConfig {
305    EditorConfig::default()
306}
307
308#[cfg(any(feature = "flightsql", feature = "http"))]
309fn default_auth_config() -> AuthConfig {
310    AuthConfig::default()
311}
312
313#[cfg(feature = "http")]
314fn default_timeout_seconds() -> u64 {
315    10
316}
317
318#[cfg(feature = "http")]
319fn default_result_limit() -> usize {
320    1000
321}
322
323pub fn create_config(config_path: PathBuf) -> AppConfig {
324    if config_path.exists() {
325        debug!("Config exists");
326        let maybe_config_contents = std::fs::read_to_string(config_path);
327        if let Ok(config_contents) = maybe_config_contents {
328            let maybe_parsed_config: std::result::Result<AppConfig, toml::de::Error> =
329                toml::from_str(&config_contents);
330            match maybe_parsed_config {
331                Ok(parsed_config) => {
332                    debug!("Parsed config: {:?}", parsed_config);
333                    parsed_config
334                }
335                Err(err) => {
336                    error!("Error parsing config: {:?}", err);
337                    AppConfig::default()
338                }
339            }
340        } else {
341            AppConfig::default()
342        }
343    } else {
344        debug!("No config, using default");
345        AppConfig::default()
346    }
347}