use crate::tpch_cli::{Compression, DEFAULT_PARQUET_ROW_GROUP_BYTES};
use clap::{ArgAction, Args, Subcommand};
use std::fmt;
use std::path::PathBuf;
use tpcdsgen::config::{CompatMode, Session, SessionBuilder, Table};
use tpcdsgen::error::InvalidOptionError;
pub mod dat;
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
const NOT_IMPLEMENTED: &str = "TPC-DS data generation is not yet implemented";
#[derive(Args)]
#[command(version)]
#[command(args_conflicts_with_subcommands = true)]
pub struct Cli {
#[command(subcommand)]
command: Option<Commands>,
#[command(flatten)]
args: DatArgs,
}
#[derive(Subcommand)]
enum Commands {
Dat(DatArgs),
Csv(CsvArgs),
Parquet(ParquetArgs),
}
#[derive(Args)]
struct DatArgs {
#[command(flatten)]
common: CommonArgs,
}
#[derive(Args)]
struct CsvArgs {
#[command(flatten)]
common: CommonArgs,
#[arg(long, default_value = ",", value_parser = parse_delimiter)]
delimiter: char,
}
#[derive(Args)]
struct ParquetArgs {
#[command(flatten)]
common: CommonArgs,
#[arg(short = 'c', long, default_value = "SNAPPY")]
compression: Compression,
#[arg(long, default_value_t = DEFAULT_PARQUET_ROW_GROUP_BYTES)]
row_group_bytes: i64,
}
#[derive(Args)]
pub struct CommonArgs {
#[arg(short, long, default_value_t = 1.)]
scale_factor: f64,
#[arg(short, long, default_value = ".")]
output_dir: PathBuf,
#[arg(short = 'T', long = "tables", value_delimiter = ',')]
tables: Option<Vec<String>>,
#[arg(long, default_value_t = CompatMode::Trino)]
compat: CompatMode,
#[arg(short, long, default_value_t = false, conflicts_with = "quiet")]
verbose: bool,
#[arg(short, long, default_value_t = false, conflicts_with = "verbose")]
quiet: bool,
#[arg(long = "no-progress", action = ArgAction::SetFalse, default_value_t = true)]
progress_bars_enabled: bool,
}
impl Cli {
pub fn run(self) -> Result<()> {
match self.command {
Some(Commands::Dat(args)) => args.run(),
Some(Commands::Csv(args)) => args.run(),
Some(Commands::Parquet(args)) => args.run(),
None => self.args.run(),
}
}
}
impl DatArgs {
fn run(self) -> Result<()> {
self.common.run_dat()
}
}
impl CsvArgs {
fn run(self) -> Result<()> {
let _ = self.delimiter;
self.common.run_not_implemented()
}
}
impl ParquetArgs {
fn run(self) -> Result<()> {
let _ = (self.compression, self.row_group_bytes);
self.common.run_not_implemented()
}
}
impl CommonArgs {
fn run_dat(self) -> Result<()> {
let _ = self.progress_bars_enabled;
std::fs::create_dir_all(&self.output_dir)?;
if let Some(tables) = &self.tables {
for table in tables {
self.run_dat_for_table(Some(table.clone()))?;
}
} else {
self.run_dat_for_table(None)?;
}
Ok(())
}
fn run_dat_for_table(&self, table: Option<String>) -> Result<()> {
let session = self.to_session(table)?;
dat::generate(&session)
}
fn to_session(&self, table: Option<String>) -> Result<Session> {
let table = table
.as_deref()
.map(|table| {
table
.parse::<Table>()
.map_err(|_| InvalidOptionError::new("table", table))
})
.transpose()?;
let command_line_arguments = std::env::args().collect::<Vec<_>>().join(" ");
let mut builder = SessionBuilder::new()
.with_scale_factor(self.scale_factor)
.with_target_directory(self.output_dir.to_string_lossy())
.with_compat_mode(self.compat)
.with_command_line_arguments(command_line_arguments);
if let Some(table) = table {
builder = builder.with_table(table);
}
Ok(builder.build()?)
}
fn run_not_implemented(self) -> Result<()> {
let _ = self;
Err(Box::new(NotImplemented))
}
}
struct NotImplemented;
impl fmt::Display for NotImplemented {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(NOT_IMPLEMENTED)
}
}
impl fmt::Debug for NotImplemented {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl std::error::Error for NotImplemented {}
fn parse_delimiter(s: &str) -> std::result::Result<char, String> {
let parsed = match s {
"\\t" => '\t',
"\\n" => '\n',
"\\r" => '\r',
"\\\\" => '\\',
_ => {
let chars: Vec<char> = s.chars().collect();
if chars.len() != 1 {
return Err(format!(
"Delimiter must be a single character or escape sequence (\\t, \\n, \\r, \\\\), got: '{}'",
s
));
}
chars[0]
}
};
if !parsed.is_ascii() {
return Err(format!(
"Delimiter must be an ASCII character, got: '{}'",
parsed
));
}
Ok(parsed)
}