tpchgen_cli/
parquet.rs

1//! Parquet output format
2
3use crate::statistics::WriteStatistics;
4use arrow::datatypes::SchemaRef;
5use futures::StreamExt;
6use log::debug;
7use parquet::arrow::arrow_writer::{compute_leaves, ArrowColumnChunk};
8use parquet::arrow::ArrowSchemaConverter;
9use parquet::basic::Compression;
10use parquet::file::properties::WriterProperties;
11use parquet::file::writer::SerializedFileWriter;
12use parquet::schema::types::SchemaDescPtr;
13use std::io;
14use std::io::Write;
15use std::sync::Arc;
16use tokio::sync::mpsc::{Receiver, Sender};
17use tpchgen_arrow::RecordBatchIterator;
18
19pub trait IntoSize {
20    /// Convert the object into a size
21    fn into_size(self) -> Result<usize, io::Error>;
22}
23
24/// Converts a set of RecordBatchIterators into a Parquet file
25///
26/// Uses num_threads to generate the data in parallel
27///
28/// Note the input is an iterator of [`RecordBatchIterator`]; The batches
29/// produced by each iterator is encoded as its own row group.
30pub async fn generate_parquet<W: Write + Send + IntoSize + 'static, I>(
31    writer: W,
32    iter_iter: I,
33    num_threads: usize,
34    parquet_compression: Compression,
35) -> Result<(), io::Error>
36where
37    I: Iterator<Item: RecordBatchIterator> + 'static,
38{
39    debug!(
40        "Generating Parquet with {num_threads} threads, using {parquet_compression} compression"
41    );
42    // Based on example in https://docs.rs/parquet/latest/parquet/arrow/arrow_writer/struct.ArrowColumnWriter.html
43    let mut iter_iter = iter_iter.peekable();
44
45    // get schema from the first iterator
46    let Some(first_iter) = iter_iter.peek() else {
47        return Ok(()); // no data shrug
48    };
49    let schema = Arc::clone(first_iter.schema());
50
51    // Compute the parquet schema
52    let writer_properties = WriterProperties::builder()
53        .set_compression(parquet_compression)
54        .build();
55    let writer_properties = Arc::new(writer_properties);
56    let parquet_schema = Arc::new(
57        ArrowSchemaConverter::new()
58            .with_coerce_types(writer_properties.coerce_types())
59            .convert(&schema)
60            .unwrap(),
61    );
62
63    // create a stream that computes the data for each row group
64    let mut row_group_stream = futures::stream::iter(iter_iter)
65        .map(async |iter| {
66            let parquet_schema = Arc::clone(&parquet_schema);
67            let writer_properties = Arc::clone(&writer_properties);
68            let schema = Arc::clone(&schema);
69            // run on a separate thread
70            tokio::task::spawn(async move {
71                encode_row_group(parquet_schema, writer_properties, schema, iter)
72            })
73            .await
74            .expect("Inner task panicked")
75        })
76        .buffered(num_threads); // generate row groups in parallel
77
78    let mut statistics = WriteStatistics::new("row groups");
79
80    // A blocking task that writes the row groups to the file
81    // done in a blocking task to avoid having a thread waiting on IO
82    // Now, read each completed row group and write it to the file
83    let root_schema = parquet_schema.root_schema_ptr();
84    let writer_properties_captured = Arc::clone(&writer_properties);
85    let (tx, mut rx): (
86        Sender<Vec<ArrowColumnChunk>>,
87        Receiver<Vec<ArrowColumnChunk>>,
88    ) = tokio::sync::mpsc::channel(num_threads);
89    let writer_task = tokio::task::spawn_blocking(move || {
90        // Create parquet writer
91        let mut writer =
92            SerializedFileWriter::new(writer, root_schema, writer_properties_captured).unwrap();
93
94        while let Some(chunks) = rx.blocking_recv() {
95            // Start row group
96            let mut row_group_writer = writer.next_row_group().unwrap();
97
98            // Slap the chunks into the row group
99            for chunk in chunks {
100                chunk.append_to_row_group(&mut row_group_writer).unwrap();
101            }
102            row_group_writer.close().unwrap();
103            statistics.increment_chunks(1);
104        }
105        let size = writer.into_inner()?.into_size()?;
106        statistics.increment_bytes(size);
107        Ok(()) as Result<(), io::Error>
108    });
109
110    // now, drive the input stream and send results to the writer task
111    while let Some(chunks) = row_group_stream.next().await {
112        // send the chunks to the writer task
113        if let Err(e) = tx.send(chunks).await {
114            debug!("Error sending chunks to writer: {e}");
115            break; // stop early
116        }
117    }
118    // signal the writer task that we are done
119    drop(tx);
120
121    // Wait for the writer task to finish
122    writer_task.await??;
123
124    Ok(())
125}
126
127/// Creates the data for a particular row group
128///
129/// Note at the moment it does not use multiple tasks/threads but it could
130/// potentially encode multiple columns with different threads .
131///
132/// Returns an array of [`ArrowColumnChunk`]
133fn encode_row_group<I>(
134    parquet_schema: SchemaDescPtr,
135    writer_properties: Arc<WriterProperties>,
136    schema: SchemaRef,
137    iter: I,
138) -> Vec<ArrowColumnChunk>
139where
140    I: RecordBatchIterator,
141{
142    // Create writers for each of the leaf columns
143    #[allow(deprecated)]
144    let mut col_writers = parquet::arrow::arrow_writer::get_column_writers(
145        &parquet_schema,
146        &writer_properties,
147        &schema,
148    )
149    .unwrap();
150
151    // generate the data and send it to the tasks (via the sender channels)
152    for batch in iter {
153        let columns = batch.columns().iter();
154        let col_writers = col_writers.iter_mut();
155        let fields = schema.fields().iter();
156
157        for ((col_writer, field), arr) in col_writers.zip(fields).zip(columns) {
158            for leaves in compute_leaves(field.as_ref(), arr).unwrap() {
159                col_writer.write(&leaves).unwrap();
160            }
161        }
162    }
163    // finish the writers and create the column chunks
164    col_writers
165        .into_iter()
166        .map(|col_writer| col_writer.close().unwrap())
167        .collect()
168}