datafusion_datasource/write/
demux.rs1use std::borrow::Cow;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25use crate::url::ListingTableUrl;
26use crate::write::FileSinkConfig;
27use datafusion_common::error::Result;
28use datafusion_physical_plan::SendableRecordBatchStream;
29
30use arrow::array::{
31 ArrayAccessor, RecordBatch, StringArray, StructArray, builder::UInt64Builder,
32 cast::AsArray, downcast_dictionary_array,
33};
34use arrow::datatypes::{DataType, Schema};
35use datafusion_common::cast::{
36 as_boolean_array, as_date32_array, as_date64_array, as_float16_array,
37 as_float32_array, as_float64_array, as_int8_array, as_int16_array, as_int32_array,
38 as_int64_array, as_string_array, as_string_view_array, as_uint8_array,
39 as_uint16_array, as_uint32_array, as_uint64_array,
40};
41use datafusion_common::{exec_datafusion_err, internal_datafusion_err, not_impl_err};
42use datafusion_common_runtime::SpawnedTask;
43
44use chrono::NaiveDate;
45use datafusion_execution::TaskContext;
46use futures::StreamExt;
47use object_store::path::Path;
48use rand::distr::SampleString;
49use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender};
50
51type RecordBatchReceiver = Receiver<RecordBatch>;
52pub type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>;
53
54pub(crate) fn start_demuxer_task(
100 config: &FileSinkConfig,
101 data: SendableRecordBatchStream,
102 context: &Arc<TaskContext>,
103) -> (SpawnedTask<Result<()>>, DemuxedStreamReceiver) {
104 let (tx, rx) = mpsc::unbounded_channel();
105 let context = Arc::clone(context);
106 let file_extension = config.file_extension.clone();
107 let base_output_path = config.table_paths[0].clone();
108 let task = if config.table_partition_cols.is_empty() {
109 let single_file_output = !base_output_path.is_collection()
110 && base_output_path.file_extension().is_some();
111 SpawnedTask::spawn(async move {
112 row_count_demuxer(
113 tx,
114 data,
115 context,
116 base_output_path,
117 file_extension,
118 single_file_output,
119 )
120 .await
121 })
122 } else {
123 let partition_by = config.table_partition_cols.clone();
126 let keep_partition_by_columns = config.keep_partition_by_columns;
127 SpawnedTask::spawn(async move {
128 hive_style_partitions_demuxer(
129 tx,
130 data,
131 context,
132 partition_by,
133 base_output_path,
134 file_extension,
135 keep_partition_by_columns,
136 )
137 .await
138 })
139 };
140
141 (task, rx)
142}
143
144async fn row_count_demuxer(
146 mut tx: UnboundedSender<(Path, Receiver<RecordBatch>)>,
147 mut input: SendableRecordBatchStream,
148 context: Arc<TaskContext>,
149 base_output_path: ListingTableUrl,
150 file_extension: String,
151 single_file_output: bool,
152) -> Result<()> {
153 let exec_options = &context.session_config().options().execution;
154
155 let max_rows_per_file = exec_options.soft_max_rows_per_output_file;
156 let max_buffered_batches = exec_options.max_buffered_batches_per_output_file;
157 let minimum_parallel_files = exec_options.minimum_parallel_output_files;
158 let mut part_idx = 0;
159 let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16);
160
161 let mut open_file_streams = Vec::with_capacity(minimum_parallel_files);
162
163 let mut next_send_steam = 0;
164 let mut row_counts = Vec::with_capacity(minimum_parallel_files);
165
166 let minimum_parallel_files = if single_file_output {
168 1
169 } else {
170 minimum_parallel_files
171 };
172
173 let max_rows_per_file = if single_file_output {
174 usize::MAX
175 } else {
176 max_rows_per_file
177 };
178
179 if single_file_output {
180 open_file_streams.push(create_new_file_stream(
182 &base_output_path,
183 &write_id,
184 part_idx,
185 &file_extension,
186 single_file_output,
187 max_buffered_batches,
188 &mut tx,
189 )?);
190 row_counts.push(0);
191 part_idx += 1;
192 }
193
194 let schema = input.schema();
195 let mut is_batch_received = false;
196
197 while let Some(rb) = input.next().await.transpose()? {
198 is_batch_received = true;
199 if open_file_streams.len() < minimum_parallel_files {
201 open_file_streams.push(create_new_file_stream(
202 &base_output_path,
203 &write_id,
204 part_idx,
205 &file_extension,
206 single_file_output,
207 max_buffered_batches,
208 &mut tx,
209 )?);
210 row_counts.push(0);
211 part_idx += 1;
212 } else if row_counts[next_send_steam] >= max_rows_per_file {
213 row_counts[next_send_steam] = 0;
214 open_file_streams[next_send_steam] = create_new_file_stream(
215 &base_output_path,
216 &write_id,
217 part_idx,
218 &file_extension,
219 single_file_output,
220 max_buffered_batches,
221 &mut tx,
222 )?;
223 part_idx += 1;
224 }
225 row_counts[next_send_steam] += rb.num_rows();
226 open_file_streams[next_send_steam]
227 .send(rb)
228 .await
229 .map_err(|_| {
230 exec_datafusion_err!("Error sending RecordBatch to file stream!")
231 })?;
232
233 next_send_steam = (next_send_steam + 1) % minimum_parallel_files;
234 }
235
236 if single_file_output && !is_batch_received {
238 open_file_streams
239 .first_mut()
240 .ok_or_else(|| internal_datafusion_err!("Expected a single output file"))?
241 .send(RecordBatch::new_empty(schema))
242 .await
243 .map_err(|_| {
244 exec_datafusion_err!("Error sending empty RecordBatch to file stream!")
245 })?;
246 }
247
248 Ok(())
249}
250
251fn generate_file_path(
253 base_output_path: &ListingTableUrl,
254 write_id: &str,
255 part_idx: usize,
256 file_extension: &str,
257 single_file_output: bool,
258) -> Path {
259 if !single_file_output {
260 base_output_path
261 .prefix()
262 .child(format!("{write_id}_{part_idx}.{file_extension}"))
263 } else {
264 base_output_path.prefix().to_owned()
265 }
266}
267
268fn create_new_file_stream(
270 base_output_path: &ListingTableUrl,
271 write_id: &str,
272 part_idx: usize,
273 file_extension: &str,
274 single_file_output: bool,
275 max_buffered_batches: usize,
276 tx: &mut UnboundedSender<(Path, Receiver<RecordBatch>)>,
277) -> Result<Sender<RecordBatch>> {
278 let file_path = generate_file_path(
279 base_output_path,
280 write_id,
281 part_idx,
282 file_extension,
283 single_file_output,
284 );
285 let (tx_file, rx_file) = mpsc::channel(max_buffered_batches / 2);
286 tx.send((file_path, rx_file))
287 .map_err(|_| exec_datafusion_err!("Error sending RecordBatch to file stream!"))?;
288 Ok(tx_file)
289}
290
291async fn hive_style_partitions_demuxer(
295 tx: UnboundedSender<(Path, Receiver<RecordBatch>)>,
296 mut input: SendableRecordBatchStream,
297 context: Arc<TaskContext>,
298 partition_by: Vec<(String, DataType)>,
299 base_output_path: ListingTableUrl,
300 file_extension: String,
301 keep_partition_by_columns: bool,
302) -> Result<()> {
303 let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16);
304
305 let exec_options = &context.session_config().options().execution;
306 let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file;
307
308 let mut value_map: HashMap<Vec<String>, Sender<RecordBatch>> = HashMap::new();
310
311 while let Some(rb) = input.next().await.transpose()? {
312 let all_partition_values = compute_partition_keys_by_row(&rb, &partition_by)?;
314
315 let take_map = compute_take_arrays(&rb, &all_partition_values);
317
318 for (part_key, mut builder) in take_map.into_iter() {
320 let take_indices = builder.finish();
323 let struct_array: StructArray = rb.clone().into();
324 let parted_batch = RecordBatch::from(
325 arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(),
326 );
327
328 let part_tx = match value_map.get_mut(&part_key) {
330 Some(part_tx) => part_tx,
331 None => {
332 let (part_tx, part_rx) =
334 mpsc::channel::<RecordBatch>(max_buffered_recordbatches);
335 let file_path = compute_hive_style_file_path(
336 &part_key,
337 &partition_by,
338 &write_id,
339 &file_extension,
340 &base_output_path,
341 );
342
343 tx.send((file_path, part_rx)).map_err(|_| {
344 exec_datafusion_err!("Error sending new file stream!")
345 })?;
346
347 value_map.insert(part_key.clone(), part_tx);
348 value_map.get_mut(&part_key).ok_or_else(|| {
349 exec_datafusion_err!("Key must exist since it was just inserted!")
350 })?
351 }
352 };
353
354 let final_batch_to_send = if keep_partition_by_columns {
355 parted_batch
356 } else {
357 remove_partition_by_columns(&parted_batch, &partition_by)?
358 };
359
360 part_tx.send(final_batch_to_send).await.map_err(|_| {
362 internal_datafusion_err!("Unexpected error sending parted batch!")
363 })?;
364 }
365 }
366
367 Ok(())
368}
369
370fn compute_partition_keys_by_row<'a>(
371 rb: &'a RecordBatch,
372 partition_by: &'a [(String, DataType)],
373) -> Result<Vec<Vec<Cow<'a, str>>>> {
374 let mut all_partition_values = vec![];
375
376 const EPOCH_DAYS_FROM_CE: i32 = 719_163;
377
378 let schema = rb.schema();
384 for (col, _) in partition_by.iter() {
385 let mut partition_values = vec![];
386
387 let dtype = schema.field_with_name(col)?.data_type();
388 let col_array = rb.column_by_name(col).ok_or(exec_datafusion_err!(
389 "PartitionBy Column {} does not exist in source data! Got schema {schema}.",
390 col
391 ))?;
392
393 match dtype {
394 DataType::Utf8 => {
395 let array = as_string_array(col_array)?;
396 for i in 0..rb.num_rows() {
397 partition_values.push(Cow::from(array.value(i)));
398 }
399 }
400 DataType::Utf8View => {
401 let array = as_string_view_array(col_array)?;
402 for i in 0..rb.num_rows() {
403 partition_values.push(Cow::from(array.value(i)));
404 }
405 }
406 DataType::Boolean => {
407 let array = as_boolean_array(col_array)?;
408 for i in 0..rb.num_rows() {
409 partition_values.push(Cow::from(array.value(i).to_string()));
410 }
411 }
412 DataType::Date32 => {
413 let array = as_date32_array(col_array)?;
414 let format = "%Y-%m-%d";
416 for i in 0..rb.num_rows() {
417 let date = NaiveDate::from_num_days_from_ce_opt(
418 EPOCH_DAYS_FROM_CE + array.value(i),
419 )
420 .unwrap()
421 .format(format)
422 .to_string();
423 partition_values.push(Cow::from(date));
424 }
425 }
426 DataType::Date64 => {
427 let array = as_date64_array(col_array)?;
428 let format = "%Y-%m-%d";
430 for i in 0..rb.num_rows() {
431 let date = NaiveDate::from_num_days_from_ce_opt(
432 EPOCH_DAYS_FROM_CE + (array.value(i) / 86_400_000) as i32,
433 )
434 .unwrap()
435 .format(format)
436 .to_string();
437 partition_values.push(Cow::from(date));
438 }
439 }
440 DataType::Int8 => {
441 let array = as_int8_array(col_array)?;
442 for i in 0..rb.num_rows() {
443 partition_values.push(Cow::from(array.value(i).to_string()));
444 }
445 }
446 DataType::Int16 => {
447 let array = as_int16_array(col_array)?;
448 for i in 0..rb.num_rows() {
449 partition_values.push(Cow::from(array.value(i).to_string()));
450 }
451 }
452 DataType::Int32 => {
453 let array = as_int32_array(col_array)?;
454 for i in 0..rb.num_rows() {
455 partition_values.push(Cow::from(array.value(i).to_string()));
456 }
457 }
458 DataType::Int64 => {
459 let array = as_int64_array(col_array)?;
460 for i in 0..rb.num_rows() {
461 partition_values.push(Cow::from(array.value(i).to_string()));
462 }
463 }
464 DataType::UInt8 => {
465 let array = as_uint8_array(col_array)?;
466 for i in 0..rb.num_rows() {
467 partition_values.push(Cow::from(array.value(i).to_string()));
468 }
469 }
470 DataType::UInt16 => {
471 let array = as_uint16_array(col_array)?;
472 for i in 0..rb.num_rows() {
473 partition_values.push(Cow::from(array.value(i).to_string()));
474 }
475 }
476 DataType::UInt32 => {
477 let array = as_uint32_array(col_array)?;
478 for i in 0..rb.num_rows() {
479 partition_values.push(Cow::from(array.value(i).to_string()));
480 }
481 }
482 DataType::UInt64 => {
483 let array = as_uint64_array(col_array)?;
484 for i in 0..rb.num_rows() {
485 partition_values.push(Cow::from(array.value(i).to_string()));
486 }
487 }
488 DataType::Float16 => {
489 let array = as_float16_array(col_array)?;
490 for i in 0..rb.num_rows() {
491 partition_values.push(Cow::from(array.value(i).to_string()));
492 }
493 }
494 DataType::Float32 => {
495 let array = as_float32_array(col_array)?;
496 for i in 0..rb.num_rows() {
497 partition_values.push(Cow::from(array.value(i).to_string()));
498 }
499 }
500 DataType::Float64 => {
501 let array = as_float64_array(col_array)?;
502 for i in 0..rb.num_rows() {
503 partition_values.push(Cow::from(array.value(i).to_string()));
504 }
505 }
506 DataType::Dictionary(_, _) => {
507 downcast_dictionary_array!(
508 col_array => {
509 let array = col_array.downcast_dict::<StringArray>()
510 .ok_or(exec_datafusion_err!("it is not yet supported to write to hive partitions with datatype {}",
511 dtype))?;
512
513 for i in 0..rb.num_rows() {
514 partition_values.push(Cow::from(array.value(i)));
515 }
516 },
517 _ => unreachable!(),
518 )
519 }
520 _ => {
521 return not_impl_err!(
522 "it is not yet supported to write to hive partitions with datatype {}",
523 dtype
524 );
525 }
526 }
527
528 all_partition_values.push(partition_values);
529 }
530
531 Ok(all_partition_values)
532}
533
534fn compute_take_arrays(
535 rb: &RecordBatch,
536 all_partition_values: &[Vec<Cow<str>>],
537) -> HashMap<Vec<String>, UInt64Builder> {
538 let mut take_map = HashMap::new();
539 for i in 0..rb.num_rows() {
540 let mut part_key = vec![];
541 for vals in all_partition_values.iter() {
542 part_key.push(vals[i].clone().into());
543 }
544 let builder = take_map.entry(part_key).or_insert_with(UInt64Builder::new);
545 builder.append_value(i as u64);
546 }
547 take_map
548}
549
550fn remove_partition_by_columns(
551 parted_batch: &RecordBatch,
552 partition_by: &[(String, DataType)],
553) -> Result<RecordBatch> {
554 let partition_names: Vec<_> = partition_by.iter().map(|(s, _)| s).collect();
555 let (non_part_cols, non_part_fields): (Vec<_>, Vec<_>) = parted_batch
556 .columns()
557 .iter()
558 .zip(parted_batch.schema().fields())
559 .filter_map(|(a, f)| {
560 if !partition_names.contains(&f.name()) {
561 Some((Arc::clone(a), (**f).clone()))
562 } else {
563 None
564 }
565 })
566 .unzip();
567
568 let non_part_schema = Schema::new(non_part_fields);
569 let final_batch_to_send =
570 RecordBatch::try_new(Arc::new(non_part_schema), non_part_cols)?;
571
572 Ok(final_batch_to_send)
573}
574
575fn compute_hive_style_file_path(
576 part_key: &[String],
577 partition_by: &[(String, DataType)],
578 write_id: &str,
579 file_extension: &str,
580 base_output_path: &ListingTableUrl,
581) -> Path {
582 let mut file_path = base_output_path.prefix().clone();
583 for j in 0..part_key.len() {
584 file_path = file_path.child(format!("{}={}", partition_by[j].0, part_key[j]));
585 }
586
587 file_path.child(format!("{write_id}.{file_extension}"))
588}