1use crate::config::get_data_dir;
21use clap::{Parser, Subcommand};
22use http::{HeaderName, HeaderValue};
23#[cfg(any(feature = "http", feature = "flightsql"))]
24use std::net::SocketAddr;
25use std::path::{Path, PathBuf};
26
27const LONG_ABOUT: &str = "
28dft - DataFusion TUI
29
30CLI and terminal UI data analysis tool using Apache DataFusion as query
31execution engine.
32
33dft provides a rich terminal UI as well as a broad array of pre-integrated
34data sources and formats for querying and analyzing data.
35
36Environment Variables
37RUST_LOG { trace | debug | info | error }: Standard rust logging level. Default is info.
38";
39
40#[derive(Clone, Debug, Parser, Default)]
41#[command(author, version, about, long_about = LONG_ABOUT)]
42pub struct DftArgs {
43 #[clap(
44 short,
45 long,
46 num_args = 0..,
47 help = "Execute commands from file(s), then exit",
48 value_parser(parse_valid_file)
49 )]
50 pub files: Vec<PathBuf>,
51
52 #[clap(
53 short = 'c',
54 long,
55 num_args = 0..,
56 help = "Execute the given SQL string(s), then exit.",
57 value_parser(parse_command)
58 )]
59 pub commands: Vec<String>,
60
61 #[clap(long, global = true, help = "Path to the configuration file")]
62 pub config: Option<String>,
63
64 #[clap(
65 long,
66 short = 'q',
67 help = "Use the FlightSQL client defined in your config"
68 )]
69 pub flightsql: bool,
70
71 #[clap(long, help = "Run DDL prior to executing")]
72 pub run_ddl: bool,
73
74 #[clap(long, short, help = "Only show how long the query took to run")]
75 pub time: bool,
76
77 #[clap(long, short, help = "Benchmark the provided query")]
78 pub bench: bool,
79
80 #[clap(
81 long,
82 help = "Print a summary of the query's execution plan and statistics"
83 )]
84 pub analyze: bool,
85
86 #[clap(long, help = "Run the provided query before running the benchmark")]
87 pub run_before: Option<String>,
88
89 #[clap(long, help = "Save the benchmark results to a file")]
90 pub save: Option<PathBuf>,
91
92 #[clap(long, help = "Append the benchmark results to an existing file")]
93 pub append: bool,
94
95 #[clap(short = 'n', help = "Set the number of benchmark iterations to run")]
96 pub benchmark_iterations: Option<usize>,
97
98 #[clap(long, help = "Run benchmark iterations concurrently/in parallel")]
99 pub concurrent: bool,
100
101 #[clap(long, help = "Host address to query. Only used for FlightSQL")]
102 pub host: Option<String>,
103
104 #[clap(
105 long,
106 help = "Header to add to Flight SQL connection. Only used for FlightSQL",
107 value_parser(parse_header_line),
108 action = clap::ArgAction::Append
109 )]
110 pub header: Option<Vec<(String, String)>>,
111
112 #[clap(
113 long,
114 help = "Path to file containing Flight SQL headers. Supports simple format ('Name: Value') and curl config format ('header = Name: Value' or '-H \"Name: Value\"'). Only used for FlightSQL"
115 )]
116 pub headers_file: Option<PathBuf>,
117
118 #[clap(
119 long,
120 short,
121 help = "Path to save output to. Type is inferred from file suffix"
122 )]
123 pub output: Option<PathBuf>,
124
125 #[command(subcommand)]
126 pub command: Option<Command>,
127}
128
129impl DftArgs {
130 pub fn config_path(&self) -> PathBuf {
131 #[cfg(feature = "flightsql")]
132 if let Some(Command::ServeFlightSql {
133 config: Some(cfg), ..
134 }) = &self.command
135 {
136 return Path::new(cfg).to_path_buf();
137 }
138 if let Some(config) = self.config.as_ref() {
139 Path::new(config).to_path_buf()
140 } else {
141 let mut config = get_data_dir();
142 config.push("config.toml");
143 config
144 }
145 }
146}
147
148#[derive(Clone, Debug, Subcommand)]
150pub enum FlightSqlCommand {
151 StatementQuery {
153 #[clap(long)]
155 sql: String,
156 },
157 GetCatalogs,
159 GetDbSchemas {
161 #[clap(long)]
163 catalog: Option<String>,
164 #[clap(long)]
166 db_schema_filter_pattern: Option<String>,
167 },
168 GetTables {
170 #[clap(long)]
172 catalog: Option<String>,
173 #[clap(long)]
175 db_schema_filter_pattern: Option<String>,
176 #[clap(long)]
178 table_name_filter_pattern: Option<String>,
179 #[clap(long)]
181 table_types: Option<Vec<String>>,
182 },
183 GetTableTypes,
185 GetSqlInfo {
187 #[clap(long)]
189 info: Option<Vec<u32>>,
190 },
191 GetXdbcTypeInfo {
193 #[clap(long)]
195 data_type: Option<i32>,
196 },
197}
198
199#[derive(Clone, Debug, Subcommand)]
200pub enum Command {
201 #[cfg(feature = "http")]
203 ServeHttp {
204 #[clap(short, long)]
205 config: Option<String>,
206 #[clap(long, help = "Set the port to be used for server")]
207 addr: Option<SocketAddr>,
208 #[clap(long, help = "Set the port to be used for serving metrics")]
209 metrics_addr: Option<SocketAddr>,
210 },
211 #[cfg(feature = "flightsql")]
213 #[command(name = "flightsql")]
214 FlightSql {
215 #[clap(subcommand)]
216 command: FlightSqlCommand,
217 },
218 #[cfg(feature = "flightsql")]
220 #[command(name = "serve-flightsql")]
221 ServeFlightSql {
222 #[clap(short, long)]
223 config: Option<String>,
224 #[clap(long, help = "Set the port to be used for server")]
225 addr: Option<SocketAddr>,
226 #[clap(long, help = "Set the port to be used for serving metrics")]
227 metrics_addr: Option<SocketAddr>,
228 },
229 GenerateTpch {
230 #[clap(long, default_value = "1.0")]
231 scale_factor: f64,
232 #[clap(long, default_value = "parquet")]
233 format: TpchFormat,
234 },
235}
236
237#[derive(Clone, Debug, clap::ValueEnum)]
238pub enum TpchFormat {
239 Parquet,
240 #[cfg(feature = "vortex")]
241 Vortex,
242}
243
244fn parse_valid_file(file: &str) -> std::result::Result<PathBuf, String> {
245 let path = PathBuf::from(file);
246 if !path.exists() {
247 Err(format!("File does not exist: '{file}'"))
248 } else if !path.is_file() {
249 Err(format!("Exists but is not a file: '{file}'"))
250 } else {
251 Ok(path)
252 }
253}
254
255fn parse_command(command: &str) -> std::result::Result<String, String> {
256 if !command.is_empty() {
257 Ok(command.to_string())
258 } else {
259 Err("-c flag expects only non empty commands".to_string())
260 }
261}
262
263fn parse_header_line(line: &str) -> Result<(String, String), String> {
264 let (name, value) = line
265 .split_once(':')
266 .ok_or_else(|| format!("Invalid header format: '{}'\n Expected format: 'Header-Name: Header-Value', 'header = Name: Value', or '-H \"Name: Value\"'", line))?;
267
268 let name =
269 HeaderName::try_from(name.trim()).map_err(|e| format!("Invalid header name: {}", e))?;
270 let value =
271 HeaderValue::try_from(value.trim()).map_err(|e| format!("Invalid header value: {}", e))?;
272
273 let value_str = value
274 .to_str()
275 .map_err(|e| format!("Header value contains invalid characters: {}", e))?;
276
277 Ok((name.to_string(), value_str.to_string()))
278}
279
280pub fn parse_headers_file(path: &Path) -> Result<Vec<(String, String)>, String> {
290 let content = std::fs::read_to_string(path)
291 .map_err(|e| format!("Failed to read headers file '{}': {}", path.display(), e))?;
292
293 let mut headers = Vec::new();
294 for (line_num, line) in content.lines().enumerate() {
295 let line = line.trim();
296
297 if line.is_empty() || line.starts_with('#') {
299 continue;
300 }
301
302 let header_value = if let Some(stripped) = line.strip_prefix("header") {
304 let stripped = stripped.trim_start();
306 if let Some(value) = stripped.strip_prefix('=') {
307 value.trim()
308 } else {
309 line }
311 } else if let Some(stripped) = line.strip_prefix("-H") {
312 let stripped = stripped.trim();
314 stripped.trim_matches(|c| c == '"' || c == '\'')
316 } else {
317 line
319 };
320
321 match parse_header_line(header_value) {
323 Ok(header) => headers.push(header),
324 Err(e) => {
325 return Err(format!(
326 "Invalid header format at line {} in '{}': '{}'\n{}",
327 line_num + 1,
328 path.display(),
329 line,
330 e
331 ));
332 }
333 }
334 }
335
336 Ok(headers)
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use std::io::Write;
343 use tempfile::NamedTempFile;
344
345 #[test]
346 fn test_parse_headers_file_simple_format() {
347 let mut file = NamedTempFile::new().unwrap();
348 writeln!(file, "x-api-key: secret123").unwrap();
349 writeln!(file, "database: production").unwrap();
350 file.flush().unwrap();
351
352 let headers = parse_headers_file(file.path()).unwrap();
353 assert_eq!(headers.len(), 2);
354 assert_eq!(
355 headers[0],
356 ("x-api-key".to_string(), "secret123".to_string())
357 );
358 assert_eq!(
359 headers[1],
360 ("database".to_string(), "production".to_string())
361 );
362 }
363
364 #[test]
365 fn test_parse_headers_file_curl_format() {
366 let mut file = NamedTempFile::new().unwrap();
367 writeln!(file, "header = x-api-key: secret123").unwrap();
368 writeln!(file, "-H \"database: production\"").unwrap();
369 file.flush().unwrap();
370
371 let headers = parse_headers_file(file.path()).unwrap();
372 assert_eq!(headers.len(), 2);
373 assert_eq!(
374 headers[0],
375 ("x-api-key".to_string(), "secret123".to_string())
376 );
377 assert_eq!(
378 headers[1],
379 ("database".to_string(), "production".to_string())
380 );
381 }
382
383 #[test]
384 fn test_parse_headers_file_mixed_format() {
385 let mut file = NamedTempFile::new().unwrap();
386 writeln!(file, "# Simple format").unwrap();
387 writeln!(file, "x-test: value1").unwrap();
388 writeln!(file, "").unwrap();
389 writeln!(file, "# Curl config format").unwrap();
390 writeln!(file, "header = x-api-key: secret123").unwrap();
391 writeln!(file, "-H \"database: production\"").unwrap();
392 file.flush().unwrap();
393
394 let headers = parse_headers_file(file.path()).unwrap();
395 assert_eq!(headers.len(), 3);
396 assert_eq!(headers[0], ("x-test".to_string(), "value1".to_string()));
397 assert_eq!(
398 headers[1],
399 ("x-api-key".to_string(), "secret123".to_string())
400 );
401 assert_eq!(
402 headers[2],
403 ("database".to_string(), "production".to_string())
404 );
405 }
406
407 #[test]
408 fn test_parse_headers_file_with_comments() {
409 let mut file = NamedTempFile::new().unwrap();
410 writeln!(file, "# This is a comment").unwrap();
411 writeln!(file, "x-api-key: secret123").unwrap();
412 writeln!(file, "# Another comment").unwrap();
413 writeln!(file, "database: production").unwrap();
414 file.flush().unwrap();
415
416 let headers = parse_headers_file(file.path()).unwrap();
417 assert_eq!(headers.len(), 2);
418 assert_eq!(
419 headers[0],
420 ("x-api-key".to_string(), "secret123".to_string())
421 );
422 assert_eq!(
423 headers[1],
424 ("database".to_string(), "production".to_string())
425 );
426 }
427
428 #[test]
429 fn test_parse_headers_file_blank_lines() {
430 let mut file = NamedTempFile::new().unwrap();
431 writeln!(file, "x-api-key: secret123").unwrap();
432 writeln!(file, "").unwrap();
433 writeln!(file, " ").unwrap();
434 writeln!(file, "database: production").unwrap();
435 file.flush().unwrap();
436
437 let headers = parse_headers_file(file.path()).unwrap();
438 assert_eq!(headers.len(), 2);
439 assert_eq!(
440 headers[0],
441 ("x-api-key".to_string(), "secret123".to_string())
442 );
443 assert_eq!(
444 headers[1],
445 ("database".to_string(), "production".to_string())
446 );
447 }
448
449 #[test]
450 fn test_parse_headers_file_curl_with_quotes() {
451 let mut file = NamedTempFile::new().unwrap();
452 writeln!(file, "-H \"x-api-key: secret123\"").unwrap();
453 writeln!(file, "-H 'database: production'").unwrap();
454 file.flush().unwrap();
455
456 let headers = parse_headers_file(file.path()).unwrap();
457 assert_eq!(headers.len(), 2);
458 assert_eq!(
459 headers[0],
460 ("x-api-key".to_string(), "secret123".to_string())
461 );
462 assert_eq!(
463 headers[1],
464 ("database".to_string(), "production".to_string())
465 );
466 }
467
468 #[test]
469 fn test_parse_headers_file_invalid_format() {
470 let mut file = NamedTempFile::new().unwrap();
471 writeln!(file, "x-api-key: secret123").unwrap();
472 writeln!(file, "invalid-line-without-colon").unwrap();
473 file.flush().unwrap();
474
475 let result = parse_headers_file(file.path());
476 assert!(result.is_err());
477 assert!(result
478 .unwrap_err()
479 .contains("Invalid header format at line 2"));
480 }
481
482 #[test]
483 fn test_parse_headers_file_not_found() {
484 let result = parse_headers_file(Path::new("/nonexistent/file.txt"));
485 assert!(result.is_err());
486 assert!(result.unwrap_err().contains("Failed to read headers file"));
487 }
488}