datafusion_datasource/write/
demux.rs1use std::borrow::Cow;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25use crate::url::ListingTableUrl;
26use crate::write::FileSinkConfig;
27use datafusion_common::error::Result;
28use datafusion_physical_plan::SendableRecordBatchStream;
29
30use arrow::array::{
31 ArrayAccessor, RecordBatch, StringArray, StructArray, builder::UInt64Builder,
32 cast::AsArray, downcast_dictionary_array,
33};
34use arrow::datatypes::{DataType, Schema};
35use datafusion_common::cast::{
36 as_boolean_array, as_date32_array, as_date64_array, as_float16_array,
37 as_float32_array, as_float64_array, as_int8_array, as_int16_array, as_int32_array,
38 as_int64_array, as_large_string_array, as_string_array, as_string_view_array,
39 as_uint8_array, as_uint16_array, as_uint32_array, as_uint64_array,
40};
41use datafusion_common::{exec_datafusion_err, internal_datafusion_err, not_impl_err};
42use datafusion_common_runtime::SpawnedTask;
43
44use chrono::NaiveDate;
45use datafusion_execution::TaskContext;
46use futures::StreamExt;
47use object_store::path::Path;
48use rand::distr::SampleString;
49use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender};
50
51type RecordBatchReceiver = Receiver<RecordBatch>;
52pub type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>;
53
54pub(crate) fn start_demuxer_task(
100 config: &FileSinkConfig,
101 data: SendableRecordBatchStream,
102 context: &Arc<TaskContext>,
103) -> (SpawnedTask<Result<()>>, DemuxedStreamReceiver) {
104 let (tx, rx) = mpsc::unbounded_channel();
105 let context = Arc::clone(context);
106 let file_extension = config.file_extension.clone();
107 let base_output_path = config.table_paths[0].clone();
108 let task = if config.table_partition_cols.is_empty() {
109 let single_file_output = config
110 .file_output_mode
111 .single_file_output(&base_output_path);
112 SpawnedTask::spawn(async move {
113 row_count_demuxer(
114 tx,
115 data,
116 context,
117 base_output_path,
118 file_extension,
119 single_file_output,
120 )
121 .await
122 })
123 } else {
124 let partition_by = config.table_partition_cols.clone();
127 let keep_partition_by_columns = config.keep_partition_by_columns;
128 SpawnedTask::spawn(async move {
129 hive_style_partitions_demuxer(
130 tx,
131 data,
132 context,
133 partition_by,
134 base_output_path,
135 file_extension,
136 keep_partition_by_columns,
137 )
138 .await
139 })
140 };
141
142 (task, rx)
143}
144
145async fn row_count_demuxer(
147 mut tx: UnboundedSender<(Path, Receiver<RecordBatch>)>,
148 mut input: SendableRecordBatchStream,
149 context: Arc<TaskContext>,
150 base_output_path: ListingTableUrl,
151 file_extension: String,
152 single_file_output: bool,
153) -> Result<()> {
154 let exec_options = &context.session_config().options().execution;
155
156 let max_rows_per_file = exec_options.soft_max_rows_per_output_file;
157 let max_buffered_batches = exec_options.max_buffered_batches_per_output_file;
158 let minimum_parallel_files = exec_options.minimum_parallel_output_files;
159 let mut part_idx = 0;
160 let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16);
161
162 let mut open_file_streams = Vec::with_capacity(minimum_parallel_files);
163
164 let mut next_send_steam = 0;
165 let mut row_counts = Vec::with_capacity(minimum_parallel_files);
166
167 let minimum_parallel_files = if single_file_output {
169 1
170 } else {
171 minimum_parallel_files
172 };
173
174 let max_rows_per_file = if single_file_output {
175 usize::MAX
176 } else {
177 max_rows_per_file
178 };
179
180 if single_file_output {
181 open_file_streams.push(create_new_file_stream(
183 &base_output_path,
184 &write_id,
185 part_idx,
186 &file_extension,
187 single_file_output,
188 max_buffered_batches,
189 &mut tx,
190 )?);
191 row_counts.push(0);
192 part_idx += 1;
193 }
194
195 let schema = input.schema();
196 let mut is_batch_received = false;
197
198 while let Some(rb) = input.next().await.transpose()? {
199 is_batch_received = true;
200 if open_file_streams.len() < minimum_parallel_files {
202 open_file_streams.push(create_new_file_stream(
203 &base_output_path,
204 &write_id,
205 part_idx,
206 &file_extension,
207 single_file_output,
208 max_buffered_batches,
209 &mut tx,
210 )?);
211 row_counts.push(0);
212 part_idx += 1;
213 } else if row_counts[next_send_steam] >= max_rows_per_file {
214 row_counts[next_send_steam] = 0;
215 open_file_streams[next_send_steam] = create_new_file_stream(
216 &base_output_path,
217 &write_id,
218 part_idx,
219 &file_extension,
220 single_file_output,
221 max_buffered_batches,
222 &mut tx,
223 )?;
224 part_idx += 1;
225 }
226 row_counts[next_send_steam] += rb.num_rows();
227 open_file_streams[next_send_steam]
228 .send(rb)
229 .await
230 .map_err(|_| {
231 exec_datafusion_err!("Error sending RecordBatch to file stream!")
232 })?;
233
234 next_send_steam = (next_send_steam + 1) % minimum_parallel_files;
235 }
236
237 if single_file_output && !is_batch_received {
239 open_file_streams
240 .first_mut()
241 .ok_or_else(|| internal_datafusion_err!("Expected a single output file"))?
242 .send(RecordBatch::new_empty(schema))
243 .await
244 .map_err(|_| {
245 exec_datafusion_err!("Error sending empty RecordBatch to file stream!")
246 })?;
247 }
248
249 Ok(())
250}
251
252fn generate_file_path(
254 base_output_path: &ListingTableUrl,
255 write_id: &str,
256 part_idx: usize,
257 file_extension: &str,
258 single_file_output: bool,
259) -> Path {
260 if !single_file_output {
261 base_output_path
262 .prefix()
263 .clone()
264 .join(format!("{write_id}_{part_idx}.{file_extension}"))
265 } else {
266 base_output_path.prefix().to_owned()
267 }
268}
269
270fn create_new_file_stream(
272 base_output_path: &ListingTableUrl,
273 write_id: &str,
274 part_idx: usize,
275 file_extension: &str,
276 single_file_output: bool,
277 max_buffered_batches: usize,
278 tx: &mut UnboundedSender<(Path, Receiver<RecordBatch>)>,
279) -> Result<Sender<RecordBatch>> {
280 let file_path = generate_file_path(
281 base_output_path,
282 write_id,
283 part_idx,
284 file_extension,
285 single_file_output,
286 );
287 let (tx_file, rx_file) = mpsc::channel(max_buffered_batches / 2);
288 tx.send((file_path, rx_file))
289 .map_err(|_| exec_datafusion_err!("Error sending RecordBatch to file stream!"))?;
290 Ok(tx_file)
291}
292
293async fn hive_style_partitions_demuxer(
297 tx: UnboundedSender<(Path, Receiver<RecordBatch>)>,
298 mut input: SendableRecordBatchStream,
299 context: Arc<TaskContext>,
300 partition_by: Vec<(String, DataType)>,
301 base_output_path: ListingTableUrl,
302 file_extension: String,
303 keep_partition_by_columns: bool,
304) -> Result<()> {
305 let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16);
306
307 let exec_options = &context.session_config().options().execution;
308 let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file;
309
310 let mut value_map: HashMap<Vec<String>, Sender<RecordBatch>> = HashMap::new();
312
313 while let Some(rb) = input.next().await.transpose()? {
314 let all_partition_values = compute_partition_keys_by_row(&rb, &partition_by)?;
316
317 let take_map = compute_take_arrays(&rb, &all_partition_values);
319
320 for (part_key, mut builder) in take_map.into_iter() {
322 let take_indices = builder.finish();
325 let struct_array: StructArray = rb.clone().into();
326 let parted_batch = RecordBatch::from(
327 arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(),
328 );
329
330 let part_tx = match value_map.get_mut(&part_key) {
332 Some(part_tx) => part_tx,
333 None => {
334 let (part_tx, part_rx) =
336 mpsc::channel::<RecordBatch>(max_buffered_recordbatches);
337 let file_path = compute_hive_style_file_path(
338 &part_key,
339 &partition_by,
340 &write_id,
341 &file_extension,
342 &base_output_path,
343 );
344
345 tx.send((file_path, part_rx)).map_err(|_| {
346 exec_datafusion_err!("Error sending new file stream!")
347 })?;
348
349 value_map.insert(part_key.clone(), part_tx);
350 value_map.get_mut(&part_key).ok_or_else(|| {
351 exec_datafusion_err!("Key must exist since it was just inserted!")
352 })?
353 }
354 };
355
356 let final_batch_to_send = if keep_partition_by_columns {
357 parted_batch
358 } else {
359 remove_partition_by_columns(&parted_batch, &partition_by)?
360 };
361
362 part_tx.send(final_batch_to_send).await.map_err(|_| {
364 internal_datafusion_err!("Unexpected error sending parted batch!")
365 })?;
366 }
367 }
368
369 Ok(())
370}
371
372fn compute_partition_keys_by_row<'a>(
373 rb: &'a RecordBatch,
374 partition_by: &'a [(String, DataType)],
375) -> Result<Vec<Vec<Cow<'a, str>>>> {
376 let mut all_partition_values = vec![];
377
378 const EPOCH_DAYS_FROM_CE: i32 = 719_163;
379
380 let schema = rb.schema();
386 for (col, _) in partition_by.iter() {
387 let mut partition_values = vec![];
388
389 let dtype = schema.field_with_name(col)?.data_type();
390 let col_array = rb.column_by_name(col).ok_or(exec_datafusion_err!(
391 "PartitionBy Column {} does not exist in source data! Got schema {schema}.",
392 col
393 ))?;
394
395 match dtype {
396 DataType::Utf8 => {
397 let array = as_string_array(col_array)?;
398 for i in 0..rb.num_rows() {
399 partition_values.push(Cow::from(array.value(i)));
400 }
401 }
402 DataType::LargeUtf8 => {
403 let array = as_large_string_array(col_array)?;
404 for i in 0..rb.num_rows() {
405 partition_values.push(Cow::from(array.value(i)));
406 }
407 }
408 DataType::Utf8View => {
409 let array = as_string_view_array(col_array)?;
410 for i in 0..rb.num_rows() {
411 partition_values.push(Cow::from(array.value(i)));
412 }
413 }
414 DataType::Boolean => {
415 let array = as_boolean_array(col_array)?;
416 for i in 0..rb.num_rows() {
417 partition_values.push(Cow::from(array.value(i).to_string()));
418 }
419 }
420 DataType::Date32 => {
421 let array = as_date32_array(col_array)?;
422 let format = "%Y-%m-%d";
424 for i in 0..rb.num_rows() {
425 let date = NaiveDate::from_num_days_from_ce_opt(
426 EPOCH_DAYS_FROM_CE + array.value(i),
427 )
428 .unwrap()
429 .format(format)
430 .to_string();
431 partition_values.push(Cow::from(date));
432 }
433 }
434 DataType::Date64 => {
435 let array = as_date64_array(col_array)?;
436 let format = "%Y-%m-%d";
438 for i in 0..rb.num_rows() {
439 let date = NaiveDate::from_num_days_from_ce_opt(
440 EPOCH_DAYS_FROM_CE + (array.value(i) / 86_400_000) as i32,
441 )
442 .unwrap()
443 .format(format)
444 .to_string();
445 partition_values.push(Cow::from(date));
446 }
447 }
448 DataType::Int8 => {
449 let array = as_int8_array(col_array)?;
450 for i in 0..rb.num_rows() {
451 partition_values.push(Cow::from(array.value(i).to_string()));
452 }
453 }
454 DataType::Int16 => {
455 let array = as_int16_array(col_array)?;
456 for i in 0..rb.num_rows() {
457 partition_values.push(Cow::from(array.value(i).to_string()));
458 }
459 }
460 DataType::Int32 => {
461 let array = as_int32_array(col_array)?;
462 for i in 0..rb.num_rows() {
463 partition_values.push(Cow::from(array.value(i).to_string()));
464 }
465 }
466 DataType::Int64 => {
467 let array = as_int64_array(col_array)?;
468 for i in 0..rb.num_rows() {
469 partition_values.push(Cow::from(array.value(i).to_string()));
470 }
471 }
472 DataType::UInt8 => {
473 let array = as_uint8_array(col_array)?;
474 for i in 0..rb.num_rows() {
475 partition_values.push(Cow::from(array.value(i).to_string()));
476 }
477 }
478 DataType::UInt16 => {
479 let array = as_uint16_array(col_array)?;
480 for i in 0..rb.num_rows() {
481 partition_values.push(Cow::from(array.value(i).to_string()));
482 }
483 }
484 DataType::UInt32 => {
485 let array = as_uint32_array(col_array)?;
486 for i in 0..rb.num_rows() {
487 partition_values.push(Cow::from(array.value(i).to_string()));
488 }
489 }
490 DataType::UInt64 => {
491 let array = as_uint64_array(col_array)?;
492 for i in 0..rb.num_rows() {
493 partition_values.push(Cow::from(array.value(i).to_string()));
494 }
495 }
496 DataType::Float16 => {
497 let array = as_float16_array(col_array)?;
498 for i in 0..rb.num_rows() {
499 partition_values.push(Cow::from(array.value(i).to_string()));
500 }
501 }
502 DataType::Float32 => {
503 let array = as_float32_array(col_array)?;
504 for i in 0..rb.num_rows() {
505 partition_values.push(Cow::from(array.value(i).to_string()));
506 }
507 }
508 DataType::Float64 => {
509 let array = as_float64_array(col_array)?;
510 for i in 0..rb.num_rows() {
511 partition_values.push(Cow::from(array.value(i).to_string()));
512 }
513 }
514 DataType::Dictionary(_, _) => {
515 downcast_dictionary_array!(
516 col_array => {
517 let array = col_array.downcast_dict::<StringArray>()
518 .ok_or(exec_datafusion_err!("it is not yet supported to write to hive partitions with datatype {}",
519 dtype))?;
520
521 for i in 0..rb.num_rows() {
522 partition_values.push(Cow::from(array.value(i)));
523 }
524 },
525 _ => unreachable!(),
526 )
527 }
528 _ => {
529 return not_impl_err!(
530 "it is not yet supported to write to hive partitions with datatype {}",
531 dtype
532 );
533 }
534 }
535
536 all_partition_values.push(partition_values);
537 }
538
539 Ok(all_partition_values)
540}
541
542fn compute_take_arrays(
543 rb: &RecordBatch,
544 all_partition_values: &[Vec<Cow<str>>],
545) -> HashMap<Vec<String>, UInt64Builder> {
546 let mut take_map = HashMap::new();
547 for i in 0..rb.num_rows() {
548 let mut part_key = vec![];
549 for vals in all_partition_values.iter() {
550 part_key.push(vals[i].clone().into());
551 }
552 let builder = take_map.entry(part_key).or_insert_with(UInt64Builder::new);
553 builder.append_value(i as u64);
554 }
555 take_map
556}
557
558fn remove_partition_by_columns(
559 parted_batch: &RecordBatch,
560 partition_by: &[(String, DataType)],
561) -> Result<RecordBatch> {
562 let partition_names: Vec<_> = partition_by.iter().map(|(s, _)| s).collect();
563 let (non_part_cols, non_part_fields): (Vec<_>, Vec<_>) = parted_batch
564 .columns()
565 .iter()
566 .zip(parted_batch.schema().fields())
567 .filter_map(|(a, f)| {
568 if !partition_names.contains(&f.name()) {
569 Some((Arc::clone(a), (**f).clone()))
570 } else {
571 None
572 }
573 })
574 .unzip();
575
576 let non_part_schema = Schema::new(non_part_fields);
577 let final_batch_to_send =
578 RecordBatch::try_new(Arc::new(non_part_schema), non_part_cols)?;
579
580 Ok(final_batch_to_send)
581}
582
583fn compute_hive_style_file_path(
584 part_key: &[String],
585 partition_by: &[(String, DataType)],
586 write_id: &str,
587 file_extension: &str,
588 base_output_path: &ListingTableUrl,
589) -> Path {
590 let mut file_path = base_output_path.prefix().clone();
591 for j in 0..part_key.len() {
592 file_path = file_path.join(format!("{}={}", partition_by[j].0, part_key[j]));
593 }
594
595 file_path.join(format!("{write_id}.{file_extension}"))
596}