tpchgen_cli/
runner.rs

1//! [`PlanRunner`] for running [`OutputPlan`]s.
2
3use crate::csv::*;
4use crate::generate::{generate_in_chunks, Source};
5use crate::output_plan::{OutputLocation, OutputPlan};
6use crate::parquet::generate_parquet;
7use crate::tbl::*;
8use crate::tbl::{LineItemTblSource, NationTblSource, RegionTblSource};
9use crate::{OutputFormat, Table, WriterSink};
10use log::{debug, info};
11use std::io;
12use std::io::BufWriter;
13use tokio::task::{JoinError, JoinSet};
14use tpchgen::generators::{
15    CustomerGenerator, LineItemGenerator, NationGenerator, OrderGenerator, PartGenerator,
16    PartSuppGenerator, RegionGenerator, SupplierGenerator,
17};
18use tpchgen_arrow::{
19    CustomerArrow, LineItemArrow, NationArrow, OrderArrow, PartArrow, PartSuppArrow,
20    RecordBatchIterator, RegionArrow, SupplierArrow,
21};
22
23/// Runs multiple [`OutputPlan`]s in parallel, managing the number of threads
24/// used to run them.
25#[derive(Debug)]
26pub struct PlanRunner {
27    plans: Vec<OutputPlan>,
28    num_threads: usize,
29}
30
31impl PlanRunner {
32    /// Create a new [`PlanRunner`] with the given plans and number of threads.
33    pub fn new(plans: Vec<OutputPlan>, num_threads: usize) -> Self {
34        Self { plans, num_threads }
35    }
36
37    /// Run all the plans in the runner.
38    pub async fn run(self) -> Result<(), io::Error> {
39        debug!(
40            "Running {} plans with {} threads...",
41            self.plans.len(),
42            self.num_threads
43        );
44        let Self {
45            mut plans,
46            num_threads,
47        } = self;
48
49        // Sort the plans by the number of parts so the largest are first
50        plans.sort_unstable_by(|a, b| {
51            let a_cnt = a.chunk_count();
52            let b_cnt = b.chunk_count();
53            a_cnt.cmp(&b_cnt)
54        });
55
56        // Do the actual work in parallel, using a worker queue
57        let mut worker_queue = WorkerQueue::new(num_threads);
58        while let Some(plan) = plans.pop() {
59            worker_queue.schedule_plan(plan).await?;
60        }
61        worker_queue.join_all().await
62    }
63}
64
65/// Manages worker tasks, limiting the number of total outstanding threads
66/// to some fixed number
67///
68/// The runner executes each plan with a number of threads equal to the
69/// number of parts in the plan, but no more than the total number of
70/// threads specified when creating the runner. If a plan does not need all
71/// the threads, the remaining threads are used to run other plans.
72///
73/// This is important to keep all cores busy for smaller tables that may not
74/// have sufficient parts to keep all threads busy (see [`GenerationPlan`]
75/// for more details), but not schedule more tasks than we have threads for.
76///
77/// Scheduling too many tasks requires more memory and leads to context
78/// switching overhead, which can slow down the generation process.
79///
80/// [`GenerationPlan`]: crate::plan::GenerationPlan
81struct WorkerQueue {
82    join_set: JoinSet<io::Result<usize>>,
83    /// Current number of threads available to commit
84    available_threads: usize,
85}
86
87impl WorkerQueue {
88    pub fn new(max_threads: usize) -> Self {
89        assert!(max_threads > 0);
90        Self {
91            join_set: JoinSet::new(),
92            available_threads: max_threads,
93        }
94    }
95
96    /// Spawns a task to run the plan with as many threads as possible
97    /// without exceeding the maximum number of threads.
98    ///
99    /// If there are no threads available, it will wait for one to finish
100    /// before spawning the new task.
101    ///
102    /// Note this algorithm does not guarantee that all threads are always busy,
103    /// but it should be good enough for most cases. For best thread utilization
104    /// spawn the largest plans first.
105    pub async fn schedule_plan(&mut self, plan: OutputPlan) -> io::Result<()> {
106        debug!("scheduling plan {plan}");
107        loop {
108            if self.available_threads == 0 {
109                debug!("no threads left, wait for one to finish");
110                let Some(result) = self.join_set.join_next().await else {
111                    return Err(io::Error::other(
112                        "Internal Error No more tasks to wait for, but had no threads",
113                    ));
114                };
115                self.available_threads += task_result(result)?;
116                continue; // look for threads again
117            }
118
119            // Check for any other jobs done so we can reuse their threads
120            if let Some(result) = self.join_set.try_join_next() {
121                self.available_threads += task_result(result)?;
122                continue;
123            }
124
125            debug_assert!(
126                self.available_threads > 0,
127                "should have at least one thread to continue"
128            );
129
130            // figure out how many threads to allocate to this plan. Each plan
131            // can use up to `part_count` threads.
132            let chunk_count = plan.chunk_count();
133
134            let num_plan_threads = self.available_threads.min(chunk_count);
135
136            // run the plan in a separate task, which returns the number of threads it used
137            debug!("Spawning plan {plan} with {num_plan_threads} threads");
138
139            self.join_set
140                .spawn(async move { run_plan(plan, num_plan_threads).await });
141            self.available_threads -= num_plan_threads;
142            return Ok(());
143        }
144    }
145
146    // Wait for all tasks to finish
147    pub async fn join_all(mut self) -> io::Result<()> {
148        debug!("Waiting for tasks to finish...");
149        while let Some(result) = self.join_set.join_next().await {
150            task_result(result)?;
151        }
152        debug!("Tasks finished.");
153        Ok(())
154    }
155}
156
157/// unwraps the result of a task and converts it to an `io::Result<T>`.
158fn task_result<T>(result: Result<io::Result<T>, JoinError>) -> io::Result<T> {
159    result.map_err(|e| io::Error::other(format!("Task Panic: {e}")))?
160}
161
162/// Run a single [`OutputPlan`]
163async fn run_plan(plan: OutputPlan, num_threads: usize) -> io::Result<usize> {
164    match plan.table() {
165        Table::Nation => run_nation_plan(plan, num_threads).await,
166        Table::Region => run_region_plan(plan, num_threads).await,
167        Table::Part => run_part_plan(plan, num_threads).await,
168        Table::Supplier => run_supplier_plan(plan, num_threads).await,
169        Table::Partsupp => run_partsupp_plan(plan, num_threads).await,
170        Table::Customer => run_customer_plan(plan, num_threads).await,
171        Table::Orders => run_orders_plan(plan, num_threads).await,
172        Table::Lineitem => run_lineitem_plan(plan, num_threads).await,
173    }
174}
175
176/// Writes a CSV/TSV output from the sources
177async fn write_file<I>(plan: OutputPlan, num_threads: usize, sources: I) -> Result<(), io::Error>
178where
179    I: Iterator<Item: Source> + 'static,
180{
181    // Since generate_in_chunks already buffers, there is no need to buffer
182    // again (aka don't use BufWriter here)
183    match plan.output_location() {
184        OutputLocation::Stdout => {
185            let sink = WriterSink::new(io::stdout());
186            generate_in_chunks(sink, sources, num_threads).await
187        }
188        OutputLocation::File(path) => {
189            // if the output already exists, skip running
190            if path.exists() {
191                info!("{} already exists, skipping generation", path.display());
192                return Ok(());
193            }
194            // write to a temp file and then rename to avoid partial files
195            let temp_path = path.with_extension("inprogress");
196            let file = std::fs::File::create(&temp_path).map_err(|err| {
197                io::Error::other(format!("Failed to create {temp_path:?}: {err}"))
198            })?;
199            let sink = WriterSink::new(file);
200            generate_in_chunks(sink, sources, num_threads).await?;
201            // rename the temp file to the final path
202            std::fs::rename(&temp_path, path).map_err(|e| {
203                io::Error::other(format!(
204                    "Failed to rename {temp_path:?} to {path:?} file: {e}"
205                ))
206            })?;
207            Ok(())
208        }
209    }
210}
211
212/// Generates an output parquet file from the sources
213async fn write_parquet<I>(plan: OutputPlan, num_threads: usize, sources: I) -> Result<(), io::Error>
214where
215    I: Iterator<Item: RecordBatchIterator> + 'static,
216{
217    match plan.output_location() {
218        OutputLocation::Stdout => {
219            let writer = BufWriter::with_capacity(32 * 1024 * 1024, io::stdout()); // 32MB buffer
220            generate_parquet(writer, sources, num_threads, plan.parquet_compression()).await
221        }
222        OutputLocation::File(path) => {
223            // if the output already exists, skip running
224            if path.exists() {
225                info!("{} already exists, skipping generation", path.display());
226                return Ok(());
227            }
228            // write to a temp file and then rename to avoid partial files
229            let temp_path = path.with_extension("inprogress");
230            let file = std::fs::File::create(&temp_path).map_err(|err| {
231                io::Error::other(format!("Failed to create {temp_path:?}: {err}"))
232            })?;
233            let writer = BufWriter::with_capacity(32 * 1024 * 1024, file); // 32MB buffer
234            generate_parquet(writer, sources, num_threads, plan.parquet_compression()).await?;
235            // rename the temp file to the final path
236            std::fs::rename(&temp_path, path).map_err(|e| {
237                io::Error::other(format!(
238                    "Failed to rename {temp_path:?} to {path:?} file: {e}"
239                ))
240            })?;
241            Ok(())
242        }
243    }
244}
245
246/// macro to create a function for generating a part of a particular able
247///
248/// Arguments:
249/// $FUN_NAME: name of the function to create
250/// $GENERATOR: The generator type to use
251/// $TBL_SOURCE: The [`Source`] type to use for TBL format
252/// $CSV_SOURCE: The [`Source`] type to use for CSV format
253/// $PARQUET_SOURCE: The [`RecordBatchIterator`] type to use for Parquet format
254macro_rules! define_run {
255    ($FUN_NAME:ident, $GENERATOR:ident, $TBL_SOURCE:ty, $CSV_SOURCE:ty, $PARQUET_SOURCE:ty) => {
256        async fn $FUN_NAME(plan: OutputPlan, num_threads: usize) -> io::Result<usize> {
257            use crate::GenerationPlan;
258            let scale_factor = plan.scale_factor();
259            info!("Writing {plan} using {num_threads} threads");
260
261            /// These interior functions are used to tell the compiler that the lifetime is 'static
262            /// (when these were closures, the compiler could not figure out the lifetime) and
263            /// resulted in errors like this:
264            ///          let _ = join_set.spawn(async move {
265            ///                 |  _____________________^
266            ///              96 | |                 run_plan(plan, num_plan_threads).await
267            ///              97 | |             });
268            ///                 | |______________^ implementation of `FnOnce` is not general enough
269            fn tbl_sources(
270                generation_plan: &GenerationPlan,
271                scale_factor: f64,
272            ) -> impl Iterator<Item: Source> + 'static {
273                generation_plan
274                    .clone()
275                    .into_iter()
276                    .map(move |(part, num_parts)| $GENERATOR::new(scale_factor, part, num_parts))
277                    .map(<$TBL_SOURCE>::new)
278            }
279
280            fn csv_sources(
281                generation_plan: &GenerationPlan,
282                scale_factor: f64,
283            ) -> impl Iterator<Item: Source> + 'static {
284                generation_plan
285                    .clone()
286                    .into_iter()
287                    .map(move |(part, num_parts)| $GENERATOR::new(scale_factor, part, num_parts))
288                    .map(<$CSV_SOURCE>::new)
289            }
290
291            fn parquet_sources(
292                generation_plan: &GenerationPlan,
293                scale_factor: f64,
294            ) -> impl Iterator<Item: RecordBatchIterator> + 'static {
295                generation_plan
296                    .clone()
297                    .into_iter()
298                    .map(move |(part, num_parts)| $GENERATOR::new(scale_factor, part, num_parts))
299                    .map(<$PARQUET_SOURCE>::new)
300            }
301
302            // Dispach to the appropriate output format
303            match plan.output_format() {
304                OutputFormat::Tbl => {
305                    let gens = tbl_sources(plan.generation_plan(), scale_factor);
306                    write_file(plan, num_threads, gens).await?
307                }
308                OutputFormat::Csv => {
309                    let gens = csv_sources(plan.generation_plan(), scale_factor);
310                    write_file(plan, num_threads, gens).await?
311                }
312                OutputFormat::Parquet => {
313                    let gens = parquet_sources(plan.generation_plan(), scale_factor);
314                    write_parquet(plan, num_threads, gens).await?
315                }
316            };
317            Ok(num_threads)
318        }
319    };
320}
321
322define_run!(
323    run_lineitem_plan,
324    LineItemGenerator,
325    LineItemTblSource,
326    LineItemCsvSource,
327    LineItemArrow
328);
329
330define_run!(
331    run_nation_plan,
332    NationGenerator,
333    NationTblSource,
334    NationCsvSource,
335    NationArrow
336);
337
338define_run!(
339    run_region_plan,
340    RegionGenerator,
341    RegionTblSource,
342    RegionCsvSource,
343    RegionArrow
344);
345
346define_run!(
347    run_part_plan,
348    PartGenerator,
349    PartTblSource,
350    PartCsvSource,
351    PartArrow
352);
353
354define_run!(
355    run_supplier_plan,
356    SupplierGenerator,
357    SupplierTblSource,
358    SupplierCsvSource,
359    SupplierArrow
360);
361define_run!(
362    run_partsupp_plan,
363    PartSuppGenerator,
364    PartSuppTblSource,
365    PartSuppCsvSource,
366    PartSuppArrow
367);
368
369define_run!(
370    run_customer_plan,
371    CustomerGenerator,
372    CustomerTblSource,
373    CustomerCsvSource,
374    CustomerArrow
375);
376
377define_run!(
378    run_orders_plan,
379    OrderGenerator,
380    OrderTblSource,
381    OrderCsvSource,
382    OrderArrow
383);