1use crate::csv::*;
4use crate::generate::{generate_in_chunks, Source};
5use crate::output_plan::{OutputLocation, OutputPlan};
6use crate::parquet::generate_parquet;
7use crate::tbl::*;
8use crate::tbl::{LineItemTblSource, NationTblSource, RegionTblSource};
9use crate::{OutputFormat, Table, WriterSink};
10use log::{debug, info};
11use std::io;
12use std::io::BufWriter;
13use tokio::task::{JoinError, JoinSet};
14use tpchgen::generators::{
15 CustomerGenerator, LineItemGenerator, NationGenerator, OrderGenerator, PartGenerator,
16 PartSuppGenerator, RegionGenerator, SupplierGenerator,
17};
18use tpchgen_arrow::{
19 CustomerArrow, LineItemArrow, NationArrow, OrderArrow, PartArrow, PartSuppArrow,
20 RecordBatchIterator, RegionArrow, SupplierArrow,
21};
22
23#[derive(Debug)]
26pub struct PlanRunner {
27 plans: Vec<OutputPlan>,
28 num_threads: usize,
29}
30
31impl PlanRunner {
32 pub fn new(plans: Vec<OutputPlan>, num_threads: usize) -> Self {
34 Self { plans, num_threads }
35 }
36
37 pub async fn run(self) -> Result<(), io::Error> {
39 debug!(
40 "Running {} plans with {} threads...",
41 self.plans.len(),
42 self.num_threads
43 );
44 let Self {
45 mut plans,
46 num_threads,
47 } = self;
48
49 plans.sort_unstable_by(|a, b| {
51 let a_cnt = a.chunk_count();
52 let b_cnt = b.chunk_count();
53 a_cnt.cmp(&b_cnt)
54 });
55
56 let mut worker_queue = WorkerQueue::new(num_threads);
58 while let Some(plan) = plans.pop() {
59 worker_queue.schedule_plan(plan).await?;
60 }
61 worker_queue.join_all().await
62 }
63}
64
65struct WorkerQueue {
82 join_set: JoinSet<io::Result<usize>>,
83 available_threads: usize,
85}
86
87impl WorkerQueue {
88 pub fn new(max_threads: usize) -> Self {
89 assert!(max_threads > 0);
90 Self {
91 join_set: JoinSet::new(),
92 available_threads: max_threads,
93 }
94 }
95
96 pub async fn schedule_plan(&mut self, plan: OutputPlan) -> io::Result<()> {
106 debug!("scheduling plan {plan}");
107 loop {
108 if self.available_threads == 0 {
109 debug!("no threads left, wait for one to finish");
110 let Some(result) = self.join_set.join_next().await else {
111 return Err(io::Error::other(
112 "Internal Error No more tasks to wait for, but had no threads",
113 ));
114 };
115 self.available_threads += task_result(result)?;
116 continue; }
118
119 if let Some(result) = self.join_set.try_join_next() {
121 self.available_threads += task_result(result)?;
122 continue;
123 }
124
125 debug_assert!(
126 self.available_threads > 0,
127 "should have at least one thread to continue"
128 );
129
130 let chunk_count = plan.chunk_count();
133
134 let num_plan_threads = self.available_threads.min(chunk_count);
135
136 debug!("Spawning plan {plan} with {num_plan_threads} threads");
138
139 self.join_set
140 .spawn(async move { run_plan(plan, num_plan_threads).await });
141 self.available_threads -= num_plan_threads;
142 return Ok(());
143 }
144 }
145
146 pub async fn join_all(mut self) -> io::Result<()> {
148 debug!("Waiting for tasks to finish...");
149 while let Some(result) = self.join_set.join_next().await {
150 task_result(result)?;
151 }
152 debug!("Tasks finished.");
153 Ok(())
154 }
155}
156
157fn task_result<T>(result: Result<io::Result<T>, JoinError>) -> io::Result<T> {
159 result.map_err(|e| io::Error::other(format!("Task Panic: {e}")))?
160}
161
162async fn run_plan(plan: OutputPlan, num_threads: usize) -> io::Result<usize> {
164 match plan.table() {
165 Table::Nation => run_nation_plan(plan, num_threads).await,
166 Table::Region => run_region_plan(plan, num_threads).await,
167 Table::Part => run_part_plan(plan, num_threads).await,
168 Table::Supplier => run_supplier_plan(plan, num_threads).await,
169 Table::Partsupp => run_partsupp_plan(plan, num_threads).await,
170 Table::Customer => run_customer_plan(plan, num_threads).await,
171 Table::Orders => run_orders_plan(plan, num_threads).await,
172 Table::Lineitem => run_lineitem_plan(plan, num_threads).await,
173 }
174}
175
176async fn write_file<I>(plan: OutputPlan, num_threads: usize, sources: I) -> Result<(), io::Error>
178where
179 I: Iterator<Item: Source> + 'static,
180{
181 match plan.output_location() {
184 OutputLocation::Stdout => {
185 let sink = WriterSink::new(io::stdout());
186 generate_in_chunks(sink, sources, num_threads).await
187 }
188 OutputLocation::File(path) => {
189 if path.exists() {
191 info!("{} already exists, skipping generation", path.display());
192 return Ok(());
193 }
194 let temp_path = path.with_extension("inprogress");
196 let file = std::fs::File::create(&temp_path).map_err(|err| {
197 io::Error::other(format!("Failed to create {temp_path:?}: {err}"))
198 })?;
199 let sink = WriterSink::new(file);
200 generate_in_chunks(sink, sources, num_threads).await?;
201 std::fs::rename(&temp_path, path).map_err(|e| {
203 io::Error::other(format!(
204 "Failed to rename {temp_path:?} to {path:?} file: {e}"
205 ))
206 })?;
207 Ok(())
208 }
209 }
210}
211
212async fn write_parquet<I>(plan: OutputPlan, num_threads: usize, sources: I) -> Result<(), io::Error>
214where
215 I: Iterator<Item: RecordBatchIterator> + 'static,
216{
217 match plan.output_location() {
218 OutputLocation::Stdout => {
219 let writer = BufWriter::with_capacity(32 * 1024 * 1024, io::stdout()); generate_parquet(writer, sources, num_threads, plan.parquet_compression()).await
221 }
222 OutputLocation::File(path) => {
223 if path.exists() {
225 info!("{} already exists, skipping generation", path.display());
226 return Ok(());
227 }
228 let temp_path = path.with_extension("inprogress");
230 let file = std::fs::File::create(&temp_path).map_err(|err| {
231 io::Error::other(format!("Failed to create {temp_path:?}: {err}"))
232 })?;
233 let writer = BufWriter::with_capacity(32 * 1024 * 1024, file); generate_parquet(writer, sources, num_threads, plan.parquet_compression()).await?;
235 std::fs::rename(&temp_path, path).map_err(|e| {
237 io::Error::other(format!(
238 "Failed to rename {temp_path:?} to {path:?} file: {e}"
239 ))
240 })?;
241 Ok(())
242 }
243 }
244}
245
246macro_rules! define_run {
255 ($FUN_NAME:ident, $GENERATOR:ident, $TBL_SOURCE:ty, $CSV_SOURCE:ty, $PARQUET_SOURCE:ty) => {
256 async fn $FUN_NAME(plan: OutputPlan, num_threads: usize) -> io::Result<usize> {
257 use crate::GenerationPlan;
258 let scale_factor = plan.scale_factor();
259 info!("Writing {plan} using {num_threads} threads");
260
261 fn tbl_sources(
270 generation_plan: &GenerationPlan,
271 scale_factor: f64,
272 ) -> impl Iterator<Item: Source> + 'static {
273 generation_plan
274 .clone()
275 .into_iter()
276 .map(move |(part, num_parts)| $GENERATOR::new(scale_factor, part, num_parts))
277 .map(<$TBL_SOURCE>::new)
278 }
279
280 fn csv_sources(
281 generation_plan: &GenerationPlan,
282 scale_factor: f64,
283 ) -> impl Iterator<Item: Source> + 'static {
284 generation_plan
285 .clone()
286 .into_iter()
287 .map(move |(part, num_parts)| $GENERATOR::new(scale_factor, part, num_parts))
288 .map(<$CSV_SOURCE>::new)
289 }
290
291 fn parquet_sources(
292 generation_plan: &GenerationPlan,
293 scale_factor: f64,
294 ) -> impl Iterator<Item: RecordBatchIterator> + 'static {
295 generation_plan
296 .clone()
297 .into_iter()
298 .map(move |(part, num_parts)| $GENERATOR::new(scale_factor, part, num_parts))
299 .map(<$PARQUET_SOURCE>::new)
300 }
301
302 match plan.output_format() {
304 OutputFormat::Tbl => {
305 let gens = tbl_sources(plan.generation_plan(), scale_factor);
306 write_file(plan, num_threads, gens).await?
307 }
308 OutputFormat::Csv => {
309 let gens = csv_sources(plan.generation_plan(), scale_factor);
310 write_file(plan, num_threads, gens).await?
311 }
312 OutputFormat::Parquet => {
313 let gens = parquet_sources(plan.generation_plan(), scale_factor);
314 write_parquet(plan, num_threads, gens).await?
315 }
316 };
317 Ok(num_threads)
318 }
319 };
320}
321
322define_run!(
323 run_lineitem_plan,
324 LineItemGenerator,
325 LineItemTblSource,
326 LineItemCsvSource,
327 LineItemArrow
328);
329
330define_run!(
331 run_nation_plan,
332 NationGenerator,
333 NationTblSource,
334 NationCsvSource,
335 NationArrow
336);
337
338define_run!(
339 run_region_plan,
340 RegionGenerator,
341 RegionTblSource,
342 RegionCsvSource,
343 RegionArrow
344);
345
346define_run!(
347 run_part_plan,
348 PartGenerator,
349 PartTblSource,
350 PartCsvSource,
351 PartArrow
352);
353
354define_run!(
355 run_supplier_plan,
356 SupplierGenerator,
357 SupplierTblSource,
358 SupplierCsvSource,
359 SupplierArrow
360);
361define_run!(
362 run_partsupp_plan,
363 PartSuppGenerator,
364 PartSuppTblSource,
365 PartSuppCsvSource,
366 PartSuppArrow
367);
368
369define_run!(
370 run_customer_plan,
371 CustomerGenerator,
372 CustomerTblSource,
373 CustomerCsvSource,
374 CustomerArrow
375);
376
377define_run!(
378 run_orders_plan,
379 OrderGenerator,
380 OrderTblSource,
381 OrderCsvSource,
382 OrderArrow
383);