Skip to main content

datafusion_datasource/write/
demux.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Module containing helper methods/traits related to enabling
19//! dividing input stream into multiple output files at execution time
20
21use std::borrow::Cow;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25use crate::url::ListingTableUrl;
26use crate::write::FileSinkConfig;
27use datafusion_common::error::Result;
28use datafusion_physical_plan::SendableRecordBatchStream;
29
30use arrow::array::{
31    ArrayAccessor, RecordBatch, StringArray, StructArray, builder::UInt64Builder,
32    cast::AsArray, downcast_dictionary_array,
33};
34use arrow::datatypes::{DataType, Schema};
35use datafusion_common::cast::{
36    as_boolean_array, as_date32_array, as_date64_array, as_float16_array,
37    as_float32_array, as_float64_array, as_int8_array, as_int16_array, as_int32_array,
38    as_int64_array, as_large_string_array, as_string_array, as_string_view_array,
39    as_uint8_array, as_uint16_array, as_uint32_array, as_uint64_array,
40};
41use datafusion_common::{exec_datafusion_err, internal_datafusion_err, not_impl_err};
42use datafusion_common_runtime::SpawnedTask;
43
44use chrono::NaiveDate;
45use datafusion_execution::TaskContext;
46use futures::StreamExt;
47use object_store::path::Path;
48use rand::distr::SampleString;
49use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender};
50
51type RecordBatchReceiver = Receiver<RecordBatch>;
52pub type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>;
53
54/// Splits a single [SendableRecordBatchStream] into a dynamically determined
55/// number of partitions at execution time.
56///
57/// The partitions are determined by factors known only at execution time, such
58/// as total number of rows and partition column values. The demuxer task
59/// communicates to the caller by sending channels over a channel. The inner
60/// channels send RecordBatches which should be contained within the same output
61/// file. The outer channel is used to send a dynamic number of inner channels,
62/// representing a dynamic number of total output files.
63///
64/// The caller is also responsible to monitor the demux task for errors and
65/// abort accordingly.
66///
67/// A path with an extension will force only a single file to
68/// be written with the extension from the path. Otherwise the default extension
69/// will be used and the output will be split into multiple files.
70///
71/// Output file guarantees:
72///  - Partitioned files: Files are created only for non-empty partitions.
73///  - Single-file output: 1 file is always written, even when the stream is empty.
74///  - Multi-file output: Depending on the number of record batches, 0 or more files are written.
75///
76/// Examples of `base_output_path`
77///  * `tmp/dataset/` -> is a folder since it ends in `/`
78///  * `tmp/dataset` -> is still a folder since it does not end in `/` but has no valid file extension
79///  * `tmp/file.parquet` -> is a file since it does not end in `/` and has a valid file extension `.parquet`
80///  * `tmp/file.parquet/` -> is a folder since it ends in `/`
81///
82/// The `partition_by` parameter will additionally split the input based on the
83/// unique values of a specific column, see
84/// <https://github.com/apache/datafusion/issues/7744>
85///
86/// ```text
87///                                                                              ┌───────────┐               ┌────────────┐    ┌─────────────┐
88///                                                                     ┌──────▶ │  batch 1  ├────▶...──────▶│   Batch a  │    │ Output File1│
89///                                                                     │        └───────────┘               └────────────┘    └─────────────┘
90///                                                                     │
91///                                                 ┌──────────┐        │        ┌───────────┐               ┌────────────┐    ┌─────────────┐
92/// ┌───────────┐               ┌────────────┐      │          │        ├──────▶ │  batch a+1├────▶...──────▶│   Batch b  │    │ Output File2│
93/// │  batch 1  ├────▶...──────▶│   Batch N  ├─────▶│  Demux   ├────────┤ ...    └───────────┘               └────────────┘    └─────────────┘
94/// └───────────┘               └────────────┘      │          │        │
95///                                                 └──────────┘        │        ┌───────────┐               ┌────────────┐    ┌─────────────┐
96///                                                                     └──────▶ │  batch d  ├────▶...──────▶│   Batch n  │    │ Output FileN│
97///                                                                              └───────────┘               └────────────┘    └─────────────┘
98/// ```
99pub(crate) fn start_demuxer_task(
100    config: &FileSinkConfig,
101    data: SendableRecordBatchStream,
102    context: &Arc<TaskContext>,
103) -> (SpawnedTask<Result<()>>, DemuxedStreamReceiver) {
104    let (tx, rx) = mpsc::unbounded_channel();
105    let context = Arc::clone(context);
106    let file_extension = config.file_extension.clone();
107    let base_output_path = config.table_paths[0].clone();
108    let task = if config.table_partition_cols.is_empty() {
109        let single_file_output = config
110            .file_output_mode
111            .single_file_output(&base_output_path);
112        SpawnedTask::spawn(async move {
113            row_count_demuxer(
114                tx,
115                data,
116                context,
117                base_output_path,
118                file_extension,
119                single_file_output,
120            )
121            .await
122        })
123    } else {
124        // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot
125        // bound this channel without risking a deadlock.
126        let partition_by = config.table_partition_cols.clone();
127        let keep_partition_by_columns = config.keep_partition_by_columns;
128        SpawnedTask::spawn(async move {
129            hive_style_partitions_demuxer(
130                tx,
131                data,
132                context,
133                partition_by,
134                base_output_path,
135                file_extension,
136                keep_partition_by_columns,
137            )
138            .await
139        })
140    };
141
142    (task, rx)
143}
144
145/// Dynamically partitions input stream to achieve desired maximum rows per file
146async fn row_count_demuxer(
147    mut tx: UnboundedSender<(Path, Receiver<RecordBatch>)>,
148    mut input: SendableRecordBatchStream,
149    context: Arc<TaskContext>,
150    base_output_path: ListingTableUrl,
151    file_extension: String,
152    single_file_output: bool,
153) -> Result<()> {
154    let exec_options = &context.session_config().options().execution;
155
156    let max_rows_per_file = exec_options.soft_max_rows_per_output_file;
157    let max_buffered_batches = exec_options.max_buffered_batches_per_output_file;
158    let minimum_parallel_files = exec_options.minimum_parallel_output_files;
159    let mut part_idx = 0;
160    let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16);
161
162    let mut open_file_streams = Vec::with_capacity(minimum_parallel_files);
163
164    let mut next_send_steam = 0;
165    let mut row_counts = Vec::with_capacity(minimum_parallel_files);
166
167    // Overrides if single_file_output is set
168    let minimum_parallel_files = if single_file_output {
169        1
170    } else {
171        minimum_parallel_files
172    };
173
174    let max_rows_per_file = if single_file_output {
175        usize::MAX
176    } else {
177        max_rows_per_file
178    };
179
180    if single_file_output {
181        // ensure we have one file open, even when the input stream is empty
182        open_file_streams.push(create_new_file_stream(
183            &base_output_path,
184            &write_id,
185            part_idx,
186            &file_extension,
187            single_file_output,
188            max_buffered_batches,
189            &mut tx,
190        )?);
191        row_counts.push(0);
192        part_idx += 1;
193    }
194
195    let schema = input.schema();
196    let mut is_batch_received = false;
197
198    while let Some(rb) = input.next().await.transpose()? {
199        is_batch_received = true;
200        // ensure we have at least minimum_parallel_files open
201        if open_file_streams.len() < minimum_parallel_files {
202            open_file_streams.push(create_new_file_stream(
203                &base_output_path,
204                &write_id,
205                part_idx,
206                &file_extension,
207                single_file_output,
208                max_buffered_batches,
209                &mut tx,
210            )?);
211            row_counts.push(0);
212            part_idx += 1;
213        } else if row_counts[next_send_steam] >= max_rows_per_file {
214            row_counts[next_send_steam] = 0;
215            open_file_streams[next_send_steam] = create_new_file_stream(
216                &base_output_path,
217                &write_id,
218                part_idx,
219                &file_extension,
220                single_file_output,
221                max_buffered_batches,
222                &mut tx,
223            )?;
224            part_idx += 1;
225        }
226        row_counts[next_send_steam] += rb.num_rows();
227        open_file_streams[next_send_steam]
228            .send(rb)
229            .await
230            .map_err(|_| {
231                exec_datafusion_err!("Error sending RecordBatch to file stream!")
232            })?;
233
234        next_send_steam = (next_send_steam + 1) % minimum_parallel_files;
235    }
236
237    // if there is no batch send but with a single file, send an empty batch
238    if single_file_output && !is_batch_received {
239        open_file_streams
240            .first_mut()
241            .ok_or_else(|| internal_datafusion_err!("Expected a single output file"))?
242            .send(RecordBatch::new_empty(schema))
243            .await
244            .map_err(|_| {
245                exec_datafusion_err!("Error sending empty RecordBatch to file stream!")
246            })?;
247    }
248
249    Ok(())
250}
251
252/// Helper for row count demuxer
253fn generate_file_path(
254    base_output_path: &ListingTableUrl,
255    write_id: &str,
256    part_idx: usize,
257    file_extension: &str,
258    single_file_output: bool,
259) -> Path {
260    if !single_file_output {
261        base_output_path
262            .prefix()
263            .clone()
264            .join(format!("{write_id}_{part_idx}.{file_extension}"))
265    } else {
266        base_output_path.prefix().to_owned()
267    }
268}
269
270/// Helper for row count demuxer
271fn create_new_file_stream(
272    base_output_path: &ListingTableUrl,
273    write_id: &str,
274    part_idx: usize,
275    file_extension: &str,
276    single_file_output: bool,
277    max_buffered_batches: usize,
278    tx: &mut UnboundedSender<(Path, Receiver<RecordBatch>)>,
279) -> Result<Sender<RecordBatch>> {
280    let file_path = generate_file_path(
281        base_output_path,
282        write_id,
283        part_idx,
284        file_extension,
285        single_file_output,
286    );
287    let (tx_file, rx_file) = mpsc::channel(max_buffered_batches / 2);
288    tx.send((file_path, rx_file))
289        .map_err(|_| exec_datafusion_err!("Error sending RecordBatch to file stream!"))?;
290    Ok(tx_file)
291}
292
293/// Splits an input stream based on the distinct values of a set of columns
294/// Assumes standard hive style partition paths such as
295/// /col1=val1/col2=val2/outputfile.parquet
296async fn hive_style_partitions_demuxer(
297    tx: UnboundedSender<(Path, Receiver<RecordBatch>)>,
298    mut input: SendableRecordBatchStream,
299    context: Arc<TaskContext>,
300    partition_by: Vec<(String, DataType)>,
301    base_output_path: ListingTableUrl,
302    file_extension: String,
303    keep_partition_by_columns: bool,
304) -> Result<()> {
305    let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16);
306
307    let exec_options = &context.session_config().options().execution;
308    let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file;
309
310    // To support non string partition col types, cast the type to &str first
311    let mut value_map: HashMap<Vec<String>, Sender<RecordBatch>> = HashMap::new();
312
313    while let Some(rb) = input.next().await.transpose()? {
314        // First compute partition key for each row of batch, e.g. (col1=val1, col2=val2, ...)
315        let all_partition_values = compute_partition_keys_by_row(&rb, &partition_by)?;
316
317        // Next compute how the batch should be split up to take each distinct key to its own batch
318        let take_map = compute_take_arrays(&rb, &all_partition_values);
319
320        // Divide up the batch into distinct partition key batches and send each batch
321        for (part_key, mut builder) in take_map.into_iter() {
322            // Take method adapted from https://github.com/lancedb/lance/pull/1337/files
323            // TODO: upstream RecordBatch::take to arrow-rs
324            let take_indices = builder.finish();
325            let struct_array: StructArray = rb.clone().into();
326            let parted_batch = RecordBatch::from(
327                arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(),
328            );
329
330            // Get or create channel for this batch
331            let part_tx = match value_map.get_mut(&part_key) {
332                Some(part_tx) => part_tx,
333                None => {
334                    // Create channel for previously unseen distinct partition key and notify consumer of new file
335                    let (part_tx, part_rx) =
336                        mpsc::channel::<RecordBatch>(max_buffered_recordbatches);
337                    let file_path = compute_hive_style_file_path(
338                        &part_key,
339                        &partition_by,
340                        &write_id,
341                        &file_extension,
342                        &base_output_path,
343                    );
344
345                    tx.send((file_path, part_rx)).map_err(|_| {
346                        exec_datafusion_err!("Error sending new file stream!")
347                    })?;
348
349                    value_map.insert(part_key.clone(), part_tx);
350                    value_map.get_mut(&part_key).ok_or_else(|| {
351                        exec_datafusion_err!("Key must exist since it was just inserted!")
352                    })?
353                }
354            };
355
356            let final_batch_to_send = if keep_partition_by_columns {
357                parted_batch
358            } else {
359                remove_partition_by_columns(&parted_batch, &partition_by)?
360            };
361
362            // Finally send the partial batch partitioned by distinct value!
363            part_tx.send(final_batch_to_send).await.map_err(|_| {
364                internal_datafusion_err!("Unexpected error sending parted batch!")
365            })?;
366        }
367    }
368
369    Ok(())
370}
371
372fn compute_partition_keys_by_row<'a>(
373    rb: &'a RecordBatch,
374    partition_by: &'a [(String, DataType)],
375) -> Result<Vec<Vec<Cow<'a, str>>>> {
376    let mut all_partition_values = vec![];
377
378    const EPOCH_DAYS_FROM_CE: i32 = 719_163;
379
380    // For the purposes of writing partitioned data, we can rely on schema inference
381    // to determine the type of the partition cols in order to provide a more ergonomic
382    // UI which does not require specifying DataTypes manually. So, we ignore the
383    // DataType within the partition_by array and infer the correct type from the
384    // batch schema instead.
385    let schema = rb.schema();
386    for (col, _) in partition_by.iter() {
387        let mut partition_values = vec![];
388
389        let dtype = schema.field_with_name(col)?.data_type();
390        let col_array = rb.column_by_name(col).ok_or(exec_datafusion_err!(
391            "PartitionBy Column {} does not exist in source data! Got schema {schema}.",
392            col
393        ))?;
394
395        match dtype {
396            DataType::Utf8 => {
397                let array = as_string_array(col_array)?;
398                for i in 0..rb.num_rows() {
399                    partition_values.push(Cow::from(array.value(i)));
400                }
401            }
402            DataType::LargeUtf8 => {
403                let array = as_large_string_array(col_array)?;
404                for i in 0..rb.num_rows() {
405                    partition_values.push(Cow::from(array.value(i)));
406                }
407            }
408            DataType::Utf8View => {
409                let array = as_string_view_array(col_array)?;
410                for i in 0..rb.num_rows() {
411                    partition_values.push(Cow::from(array.value(i)));
412                }
413            }
414            DataType::Boolean => {
415                let array = as_boolean_array(col_array)?;
416                for i in 0..rb.num_rows() {
417                    partition_values.push(Cow::from(array.value(i).to_string()));
418                }
419            }
420            DataType::Date32 => {
421                let array = as_date32_array(col_array)?;
422                // ISO-8601/RFC3339 format - yyyy-mm-dd
423                let format = "%Y-%m-%d";
424                for i in 0..rb.num_rows() {
425                    let date = NaiveDate::from_num_days_from_ce_opt(
426                        EPOCH_DAYS_FROM_CE + array.value(i),
427                    )
428                    .unwrap()
429                    .format(format)
430                    .to_string();
431                    partition_values.push(Cow::from(date));
432                }
433            }
434            DataType::Date64 => {
435                let array = as_date64_array(col_array)?;
436                // ISO-8601/RFC3339 format - yyyy-mm-dd
437                let format = "%Y-%m-%d";
438                for i in 0..rb.num_rows() {
439                    let date = NaiveDate::from_num_days_from_ce_opt(
440                        EPOCH_DAYS_FROM_CE + (array.value(i) / 86_400_000) as i32,
441                    )
442                    .unwrap()
443                    .format(format)
444                    .to_string();
445                    partition_values.push(Cow::from(date));
446                }
447            }
448            DataType::Int8 => {
449                let array = as_int8_array(col_array)?;
450                for i in 0..rb.num_rows() {
451                    partition_values.push(Cow::from(array.value(i).to_string()));
452                }
453            }
454            DataType::Int16 => {
455                let array = as_int16_array(col_array)?;
456                for i in 0..rb.num_rows() {
457                    partition_values.push(Cow::from(array.value(i).to_string()));
458                }
459            }
460            DataType::Int32 => {
461                let array = as_int32_array(col_array)?;
462                for i in 0..rb.num_rows() {
463                    partition_values.push(Cow::from(array.value(i).to_string()));
464                }
465            }
466            DataType::Int64 => {
467                let array = as_int64_array(col_array)?;
468                for i in 0..rb.num_rows() {
469                    partition_values.push(Cow::from(array.value(i).to_string()));
470                }
471            }
472            DataType::UInt8 => {
473                let array = as_uint8_array(col_array)?;
474                for i in 0..rb.num_rows() {
475                    partition_values.push(Cow::from(array.value(i).to_string()));
476                }
477            }
478            DataType::UInt16 => {
479                let array = as_uint16_array(col_array)?;
480                for i in 0..rb.num_rows() {
481                    partition_values.push(Cow::from(array.value(i).to_string()));
482                }
483            }
484            DataType::UInt32 => {
485                let array = as_uint32_array(col_array)?;
486                for i in 0..rb.num_rows() {
487                    partition_values.push(Cow::from(array.value(i).to_string()));
488                }
489            }
490            DataType::UInt64 => {
491                let array = as_uint64_array(col_array)?;
492                for i in 0..rb.num_rows() {
493                    partition_values.push(Cow::from(array.value(i).to_string()));
494                }
495            }
496            DataType::Float16 => {
497                let array = as_float16_array(col_array)?;
498                for i in 0..rb.num_rows() {
499                    partition_values.push(Cow::from(array.value(i).to_string()));
500                }
501            }
502            DataType::Float32 => {
503                let array = as_float32_array(col_array)?;
504                for i in 0..rb.num_rows() {
505                    partition_values.push(Cow::from(array.value(i).to_string()));
506                }
507            }
508            DataType::Float64 => {
509                let array = as_float64_array(col_array)?;
510                for i in 0..rb.num_rows() {
511                    partition_values.push(Cow::from(array.value(i).to_string()));
512                }
513            }
514            DataType::Dictionary(_, _) => {
515                downcast_dictionary_array!(
516                    col_array =>  {
517                        let array = col_array.downcast_dict::<StringArray>()
518                            .ok_or(exec_datafusion_err!("it is not yet supported to write to hive partitions with datatype {}",
519                            dtype))?;
520
521                        for i in 0..rb.num_rows() {
522                            partition_values.push(Cow::from(array.value(i)));
523                        }
524                    },
525                    _ => unreachable!(),
526                )
527            }
528            _ => {
529                return not_impl_err!(
530                    "it is not yet supported to write to hive partitions with datatype {}",
531                    dtype
532                );
533            }
534        }
535
536        all_partition_values.push(partition_values);
537    }
538
539    Ok(all_partition_values)
540}
541
542fn compute_take_arrays(
543    rb: &RecordBatch,
544    all_partition_values: &[Vec<Cow<str>>],
545) -> HashMap<Vec<String>, UInt64Builder> {
546    let mut take_map = HashMap::new();
547    for i in 0..rb.num_rows() {
548        let mut part_key = vec![];
549        for vals in all_partition_values.iter() {
550            part_key.push(vals[i].clone().into());
551        }
552        let builder = take_map.entry(part_key).or_insert_with(UInt64Builder::new);
553        builder.append_value(i as u64);
554    }
555    take_map
556}
557
558fn remove_partition_by_columns(
559    parted_batch: &RecordBatch,
560    partition_by: &[(String, DataType)],
561) -> Result<RecordBatch> {
562    let partition_names: Vec<_> = partition_by.iter().map(|(s, _)| s).collect();
563    let (non_part_cols, non_part_fields): (Vec<_>, Vec<_>) = parted_batch
564        .columns()
565        .iter()
566        .zip(parted_batch.schema().fields())
567        .filter_map(|(a, f)| {
568            if !partition_names.contains(&f.name()) {
569                Some((Arc::clone(a), (**f).clone()))
570            } else {
571                None
572            }
573        })
574        .unzip();
575
576    let non_part_schema = Schema::new(non_part_fields);
577    let final_batch_to_send =
578        RecordBatch::try_new(Arc::new(non_part_schema), non_part_cols)?;
579
580    Ok(final_batch_to_send)
581}
582
583fn compute_hive_style_file_path(
584    part_key: &[String],
585    partition_by: &[(String, DataType)],
586    write_id: &str,
587    file_extension: &str,
588    base_output_path: &ListingTableUrl,
589) -> Path {
590    let mut file_path = base_output_path.prefix().clone();
591    for j in 0..part_key.len() {
592        file_path = file_path.join(format!("{}={}", partition_by[j].0, part_key[j]));
593    }
594
595    file_path.join(format!("{write_id}.{file_extension}"))
596}