datafusion_datasource/write/
demux.rs1use std::borrow::Cow;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25use crate::url::ListingTableUrl;
26use crate::write::FileSinkConfig;
27use datafusion_common::error::Result;
28use datafusion_physical_plan::SendableRecordBatchStream;
29
30use arrow::array::{
31 builder::UInt64Builder, cast::AsArray, downcast_dictionary_array, ArrayAccessor,
32 RecordBatch, StringArray, StructArray,
33};
34use arrow::datatypes::{DataType, Schema};
35use datafusion_common::cast::{
36 as_boolean_array, as_date32_array, as_date64_array, as_float16_array,
37 as_float32_array, as_float64_array, as_int16_array, as_int32_array, as_int64_array,
38 as_int8_array, as_string_array, as_string_view_array, as_uint16_array,
39 as_uint32_array, as_uint64_array, as_uint8_array,
40};
41use datafusion_common::{exec_datafusion_err, not_impl_err, DataFusionError};
42use datafusion_common_runtime::SpawnedTask;
43use datafusion_execution::TaskContext;
44
45use chrono::NaiveDate;
46use futures::StreamExt;
47use object_store::path::Path;
48use rand::distr::SampleString;
49use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender};
50
51type RecordBatchReceiver = Receiver<RecordBatch>;
52pub type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>;
53
54pub(crate) fn start_demuxer_task(
95 config: &FileSinkConfig,
96 data: SendableRecordBatchStream,
97 context: &Arc<TaskContext>,
98) -> (SpawnedTask<Result<()>>, DemuxedStreamReceiver) {
99 let (tx, rx) = mpsc::unbounded_channel();
100 let context = Arc::clone(context);
101 let file_extension = config.file_extension.clone();
102 let base_output_path = config.table_paths[0].clone();
103 let task = if config.table_partition_cols.is_empty() {
104 let single_file_output = !base_output_path.is_collection()
105 && base_output_path.file_extension().is_some();
106 SpawnedTask::spawn(async move {
107 row_count_demuxer(
108 tx,
109 data,
110 context,
111 base_output_path,
112 file_extension,
113 single_file_output,
114 )
115 .await
116 })
117 } else {
118 let partition_by = config.table_partition_cols.clone();
121 let keep_partition_by_columns = config.keep_partition_by_columns;
122 SpawnedTask::spawn(async move {
123 hive_style_partitions_demuxer(
124 tx,
125 data,
126 context,
127 partition_by,
128 base_output_path,
129 file_extension,
130 keep_partition_by_columns,
131 )
132 .await
133 })
134 };
135
136 (task, rx)
137}
138
139async fn row_count_demuxer(
141 mut tx: UnboundedSender<(Path, Receiver<RecordBatch>)>,
142 mut input: SendableRecordBatchStream,
143 context: Arc<TaskContext>,
144 base_output_path: ListingTableUrl,
145 file_extension: String,
146 single_file_output: bool,
147) -> Result<()> {
148 let exec_options = &context.session_config().options().execution;
149
150 let max_rows_per_file = exec_options.soft_max_rows_per_output_file;
151 let max_buffered_batches = exec_options.max_buffered_batches_per_output_file;
152 let minimum_parallel_files = exec_options.minimum_parallel_output_files;
153 let mut part_idx = 0;
154 let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16);
155
156 let mut open_file_streams = Vec::with_capacity(minimum_parallel_files);
157
158 let mut next_send_steam = 0;
159 let mut row_counts = Vec::with_capacity(minimum_parallel_files);
160
161 let minimum_parallel_files = if single_file_output {
163 1
164 } else {
165 minimum_parallel_files
166 };
167
168 let max_rows_per_file = if single_file_output {
169 usize::MAX
170 } else {
171 max_rows_per_file
172 };
173
174 while let Some(rb) = input.next().await.transpose()? {
175 if open_file_streams.len() < minimum_parallel_files {
177 open_file_streams.push(create_new_file_stream(
178 &base_output_path,
179 &write_id,
180 part_idx,
181 &file_extension,
182 single_file_output,
183 max_buffered_batches,
184 &mut tx,
185 )?);
186 row_counts.push(0);
187 part_idx += 1;
188 } else if row_counts[next_send_steam] >= max_rows_per_file {
189 row_counts[next_send_steam] = 0;
190 open_file_streams[next_send_steam] = create_new_file_stream(
191 &base_output_path,
192 &write_id,
193 part_idx,
194 &file_extension,
195 single_file_output,
196 max_buffered_batches,
197 &mut tx,
198 )?;
199 part_idx += 1;
200 }
201 row_counts[next_send_steam] += rb.num_rows();
202 open_file_streams[next_send_steam]
203 .send(rb)
204 .await
205 .map_err(|_| {
206 DataFusionError::Execution(
207 "Error sending RecordBatch to file stream!".into(),
208 )
209 })?;
210
211 next_send_steam = (next_send_steam + 1) % minimum_parallel_files;
212 }
213 Ok(())
214}
215
216fn generate_file_path(
218 base_output_path: &ListingTableUrl,
219 write_id: &str,
220 part_idx: usize,
221 file_extension: &str,
222 single_file_output: bool,
223) -> Path {
224 if !single_file_output {
225 base_output_path
226 .prefix()
227 .child(format!("{write_id}_{part_idx}.{file_extension}"))
228 } else {
229 base_output_path.prefix().to_owned()
230 }
231}
232
233fn create_new_file_stream(
235 base_output_path: &ListingTableUrl,
236 write_id: &str,
237 part_idx: usize,
238 file_extension: &str,
239 single_file_output: bool,
240 max_buffered_batches: usize,
241 tx: &mut UnboundedSender<(Path, Receiver<RecordBatch>)>,
242) -> Result<Sender<RecordBatch>> {
243 let file_path = generate_file_path(
244 base_output_path,
245 write_id,
246 part_idx,
247 file_extension,
248 single_file_output,
249 );
250 let (tx_file, rx_file) = mpsc::channel(max_buffered_batches / 2);
251 tx.send((file_path, rx_file)).map_err(|_| {
252 DataFusionError::Execution("Error sending RecordBatch to file stream!".into())
253 })?;
254 Ok(tx_file)
255}
256
257async fn hive_style_partitions_demuxer(
261 tx: UnboundedSender<(Path, Receiver<RecordBatch>)>,
262 mut input: SendableRecordBatchStream,
263 context: Arc<TaskContext>,
264 partition_by: Vec<(String, DataType)>,
265 base_output_path: ListingTableUrl,
266 file_extension: String,
267 keep_partition_by_columns: bool,
268) -> Result<()> {
269 let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16);
270
271 let exec_options = &context.session_config().options().execution;
272 let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file;
273
274 let mut value_map: HashMap<Vec<String>, Sender<RecordBatch>> = HashMap::new();
276
277 while let Some(rb) = input.next().await.transpose()? {
278 let all_partition_values = compute_partition_keys_by_row(&rb, &partition_by)?;
280
281 let take_map = compute_take_arrays(&rb, all_partition_values);
283
284 for (part_key, mut builder) in take_map.into_iter() {
286 let take_indices = builder.finish();
289 let struct_array: StructArray = rb.clone().into();
290 let parted_batch = RecordBatch::from(
291 arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(),
292 );
293
294 let part_tx = match value_map.get_mut(&part_key) {
296 Some(part_tx) => part_tx,
297 None => {
298 let (part_tx, part_rx) =
300 mpsc::channel::<RecordBatch>(max_buffered_recordbatches);
301 let file_path = compute_hive_style_file_path(
302 &part_key,
303 &partition_by,
304 &write_id,
305 &file_extension,
306 &base_output_path,
307 );
308
309 tx.send((file_path, part_rx)).map_err(|_| {
310 DataFusionError::Execution(
311 "Error sending new file stream!".into(),
312 )
313 })?;
314
315 value_map.insert(part_key.clone(), part_tx);
316 value_map
317 .get_mut(&part_key)
318 .ok_or(DataFusionError::Internal(
319 "Key must exist since it was just inserted!".into(),
320 ))?
321 }
322 };
323
324 let final_batch_to_send = if keep_partition_by_columns {
325 parted_batch
326 } else {
327 remove_partition_by_columns(&parted_batch, &partition_by)?
328 };
329
330 part_tx.send(final_batch_to_send).await.map_err(|_| {
332 DataFusionError::Internal("Unexpected error sending parted batch!".into())
333 })?;
334 }
335 }
336
337 Ok(())
338}
339
340fn compute_partition_keys_by_row<'a>(
341 rb: &'a RecordBatch,
342 partition_by: &'a [(String, DataType)],
343) -> Result<Vec<Vec<Cow<'a, str>>>> {
344 let mut all_partition_values = vec![];
345
346 const EPOCH_DAYS_FROM_CE: i32 = 719_163;
347
348 let schema = rb.schema();
354 for (col, _) in partition_by.iter() {
355 let mut partition_values = vec![];
356
357 let dtype = schema.field_with_name(col)?.data_type();
358 let col_array = rb.column_by_name(col).ok_or(exec_datafusion_err!(
359 "PartitionBy Column {} does not exist in source data! Got schema {schema}.",
360 col
361 ))?;
362
363 match dtype {
364 DataType::Utf8 => {
365 let array = as_string_array(col_array)?;
366 for i in 0..rb.num_rows() {
367 partition_values.push(Cow::from(array.value(i)));
368 }
369 }
370 DataType::Utf8View => {
371 let array = as_string_view_array(col_array)?;
372 for i in 0..rb.num_rows() {
373 partition_values.push(Cow::from(array.value(i)));
374 }
375 }
376 DataType::Boolean => {
377 let array = as_boolean_array(col_array)?;
378 for i in 0..rb.num_rows() {
379 partition_values.push(Cow::from(array.value(i).to_string()));
380 }
381 }
382 DataType::Date32 => {
383 let array = as_date32_array(col_array)?;
384 let format = "%Y-%m-%d";
386 for i in 0..rb.num_rows() {
387 let date = NaiveDate::from_num_days_from_ce_opt(
388 EPOCH_DAYS_FROM_CE + array.value(i),
389 )
390 .unwrap()
391 .format(format)
392 .to_string();
393 partition_values.push(Cow::from(date));
394 }
395 }
396 DataType::Date64 => {
397 let array = as_date64_array(col_array)?;
398 let format = "%Y-%m-%d";
400 for i in 0..rb.num_rows() {
401 let date = NaiveDate::from_num_days_from_ce_opt(
402 EPOCH_DAYS_FROM_CE + (array.value(i) / 86_400_000) as i32,
403 )
404 .unwrap()
405 .format(format)
406 .to_string();
407 partition_values.push(Cow::from(date));
408 }
409 }
410 DataType::Int8 => {
411 let array = as_int8_array(col_array)?;
412 for i in 0..rb.num_rows() {
413 partition_values.push(Cow::from(array.value(i).to_string()));
414 }
415 }
416 DataType::Int16 => {
417 let array = as_int16_array(col_array)?;
418 for i in 0..rb.num_rows() {
419 partition_values.push(Cow::from(array.value(i).to_string()));
420 }
421 }
422 DataType::Int32 => {
423 let array = as_int32_array(col_array)?;
424 for i in 0..rb.num_rows() {
425 partition_values.push(Cow::from(array.value(i).to_string()));
426 }
427 }
428 DataType::Int64 => {
429 let array = as_int64_array(col_array)?;
430 for i in 0..rb.num_rows() {
431 partition_values.push(Cow::from(array.value(i).to_string()));
432 }
433 }
434 DataType::UInt8 => {
435 let array = as_uint8_array(col_array)?;
436 for i in 0..rb.num_rows() {
437 partition_values.push(Cow::from(array.value(i).to_string()));
438 }
439 }
440 DataType::UInt16 => {
441 let array = as_uint16_array(col_array)?;
442 for i in 0..rb.num_rows() {
443 partition_values.push(Cow::from(array.value(i).to_string()));
444 }
445 }
446 DataType::UInt32 => {
447 let array = as_uint32_array(col_array)?;
448 for i in 0..rb.num_rows() {
449 partition_values.push(Cow::from(array.value(i).to_string()));
450 }
451 }
452 DataType::UInt64 => {
453 let array = as_uint64_array(col_array)?;
454 for i in 0..rb.num_rows() {
455 partition_values.push(Cow::from(array.value(i).to_string()));
456 }
457 }
458 DataType::Float16 => {
459 let array = as_float16_array(col_array)?;
460 for i in 0..rb.num_rows() {
461 partition_values.push(Cow::from(array.value(i).to_string()));
462 }
463 }
464 DataType::Float32 => {
465 let array = as_float32_array(col_array)?;
466 for i in 0..rb.num_rows() {
467 partition_values.push(Cow::from(array.value(i).to_string()));
468 }
469 }
470 DataType::Float64 => {
471 let array = as_float64_array(col_array)?;
472 for i in 0..rb.num_rows() {
473 partition_values.push(Cow::from(array.value(i).to_string()));
474 }
475 }
476 DataType::Dictionary(_, _) => {
477 downcast_dictionary_array!(
478 col_array => {
479 let array = col_array.downcast_dict::<StringArray>()
480 .ok_or(exec_datafusion_err!("it is not yet supported to write to hive partitions with datatype {}",
481 dtype))?;
482
483 for i in 0..rb.num_rows() {
484 partition_values.push(Cow::from(array.value(i)));
485 }
486 },
487 _ => unreachable!(),
488 )
489 }
490 _ => {
491 return not_impl_err!(
492 "it is not yet supported to write to hive partitions with datatype {}",
493 dtype
494 )
495 }
496 }
497
498 all_partition_values.push(partition_values);
499 }
500
501 Ok(all_partition_values)
502}
503
504fn compute_take_arrays(
505 rb: &RecordBatch,
506 all_partition_values: Vec<Vec<Cow<str>>>,
507) -> HashMap<Vec<String>, UInt64Builder> {
508 let mut take_map = HashMap::new();
509 for i in 0..rb.num_rows() {
510 let mut part_key = vec![];
511 for vals in all_partition_values.iter() {
512 part_key.push(vals[i].clone().into());
513 }
514 let builder = take_map.entry(part_key).or_insert_with(UInt64Builder::new);
515 builder.append_value(i as u64);
516 }
517 take_map
518}
519
520fn remove_partition_by_columns(
521 parted_batch: &RecordBatch,
522 partition_by: &[(String, DataType)],
523) -> Result<RecordBatch> {
524 let partition_names: Vec<_> = partition_by.iter().map(|(s, _)| s).collect();
525 let (non_part_cols, non_part_fields): (Vec<_>, Vec<_>) = parted_batch
526 .columns()
527 .iter()
528 .zip(parted_batch.schema().fields())
529 .filter_map(|(a, f)| {
530 if !partition_names.contains(&f.name()) {
531 Some((Arc::clone(a), (**f).clone()))
532 } else {
533 None
534 }
535 })
536 .unzip();
537
538 let non_part_schema = Schema::new(non_part_fields);
539 let final_batch_to_send =
540 RecordBatch::try_new(Arc::new(non_part_schema), non_part_cols)?;
541
542 Ok(final_batch_to_send)
543}
544
545fn compute_hive_style_file_path(
546 part_key: &[String],
547 partition_by: &[(String, DataType)],
548 write_id: &str,
549 file_extension: &str,
550 base_output_path: &ListingTableUrl,
551) -> Path {
552 let mut file_path = base_output_path.prefix().clone();
553 for j in 0..part_key.len() {
554 file_path = file_path.child(format!("{}={}", partition_by[j].0, part_key[j]));
555 }
556
557 file_path.child(format!("{write_id}.{file_extension}"))
558}