use crate::config::get_data_dir;
use clap::{Parser, Subcommand};
use http::{HeaderName, HeaderValue};
#[cfg(any(feature = "http", feature = "flightsql"))]
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
const LONG_ABOUT: &str = "
dft - DataFusion TUI
CLI and terminal UI data analysis tool using Apache DataFusion as query
execution engine.
dft provides a rich terminal UI as well as a broad array of pre-integrated
data sources and formats for querying and analyzing data.
Environment Variables
RUST_LOG { trace | debug | info | error }: Standard rust logging level. Default is info.
";
#[derive(Clone, Debug, Parser, Default)]
#[command(author, version, about, long_about = LONG_ABOUT)]
pub struct DftArgs {
#[clap(
short,
long,
num_args = 0..,
help = "Execute commands from file(s), then exit",
value_parser(parse_valid_file)
)]
pub files: Vec<PathBuf>,
#[clap(
short = 'c',
long,
num_args = 0..,
help = "Execute the given SQL string(s), then exit.",
value_parser(parse_command)
)]
pub commands: Vec<String>,
#[clap(long, global = true, help = "Path to the configuration file")]
pub config: Option<String>,
#[clap(
long,
short = 'q',
help = "Use the FlightSQL client defined in your config"
)]
pub flightsql: bool,
#[clap(long, help = "Run DDL prior to executing")]
pub run_ddl: bool,
#[clap(long, short, help = "Only show how long the query took to run")]
pub time: bool,
#[clap(long, short, help = "Benchmark the provided query")]
pub bench: bool,
#[clap(
long,
help = "Print a summary of the query's execution plan and statistics"
)]
pub analyze: bool,
#[clap(long, help = "Run the provided query before running the benchmark")]
pub run_before: Option<String>,
#[clap(long, help = "Save the benchmark results to a file")]
pub save: Option<PathBuf>,
#[clap(long, help = "Append the benchmark results to an existing file")]
pub append: bool,
#[clap(short = 'n', help = "Set the number of benchmark iterations to run")]
pub benchmark_iterations: Option<usize>,
#[clap(long, help = "Run benchmark iterations concurrently/in parallel")]
pub concurrent: bool,
#[clap(long, help = "Host address to query. Only used for FlightSQL")]
pub host: Option<String>,
#[clap(
long,
help = "Header to add to Flight SQL connection. Only used for FlightSQL",
value_parser(parse_header_line),
action = clap::ArgAction::Append
)]
pub header: Option<Vec<(String, String)>>,
#[clap(
long,
help = "Path to file containing Flight SQL headers. Supports simple format ('Name: Value') and curl config format ('header = Name: Value' or '-H \"Name: Value\"'). Only used for FlightSQL"
)]
pub headers_file: Option<PathBuf>,
#[clap(
long,
short,
help = "Path to save output to. Type is inferred from file suffix"
)]
pub output: Option<PathBuf>,
#[command(subcommand)]
pub command: Option<Command>,
}
impl DftArgs {
pub fn config_path(&self) -> PathBuf {
#[cfg(feature = "flightsql")]
if let Some(Command::ServeFlightSql {
config: Some(cfg), ..
}) = &self.command
{
return Path::new(cfg).to_path_buf();
}
if let Some(config) = self.config.as_ref() {
Path::new(config).to_path_buf()
} else {
let mut config = get_data_dir();
config.push("config.toml");
config
}
}
}
#[derive(Clone, Debug, Subcommand)]
pub enum FlightSqlCommand {
StatementQuery {
#[clap(long)]
sql: String,
},
GetCatalogs,
GetDbSchemas {
#[clap(long)]
catalog: Option<String>,
#[clap(long)]
db_schema_filter_pattern: Option<String>,
},
GetTables {
#[clap(long)]
catalog: Option<String>,
#[clap(long)]
db_schema_filter_pattern: Option<String>,
#[clap(long)]
table_name_filter_pattern: Option<String>,
#[clap(long)]
table_types: Option<Vec<String>>,
},
GetTableTypes,
GetSqlInfo {
#[clap(long)]
info: Option<Vec<u32>>,
},
GetXdbcTypeInfo {
#[clap(long)]
data_type: Option<i32>,
},
}
#[derive(Clone, Debug, Subcommand)]
pub enum Command {
#[cfg(feature = "http")]
ServeHttp {
#[clap(short, long)]
config: Option<String>,
#[clap(long, help = "Set the port to be used for server")]
addr: Option<SocketAddr>,
#[clap(long, help = "Set the port to be used for serving metrics")]
metrics_addr: Option<SocketAddr>,
},
#[cfg(feature = "flightsql")]
#[command(name = "flightsql")]
FlightSql {
#[clap(subcommand)]
command: FlightSqlCommand,
},
#[cfg(feature = "flightsql")]
#[command(name = "serve-flightsql")]
ServeFlightSql {
#[clap(short, long)]
config: Option<String>,
#[clap(long, help = "Set the port to be used for server")]
addr: Option<SocketAddr>,
#[clap(long, help = "Set the port to be used for serving metrics")]
metrics_addr: Option<SocketAddr>,
},
GenerateTpch {
#[clap(long, default_value = "1.0")]
scale_factor: f64,
#[clap(long, default_value = "parquet")]
format: TpchFormat,
},
}
#[derive(Clone, Debug, clap::ValueEnum)]
pub enum TpchFormat {
Parquet,
#[cfg(feature = "vortex")]
Vortex,
}
fn parse_valid_file(file: &str) -> std::result::Result<PathBuf, String> {
let path = PathBuf::from(file);
if !path.exists() {
Err(format!("File does not exist: '{file}'"))
} else if !path.is_file() {
Err(format!("Exists but is not a file: '{file}'"))
} else {
Ok(path)
}
}
fn parse_command(command: &str) -> std::result::Result<String, String> {
if !command.is_empty() {
Ok(command.to_string())
} else {
Err("-c flag expects only non empty commands".to_string())
}
}
fn parse_header_line(line: &str) -> Result<(String, String), String> {
let (name, value) = line
.split_once(':')
.ok_or_else(|| format!("Invalid header format: '{}'\n Expected format: 'Header-Name: Header-Value', 'header = Name: Value', or '-H \"Name: Value\"'", line))?;
let name =
HeaderName::try_from(name.trim()).map_err(|e| format!("Invalid header name: {}", e))?;
let value =
HeaderValue::try_from(value.trim()).map_err(|e| format!("Invalid header value: {}", e))?;
let value_str = value
.to_str()
.map_err(|e| format!("Header value contains invalid characters: {}", e))?;
Ok((name.to_string(), value_str.to_string()))
}
pub fn parse_headers_file(path: &Path) -> Result<Vec<(String, String)>, String> {
let content = std::fs::read_to_string(path)
.map_err(|e| format!("Failed to read headers file '{}': {}", path.display(), e))?;
let mut headers = Vec::new();
for (line_num, line) in content.lines().enumerate() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let header_value = if let Some(stripped) = line.strip_prefix("header") {
let stripped = stripped.trim_start();
if let Some(value) = stripped.strip_prefix('=') {
value.trim()
} else {
line }
} else if let Some(stripped) = line.strip_prefix("-H") {
let stripped = stripped.trim();
stripped.trim_matches(|c| c == '"' || c == '\'')
} else {
line
};
match parse_header_line(header_value) {
Ok(header) => headers.push(header),
Err(e) => {
return Err(format!(
"Invalid header format at line {} in '{}': '{}'\n{}",
line_num + 1,
path.display(),
line,
e
));
}
}
}
Ok(headers)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_parse_headers_file_simple_format() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "x-api-key: secret123").unwrap();
writeln!(file, "database: production").unwrap();
file.flush().unwrap();
let headers = parse_headers_file(file.path()).unwrap();
assert_eq!(headers.len(), 2);
assert_eq!(
headers[0],
("x-api-key".to_string(), "secret123".to_string())
);
assert_eq!(
headers[1],
("database".to_string(), "production".to_string())
);
}
#[test]
fn test_parse_headers_file_curl_format() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "header = x-api-key: secret123").unwrap();
writeln!(file, "-H \"database: production\"").unwrap();
file.flush().unwrap();
let headers = parse_headers_file(file.path()).unwrap();
assert_eq!(headers.len(), 2);
assert_eq!(
headers[0],
("x-api-key".to_string(), "secret123".to_string())
);
assert_eq!(
headers[1],
("database".to_string(), "production".to_string())
);
}
#[test]
fn test_parse_headers_file_mixed_format() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "# Simple format").unwrap();
writeln!(file, "x-test: value1").unwrap();
writeln!(file, "").unwrap();
writeln!(file, "# Curl config format").unwrap();
writeln!(file, "header = x-api-key: secret123").unwrap();
writeln!(file, "-H \"database: production\"").unwrap();
file.flush().unwrap();
let headers = parse_headers_file(file.path()).unwrap();
assert_eq!(headers.len(), 3);
assert_eq!(headers[0], ("x-test".to_string(), "value1".to_string()));
assert_eq!(
headers[1],
("x-api-key".to_string(), "secret123".to_string())
);
assert_eq!(
headers[2],
("database".to_string(), "production".to_string())
);
}
#[test]
fn test_parse_headers_file_with_comments() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "# This is a comment").unwrap();
writeln!(file, "x-api-key: secret123").unwrap();
writeln!(file, "# Another comment").unwrap();
writeln!(file, "database: production").unwrap();
file.flush().unwrap();
let headers = parse_headers_file(file.path()).unwrap();
assert_eq!(headers.len(), 2);
assert_eq!(
headers[0],
("x-api-key".to_string(), "secret123".to_string())
);
assert_eq!(
headers[1],
("database".to_string(), "production".to_string())
);
}
#[test]
fn test_parse_headers_file_blank_lines() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "x-api-key: secret123").unwrap();
writeln!(file, "").unwrap();
writeln!(file, " ").unwrap();
writeln!(file, "database: production").unwrap();
file.flush().unwrap();
let headers = parse_headers_file(file.path()).unwrap();
assert_eq!(headers.len(), 2);
assert_eq!(
headers[0],
("x-api-key".to_string(), "secret123".to_string())
);
assert_eq!(
headers[1],
("database".to_string(), "production".to_string())
);
}
#[test]
fn test_parse_headers_file_curl_with_quotes() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "-H \"x-api-key: secret123\"").unwrap();
writeln!(file, "-H 'database: production'").unwrap();
file.flush().unwrap();
let headers = parse_headers_file(file.path()).unwrap();
assert_eq!(headers.len(), 2);
assert_eq!(
headers[0],
("x-api-key".to_string(), "secret123".to_string())
);
assert_eq!(
headers[1],
("database".to_string(), "production".to_string())
);
}
#[test]
fn test_parse_headers_file_invalid_format() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "x-api-key: secret123").unwrap();
writeln!(file, "invalid-line-without-colon").unwrap();
file.flush().unwrap();
let result = parse_headers_file(file.path());
assert!(result.is_err());
assert!(result
.unwrap_err()
.contains("Invalid header format at line 2"));
}
#[test]
fn test_parse_headers_file_not_found() {
let result = parse_headers_file(Path::new("/nonexistent/file.txt"));
assert!(result.is_err());
assert!(result.unwrap_err().contains("Failed to read headers file"));
}
}