1use std::borrow::Borrow;
25use std::pin::Pin;
26use std::task::{Context, Poll};
27use std::{any::Any, sync::Arc};
28
29use super::{
30 metrics::{ExecutionPlanMetricsSet, MetricsSet},
31 ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan,
32 ExecutionPlanProperties, Partitioning, PlanProperties, RecordBatchStream,
33 SendableRecordBatchStream, Statistics,
34};
35use crate::execution_plan::{
36 boundedness_from_children, check_default_invariants, emission_type_from_children,
37 InvariantLevel,
38};
39use crate::metrics::BaselineMetrics;
40use crate::projection::{make_with_child, ProjectionExec};
41use crate::stream::ObservedStream;
42
43use arrow::datatypes::{Field, Schema, SchemaRef};
44use arrow::record_batch::RecordBatch;
45use datafusion_common::stats::Precision;
46use datafusion_common::{exec_err, internal_err, DataFusionError, Result};
47use datafusion_execution::TaskContext;
48use datafusion_physical_expr::{calculate_union, EquivalenceProperties};
49
50use futures::Stream;
51use itertools::Itertools;
52use log::{debug, trace, warn};
53use tokio::macros::support::thread_rng_n;
54
55#[derive(Debug, Clone)]
93pub struct UnionExec {
94 inputs: Vec<Arc<dyn ExecutionPlan>>,
96 metrics: ExecutionPlanMetricsSet,
98 cache: PlanProperties,
100}
101
102impl UnionExec {
103 pub fn new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Self {
105 let schema = union_schema(&inputs);
106 let cache = Self::compute_properties(&inputs, schema).unwrap();
112 UnionExec {
113 inputs,
114 metrics: ExecutionPlanMetricsSet::new(),
115 cache,
116 }
117 }
118
119 pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
121 &self.inputs
122 }
123
124 fn compute_properties(
126 inputs: &[Arc<dyn ExecutionPlan>],
127 schema: SchemaRef,
128 ) -> Result<PlanProperties> {
129 let children_eqps = inputs
131 .iter()
132 .map(|child| child.equivalence_properties().clone())
133 .collect::<Vec<_>>();
134 let eq_properties = calculate_union(children_eqps, schema)?;
135
136 let num_partitions = inputs
138 .iter()
139 .map(|plan| plan.output_partitioning().partition_count())
140 .sum();
141 let output_partitioning = Partitioning::UnknownPartitioning(num_partitions);
142 Ok(PlanProperties::new(
143 eq_properties,
144 output_partitioning,
145 emission_type_from_children(inputs),
146 boundedness_from_children(inputs),
147 ))
148 }
149}
150
151impl DisplayAs for UnionExec {
152 fn fmt_as(
153 &self,
154 t: DisplayFormatType,
155 f: &mut std::fmt::Formatter,
156 ) -> std::fmt::Result {
157 match t {
158 DisplayFormatType::Default | DisplayFormatType::Verbose => {
159 write!(f, "UnionExec")
160 }
161 DisplayFormatType::TreeRender => Ok(()),
162 }
163 }
164}
165
166impl ExecutionPlan for UnionExec {
167 fn name(&self) -> &'static str {
168 "UnionExec"
169 }
170
171 fn as_any(&self) -> &dyn Any {
173 self
174 }
175
176 fn properties(&self) -> &PlanProperties {
177 &self.cache
178 }
179
180 fn check_invariants(&self, check: InvariantLevel) -> Result<()> {
181 check_default_invariants(self, check)?;
182
183 (self.inputs().len() >= 2)
184 .then_some(())
185 .ok_or(DataFusionError::Internal(
186 "UnionExec should have at least 2 children".into(),
187 ))
188 }
189
190 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
191 self.inputs.iter().collect()
192 }
193
194 fn maintains_input_order(&self) -> Vec<bool> {
195 if let Some(output_ordering) = self.properties().output_ordering() {
204 self.inputs()
205 .iter()
206 .map(|child| {
207 if let Some(child_ordering) = child.output_ordering() {
208 output_ordering.len() == child_ordering.len()
209 } else {
210 false
211 }
212 })
213 .collect()
214 } else {
215 vec![false; self.inputs().len()]
216 }
217 }
218
219 fn with_new_children(
220 self: Arc<Self>,
221 children: Vec<Arc<dyn ExecutionPlan>>,
222 ) -> Result<Arc<dyn ExecutionPlan>> {
223 Ok(Arc::new(UnionExec::new(children)))
224 }
225
226 fn execute(
227 &self,
228 mut partition: usize,
229 context: Arc<TaskContext>,
230 ) -> Result<SendableRecordBatchStream> {
231 trace!("Start UnionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
232 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
233 let elapsed_compute = baseline_metrics.elapsed_compute().clone();
236 let _timer = elapsed_compute.timer(); for input in self.inputs.iter() {
240 if partition < input.output_partitioning().partition_count() {
242 let stream = input.execute(partition, context)?;
243 debug!("Found a Union partition to execute");
244 return Ok(Box::pin(ObservedStream::new(
245 stream,
246 baseline_metrics,
247 None,
248 )));
249 } else {
250 partition -= input.output_partitioning().partition_count();
251 }
252 }
253
254 warn!("Error in Union: Partition {partition} not found");
255
256 exec_err!("Partition {partition} not found in Union")
257 }
258
259 fn metrics(&self) -> Option<MetricsSet> {
260 Some(self.metrics.clone_inner())
261 }
262
263 fn statistics(&self) -> Result<Statistics> {
264 self.partition_statistics(None)
265 }
266
267 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
268 if let Some(partition_idx) = partition {
269 let mut remaining_idx = partition_idx;
271 for input in &self.inputs {
272 let input_partition_count = input.output_partitioning().partition_count();
273 if remaining_idx < input_partition_count {
274 return input.partition_statistics(Some(remaining_idx));
276 }
277 remaining_idx -= input_partition_count;
278 }
279 Ok(Statistics::new_unknown(&self.schema()))
281 } else {
282 let stats = self
284 .inputs
285 .iter()
286 .map(|input_exec| input_exec.partition_statistics(None))
287 .collect::<Result<Vec<_>>>()?;
288
289 Ok(stats
290 .into_iter()
291 .reduce(stats_union)
292 .unwrap_or_else(|| Statistics::new_unknown(&self.schema())))
293 }
294 }
295
296 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
297 vec![false; self.children().len()]
298 }
299
300 fn supports_limit_pushdown(&self) -> bool {
301 true
302 }
303
304 fn try_swapping_with_projection(
308 &self,
309 projection: &ProjectionExec,
310 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
311 if projection.expr().len() >= projection.input().schema().fields().len() {
313 return Ok(None);
314 }
315
316 let new_children = self
317 .children()
318 .into_iter()
319 .map(|child| make_with_child(projection, child))
320 .collect::<Result<Vec<_>>>()?;
321
322 Ok(Some(Arc::new(UnionExec::new(new_children))))
323 }
324}
325
326#[derive(Debug, Clone)]
359pub struct InterleaveExec {
360 inputs: Vec<Arc<dyn ExecutionPlan>>,
362 metrics: ExecutionPlanMetricsSet,
364 cache: PlanProperties,
366}
367
368impl InterleaveExec {
369 pub fn try_new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Result<Self> {
371 if !can_interleave(inputs.iter()) {
372 return internal_err!(
373 "Not all InterleaveExec children have a consistent hash partitioning"
374 );
375 }
376 let cache = Self::compute_properties(&inputs);
377 Ok(InterleaveExec {
378 inputs,
379 metrics: ExecutionPlanMetricsSet::new(),
380 cache,
381 })
382 }
383
384 pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
386 &self.inputs
387 }
388
389 fn compute_properties(inputs: &[Arc<dyn ExecutionPlan>]) -> PlanProperties {
391 let schema = union_schema(inputs);
392 let eq_properties = EquivalenceProperties::new(schema);
393 let output_partitioning = inputs[0].output_partitioning().clone();
395 PlanProperties::new(
396 eq_properties,
397 output_partitioning,
398 emission_type_from_children(inputs),
399 boundedness_from_children(inputs),
400 )
401 }
402}
403
404impl DisplayAs for InterleaveExec {
405 fn fmt_as(
406 &self,
407 t: DisplayFormatType,
408 f: &mut std::fmt::Formatter,
409 ) -> std::fmt::Result {
410 match t {
411 DisplayFormatType::Default | DisplayFormatType::Verbose => {
412 write!(f, "InterleaveExec")
413 }
414 DisplayFormatType::TreeRender => Ok(()),
415 }
416 }
417}
418
419impl ExecutionPlan for InterleaveExec {
420 fn name(&self) -> &'static str {
421 "InterleaveExec"
422 }
423
424 fn as_any(&self) -> &dyn Any {
426 self
427 }
428
429 fn properties(&self) -> &PlanProperties {
430 &self.cache
431 }
432
433 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
434 self.inputs.iter().collect()
435 }
436
437 fn maintains_input_order(&self) -> Vec<bool> {
438 vec![false; self.inputs().len()]
439 }
440
441 fn with_new_children(
442 self: Arc<Self>,
443 children: Vec<Arc<dyn ExecutionPlan>>,
444 ) -> Result<Arc<dyn ExecutionPlan>> {
445 if !can_interleave(children.iter()) {
447 return internal_err!(
448 "Can not create InterleaveExec: new children can not be interleaved"
449 );
450 }
451 Ok(Arc::new(InterleaveExec::try_new(children)?))
452 }
453
454 fn execute(
455 &self,
456 partition: usize,
457 context: Arc<TaskContext>,
458 ) -> Result<SendableRecordBatchStream> {
459 trace!("Start InterleaveExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
460 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
461 let elapsed_compute = baseline_metrics.elapsed_compute().clone();
464 let _timer = elapsed_compute.timer(); let mut input_stream_vec = vec![];
467 for input in self.inputs.iter() {
468 if partition < input.output_partitioning().partition_count() {
469 input_stream_vec.push(input.execute(partition, Arc::clone(&context))?);
470 } else {
471 break;
473 }
474 }
475 if input_stream_vec.len() == self.inputs.len() {
476 let stream = Box::pin(CombinedRecordBatchStream::new(
477 self.schema(),
478 input_stream_vec,
479 ));
480 return Ok(Box::pin(ObservedStream::new(
481 stream,
482 baseline_metrics,
483 None,
484 )));
485 }
486
487 warn!("Error in InterleaveExec: Partition {partition} not found");
488
489 exec_err!("Partition {partition} not found in InterleaveExec")
490 }
491
492 fn metrics(&self) -> Option<MetricsSet> {
493 Some(self.metrics.clone_inner())
494 }
495
496 fn statistics(&self) -> Result<Statistics> {
497 self.partition_statistics(None)
498 }
499
500 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
501 if partition.is_some() {
502 return Ok(Statistics::new_unknown(&self.schema()));
503 }
504 let stats = self
505 .inputs
506 .iter()
507 .map(|stat| stat.partition_statistics(None))
508 .collect::<Result<Vec<_>>>()?;
509
510 Ok(stats
511 .into_iter()
512 .reduce(stats_union)
513 .unwrap_or_else(|| Statistics::new_unknown(&self.schema())))
514 }
515
516 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
517 vec![false; self.children().len()]
518 }
519}
520
521pub fn can_interleave<T: Borrow<Arc<dyn ExecutionPlan>>>(
528 mut inputs: impl Iterator<Item = T>,
529) -> bool {
530 let Some(first) = inputs.next() else {
531 return false;
532 };
533
534 let reference = first.borrow().output_partitioning();
535 matches!(reference, Partitioning::Hash(_, _))
536 && inputs
537 .map(|plan| plan.borrow().output_partitioning().clone())
538 .all(|partition| partition == *reference)
539}
540
541fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> SchemaRef {
542 let first_schema = inputs[0].schema();
543
544 let fields = (0..first_schema.fields().len())
545 .map(|i| {
546 let base_field = first_schema.field(i).clone();
549
550 let merged_field = inputs
552 .iter()
553 .enumerate()
554 .map(|(input_idx, input)| {
555 let field = input.schema().field(i).clone();
556 let mut metadata = field.metadata().clone();
557
558 let other_metadatas = inputs
559 .iter()
560 .enumerate()
561 .filter(|(other_idx, _)| *other_idx != input_idx)
562 .flat_map(|(_, other_input)| {
563 other_input.schema().field(i).metadata().clone().into_iter()
564 });
565
566 metadata.extend(other_metadatas);
567 field.with_metadata(metadata)
568 })
569 .find_or_first(Field::is_nullable)
570 .unwrap()
573 .with_name(base_field.name());
574
575 merged_field
576 })
577 .collect::<Vec<_>>();
578
579 let all_metadata_merged = inputs
580 .iter()
581 .flat_map(|i| i.schema().metadata().clone().into_iter())
582 .collect();
583
584 Arc::new(Schema::new_with_metadata(fields, all_metadata_merged))
585}
586
587struct CombinedRecordBatchStream {
589 schema: SchemaRef,
591 entries: Vec<SendableRecordBatchStream>,
593}
594
595impl CombinedRecordBatchStream {
596 pub fn new(schema: SchemaRef, entries: Vec<SendableRecordBatchStream>) -> Self {
598 Self { schema, entries }
599 }
600}
601
602impl RecordBatchStream for CombinedRecordBatchStream {
603 fn schema(&self) -> SchemaRef {
604 Arc::clone(&self.schema)
605 }
606}
607
608impl Stream for CombinedRecordBatchStream {
609 type Item = Result<RecordBatch>;
610
611 fn poll_next(
612 mut self: Pin<&mut Self>,
613 cx: &mut Context<'_>,
614 ) -> Poll<Option<Self::Item>> {
615 use Poll::*;
616
617 let start = thread_rng_n(self.entries.len() as u32) as usize;
618 let mut idx = start;
619
620 for _ in 0..self.entries.len() {
621 let stream = self.entries.get_mut(idx).unwrap();
622
623 match Pin::new(stream).poll_next(cx) {
624 Ready(Some(val)) => return Ready(Some(val)),
625 Ready(None) => {
626 self.entries.swap_remove(idx);
628
629 if idx == self.entries.len() {
632 idx = 0;
633 } else if idx < start && start <= self.entries.len() {
634 idx = idx.wrapping_add(1) % self.entries.len();
637 }
638 }
639 Pending => {
640 idx = idx.wrapping_add(1) % self.entries.len();
641 }
642 }
643 }
644
645 if self.entries.is_empty() {
647 Ready(None)
648 } else {
649 Pending
650 }
651 }
652}
653
654fn col_stats_union(
655 mut left: ColumnStatistics,
656 right: ColumnStatistics,
657) -> ColumnStatistics {
658 left.distinct_count = Precision::Absent;
659 left.min_value = left.min_value.min(&right.min_value);
660 left.max_value = left.max_value.max(&right.max_value);
661 left.sum_value = left.sum_value.add(&right.sum_value);
662 left.null_count = left.null_count.add(&right.null_count);
663
664 left
665}
666
667fn stats_union(mut left: Statistics, right: Statistics) -> Statistics {
668 left.num_rows = left.num_rows.add(&right.num_rows);
669 left.total_byte_size = left.total_byte_size.add(&right.total_byte_size);
670 left.column_statistics = left
671 .column_statistics
672 .into_iter()
673 .zip(right.column_statistics)
674 .map(|(a, b)| col_stats_union(a, b))
675 .collect::<Vec<_>>();
676 left
677}
678
679#[cfg(test)]
680mod tests {
681 use super::*;
682 use crate::collect;
683 use crate::test::{self, TestMemoryExec};
684
685 use arrow::compute::SortOptions;
686 use arrow::datatypes::DataType;
687 use datafusion_common::ScalarValue;
688 use datafusion_physical_expr::equivalence::convert_to_orderings;
689 use datafusion_physical_expr::expressions::col;
690
691 fn create_test_schema() -> Result<SchemaRef> {
693 let a = Field::new("a", DataType::Int32, true);
694 let b = Field::new("b", DataType::Int32, true);
695 let c = Field::new("c", DataType::Int32, true);
696 let d = Field::new("d", DataType::Int32, true);
697 let e = Field::new("e", DataType::Int32, true);
698 let f = Field::new("f", DataType::Int32, true);
699 let g = Field::new("g", DataType::Int32, true);
700 let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g]));
701
702 Ok(schema)
703 }
704
705 #[tokio::test]
706 async fn test_union_partitions() -> Result<()> {
707 let task_ctx = Arc::new(TaskContext::default());
708
709 let csv = test::scan_partitioned(4);
711 let csv2 = test::scan_partitioned(5);
712
713 let union_exec = Arc::new(UnionExec::new(vec![csv, csv2]));
714
715 assert_eq!(
717 union_exec
718 .properties()
719 .output_partitioning()
720 .partition_count(),
721 9
722 );
723
724 let result: Vec<RecordBatch> = collect(union_exec, task_ctx).await?;
725 assert_eq!(result.len(), 9);
726
727 Ok(())
728 }
729
730 #[tokio::test]
731 async fn test_stats_union() {
732 let left = Statistics {
733 num_rows: Precision::Exact(5),
734 total_byte_size: Precision::Exact(23),
735 column_statistics: vec![
736 ColumnStatistics {
737 distinct_count: Precision::Exact(5),
738 max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
739 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
740 sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
741 null_count: Precision::Exact(0),
742 },
743 ColumnStatistics {
744 distinct_count: Precision::Exact(1),
745 max_value: Precision::Exact(ScalarValue::from("x")),
746 min_value: Precision::Exact(ScalarValue::from("a")),
747 sum_value: Precision::Absent,
748 null_count: Precision::Exact(3),
749 },
750 ColumnStatistics {
751 distinct_count: Precision::Absent,
752 max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
753 min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
754 sum_value: Precision::Exact(ScalarValue::Float32(Some(42.0))),
755 null_count: Precision::Absent,
756 },
757 ],
758 };
759
760 let right = Statistics {
761 num_rows: Precision::Exact(7),
762 total_byte_size: Precision::Exact(29),
763 column_statistics: vec![
764 ColumnStatistics {
765 distinct_count: Precision::Exact(3),
766 max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
767 min_value: Precision::Exact(ScalarValue::Int64(Some(1))),
768 sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
769 null_count: Precision::Exact(1),
770 },
771 ColumnStatistics {
772 distinct_count: Precision::Absent,
773 max_value: Precision::Exact(ScalarValue::from("c")),
774 min_value: Precision::Exact(ScalarValue::from("b")),
775 sum_value: Precision::Absent,
776 null_count: Precision::Absent,
777 },
778 ColumnStatistics {
779 distinct_count: Precision::Absent,
780 max_value: Precision::Absent,
781 min_value: Precision::Absent,
782 sum_value: Precision::Absent,
783 null_count: Precision::Absent,
784 },
785 ],
786 };
787
788 let result = stats_union(left, right);
789 let expected = Statistics {
790 num_rows: Precision::Exact(12),
791 total_byte_size: Precision::Exact(52),
792 column_statistics: vec![
793 ColumnStatistics {
794 distinct_count: Precision::Absent,
795 max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
796 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
797 sum_value: Precision::Exact(ScalarValue::Int64(Some(84))),
798 null_count: Precision::Exact(1),
799 },
800 ColumnStatistics {
801 distinct_count: Precision::Absent,
802 max_value: Precision::Exact(ScalarValue::from("x")),
803 min_value: Precision::Exact(ScalarValue::from("a")),
804 sum_value: Precision::Absent,
805 null_count: Precision::Absent,
806 },
807 ColumnStatistics {
808 distinct_count: Precision::Absent,
809 max_value: Precision::Absent,
810 min_value: Precision::Absent,
811 sum_value: Precision::Absent,
812 null_count: Precision::Absent,
813 },
814 ],
815 };
816
817 assert_eq!(result, expected);
818 }
819
820 #[tokio::test]
821 async fn test_union_equivalence_properties() -> Result<()> {
822 let schema = create_test_schema()?;
823 let col_a = &col("a", &schema)?;
824 let col_b = &col("b", &schema)?;
825 let col_c = &col("c", &schema)?;
826 let col_d = &col("d", &schema)?;
827 let col_e = &col("e", &schema)?;
828 let col_f = &col("f", &schema)?;
829 let options = SortOptions::default();
830 let test_cases = [
831 (
833 vec![
835 vec![(col_a, options), (col_b, options), (col_f, options)],
837 ],
838 vec![
840 vec![(col_a, options), (col_b, options), (col_c, options)],
842 vec![(col_a, options), (col_b, options), (col_f, options)],
844 ],
845 vec![
847 vec![(col_a, options), (col_b, options), (col_f, options)],
849 ],
850 ),
851 (
853 vec![
855 vec![(col_a, options), (col_b, options), (col_f, options)],
857 vec![(col_d, options)],
859 ],
860 vec![
862 vec![(col_a, options), (col_b, options), (col_c, options)],
864 vec![(col_e, options)],
866 ],
867 vec![
869 vec![(col_a, options), (col_b, options)],
871 ],
872 ),
873 ];
874
875 for (
876 test_idx,
877 (first_child_orderings, second_child_orderings, union_orderings),
878 ) in test_cases.iter().enumerate()
879 {
880 let first_orderings = convert_to_orderings(first_child_orderings);
881 let second_orderings = convert_to_orderings(second_child_orderings);
882 let union_expected_orderings = convert_to_orderings(union_orderings);
883 let child1 = Arc::new(TestMemoryExec::update_cache(Arc::new(
884 TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
885 .try_with_sort_information(first_orderings)?,
886 )));
887 let child2 = Arc::new(TestMemoryExec::update_cache(Arc::new(
888 TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
889 .try_with_sort_information(second_orderings)?,
890 )));
891
892 let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema));
893 union_expected_eq.add_orderings(union_expected_orderings);
894
895 let union = UnionExec::new(vec![child1, child2]);
896 let union_eq_properties = union.properties().equivalence_properties();
897 let err_msg = format!(
898 "Error in test id: {:?}, test case: {:?}",
899 test_idx, test_cases[test_idx]
900 );
901 assert_eq_properties_same(union_eq_properties, &union_expected_eq, err_msg);
902 }
903 Ok(())
904 }
905
906 fn assert_eq_properties_same(
907 lhs: &EquivalenceProperties,
908 rhs: &EquivalenceProperties,
909 err_msg: String,
910 ) {
911 let lhs_orderings = lhs.oeq_class();
913 let rhs_orderings = rhs.oeq_class();
914 assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{err_msg}");
915 for rhs_ordering in rhs_orderings.iter() {
916 assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg);
917 }
918 }
919}