use std::{
fmt,
net::{IpAddr, Ipv4Addr, SocketAddr},
path::PathBuf,
str::FromStr,
};
use clap::{Args, Parser, Subcommand};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Parser)]
#[command(
author,
version,
about = "Single-binary Rust service for parquet-backed text2sql workloads"
)]
pub struct Cli {
#[arg(long, env = "RUST_LOG", default_value = "info,text2sql=debug")]
log_filter: String,
#[command(subcommand)]
command: Option<Command>,
}
#[derive(Debug, Clone, Subcommand)]
pub enum Command {
Serve(ServeArgs),
Convert(ConvertArgs),
Schema(SchemaArgs),
Query(QueryArgs),
Benchmark(BenchmarkArgs),
}
#[derive(Debug, Clone, Args)]
pub struct ServeArgs {
#[arg(long, env = "TEXT2SQL_HOST", default_value_t = IpAddr::V4(Ipv4Addr::LOCALHOST))]
pub host: IpAddr,
#[arg(long, env = "TEXT2SQL_PORT", default_value_t = 3000)]
pub port: u16,
#[arg(long, env = "TEXT2SQL_SERVICE_NAME", default_value = "text2sql")]
pub service_name: String,
}
#[derive(Debug, Clone)]
pub struct RuntimeConfig {
pub bind_addr: SocketAddr,
pub service_name: String,
}
#[derive(Debug, Clone, Args)]
pub struct ConvertArgs {
#[arg(long)]
pub input: PathBuf,
#[arg(long)]
pub output: String,
#[arg(long)]
pub sheet_name: Option<String>,
#[arg(long)]
pub csv_encoding: Option<String>,
#[arg(long, default_value_t = true)]
pub normalize_columns: bool,
#[arg(long, default_value_t = false)]
pub overwrite: bool,
#[arg(long, default_value_t = false)]
pub add_filename_column: bool,
#[command(flatten)]
pub s3: S3Options,
}
#[derive(Debug, Clone, Args)]
pub struct SchemaArgs {
#[arg(long)]
pub dataset: String,
#[arg(long, default_value = "dataset")]
pub table_name: String,
#[command(flatten)]
pub s3: S3Options,
}
#[derive(Debug, Clone, Args)]
pub struct QueryArgs {
#[arg(long)]
pub sql: String,
#[arg(long)]
pub dataset: String,
#[arg(long, default_value = "dataset")]
pub table_name: String,
#[arg(long, default_value_t = QueryMode::ParquetSelective)]
pub mode: QueryMode,
#[arg(long, default_value_t = default_small_file_threshold_bytes())]
pub small_file_threshold_bytes: u64,
#[arg(long)]
pub limit: Option<usize>,
#[command(flatten)]
pub s3: S3Options,
}
#[derive(Debug, Clone, Args)]
pub struct BenchmarkArgs {
#[arg(long, default_value = "./benchmark-data")]
pub output_dir: PathBuf,
#[arg(long, value_delimiter = ',', default_values_t = vec![10_000usize, 50_000, 100_000, 500_000])]
pub row_counts: Vec<usize>,
#[arg(
long,
default_value = "SELECT symbol, AVG(trade_value) AS avg_trade_value FROM dataset WHERE trade_month = 9 GROUP BY symbol ORDER BY avg_trade_value DESC"
)]
pub sql: String,
#[arg(long)]
pub limit: Option<usize>,
}
impl Cli {
pub fn log_filter(&self) -> &str {
&self.log_filter
}
pub fn command(&self) -> Command {
self.command.clone().unwrap_or_else(|| {
Command::Serve(ServeArgs {
host: IpAddr::V4(Ipv4Addr::LOCALHOST),
port: 3000,
service_name: "text2sql".to_string(),
})
})
}
}
impl From<ServeArgs> for RuntimeConfig {
fn from(value: ServeArgs) -> Self {
Self {
bind_addr: SocketAddr::from((value.host, value.port)),
service_name: value.service_name,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum QueryMode {
#[default]
ParquetSelective,
FullDownload,
AutoHybrid,
}
impl fmt::Display for QueryMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let value = match self {
Self::ParquetSelective => "parquet_selective",
Self::FullDownload => "full_download",
Self::AutoHybrid => "hybrid",
};
f.write_str(value)
}
}
impl FromStr for QueryMode {
type Err = String;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value.trim().to_ascii_lowercase().as_str() {
"parquet_selective" => Ok(Self::ParquetSelective),
"full_download" => Ok(Self::FullDownload),
"hybrid" | "auto_hybrid" | "auto" => Ok(Self::AutoHybrid),
other => Err(format!(
"unsupported query mode: {other} (expected parquet_selective, full_download, or hybrid)"
)),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryRequest {
pub sql: String,
pub dataset: DatasetSource,
#[serde(default = "default_table_name")]
pub table_name: String,
#[serde(default)]
pub mode: QueryMode,
#[serde(default = "default_small_file_threshold_bytes")]
pub small_file_threshold_bytes: u64,
pub limit: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvertRequest {
pub input_path: String,
pub output: DatasetSource,
#[serde(default)]
pub normalize_columns: bool,
pub sheet_name: Option<String>,
#[serde(default)]
pub csv_encoding: Option<String>,
#[serde(default)]
pub overwrite: bool,
#[serde(default)]
pub add_filename_column: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchemaRequest {
pub dataset: DatasetSource,
#[serde(default = "default_table_name")]
pub table_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkRequest {
pub output_dir: Option<String>,
#[serde(default = "default_row_counts")]
pub row_counts: Vec<usize>,
#[serde(default = "default_benchmark_sql")]
pub sql: String,
pub limit: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetSource {
pub uri: String,
#[serde(default)]
pub storage: StorageConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StorageConfig {
#[default]
Local,
S3(S3Options),
}
#[derive(Debug, Clone, Args, Serialize, Deserialize, Default, PartialEq, Eq)]
pub struct S3Options {
#[arg(long, env = "TEXT2SQL_S3_REGION")]
pub region: Option<String>,
#[arg(long, env = "TEXT2SQL_S3_ENDPOINT")]
pub endpoint: Option<String>,
#[arg(long, env = "TEXT2SQL_S3_ACCESS_KEY_ID")]
pub access_key_id: Option<String>,
#[arg(long, env = "TEXT2SQL_S3_SECRET_ACCESS_KEY")]
pub secret_access_key: Option<String>,
#[arg(long, env = "TEXT2SQL_S3_SESSION_TOKEN")]
pub session_token: Option<String>,
#[arg(long, env = "TEXT2SQL_S3_ALLOW_HTTP", default_value_t = false)]
pub allow_http: bool,
#[arg(long, env = "TEXT2SQL_S3_FORCE_PATH_STYLE", default_value_t = true)]
pub force_path_style: bool,
}
#[derive(Debug, Clone, Serialize)]
pub struct QueryResponse {
pub table_name: String,
pub mode: QueryMode,
pub row_count: usize,
pub columns: Vec<ColumnSummary>,
pub rows: Vec<serde_json::Value>,
pub elapsed_ms: u128,
pub memory_delta_bytes: i64,
pub notes: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct SchemaResponse {
pub table_name: String,
pub columns: Vec<ColumnSummary>,
pub preview_rows: Vec<serde_json::Value>,
pub elapsed_ms: u128,
pub notes: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ColumnSummary {
pub name: String,
pub duckdb_type: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct ConvertResponse {
pub input_path: String,
pub output_uri: String,
pub row_count: usize,
pub columns: Vec<String>,
pub elapsed_ms: u128,
pub notes: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct BenchmarkResponse {
pub sql: String,
pub output_dir: String,
pub results: Vec<BenchmarkResult>,
pub recommendation: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct BenchmarkResult {
pub row_count: usize,
pub parquet_path: String,
pub parquet_selective: BenchmarkMeasurement,
pub full_download: BenchmarkMeasurement,
}
#[derive(Debug, Clone, Serialize)]
pub struct BenchmarkMeasurement {
pub elapsed_ms: u128,
pub memory_delta_bytes: i64,
pub result_rows: usize,
}
fn default_table_name() -> String {
"dataset".to_string()
}
pub const DEFAULT_SMALL_FILE_THRESHOLD_BYTES: u64 = 10 * 1024 * 1024;
pub const DEFAULT_SCHEMA_PREVIEW_ROWS: usize = 3;
const fn default_small_file_threshold_bytes() -> u64 {
DEFAULT_SMALL_FILE_THRESHOLD_BYTES
}
fn default_row_counts() -> Vec<usize> {
vec![10_000, 50_000, 100_000, 500_000]
}
fn default_benchmark_sql() -> String {
"SELECT symbol, AVG(trade_value) AS avg_trade_value FROM dataset WHERE trade_month = 9 GROUP BY symbol ORDER BY avg_trade_value DESC".to_string()
}
pub fn dataset_source_from_uri(uri: String, s3: &S3Options) -> DatasetSource {
let storage = if uri.starts_with("s3://") {
StorageConfig::S3(s3.clone())
} else {
StorageConfig::Local
};
DatasetSource { uri, storage }
}
#[cfg(test)]
mod tests {
use clap::Parser;
use super::{Cli, Command, QueryMode};
#[test]
fn cli_defaults_to_serving() {
let cli = Cli::parse_from(["text2sql"]);
match cli.command() {
Command::Serve(args) => {
assert_eq!(args.port, 3000);
assert_eq!(args.service_name, "text2sql");
}
other => panic!("expected serve command, got {other:?}"),
}
}
#[test]
fn query_mode_default_is_selective() {
assert!(matches!(QueryMode::default(), QueryMode::ParquetSelective));
}
#[test]
fn query_mode_parses_cli_strings() {
assert!(matches!(
"parquet_selective".parse::<QueryMode>().unwrap(),
QueryMode::ParquetSelective
));
assert!(matches!(
"full_download".parse::<QueryMode>().unwrap(),
QueryMode::FullDownload
));
assert!(matches!(
"hybrid".parse::<QueryMode>().unwrap(),
QueryMode::AutoHybrid
));
}
#[test]
fn cli_parses_schema_command() {
let cli = Cli::parse_from(["text2sql", "schema", "--dataset", "./tmp/data.parquet"]);
match cli.command() {
Command::Schema(args) => {
assert_eq!(args.dataset, "./tmp/data.parquet");
assert_eq!(args.table_name, "dataset");
}
other => panic!("expected schema command, got {other:?}"),
}
}
#[test]
fn cli_parses_query_command() {
let cli = Cli::parse_from([
"text2sql",
"query",
"--dataset",
"./tmp/data.parquet",
"--sql",
"SELECT * FROM dataset",
"--mode",
"full_download",
"--small-file-threshold-bytes",
"4096",
"--limit",
"5",
]);
match cli.command() {
Command::Query(args) => {
assert_eq!(args.dataset, "./tmp/data.parquet");
assert_eq!(args.sql, "SELECT * FROM dataset");
assert!(matches!(args.mode, QueryMode::FullDownload));
assert_eq!(args.small_file_threshold_bytes, 4096);
assert_eq!(args.limit, Some(5));
}
other => panic!("expected query command, got {other:?}"),
}
}
#[test]
fn cli_parses_benchmark_command() {
let cli = Cli::parse_from([
"text2sql",
"benchmark",
"--output-dir",
"./tmp/bench",
"--row-counts",
"10000,75000",
"--sql",
"SELECT COUNT(*) FROM dataset",
"--limit",
"3",
]);
match cli.command() {
Command::Benchmark(args) => {
assert_eq!(args.output_dir, std::path::PathBuf::from("./tmp/bench"));
assert_eq!(args.row_counts, vec![10_000, 75_000]);
assert_eq!(args.sql, "SELECT COUNT(*) FROM dataset");
assert_eq!(args.limit, Some(3));
}
other => panic!("expected benchmark command, got {other:?}"),
}
}
}