use std::sync::Arc;
use crate::args::TpchFormat;
use crate::config::AppConfig;
use color_eyre::{eyre, Result};
use datafusion::{arrow::record_batch::RecordBatch, datasource::listing::ListingTableUrl};
use datafusion_app::{
config::merge_configs, extensions::DftSessionStateBuilder, local::ExecutionContext,
};
use log::info;
use object_store::ObjectStore;
use parquet::arrow::ArrowWriter;
use tpchgen::generators::{
CustomerGenerator, LineItemGenerator, NationGenerator, OrderGenerator, PartGenerator,
PartSuppGenerator, RegionGenerator, SupplierGenerator,
};
use tpchgen_arrow::{
CustomerArrow, LineItemArrow, NationArrow, OrderArrow, PartArrow, PartSuppArrow, RegionArrow,
SupplierArrow,
};
use url::Url;
#[cfg(feature = "vortex")]
use {
datafusion::arrow::compute::concat_batches,
vortex::array::{arrow::FromArrowArray, ArrayRef},
vortex_file::VortexWriteOptions,
vortex_session::VortexSession,
};
enum GeneratorType {
Customer,
Order,
LineItem,
Nation,
Part,
PartSupp,
Region,
Supplier,
}
impl TryFrom<&str> for GeneratorType {
type Error = color_eyre::Report;
fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
match value {
"customer/" => Ok(Self::Customer),
"orders/" => Ok(Self::Order),
"lineitem/" => Ok(Self::LineItem),
"nation/" => Ok(Self::Nation),
"part/" => Ok(Self::Part),
"partsupp/" => Ok(Self::PartSupp),
"region/" => Ok(Self::Region),
"supplier/" => Ok(Self::Supplier),
_ => Err(eyre::Report::msg(format!("unknown generator type {value}"))),
}
}
}
fn create_tpch_dirs(config: &AppConfig) -> Result<Vec<(GeneratorType, Url)>> {
info!("...configured DB directory is {:?}", config.db.path);
let tpch_dir = config
.db
.path
.join("tables/")?
.join("dft/")?
.join("tpch/")?;
let needed_dirs = [
"customer/",
"orders/",
"lineitem/",
"nation/",
"part/",
"partsupp/",
"region/",
"supplier/",
];
let mut table_paths = Vec::new();
for dir in needed_dirs {
let table_path = tpch_dir.join(dir)?;
info!("table path {:?} for table {dir}", table_path.path());
table_paths.push((GeneratorType::try_from(dir)?, table_path))
}
Ok(table_paths)
}
async fn write_batches_to_parquet<I>(
mut batches: std::iter::Peekable<I>,
table_path: &Url,
table_type: &str,
store: Arc<dyn ObjectStore>,
) -> Result<()>
where
I: Iterator<Item = RecordBatch>,
{
let first = batches.peek().ok_or(eyre::Error::msg(format!(
"unable to generate {table_type} TPC-H data"
)))?;
let file_url = table_path.join("data.parquet")?;
info!("...file URL '{file_url}'");
let mut buf: Vec<u8> = Vec::new();
{
let mut writer = ArrowWriter::try_new(&mut buf, Arc::clone(first.schema_ref()), None)?;
info!("...writing {table_type} batches");
for batch in batches {
writer.write(&batch)?;
}
writer.finish()?;
}
let file_path = object_store::path::Path::from_url_path(file_url.path())?;
info!("...putting to file path {}", file_path);
store.put(&file_path, buf.into()).await?;
Ok(())
}
#[cfg(feature = "vortex")]
async fn write_batches_to_vortex<I>(
batches: std::iter::Peekable<I>,
table_path: &Url,
table_type: &str,
store: Arc<dyn ObjectStore>,
) -> Result<()>
where
I: Iterator<Item = RecordBatch>,
{
let batches_vec: Vec<RecordBatch> = batches.collect();
if batches_vec.is_empty() {
return Err(eyre::Error::msg(format!(
"unable to generate {table_type} TPC-H data"
)));
}
let file_url = table_path.join("data.vortex")?;
info!("...file URL '{file_url}'");
let schema = batches_vec[0].schema();
let concatenated = concat_batches(&schema, &batches_vec)?;
let vortex_array = ArrayRef::from_arrow(concatenated, false);
let stream = vortex_array.to_array_stream();
let mut buf: Vec<u8> = Vec::new();
info!("...writing {table_type} batches to vortex format");
let session = VortexSession::empty();
VortexWriteOptions::new(session)
.write(&mut buf, stream)
.await
.map_err(|e| eyre::Error::msg(format!("Failed to write Vortex file: {}", e)))?;
let file_path = object_store::path::Path::from_url_path(file_url.path())?;
info!("...putting to file path {}", file_path);
store.put(&file_path, buf.into()).await?;
Ok(())
}
async fn write_batches<I>(
batches: std::iter::Peekable<I>,
table_path: &Url,
table_type: &str,
store: Arc<dyn ObjectStore>,
format: &TpchFormat,
) -> Result<()>
where
I: Iterator<Item = RecordBatch>,
{
match format {
TpchFormat::Parquet => {
write_batches_to_parquet(batches, table_path, table_type, store).await
}
#[cfg(feature = "vortex")]
TpchFormat::Vortex => write_batches_to_vortex(batches, table_path, table_type, store).await,
}
}
pub async fn generate(config: AppConfig, scale_factor: f64, format: TpchFormat) -> Result<()> {
let merged_exec_config = merge_configs(config.shared.clone(), config.cli.execution.clone());
let session_state_builder = DftSessionStateBuilder::try_new(Some(merged_exec_config.clone()))?
.with_extensions()
.await?;
let session_state = session_state_builder.build()?;
let execution_ctx = ExecutionContext::try_new(
&merged_exec_config,
session_state,
crate::APP_NAME,
env!("CARGO_PKG_VERSION"),
)?;
let tables_path = config.db.path.join("tables")?;
let tables_url = ListingTableUrl::parse(tables_path)?;
let store_url = tables_url.object_store();
let store = execution_ctx
.session_ctx()
.runtime_env()
.object_store(store_url)?;
info!("configured db store: {store:?}");
info!("generating TPC-H data");
let table_paths = create_tpch_dirs(&config)?;
for (table, table_path) in table_paths {
match table {
GeneratorType::Customer => {
info!("...generating customers");
let arrow_generator =
CustomerArrow::new(CustomerGenerator::new(scale_factor, 1, 1));
write_batches(
arrow_generator.peekable(),
&table_path,
"Customer",
Arc::clone(&store),
&format,
)
.await?;
}
GeneratorType::Order => {
info!("...generating orders");
let arrow_generator = OrderArrow::new(OrderGenerator::new(scale_factor, 1, 1));
write_batches(
arrow_generator.peekable(),
&table_path,
"Order",
Arc::clone(&store),
&format,
)
.await?;
}
GeneratorType::LineItem => {
info!("...generating LineItems");
let arrow_generator =
LineItemArrow::new(LineItemGenerator::new(scale_factor, 1, 1));
write_batches(
arrow_generator.peekable(),
&table_path,
"LineItem",
Arc::clone(&store),
&format,
)
.await?;
}
GeneratorType::Nation => {
info!("...generating Nations");
let arrow_generator = NationArrow::new(NationGenerator::new(scale_factor, 1, 1));
write_batches(
arrow_generator.peekable(),
&table_path,
"Nation",
Arc::clone(&store),
&format,
)
.await?;
}
GeneratorType::Part => {
info!("...generating Parts");
let arrow_generator = PartArrow::new(PartGenerator::new(scale_factor, 1, 1));
write_batches(
arrow_generator.peekable(),
&table_path,
"Part",
Arc::clone(&store),
&format,
)
.await?;
}
GeneratorType::PartSupp => {
info!("...generating PartSupps");
let arrow_generator =
PartSuppArrow::new(PartSuppGenerator::new(scale_factor, 1, 1));
write_batches(
arrow_generator.peekable(),
&table_path,
"PartSupp",
Arc::clone(&store),
&format,
)
.await?;
}
GeneratorType::Region => {
info!("...generating Regions");
let arrow_generator = RegionArrow::new(RegionGenerator::new(scale_factor, 1, 1));
write_batches(
arrow_generator.peekable(),
&table_path,
"Region",
Arc::clone(&store),
&format,
)
.await?;
}
GeneratorType::Supplier => {
info!("...generating Suppliers");
let arrow_generator =
SupplierArrow::new(SupplierGenerator::new(scale_factor, 1, 1));
write_batches(
arrow_generator.peekable(),
&table_path,
"Supplier",
Arc::clone(&store),
&format,
)
.await?;
}
}
}
let tpch_dir = config
.db
.path
.join("tables/")?
.join("dft/")?
.join("tpch/")?;
println!("TPC-H dataset saved to: {}", tpch_dir);
Ok(())
}