1use std::path::PathBuf;
21
22#[cfg(any(feature = "flightsql", feature = "http"))]
23use std::net::{IpAddr, Ipv4Addr, SocketAddr};
24
25use datafusion_app::config::ExecutionConfig;
26use directories::{ProjectDirs, UserDirs};
27use lazy_static::lazy_static;
28use log::{debug, error};
29use serde::Deserialize;
30
31#[cfg(any(feature = "flightsql", feature = "http"))]
32use datafusion_app::config::AuthConfig;
33#[cfg(feature = "flightsql")]
34use std::collections::HashMap;
35use url::Url;
36
37lazy_static! {
38 pub static ref PROJECT_NAME: String = env!("CARGO_CRATE_NAME").to_uppercase().to_string();
39 pub static ref DATA_FOLDER: Option<PathBuf> =
40 std::env::var(format!("{}_DATA", PROJECT_NAME.clone()))
41 .ok()
42 .map(PathBuf::from);
43 pub static ref LOG_ENV: String = format!("{}_LOGLEVEL", PROJECT_NAME.clone());
44 pub static ref LOG_FILE: String = format!("{}.log", env!("CARGO_PKG_NAME"));
45}
46
47fn project_directory() -> PathBuf {
48 if let Some(user_dirs) = UserDirs::new() {
49 return user_dirs.home_dir().join(".config").join("dft");
50 };
51
52 let maybe_project_dirs = ProjectDirs::from("", "", env!("CARGO_PKG_NAME"));
53 if let Some(project_dirs) = maybe_project_dirs {
54 project_dirs.data_local_dir().to_path_buf()
55 } else {
56 panic!("No known data directory")
57 }
58}
59
60pub fn get_data_dir() -> PathBuf {
61 if let Some(data_dir) = DATA_FOLDER.clone() {
62 data_dir
63 } else {
64 project_directory()
65 }
66}
67
68#[derive(Clone, Debug, Default, Deserialize)]
69pub struct CliConfig {
70 #[serde(default = "default_execution_config")]
71 pub execution: ExecutionConfig,
72}
73
74#[cfg(feature = "tui")]
75#[derive(Clone, Debug, Default, Deserialize)]
76pub struct TuiConfig {
77 #[serde(default = "default_execution_config")]
78 pub execution: ExecutionConfig,
79 #[serde(default = "default_display_config")]
80 pub display: DisplayConfig,
81 #[serde(default = "default_interaction_config")]
82 pub interaction: InteractionConfig,
83 #[serde(default = "default_editor_config")]
84 pub editor: EditorConfig,
85}
86
87#[cfg(feature = "flightsql")]
88#[derive(Clone, Debug, Deserialize)]
89pub struct FlightSQLServerConfig {
90 #[serde(default = "default_execution_config")]
91 pub execution: ExecutionConfig,
92 #[serde(default = "default_connection_url")]
93 pub connection_url: String,
94 #[serde(default = "default_server_metrics_addr")]
95 pub server_metrics_addr: SocketAddr,
96 #[serde(default = "default_auth_config")]
97 pub auth: AuthConfig,
98}
99
100#[cfg(feature = "flightsql")]
101impl Default for FlightSQLServerConfig {
102 fn default() -> Self {
103 Self {
104 execution: default_execution_config(),
105 connection_url: default_connection_url(),
106 server_metrics_addr: default_server_metrics_addr(),
107 auth: default_auth_config(),
108 }
109 }
110}
111
112#[cfg(feature = "flightsql")]
113#[derive(Clone, Debug, Deserialize)]
114pub struct FlightSQLClientConfig {
115 #[serde(default = "default_connection_url")]
116 pub connection_url: String,
117 #[serde(default = "default_benchmark_iterations")]
118 pub benchmark_iterations: usize,
119 #[serde(default = "default_auth_config")]
120 pub auth: AuthConfig,
121 #[serde(default = "default_headers")]
122 pub headers: HashMap<String, String>,
123 #[serde(default)]
124 pub headers_file: Option<PathBuf>,
125}
126
127#[cfg(feature = "flightsql")]
128impl Default for FlightSQLClientConfig {
129 fn default() -> Self {
130 Self {
131 connection_url: default_connection_url(),
132 benchmark_iterations: default_benchmark_iterations(),
133 auth: default_auth_config(),
134 headers: default_headers(),
135 headers_file: None,
136 }
137 }
138}
139
140#[cfg(feature = "http")]
141#[derive(Clone, Debug, Deserialize)]
142pub struct HttpServerConfig {
143 #[serde(default = "default_execution_config")]
144 pub execution: ExecutionConfig,
145 #[serde(default = "default_connection_url")]
146 pub connection_url: String,
147 #[serde(default = "default_server_metrics_addr")]
148 pub server_metrics_addr: SocketAddr,
149 #[serde(default = "default_auth_config")]
150 pub auth: AuthConfig,
151 #[serde(default = "default_timeout_seconds")]
152 pub timeout_seconds: u64,
153 #[serde(default = "default_result_limit")]
154 pub result_limit: usize,
155}
156
157#[cfg(feature = "http")]
158impl Default for HttpServerConfig {
159 fn default() -> Self {
160 Self {
161 execution: default_execution_config(),
162 connection_url: default_connection_url(),
163 server_metrics_addr: default_server_metrics_addr(),
164 auth: default_auth_config(),
165 timeout_seconds: default_timeout_seconds(),
166 result_limit: default_result_limit(),
167 }
168 }
169}
170
171#[derive(Clone, Debug, Default, Deserialize)]
172pub struct AppConfig {
173 #[serde(default)]
174 pub shared: ExecutionConfig,
175 #[serde(default)]
176 pub cli: CliConfig,
177 #[cfg(feature = "tui")]
178 #[serde(default)]
179 pub tui: TuiConfig,
180 #[cfg(feature = "flightsql")]
181 #[serde(default)]
182 pub flightsql_client: FlightSQLClientConfig,
183 #[cfg(feature = "flightsql")]
184 #[serde(default)]
185 pub flightsql_server: FlightSQLServerConfig,
186 #[cfg(feature = "http")]
187 #[serde(default)]
188 pub http_server: HttpServerConfig,
189 #[serde(default = "default_db_config")]
190 pub db: DbConfig,
191}
192
193fn default_execution_config() -> ExecutionConfig {
194 ExecutionConfig::default()
195}
196
197#[cfg(feature = "tui")]
198fn default_display_config() -> DisplayConfig {
199 DisplayConfig::default()
200}
201
202#[cfg(feature = "tui")]
203fn default_interaction_config() -> InteractionConfig {
204 InteractionConfig::default()
205}
206
207#[derive(Debug, Clone, Deserialize)]
208pub struct DbConfig {
209 #[serde(default = "default_db_path")]
210 pub path: Url,
211}
212
213impl Default for DbConfig {
214 fn default() -> Self {
215 default_db_config()
216 }
217}
218
219fn default_db_config() -> DbConfig {
220 DbConfig {
221 path: default_db_path(),
222 }
223}
224
225#[allow(unused)]
226fn default_db_path() -> Url {
227 let base = directories::BaseDirs::new().expect("Base directories should be available");
228 let path = base
229 .data_dir()
230 .to_path_buf()
231 .join("dft/")
232 .to_str()
233 .unwrap()
234 .to_string();
235 let with_schema = format!("file://{path}");
236 Url::parse(&with_schema).unwrap()
237}
238
239#[cfg(feature = "tui")]
240#[derive(Clone, Debug, Deserialize)]
241pub struct DisplayConfig {
242 #[serde(default = "default_frame_rate")]
243 pub frame_rate: f64,
244}
245
246#[cfg(feature = "tui")]
247fn default_frame_rate() -> f64 {
248 30.0
249}
250
251#[cfg(feature = "tui")]
252impl Default for DisplayConfig {
253 fn default() -> Self {
254 Self { frame_rate: 30.0 }
255 }
256}
257
258#[cfg(feature = "flightsql")]
259fn default_benchmark_iterations() -> usize {
260 10
261}
262
263#[cfg(feature = "tui")]
264#[derive(Clone, Debug, Default, Deserialize)]
265pub struct InteractionConfig {
266 #[serde(default = "default_mouse")]
267 pub mouse: bool,
268 #[serde(default = "default_paste")]
269 pub paste: bool,
270}
271
272#[cfg(feature = "tui")]
273fn default_mouse() -> bool {
274 false
275}
276
277#[cfg(feature = "tui")]
278fn default_paste() -> bool {
279 false
280}
281
282#[cfg(any(feature = "flightsql", feature = "http"))]
283pub fn default_connection_url() -> String {
284 "http://localhost:50051".to_string()
285}
286
287#[cfg(feature = "flightsql")]
288pub fn default_headers() -> HashMap<String, String> {
289 HashMap::new()
290}
291
292#[cfg(any(feature = "flightsql", feature = "http"))]
293fn default_server_metrics_addr() -> SocketAddr {
294 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9000)
295}
296
297#[cfg(feature = "tui")]
298#[derive(Clone, Debug, Default, Deserialize)]
299pub struct EditorConfig {
300 pub experimental_syntax_highlighting: bool,
301}
302
303#[cfg(feature = "tui")]
304fn default_editor_config() -> EditorConfig {
305 EditorConfig::default()
306}
307
308#[cfg(any(feature = "flightsql", feature = "http"))]
309fn default_auth_config() -> AuthConfig {
310 AuthConfig::default()
311}
312
313#[cfg(feature = "http")]
314fn default_timeout_seconds() -> u64 {
315 10
316}
317
318#[cfg(feature = "http")]
319fn default_result_limit() -> usize {
320 1000
321}
322
323pub fn create_config(config_path: PathBuf) -> AppConfig {
324 if config_path.exists() {
325 debug!("Config exists");
326 let maybe_config_contents = std::fs::read_to_string(config_path);
327 if let Ok(config_contents) = maybe_config_contents {
328 let maybe_parsed_config: std::result::Result<AppConfig, toml::de::Error> =
329 toml::from_str(&config_contents);
330 match maybe_parsed_config {
331 Ok(parsed_config) => {
332 debug!("Parsed config: {:?}", parsed_config);
333 parsed_config
334 }
335 Err(err) => {
336 error!("Error parsing config: {:?}", err);
337 AppConfig::default()
338 }
339 }
340 } else {
341 AppConfig::default()
342 }
343 } else {
344 debug!("No config, using default");
345 AppConfig::default()
346 }
347}