datafusion_datasource/write/
demux.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Module containing helper methods/traits related to enabling
19//! dividing input stream into multiple output files at execution time
20
21use std::borrow::Cow;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25use crate::url::ListingTableUrl;
26use crate::write::FileSinkConfig;
27use datafusion_common::error::Result;
28use datafusion_physical_plan::SendableRecordBatchStream;
29
30use arrow::array::{
31    builder::UInt64Builder, cast::AsArray, downcast_dictionary_array, ArrayAccessor,
32    RecordBatch, StringArray, StructArray,
33};
34use arrow::datatypes::{DataType, Schema};
35use datafusion_common::cast::{
36    as_boolean_array, as_date32_array, as_date64_array, as_float16_array,
37    as_float32_array, as_float64_array, as_int16_array, as_int32_array, as_int64_array,
38    as_int8_array, as_string_array, as_string_view_array, as_uint16_array,
39    as_uint32_array, as_uint64_array, as_uint8_array,
40};
41use datafusion_common::{exec_datafusion_err, not_impl_err, DataFusionError};
42use datafusion_common_runtime::SpawnedTask;
43use datafusion_execution::TaskContext;
44
45use chrono::NaiveDate;
46use futures::StreamExt;
47use object_store::path::Path;
48use rand::distr::SampleString;
49use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender};
50
51type RecordBatchReceiver = Receiver<RecordBatch>;
52pub type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>;
53
54/// Splits a single [SendableRecordBatchStream] into a dynamically determined
55/// number of partitions at execution time.
56///
57/// The partitions are determined by factors known only at execution time, such
58/// as total number of rows and partition column values. The demuxer task
59/// communicates to the caller by sending channels over a channel. The inner
60/// channels send RecordBatches which should be contained within the same output
61/// file. The outer channel is used to send a dynamic number of inner channels,
62/// representing a dynamic number of total output files.
63///
64/// The caller is also responsible to monitor the demux task for errors and
65/// abort accordingly.
66///
67/// A path with an extension will force only a single file to
68/// be written with the extension from the path. Otherwise the default extension
69/// will be used and the output will be split into multiple files.
70///
71/// Examples of `base_output_path`
72///  * `tmp/dataset/` -> is a folder since it ends in `/`
73///  * `tmp/dataset` -> is still a folder since it does not end in `/` but has no valid file extension
74///  * `tmp/file.parquet` -> is a file since it does not end in `/` and has a valid file extension `.parquet`
75///  * `tmp/file.parquet/` -> is a folder since it ends in `/`
76///
77/// The `partition_by` parameter will additionally split the input based on the
78/// unique values of a specific column, see
79/// <https://github.com/apache/datafusion/issues/7744>
80///
81/// ```text
82///                                                                              ┌───────────┐               ┌────────────┐    ┌─────────────┐
83///                                                                     ┌──────▶ │  batch 1  ├────▶...──────▶│   Batch a  │    │ Output File1│
84///                                                                     │        └───────────┘               └────────────┘    └─────────────┘
85///                                                                     │
86///                                                 ┌──────────┐        │        ┌───────────┐               ┌────────────┐    ┌─────────────┐
87/// ┌───────────┐               ┌────────────┐      │          │        ├──────▶ │  batch a+1├────▶...──────▶│   Batch b  │    │ Output File2│
88/// │  batch 1  ├────▶...──────▶│   Batch N  ├─────▶│  Demux   ├────────┤ ...    └───────────┘               └────────────┘    └─────────────┘
89/// └───────────┘               └────────────┘      │          │        │
90///                                                 └──────────┘        │        ┌───────────┐               ┌────────────┐    ┌─────────────┐
91///                                                                     └──────▶ │  batch d  ├────▶...──────▶│   Batch n  │    │ Output FileN│
92///                                                                              └───────────┘               └────────────┘    └─────────────┘
93/// ```
94pub(crate) fn start_demuxer_task(
95    config: &FileSinkConfig,
96    data: SendableRecordBatchStream,
97    context: &Arc<TaskContext>,
98) -> (SpawnedTask<Result<()>>, DemuxedStreamReceiver) {
99    let (tx, rx) = mpsc::unbounded_channel();
100    let context = Arc::clone(context);
101    let file_extension = config.file_extension.clone();
102    let base_output_path = config.table_paths[0].clone();
103    let task = if config.table_partition_cols.is_empty() {
104        let single_file_output = !base_output_path.is_collection()
105            && base_output_path.file_extension().is_some();
106        SpawnedTask::spawn(async move {
107            row_count_demuxer(
108                tx,
109                data,
110                context,
111                base_output_path,
112                file_extension,
113                single_file_output,
114            )
115            .await
116        })
117    } else {
118        // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot
119        // bound this channel without risking a deadlock.
120        let partition_by = config.table_partition_cols.clone();
121        let keep_partition_by_columns = config.keep_partition_by_columns;
122        SpawnedTask::spawn(async move {
123            hive_style_partitions_demuxer(
124                tx,
125                data,
126                context,
127                partition_by,
128                base_output_path,
129                file_extension,
130                keep_partition_by_columns,
131            )
132            .await
133        })
134    };
135
136    (task, rx)
137}
138
139/// Dynamically partitions input stream to achieve desired maximum rows per file
140async fn row_count_demuxer(
141    mut tx: UnboundedSender<(Path, Receiver<RecordBatch>)>,
142    mut input: SendableRecordBatchStream,
143    context: Arc<TaskContext>,
144    base_output_path: ListingTableUrl,
145    file_extension: String,
146    single_file_output: bool,
147) -> Result<()> {
148    let exec_options = &context.session_config().options().execution;
149
150    let max_rows_per_file = exec_options.soft_max_rows_per_output_file;
151    let max_buffered_batches = exec_options.max_buffered_batches_per_output_file;
152    let minimum_parallel_files = exec_options.minimum_parallel_output_files;
153    let mut part_idx = 0;
154    let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16);
155
156    let mut open_file_streams = Vec::with_capacity(minimum_parallel_files);
157
158    let mut next_send_steam = 0;
159    let mut row_counts = Vec::with_capacity(minimum_parallel_files);
160
161    // Overrides if single_file_output is set
162    let minimum_parallel_files = if single_file_output {
163        1
164    } else {
165        minimum_parallel_files
166    };
167
168    let max_rows_per_file = if single_file_output {
169        usize::MAX
170    } else {
171        max_rows_per_file
172    };
173
174    while let Some(rb) = input.next().await.transpose()? {
175        // ensure we have at least minimum_parallel_files open
176        if open_file_streams.len() < minimum_parallel_files {
177            open_file_streams.push(create_new_file_stream(
178                &base_output_path,
179                &write_id,
180                part_idx,
181                &file_extension,
182                single_file_output,
183                max_buffered_batches,
184                &mut tx,
185            )?);
186            row_counts.push(0);
187            part_idx += 1;
188        } else if row_counts[next_send_steam] >= max_rows_per_file {
189            row_counts[next_send_steam] = 0;
190            open_file_streams[next_send_steam] = create_new_file_stream(
191                &base_output_path,
192                &write_id,
193                part_idx,
194                &file_extension,
195                single_file_output,
196                max_buffered_batches,
197                &mut tx,
198            )?;
199            part_idx += 1;
200        }
201        row_counts[next_send_steam] += rb.num_rows();
202        open_file_streams[next_send_steam]
203            .send(rb)
204            .await
205            .map_err(|_| {
206                DataFusionError::Execution(
207                    "Error sending RecordBatch to file stream!".into(),
208                )
209            })?;
210
211        next_send_steam = (next_send_steam + 1) % minimum_parallel_files;
212    }
213    Ok(())
214}
215
216/// Helper for row count demuxer
217fn generate_file_path(
218    base_output_path: &ListingTableUrl,
219    write_id: &str,
220    part_idx: usize,
221    file_extension: &str,
222    single_file_output: bool,
223) -> Path {
224    if !single_file_output {
225        base_output_path
226            .prefix()
227            .child(format!("{write_id}_{part_idx}.{file_extension}"))
228    } else {
229        base_output_path.prefix().to_owned()
230    }
231}
232
233/// Helper for row count demuxer
234fn create_new_file_stream(
235    base_output_path: &ListingTableUrl,
236    write_id: &str,
237    part_idx: usize,
238    file_extension: &str,
239    single_file_output: bool,
240    max_buffered_batches: usize,
241    tx: &mut UnboundedSender<(Path, Receiver<RecordBatch>)>,
242) -> Result<Sender<RecordBatch>> {
243    let file_path = generate_file_path(
244        base_output_path,
245        write_id,
246        part_idx,
247        file_extension,
248        single_file_output,
249    );
250    let (tx_file, rx_file) = mpsc::channel(max_buffered_batches / 2);
251    tx.send((file_path, rx_file)).map_err(|_| {
252        DataFusionError::Execution("Error sending RecordBatch to file stream!".into())
253    })?;
254    Ok(tx_file)
255}
256
257/// Splits an input stream based on the distinct values of a set of columns
258/// Assumes standard hive style partition paths such as
259/// /col1=val1/col2=val2/outputfile.parquet
260async fn hive_style_partitions_demuxer(
261    tx: UnboundedSender<(Path, Receiver<RecordBatch>)>,
262    mut input: SendableRecordBatchStream,
263    context: Arc<TaskContext>,
264    partition_by: Vec<(String, DataType)>,
265    base_output_path: ListingTableUrl,
266    file_extension: String,
267    keep_partition_by_columns: bool,
268) -> Result<()> {
269    let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16);
270
271    let exec_options = &context.session_config().options().execution;
272    let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file;
273
274    // To support non string partition col types, cast the type to &str first
275    let mut value_map: HashMap<Vec<String>, Sender<RecordBatch>> = HashMap::new();
276
277    while let Some(rb) = input.next().await.transpose()? {
278        // First compute partition key for each row of batch, e.g. (col1=val1, col2=val2, ...)
279        let all_partition_values = compute_partition_keys_by_row(&rb, &partition_by)?;
280
281        // Next compute how the batch should be split up to take each distinct key to its own batch
282        let take_map = compute_take_arrays(&rb, all_partition_values);
283
284        // Divide up the batch into distinct partition key batches and send each batch
285        for (part_key, mut builder) in take_map.into_iter() {
286            // Take method adapted from https://github.com/lancedb/lance/pull/1337/files
287            // TODO: upstream RecordBatch::take to arrow-rs
288            let take_indices = builder.finish();
289            let struct_array: StructArray = rb.clone().into();
290            let parted_batch = RecordBatch::from(
291                arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(),
292            );
293
294            // Get or create channel for this batch
295            let part_tx = match value_map.get_mut(&part_key) {
296                Some(part_tx) => part_tx,
297                None => {
298                    // Create channel for previously unseen distinct partition key and notify consumer of new file
299                    let (part_tx, part_rx) =
300                        mpsc::channel::<RecordBatch>(max_buffered_recordbatches);
301                    let file_path = compute_hive_style_file_path(
302                        &part_key,
303                        &partition_by,
304                        &write_id,
305                        &file_extension,
306                        &base_output_path,
307                    );
308
309                    tx.send((file_path, part_rx)).map_err(|_| {
310                        DataFusionError::Execution(
311                            "Error sending new file stream!".into(),
312                        )
313                    })?;
314
315                    value_map.insert(part_key.clone(), part_tx);
316                    value_map
317                        .get_mut(&part_key)
318                        .ok_or(DataFusionError::Internal(
319                            "Key must exist since it was just inserted!".into(),
320                        ))?
321                }
322            };
323
324            let final_batch_to_send = if keep_partition_by_columns {
325                parted_batch
326            } else {
327                remove_partition_by_columns(&parted_batch, &partition_by)?
328            };
329
330            // Finally send the partial batch partitioned by distinct value!
331            part_tx.send(final_batch_to_send).await.map_err(|_| {
332                DataFusionError::Internal("Unexpected error sending parted batch!".into())
333            })?;
334        }
335    }
336
337    Ok(())
338}
339
340fn compute_partition_keys_by_row<'a>(
341    rb: &'a RecordBatch,
342    partition_by: &'a [(String, DataType)],
343) -> Result<Vec<Vec<Cow<'a, str>>>> {
344    let mut all_partition_values = vec![];
345
346    const EPOCH_DAYS_FROM_CE: i32 = 719_163;
347
348    // For the purposes of writing partitioned data, we can rely on schema inference
349    // to determine the type of the partition cols in order to provide a more ergonomic
350    // UI which does not require specifying DataTypes manually. So, we ignore the
351    // DataType within the partition_by array and infer the correct type from the
352    // batch schema instead.
353    let schema = rb.schema();
354    for (col, _) in partition_by.iter() {
355        let mut partition_values = vec![];
356
357        let dtype = schema.field_with_name(col)?.data_type();
358        let col_array = rb.column_by_name(col).ok_or(exec_datafusion_err!(
359            "PartitionBy Column {} does not exist in source data! Got schema {schema}.",
360            col
361        ))?;
362
363        match dtype {
364            DataType::Utf8 => {
365                let array = as_string_array(col_array)?;
366                for i in 0..rb.num_rows() {
367                    partition_values.push(Cow::from(array.value(i)));
368                }
369            }
370            DataType::Utf8View => {
371                let array = as_string_view_array(col_array)?;
372                for i in 0..rb.num_rows() {
373                    partition_values.push(Cow::from(array.value(i)));
374                }
375            }
376            DataType::Boolean => {
377                let array = as_boolean_array(col_array)?;
378                for i in 0..rb.num_rows() {
379                    partition_values.push(Cow::from(array.value(i).to_string()));
380                }
381            }
382            DataType::Date32 => {
383                let array = as_date32_array(col_array)?;
384                // ISO-8601/RFC3339 format - yyyy-mm-dd
385                let format = "%Y-%m-%d";
386                for i in 0..rb.num_rows() {
387                    let date = NaiveDate::from_num_days_from_ce_opt(
388                        EPOCH_DAYS_FROM_CE + array.value(i),
389                    )
390                    .unwrap()
391                    .format(format)
392                    .to_string();
393                    partition_values.push(Cow::from(date));
394                }
395            }
396            DataType::Date64 => {
397                let array = as_date64_array(col_array)?;
398                // ISO-8601/RFC3339 format - yyyy-mm-dd
399                let format = "%Y-%m-%d";
400                for i in 0..rb.num_rows() {
401                    let date = NaiveDate::from_num_days_from_ce_opt(
402                        EPOCH_DAYS_FROM_CE + (array.value(i) / 86_400_000) as i32,
403                    )
404                    .unwrap()
405                    .format(format)
406                    .to_string();
407                    partition_values.push(Cow::from(date));
408                }
409            }
410            DataType::Int8 => {
411                let array = as_int8_array(col_array)?;
412                for i in 0..rb.num_rows() {
413                    partition_values.push(Cow::from(array.value(i).to_string()));
414                }
415            }
416            DataType::Int16 => {
417                let array = as_int16_array(col_array)?;
418                for i in 0..rb.num_rows() {
419                    partition_values.push(Cow::from(array.value(i).to_string()));
420                }
421            }
422            DataType::Int32 => {
423                let array = as_int32_array(col_array)?;
424                for i in 0..rb.num_rows() {
425                    partition_values.push(Cow::from(array.value(i).to_string()));
426                }
427            }
428            DataType::Int64 => {
429                let array = as_int64_array(col_array)?;
430                for i in 0..rb.num_rows() {
431                    partition_values.push(Cow::from(array.value(i).to_string()));
432                }
433            }
434            DataType::UInt8 => {
435                let array = as_uint8_array(col_array)?;
436                for i in 0..rb.num_rows() {
437                    partition_values.push(Cow::from(array.value(i).to_string()));
438                }
439            }
440            DataType::UInt16 => {
441                let array = as_uint16_array(col_array)?;
442                for i in 0..rb.num_rows() {
443                    partition_values.push(Cow::from(array.value(i).to_string()));
444                }
445            }
446            DataType::UInt32 => {
447                let array = as_uint32_array(col_array)?;
448                for i in 0..rb.num_rows() {
449                    partition_values.push(Cow::from(array.value(i).to_string()));
450                }
451            }
452            DataType::UInt64 => {
453                let array = as_uint64_array(col_array)?;
454                for i in 0..rb.num_rows() {
455                    partition_values.push(Cow::from(array.value(i).to_string()));
456                }
457            }
458            DataType::Float16 => {
459                let array = as_float16_array(col_array)?;
460                for i in 0..rb.num_rows() {
461                    partition_values.push(Cow::from(array.value(i).to_string()));
462                }
463            }
464            DataType::Float32 => {
465                let array = as_float32_array(col_array)?;
466                for i in 0..rb.num_rows() {
467                    partition_values.push(Cow::from(array.value(i).to_string()));
468                }
469            }
470            DataType::Float64 => {
471                let array = as_float64_array(col_array)?;
472                for i in 0..rb.num_rows() {
473                    partition_values.push(Cow::from(array.value(i).to_string()));
474                }
475            }
476            DataType::Dictionary(_, _) => {
477                downcast_dictionary_array!(
478                    col_array =>  {
479                        let array = col_array.downcast_dict::<StringArray>()
480                            .ok_or(exec_datafusion_err!("it is not yet supported to write to hive partitions with datatype {}",
481                            dtype))?;
482
483                        for i in 0..rb.num_rows() {
484                            partition_values.push(Cow::from(array.value(i)));
485                        }
486                    },
487                    _ => unreachable!(),
488                )
489            }
490            _ => {
491                return not_impl_err!(
492                "it is not yet supported to write to hive partitions with datatype {}",
493                dtype
494            )
495            }
496        }
497
498        all_partition_values.push(partition_values);
499    }
500
501    Ok(all_partition_values)
502}
503
504fn compute_take_arrays(
505    rb: &RecordBatch,
506    all_partition_values: Vec<Vec<Cow<str>>>,
507) -> HashMap<Vec<String>, UInt64Builder> {
508    let mut take_map = HashMap::new();
509    for i in 0..rb.num_rows() {
510        let mut part_key = vec![];
511        for vals in all_partition_values.iter() {
512            part_key.push(vals[i].clone().into());
513        }
514        let builder = take_map.entry(part_key).or_insert_with(UInt64Builder::new);
515        builder.append_value(i as u64);
516    }
517    take_map
518}
519
520fn remove_partition_by_columns(
521    parted_batch: &RecordBatch,
522    partition_by: &[(String, DataType)],
523) -> Result<RecordBatch> {
524    let partition_names: Vec<_> = partition_by.iter().map(|(s, _)| s).collect();
525    let (non_part_cols, non_part_fields): (Vec<_>, Vec<_>) = parted_batch
526        .columns()
527        .iter()
528        .zip(parted_batch.schema().fields())
529        .filter_map(|(a, f)| {
530            if !partition_names.contains(&f.name()) {
531                Some((Arc::clone(a), (**f).clone()))
532            } else {
533                None
534            }
535        })
536        .unzip();
537
538    let non_part_schema = Schema::new(non_part_fields);
539    let final_batch_to_send =
540        RecordBatch::try_new(Arc::new(non_part_schema), non_part_cols)?;
541
542    Ok(final_batch_to_send)
543}
544
545fn compute_hive_style_file_path(
546    part_key: &[String],
547    partition_by: &[(String, DataType)],
548    write_id: &str,
549    file_extension: &str,
550    base_output_path: &ListingTableUrl,
551) -> Path {
552    let mut file_path = base_output_path.prefix().clone();
553    for j in 0..part_key.len() {
554        file_path = file_path.child(format!("{}={}", partition_by[j].0, part_key[j]));
555    }
556
557    file_path.child(format!("{write_id}.{file_extension}"))
558}