use super::generate::Sink;
use super::output_plan::OutputPlanGenerator;
use super::parquet::IntoSize;
use super::plan::DEFAULT_PARQUET_ROW_GROUP_BYTES;
use super::progress::ProgressTracker;
use super::runner::PlanRunner;
use super::statistics::WriteStatistics;
pub use ::parquet::basic::Compression;
use log::info;
use std::fmt::Display;
use std::fs::File;
use std::io;
use std::io::{BufWriter, Stdout, Write};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Instant;
use tpchgen::distribution::Distributions;
use tpchgen::text::TextPool;
pub struct WriterSink<W: Write> {
statistics: WriteStatistics,
inner: W,
}
impl<W: Write> WriterSink<W> {
pub fn new(inner: W) -> Self {
Self {
inner,
statistics: WriteStatistics::new("buffers"),
}
}
}
impl<W: Write + Send> Sink for WriterSink<W> {
fn sink(&mut self, buffer: &[u8]) -> Result<(), io::Error> {
self.statistics.increment_chunks(1);
self.statistics.increment_bytes(buffer.len());
self.inner.write_all(buffer)
}
fn flush(mut self) -> Result<(), io::Error> {
self.inner.flush()
}
}
impl IntoSize for BufWriter<Stdout> {
fn into_size(self) -> Result<usize, io::Error> {
Ok(0)
}
}
impl IntoSize for BufWriter<File> {
fn into_size(self) -> Result<usize, io::Error> {
let file = self.into_inner()?;
let metadata = file.metadata()?;
Ok(metadata.len() as usize)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum Table {
Nation,
Region,
Part,
Supplier,
Partsupp,
Customer,
Orders,
Lineitem,
}
impl Display for Table {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
impl FromStr for Table {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"n" | "nation" => Ok(Table::Nation),
"r" | "region" => Ok(Table::Region),
"s" | "supplier" => Ok(Table::Supplier),
"P" | "part" => Ok(Table::Part),
"S" | "partsupp" => Ok(Table::Partsupp),
"c" | "customer" => Ok(Table::Customer),
"O" | "orders" => Ok(Table::Orders),
"L" | "lineitem" => Ok(Table::Lineitem),
_ => Err("Invalid table name {s}"),
}
}
}
impl Table {
fn name(&self) -> &'static str {
match self {
Table::Nation => "nation",
Table::Region => "region",
Table::Part => "part",
Table::Supplier => "supplier",
Table::Partsupp => "partsupp",
Table::Customer => "customer",
Table::Orders => "orders",
Table::Lineitem => "lineitem",
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum OutputFormat {
Tbl,
Csv,
Parquet,
}
impl FromStr for OutputFormat {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"tbl" => Ok(OutputFormat::Tbl),
"csv" => Ok(OutputFormat::Csv),
"parquet" => Ok(OutputFormat::Parquet),
_ => Err(format!(
"Invalid output format: {s}. Valid formats are: tbl, csv, parquet"
)),
}
}
}
impl Display for OutputFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OutputFormat::Tbl => write!(f, "tbl"),
OutputFormat::Csv => write!(f, "csv"),
OutputFormat::Parquet => write!(f, "parquet"),
}
}
}
#[derive(Debug, Clone)]
pub struct GeneratorConfig {
pub scale_factor: f64,
pub output_dir: std::path::PathBuf,
pub tables: Option<Vec<Table>>,
pub format: OutputFormat,
pub num_threads: usize,
pub parquet_compression: Compression,
pub parquet_row_group_bytes: i64,
pub parts: Option<i32>,
pub part: Option<i32>,
pub stdout: bool,
pub csv_delimiter: char,
}
impl Default for GeneratorConfig {
fn default() -> Self {
Self {
scale_factor: 1.0,
output_dir: std::path::PathBuf::from("."),
tables: None,
format: OutputFormat::Tbl,
num_threads: num_cpus::get(),
parquet_compression: Compression::SNAPPY,
parquet_row_group_bytes: DEFAULT_PARQUET_ROW_GROUP_BYTES,
parts: None,
part: None,
stdout: false,
csv_delimiter: ',',
}
}
}
pub struct TpchGenerator {
config: GeneratorConfig,
progress_tracker: Option<Arc<dyn ProgressTracker>>,
}
impl TpchGenerator {
pub fn builder() -> TpchGeneratorBuilder {
TpchGeneratorBuilder::new()
}
pub async fn generate(self) -> io::Result<()> {
let config = self.config;
let progress_tracker = self.progress_tracker;
if !config.stdout {
std::fs::create_dir_all(&config.output_dir)?;
}
let tables: Vec<Table> = if let Some(tables) = config.tables {
tables
} else {
vec![
Table::Nation,
Table::Region,
Table::Part,
Table::Supplier,
Table::Partsupp,
Table::Customer,
Table::Orders,
Table::Lineitem,
]
};
let mut output_plan_generator = OutputPlanGenerator::new(
config.format,
config.scale_factor,
config.parquet_compression,
config.parquet_row_group_bytes,
config.stdout,
config.output_dir,
config.csv_delimiter,
);
for table in tables {
output_plan_generator.generate_plans(table, config.part, config.parts)?;
}
let output_plans = output_plan_generator.build();
let start = Instant::now();
Distributions::static_default();
TextPool::get_or_init_default();
let elapsed = start.elapsed();
info!("Created static distributions and text pools in {elapsed:?}");
let runner = PlanRunner::new(output_plans, config.num_threads);
let runner = if let Some(tracker) = progress_tracker {
runner.with_progress_tracker(tracker)
} else {
runner
};
runner.run().await?;
info!("Generation complete!");
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct TpchGeneratorBuilder {
config: GeneratorConfig,
progress_tracker: Option<Arc<dyn ProgressTracker>>,
}
impl TpchGeneratorBuilder {
pub fn new() -> Self {
Self {
config: GeneratorConfig::default(),
progress_tracker: None,
}
}
pub fn scale_factor(&self) -> f64 {
self.config.scale_factor
}
pub fn with_scale_factor(mut self, scale_factor: f64) -> Self {
self.config.scale_factor = scale_factor;
self
}
pub fn with_output_dir(mut self, output_dir: impl Into<std::path::PathBuf>) -> Self {
self.config.output_dir = output_dir.into();
self
}
pub fn with_tables(mut self, tables: Vec<Table>) -> Self {
self.config.tables = Some(tables);
self
}
pub fn with_format(mut self, format: OutputFormat) -> Self {
self.config.format = format;
self
}
pub fn with_num_threads(mut self, num_threads: usize) -> Self {
self.config.num_threads = num_threads;
self
}
pub fn with_parquet_compression(mut self, compression: Compression) -> Self {
self.config.parquet_compression = compression;
self
}
pub fn with_parquet_row_group_bytes(mut self, bytes: i64) -> Self {
self.config.parquet_row_group_bytes = bytes;
self
}
pub fn with_parts(mut self, parts: i32) -> Self {
self.config.parts = Some(parts);
self
}
pub fn with_part(mut self, part: i32) -> Self {
self.config.part = Some(part);
self
}
pub fn with_stdout(mut self, stdout: bool) -> Self {
self.config.stdout = stdout;
self
}
pub fn with_csv_delimiter(mut self, delimiter: char) -> Self {
self.config.csv_delimiter = delimiter;
self
}
pub fn with_progress_tracker(mut self, tracker: Arc<dyn ProgressTracker>) -> Self {
self.progress_tracker = Some(tracker);
self
}
pub fn build(self) -> TpchGenerator {
TpchGenerator {
config: self.config,
progress_tracker: self.progress_tracker,
}
}
}
impl Default for TpchGeneratorBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tpch_cli::progress::ProgressTracker;
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
};
#[derive(Debug, Default)]
struct RecordingProgress {
registered: Mutex<Vec<(Table, u64)>>,
increments: Mutex<Vec<(Table, u64)>>,
finishes: AtomicU64,
}
impl ProgressTracker for RecordingProgress {
fn register(&self, table: Table, total_units: u64) {
self.registered.lock().unwrap().push((table, total_units));
}
fn increment(&self, table: Table, units: u64) {
self.increments.lock().unwrap().push((table, units));
}
fn finish(&self) {
self.finishes.fetch_add(1, Ordering::Relaxed);
}
}
#[tokio::test]
async fn builder_passes_custom_progress_tracker_to_runner() {
let output_dir = tempfile::tempdir().unwrap();
let tracker = Arc::new(RecordingProgress::default());
let progress: Arc<dyn ProgressTracker> = tracker.clone();
TpchGenerator::builder()
.with_output_dir(output_dir.path())
.with_tables(vec![Table::Region])
.with_num_threads(1)
.with_progress_tracker(progress)
.build()
.generate()
.await
.unwrap();
assert_eq!(
*tracker.registered.lock().unwrap(),
vec![(Table::Region, 1)]
);
assert_eq!(
*tracker.increments.lock().unwrap(),
vec![(Table::Region, 1)]
);
assert_eq!(tracker.finishes.load(Ordering::Relaxed), 1);
}
}