Skip to main content

datafusion_dft/
tpch.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::args::TpchFormat;
21use crate::config::AppConfig;
22use color_eyre::{eyre, Result};
23use datafusion::{arrow::record_batch::RecordBatch, datasource::listing::ListingTableUrl};
24use datafusion_app::{
25    config::merge_configs, extensions::DftSessionStateBuilder, local::ExecutionContext,
26};
27use log::info;
28use object_store::ObjectStore;
29use parquet::arrow::ArrowWriter;
30use tpchgen::generators::{
31    CustomerGenerator, LineItemGenerator, NationGenerator, OrderGenerator, PartGenerator,
32    PartSuppGenerator, RegionGenerator, SupplierGenerator,
33};
34use tpchgen_arrow::{
35    CustomerArrow, LineItemArrow, NationArrow, OrderArrow, PartArrow, PartSuppArrow, RegionArrow,
36    SupplierArrow,
37};
38use url::Url;
39
40#[cfg(feature = "vortex")]
41use {
42    datafusion::arrow::compute::concat_batches,
43    vortex::array::{arrow::FromArrowArray, ArrayRef},
44    vortex_file::VortexWriteOptions,
45    vortex_session::VortexSession,
46};
47
48enum GeneratorType {
49    Customer,
50    Order,
51    LineItem,
52    Nation,
53    Part,
54    PartSupp,
55    Region,
56    Supplier,
57}
58
59impl TryFrom<&str> for GeneratorType {
60    type Error = color_eyre::Report;
61
62    fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
63        // `/` suffix is used so that the final path part is interpretted as a directory
64        match value {
65            "customer/" => Ok(Self::Customer),
66            "orders/" => Ok(Self::Order),
67            "lineitem/" => Ok(Self::LineItem),
68            "nation/" => Ok(Self::Nation),
69            "part/" => Ok(Self::Part),
70            "partsupp/" => Ok(Self::PartSupp),
71            "region/" => Ok(Self::Region),
72            "supplier/" => Ok(Self::Supplier),
73            _ => Err(eyre::Report::msg(format!("unknown generator type {value}"))),
74        }
75    }
76}
77
78fn create_tpch_dirs(config: &AppConfig) -> Result<Vec<(GeneratorType, Url)>> {
79    info!("...configured DB directory is {:?}", config.db.path);
80    // `/` suffix is used so that the final path part is interpretted as a directory
81    let tpch_dir = config
82        .db
83        .path
84        .join("tables/")?
85        .join("dft/")?
86        .join("tpch/")?;
87    let needed_dirs = [
88        "customer/",
89        "orders/",
90        "lineitem/",
91        "nation/",
92        "part/",
93        "partsupp/",
94        "region/",
95        "supplier/",
96    ];
97    let mut table_paths = Vec::new();
98    for dir in needed_dirs {
99        let table_path = tpch_dir.join(dir)?;
100        info!("table path {:?} for table {dir}", table_path.path());
101        table_paths.push((GeneratorType::try_from(dir)?, table_path))
102    }
103    Ok(table_paths)
104}
105
106async fn write_batches_to_parquet<I>(
107    mut batches: std::iter::Peekable<I>,
108    table_path: &Url,
109    table_type: &str,
110    store: Arc<dyn ObjectStore>,
111) -> Result<()>
112where
113    I: Iterator<Item = RecordBatch>,
114{
115    let first = batches.peek().ok_or(eyre::Error::msg(format!(
116        "unable to generate {table_type} TPC-H data"
117    )))?;
118
119    let file_url = table_path.join("data.parquet")?;
120    info!("...file URL '{file_url}'");
121    let mut buf: Vec<u8> = Vec::new();
122    {
123        let mut writer = ArrowWriter::try_new(&mut buf, Arc::clone(first.schema_ref()), None)?;
124        info!("...writing {table_type} batches");
125        for batch in batches {
126            writer.write(&batch)?;
127        }
128        writer.finish()?;
129    }
130    let file_path = object_store::path::Path::from_url_path(file_url.path())?;
131    info!("...putting to file path {}", file_path);
132    store.put(&file_path, buf.into()).await?;
133    Ok(())
134}
135
136#[cfg(feature = "vortex")]
137async fn write_batches_to_vortex<I>(
138    batches: std::iter::Peekable<I>,
139    table_path: &Url,
140    table_type: &str,
141    store: Arc<dyn ObjectStore>,
142) -> Result<()>
143where
144    I: Iterator<Item = RecordBatch>,
145{
146    let batches_vec: Vec<RecordBatch> = batches.collect();
147
148    if batches_vec.is_empty() {
149        return Err(eyre::Error::msg(format!(
150            "unable to generate {table_type} TPC-H data"
151        )));
152    }
153
154    let file_url = table_path.join("data.vortex")?;
155    info!("...file URL '{file_url}'");
156
157    // Concatenate all batches into a single batch
158    let schema = batches_vec[0].schema();
159    let concatenated = concat_batches(&schema, &batches_vec)?;
160
161    // Convert to Vortex array
162    let vortex_array = ArrayRef::from_arrow(concatenated, false);
163
164    // Convert to array stream
165    let stream = vortex_array.to_array_stream();
166
167    // Write to a buffer
168    let mut buf: Vec<u8> = Vec::new();
169    info!("...writing {table_type} batches to vortex format");
170    let session = VortexSession::empty();
171    VortexWriteOptions::new(session)
172        .write(&mut buf, stream)
173        .await
174        .map_err(|e| eyre::Error::msg(format!("Failed to write Vortex file: {}", e)))?;
175
176    let file_path = object_store::path::Path::from_url_path(file_url.path())?;
177    info!("...putting to file path {}", file_path);
178    store.put(&file_path, buf.into()).await?;
179    Ok(())
180}
181
182async fn write_batches<I>(
183    batches: std::iter::Peekable<I>,
184    table_path: &Url,
185    table_type: &str,
186    store: Arc<dyn ObjectStore>,
187    format: &TpchFormat,
188) -> Result<()>
189where
190    I: Iterator<Item = RecordBatch>,
191{
192    match format {
193        TpchFormat::Parquet => {
194            write_batches_to_parquet(batches, table_path, table_type, store).await
195        }
196        #[cfg(feature = "vortex")]
197        TpchFormat::Vortex => write_batches_to_vortex(batches, table_path, table_type, store).await,
198    }
199}
200
201pub async fn generate(config: AppConfig, scale_factor: f64, format: TpchFormat) -> Result<()> {
202    let merged_exec_config = merge_configs(config.shared.clone(), config.cli.execution.clone());
203    let session_state_builder = DftSessionStateBuilder::try_new(Some(merged_exec_config.clone()))?
204        .with_extensions()
205        .await?;
206
207    let session_state = session_state_builder.build()?;
208    let execution_ctx = ExecutionContext::try_new(
209        &merged_exec_config,
210        session_state,
211        crate::APP_NAME,
212        env!("CARGO_PKG_VERSION"),
213    )?;
214
215    let tables_path = config.db.path.join("tables")?;
216    let tables_url = ListingTableUrl::parse(tables_path)?;
217    let store_url = tables_url.object_store();
218    let store = execution_ctx
219        .session_ctx()
220        .runtime_env()
221        .object_store(store_url)?;
222    info!("configured db store: {store:?}");
223    info!("generating TPC-H data");
224    let table_paths = create_tpch_dirs(&config)?;
225    for (table, table_path) in table_paths {
226        match table {
227            GeneratorType::Customer => {
228                info!("...generating customers");
229                let arrow_generator =
230                    CustomerArrow::new(CustomerGenerator::new(scale_factor, 1, 1));
231                write_batches(
232                    arrow_generator.peekable(),
233                    &table_path,
234                    "Customer",
235                    Arc::clone(&store),
236                    &format,
237                )
238                .await?;
239            }
240            GeneratorType::Order => {
241                info!("...generating orders");
242                let arrow_generator = OrderArrow::new(OrderGenerator::new(scale_factor, 1, 1));
243                write_batches(
244                    arrow_generator.peekable(),
245                    &table_path,
246                    "Order",
247                    Arc::clone(&store),
248                    &format,
249                )
250                .await?;
251            }
252            GeneratorType::LineItem => {
253                info!("...generating LineItems");
254                let arrow_generator =
255                    LineItemArrow::new(LineItemGenerator::new(scale_factor, 1, 1));
256                write_batches(
257                    arrow_generator.peekable(),
258                    &table_path,
259                    "LineItem",
260                    Arc::clone(&store),
261                    &format,
262                )
263                .await?;
264            }
265            GeneratorType::Nation => {
266                info!("...generating Nations");
267                let arrow_generator = NationArrow::new(NationGenerator::new(scale_factor, 1, 1));
268                write_batches(
269                    arrow_generator.peekable(),
270                    &table_path,
271                    "Nation",
272                    Arc::clone(&store),
273                    &format,
274                )
275                .await?;
276            }
277            GeneratorType::Part => {
278                info!("...generating Parts");
279                let arrow_generator = PartArrow::new(PartGenerator::new(scale_factor, 1, 1));
280                write_batches(
281                    arrow_generator.peekable(),
282                    &table_path,
283                    "Part",
284                    Arc::clone(&store),
285                    &format,
286                )
287                .await?;
288            }
289            GeneratorType::PartSupp => {
290                info!("...generating PartSupps");
291                let arrow_generator =
292                    PartSuppArrow::new(PartSuppGenerator::new(scale_factor, 1, 1));
293                write_batches(
294                    arrow_generator.peekable(),
295                    &table_path,
296                    "PartSupp",
297                    Arc::clone(&store),
298                    &format,
299                )
300                .await?;
301            }
302            GeneratorType::Region => {
303                info!("...generating Regions");
304                let arrow_generator = RegionArrow::new(RegionGenerator::new(scale_factor, 1, 1));
305                write_batches(
306                    arrow_generator.peekable(),
307                    &table_path,
308                    "Region",
309                    Arc::clone(&store),
310                    &format,
311                )
312                .await?;
313            }
314            GeneratorType::Supplier => {
315                info!("...generating Suppliers");
316                let arrow_generator =
317                    SupplierArrow::new(SupplierGenerator::new(scale_factor, 1, 1));
318                write_batches(
319                    arrow_generator.peekable(),
320                    &table_path,
321                    "Supplier",
322                    Arc::clone(&store),
323                    &format,
324                )
325                .await?;
326            }
327        }
328    }
329
330    let tpch_dir = config
331        .db
332        .path
333        .join("tables/")?
334        .join("dft/")?
335        .join("tpch/")?;
336    println!("TPC-H dataset saved to: {}", tpch_dir);
337
338    Ok(())
339}