Skip to main content

datafusion_dft/
args.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Command line argument parsing: [`DftArgs`]
19
20use crate::config::get_data_dir;
21use clap::{Parser, Subcommand};
22use http::{HeaderName, HeaderValue};
23#[cfg(any(feature = "http", feature = "flightsql"))]
24use std::net::SocketAddr;
25use std::path::{Path, PathBuf};
26
27const LONG_ABOUT: &str = "
28dft - DataFusion TUI
29
30CLI and terminal UI data analysis tool using Apache DataFusion as query
31execution engine.
32
33dft provides a rich terminal UI as well as a broad array of pre-integrated
34data sources and formats for querying and analyzing data.
35
36Environment Variables
37RUST_LOG { trace | debug | info | error }: Standard rust logging level.  Default is info.
38";
39
40#[derive(Clone, Debug, Parser, Default)]
41#[command(author, version, about, long_about = LONG_ABOUT)]
42pub struct DftArgs {
43    #[clap(
44        short,
45        long,
46        num_args = 0..,
47        help = "Execute commands from file(s), then exit",
48        value_parser(parse_valid_file)
49    )]
50    pub files: Vec<PathBuf>,
51
52    #[clap(
53        short = 'c',
54        long,
55        num_args = 0..,
56        help = "Execute the given SQL string(s), then exit.",
57        value_parser(parse_command)
58    )]
59    pub commands: Vec<String>,
60
61    #[clap(long, global = true, help = "Path to the configuration file")]
62    pub config: Option<String>,
63
64    #[clap(
65        long,
66        short = 'q',
67        help = "Use the FlightSQL client defined in your config"
68    )]
69    pub flightsql: bool,
70
71    #[clap(long, help = "Run DDL prior to executing")]
72    pub run_ddl: bool,
73
74    #[clap(long, short, help = "Only show how long the query took to run")]
75    pub time: bool,
76
77    #[clap(long, short, help = "Benchmark the provided query")]
78    pub bench: bool,
79
80    #[clap(
81        long,
82        help = "Print a summary of the query's execution plan and statistics"
83    )]
84    pub analyze: bool,
85
86    #[clap(long, help = "Run the provided query before running the benchmark")]
87    pub run_before: Option<String>,
88
89    #[clap(long, help = "Save the benchmark results to a file")]
90    pub save: Option<PathBuf>,
91
92    #[clap(long, help = "Append the benchmark results to an existing file")]
93    pub append: bool,
94
95    #[clap(short = 'n', help = "Set the number of benchmark iterations to run")]
96    pub benchmark_iterations: Option<usize>,
97
98    #[clap(long, help = "Run benchmark iterations concurrently/in parallel")]
99    pub concurrent: bool,
100
101    #[clap(long, help = "Host address to query. Only used for FlightSQL")]
102    pub host: Option<String>,
103
104    #[clap(
105        long,
106        help = "Header to add to Flight SQL connection. Only used for FlightSQL",
107        value_parser(parse_header_line),
108        action = clap::ArgAction::Append
109    )]
110    pub header: Option<Vec<(String, String)>>,
111
112    #[clap(
113        long,
114        help = "Path to file containing Flight SQL headers. Supports simple format ('Name: Value') and curl config format ('header = Name: Value' or '-H \"Name: Value\"'). Only used for FlightSQL"
115    )]
116    pub headers_file: Option<PathBuf>,
117
118    #[clap(
119        long,
120        short,
121        help = "Path to save output to. Type is inferred from file suffix"
122    )]
123    pub output: Option<PathBuf>,
124
125    #[command(subcommand)]
126    pub command: Option<Command>,
127}
128
129impl DftArgs {
130    pub fn config_path(&self) -> PathBuf {
131        #[cfg(feature = "flightsql")]
132        if let Some(Command::ServeFlightSql {
133            config: Some(cfg), ..
134        }) = &self.command
135        {
136            return Path::new(cfg).to_path_buf();
137        }
138        if let Some(config) = self.config.as_ref() {
139            Path::new(config).to_path_buf()
140        } else {
141            let mut config = get_data_dir();
142            config.push("config.toml");
143            config
144        }
145    }
146}
147
148/// Parameters for each command match to exactly how they are defined in specification (https://arrow.apache.org/docs/format/FlightSql.html#protocol-buffer-definitions)
149#[derive(Clone, Debug, Subcommand)]
150pub enum FlightSqlCommand {
151    /// Executes `CommandStatementQuery` and `DoGet` to return results
152    StatementQuery {
153        /// The query to execute
154        #[clap(long)]
155        sql: String,
156    },
157    /// Executes `CommandGetCatalogs` and `DoGet` to return results
158    GetCatalogs,
159    /// Executes `CommandGetDbSchemas` and `DoGet` to return results
160    GetDbSchemas {
161        /// The catalog to retrieve schemas
162        #[clap(long)]
163        catalog: Option<String>,
164        /// Schema filter pattern to apply
165        #[clap(long)]
166        db_schema_filter_pattern: Option<String>,
167    },
168    /// Executes `CommandGetDbSchemas` and `DoGet` to return results
169    GetTables {
170        /// The catalog to retrieve schemas
171        #[clap(long)]
172        catalog: Option<String>,
173        /// Schema filter pattern to apply
174        #[clap(long)]
175        db_schema_filter_pattern: Option<String>,
176        /// Table name filter pattern to apply
177        #[clap(long)]
178        table_name_filter_pattern: Option<String>,
179        /// Specific table types to return
180        #[clap(long)]
181        table_types: Option<Vec<String>>,
182    },
183    /// Executes `CommandGetTableTypes` and `DoGet` to return supported table types
184    GetTableTypes,
185    /// Executes `CommandGetSqlInfo` and `DoGet` to return server SQL capabilities
186    GetSqlInfo {
187        /// Specific SQL info IDs to retrieve (if not provided, returns all)
188        #[clap(long)]
189        info: Option<Vec<u32>>,
190    },
191    /// Executes `CommandGetXdbcTypeInfo` and `DoGet` to return type information
192    GetXdbcTypeInfo {
193        /// Optional data type to filter by
194        #[clap(long)]
195        data_type: Option<i32>,
196    },
197}
198
199#[derive(Clone, Debug, Subcommand)]
200pub enum Command {
201    /// Start a HTTP server
202    #[cfg(feature = "http")]
203    ServeHttp {
204        #[clap(short, long)]
205        config: Option<String>,
206        #[clap(long, help = "Set the port to be used for server")]
207        addr: Option<SocketAddr>,
208        #[clap(long, help = "Set the port to be used for serving metrics")]
209        metrics_addr: Option<SocketAddr>,
210    },
211    /// Make a request to a FlightSQL server
212    #[cfg(feature = "flightsql")]
213    #[command(name = "flightsql")]
214    FlightSql {
215        #[clap(subcommand)]
216        command: FlightSqlCommand,
217    },
218    /// Start a FlightSQL server
219    #[cfg(feature = "flightsql")]
220    #[command(name = "serve-flightsql")]
221    ServeFlightSql {
222        #[clap(short, long)]
223        config: Option<String>,
224        #[clap(long, help = "Set the port to be used for server")]
225        addr: Option<SocketAddr>,
226        #[clap(long, help = "Set the port to be used for serving metrics")]
227        metrics_addr: Option<SocketAddr>,
228    },
229    GenerateTpch {
230        #[clap(long, default_value = "1.0")]
231        scale_factor: f64,
232        #[clap(long, default_value = "parquet")]
233        format: TpchFormat,
234    },
235}
236
237#[derive(Clone, Debug, clap::ValueEnum)]
238pub enum TpchFormat {
239    Parquet,
240    #[cfg(feature = "vortex")]
241    Vortex,
242}
243
244fn parse_valid_file(file: &str) -> std::result::Result<PathBuf, String> {
245    let path = PathBuf::from(file);
246    if !path.exists() {
247        Err(format!("File does not exist: '{file}'"))
248    } else if !path.is_file() {
249        Err(format!("Exists but is not a file: '{file}'"))
250    } else {
251        Ok(path)
252    }
253}
254
255fn parse_command(command: &str) -> std::result::Result<String, String> {
256    if !command.is_empty() {
257        Ok(command.to_string())
258    } else {
259        Err("-c flag expects only non empty commands".to_string())
260    }
261}
262
263fn parse_header_line(line: &str) -> Result<(String, String), String> {
264    let (name, value) = line
265        .split_once(':')
266        .ok_or_else(|| format!("Invalid header format: '{}'\n       Expected format: 'Header-Name: Header-Value', 'header = Name: Value', or '-H \"Name: Value\"'", line))?;
267
268    let name =
269        HeaderName::try_from(name.trim()).map_err(|e| format!("Invalid header name: {}", e))?;
270    let value =
271        HeaderValue::try_from(value.trim()).map_err(|e| format!("Invalid header value: {}", e))?;
272
273    let value_str = value
274        .to_str()
275        .map_err(|e| format!("Header value contains invalid characters: {}", e))?;
276
277    Ok((name.to_string(), value_str.to_string()))
278}
279
280/// Parse headers from a file supporting both simple and curl config formats
281///
282/// Supported formats:
283/// - Simple: `Header-Name: Header-Value`
284/// - Curl config: `header = Name: Value` or `-H "Name: Value"`
285/// - Comments: Lines starting with `#`
286/// - Blank lines are ignored
287///
288/// Both formats can be mixed in the same file.
289pub fn parse_headers_file(path: &Path) -> Result<Vec<(String, String)>, String> {
290    let content = std::fs::read_to_string(path)
291        .map_err(|e| format!("Failed to read headers file '{}': {}", path.display(), e))?;
292
293    let mut headers = Vec::new();
294    for (line_num, line) in content.lines().enumerate() {
295        let line = line.trim();
296
297        // Skip comments and blank lines
298        if line.is_empty() || line.starts_with('#') {
299            continue;
300        }
301
302        // Detect and parse format
303        let header_value = if let Some(stripped) = line.strip_prefix("header") {
304            // Curl config format: "header = Name: Value" or "header=Name: Value"
305            let stripped = stripped.trim_start();
306            if let Some(value) = stripped.strip_prefix('=') {
307                value.trim()
308            } else {
309                line // Not curl format, try simple format
310            }
311        } else if let Some(stripped) = line.strip_prefix("-H") {
312            // Curl config format: -H "Name: Value" or -H Name: Value
313            let stripped = stripped.trim();
314            // Remove surrounding quotes if present
315            stripped.trim_matches(|c| c == '"' || c == '\'')
316        } else {
317            // Simple format: "Name: Value"
318            line
319        };
320
321        // Parse header line
322        match parse_header_line(header_value) {
323            Ok(header) => headers.push(header),
324            Err(e) => {
325                return Err(format!(
326                    "Invalid header format at line {} in '{}': '{}'\n{}",
327                    line_num + 1,
328                    path.display(),
329                    line,
330                    e
331                ));
332            }
333        }
334    }
335
336    Ok(headers)
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use std::io::Write;
343    use tempfile::NamedTempFile;
344
345    #[test]
346    fn test_parse_headers_file_simple_format() {
347        let mut file = NamedTempFile::new().unwrap();
348        writeln!(file, "x-api-key: secret123").unwrap();
349        writeln!(file, "database: production").unwrap();
350        file.flush().unwrap();
351
352        let headers = parse_headers_file(file.path()).unwrap();
353        assert_eq!(headers.len(), 2);
354        assert_eq!(
355            headers[0],
356            ("x-api-key".to_string(), "secret123".to_string())
357        );
358        assert_eq!(
359            headers[1],
360            ("database".to_string(), "production".to_string())
361        );
362    }
363
364    #[test]
365    fn test_parse_headers_file_curl_format() {
366        let mut file = NamedTempFile::new().unwrap();
367        writeln!(file, "header = x-api-key: secret123").unwrap();
368        writeln!(file, "-H \"database: production\"").unwrap();
369        file.flush().unwrap();
370
371        let headers = parse_headers_file(file.path()).unwrap();
372        assert_eq!(headers.len(), 2);
373        assert_eq!(
374            headers[0],
375            ("x-api-key".to_string(), "secret123".to_string())
376        );
377        assert_eq!(
378            headers[1],
379            ("database".to_string(), "production".to_string())
380        );
381    }
382
383    #[test]
384    fn test_parse_headers_file_mixed_format() {
385        let mut file = NamedTempFile::new().unwrap();
386        writeln!(file, "# Simple format").unwrap();
387        writeln!(file, "x-test: value1").unwrap();
388        writeln!(file, "").unwrap();
389        writeln!(file, "# Curl config format").unwrap();
390        writeln!(file, "header = x-api-key: secret123").unwrap();
391        writeln!(file, "-H \"database: production\"").unwrap();
392        file.flush().unwrap();
393
394        let headers = parse_headers_file(file.path()).unwrap();
395        assert_eq!(headers.len(), 3);
396        assert_eq!(headers[0], ("x-test".to_string(), "value1".to_string()));
397        assert_eq!(
398            headers[1],
399            ("x-api-key".to_string(), "secret123".to_string())
400        );
401        assert_eq!(
402            headers[2],
403            ("database".to_string(), "production".to_string())
404        );
405    }
406
407    #[test]
408    fn test_parse_headers_file_with_comments() {
409        let mut file = NamedTempFile::new().unwrap();
410        writeln!(file, "# This is a comment").unwrap();
411        writeln!(file, "x-api-key: secret123").unwrap();
412        writeln!(file, "# Another comment").unwrap();
413        writeln!(file, "database: production").unwrap();
414        file.flush().unwrap();
415
416        let headers = parse_headers_file(file.path()).unwrap();
417        assert_eq!(headers.len(), 2);
418        assert_eq!(
419            headers[0],
420            ("x-api-key".to_string(), "secret123".to_string())
421        );
422        assert_eq!(
423            headers[1],
424            ("database".to_string(), "production".to_string())
425        );
426    }
427
428    #[test]
429    fn test_parse_headers_file_blank_lines() {
430        let mut file = NamedTempFile::new().unwrap();
431        writeln!(file, "x-api-key: secret123").unwrap();
432        writeln!(file, "").unwrap();
433        writeln!(file, "   ").unwrap();
434        writeln!(file, "database: production").unwrap();
435        file.flush().unwrap();
436
437        let headers = parse_headers_file(file.path()).unwrap();
438        assert_eq!(headers.len(), 2);
439        assert_eq!(
440            headers[0],
441            ("x-api-key".to_string(), "secret123".to_string())
442        );
443        assert_eq!(
444            headers[1],
445            ("database".to_string(), "production".to_string())
446        );
447    }
448
449    #[test]
450    fn test_parse_headers_file_curl_with_quotes() {
451        let mut file = NamedTempFile::new().unwrap();
452        writeln!(file, "-H \"x-api-key: secret123\"").unwrap();
453        writeln!(file, "-H 'database: production'").unwrap();
454        file.flush().unwrap();
455
456        let headers = parse_headers_file(file.path()).unwrap();
457        assert_eq!(headers.len(), 2);
458        assert_eq!(
459            headers[0],
460            ("x-api-key".to_string(), "secret123".to_string())
461        );
462        assert_eq!(
463            headers[1],
464            ("database".to_string(), "production".to_string())
465        );
466    }
467
468    #[test]
469    fn test_parse_headers_file_invalid_format() {
470        let mut file = NamedTempFile::new().unwrap();
471        writeln!(file, "x-api-key: secret123").unwrap();
472        writeln!(file, "invalid-line-without-colon").unwrap();
473        file.flush().unwrap();
474
475        let result = parse_headers_file(file.path());
476        assert!(result.is_err());
477        assert!(result
478            .unwrap_err()
479            .contains("Invalid header format at line 2"));
480    }
481
482    #[test]
483    fn test_parse_headers_file_not_found() {
484        let result = parse_headers_file(Path::new("/nonexistent/file.txt"));
485        assert!(result.is_err());
486        assert!(result.unwrap_err().contains("Failed to read headers file"));
487    }
488}