Skip to main content

datafusion_physical_plan/
union.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18// Some of these functions reference the Postgres documentation
19// or implementation to ensure compatibility and are subject to
20// the Postgres license.
21
22//! The Union operator combines multiple inputs with the same schema
23
24use std::borrow::Borrow;
25use std::pin::Pin;
26use std::sync::Arc;
27use std::task::{Context, Poll};
28
29use super::{
30    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning,
31    PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
32    metrics::{ExecutionPlanMetricsSet, MetricsSet},
33};
34use crate::check_if_same_properties;
35use crate::execution_plan::{
36    CardinalityEffect, InvariantLevel, boundedness_from_children,
37    check_default_invariants, emission_type_from_children,
38};
39use crate::filter::FilterExec;
40use crate::filter_pushdown::{
41    ChildPushdownResult, FilterDescription, FilterPushdownPhase,
42    FilterPushdownPropagation, PushedDown,
43};
44use crate::metrics::BaselineMetrics;
45use crate::projection::{ProjectionExec, make_with_child};
46use crate::stream::ObservedStream;
47
48use arrow::datatypes::{Field, Schema, SchemaRef};
49use arrow::record_batch::RecordBatch;
50use datafusion_common::config::ConfigOptions;
51use datafusion_common::stats::NdvFallback;
52use datafusion_common::{
53    Result, assert_or_internal_err, exec_err, internal_datafusion_err,
54};
55use datafusion_execution::TaskContext;
56use datafusion_physical_expr::{
57    EquivalenceProperties, PhysicalExpr, calculate_union, conjunction,
58};
59
60use futures::Stream;
61use itertools::Itertools;
62use log::{debug, trace, warn};
63use tokio::macros::support::thread_rng_n;
64
65/// `UnionExec`: `UNION ALL` execution plan.
66///
67/// `UnionExec` combines multiple inputs with the same schema by
68/// concatenating the partitions.  It does not mix or copy data within
69/// or across partitions. Thus if the input partitions are sorted, the
70/// output partitions of the union are also sorted.
71///
72/// For example, given a `UnionExec` of two inputs, with `N`
73/// partitions, and `M` partitions, there will be `N+M` output
74/// partitions. The first `N` output partitions are from Input 1
75/// partitions, and then next `M` output partitions are from Input 2.
76///
77/// ```text
78///                        ▲       ▲           ▲         ▲
79///                        │       │           │         │
80///      Output            │  ...  │           │         │
81///    Partitions          │0      │N-1        │ N       │N+M-1
82/// (passes through   ┌────┴───────┴───────────┴─────────┴───┐
83///  the N+M input    │              UnionExec               │
84///   partitions)     │                                      │
85///                   └──────────────────────────────────────┘
86///                                      ▲
87///                                      │
88///                                      │
89///       Input           ┌────────┬─────┴────┬──────────┐
90///     Partitions        │ ...    │          │     ...  │
91///                    0  │        │ N-1      │ 0        │  M-1
92///                  ┌────┴────────┴───┐  ┌───┴──────────┴───┐
93///                  │                 │  │                  │
94///                  │                 │  │                  │
95///                  │                 │  │                  │
96///                  │                 │  │                  │
97///                  │                 │  │                  │
98///                  │                 │  │                  │
99///                  │Input 1          │  │Input 2           │
100///                  └─────────────────┘  └──────────────────┘
101/// ```
102#[derive(Debug, Clone)]
103pub struct UnionExec {
104    /// Input execution plan
105    inputs: Vec<Arc<dyn ExecutionPlan>>,
106    /// Execution metrics
107    metrics: ExecutionPlanMetricsSet,
108    /// Cache holding plan properties like equivalences, output partitioning etc.
109    cache: Arc<PlanProperties>,
110}
111
112impl UnionExec {
113    /// Create a new UnionExec
114    #[deprecated(since = "44.0.0", note = "Use UnionExec::try_new instead")]
115    pub fn new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Self {
116        let schema =
117            union_schema(&inputs).expect("UnionExec::new called with empty inputs");
118        // The schema of the inputs and the union schema is consistent when:
119        // - They have the same number of fields, and
120        // - Their fields have same types at the same indices.
121        // Here, we know that schemas are consistent and the call below can
122        // not return an error.
123        let cache = Self::compute_properties(&inputs, schema).unwrap();
124        UnionExec {
125            inputs,
126            metrics: ExecutionPlanMetricsSet::new(),
127            cache: Arc::new(cache),
128        }
129    }
130
131    /// Try to create a new UnionExec.
132    ///
133    /// # Errors
134    /// Returns an error if:
135    /// - `inputs` is empty
136    ///
137    /// # Optimization
138    /// If there is only one input, returns that input directly rather than wrapping it in a UnionExec
139    pub fn try_new(
140        inputs: Vec<Arc<dyn ExecutionPlan>>,
141    ) -> Result<Arc<dyn ExecutionPlan>> {
142        match inputs.len() {
143            0 => exec_err!("UnionExec requires at least one input"),
144            1 => Ok(inputs.into_iter().next().unwrap()),
145            _ => {
146                let schema = union_schema(&inputs)?;
147                // The schema of the inputs and the union schema is consistent when:
148                // - They have the same number of fields, and
149                // - Their fields have same types at the same indices.
150                // Here, we know that schemas are consistent and the call below can
151                // not return an error.
152                let cache = Self::compute_properties(&inputs, schema).unwrap();
153                Ok(Arc::new(UnionExec {
154                    inputs,
155                    metrics: ExecutionPlanMetricsSet::new(),
156                    cache: Arc::new(cache),
157                }))
158            }
159        }
160    }
161
162    /// Get inputs of the execution plan
163    pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
164        &self.inputs
165    }
166
167    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
168    fn compute_properties(
169        inputs: &[Arc<dyn ExecutionPlan>],
170        schema: SchemaRef,
171    ) -> Result<PlanProperties> {
172        // Calculate equivalence properties:
173        let children_eqps = inputs
174            .iter()
175            .map(|child| child.equivalence_properties().clone())
176            .collect::<Vec<_>>();
177        let eq_properties = calculate_union(children_eqps, schema)?;
178
179        // Calculate output partitioning; i.e. sum output partitions of the inputs.
180        let num_partitions = inputs
181            .iter()
182            .map(|plan| plan.output_partitioning().partition_count())
183            .sum();
184        let output_partitioning = Partitioning::UnknownPartitioning(num_partitions);
185        Ok(PlanProperties::new(
186            eq_properties,
187            output_partitioning,
188            emission_type_from_children(inputs),
189            boundedness_from_children(inputs),
190        ))
191    }
192
193    fn with_new_children_and_same_properties(
194        &self,
195        children: Vec<Arc<dyn ExecutionPlan>>,
196    ) -> Self {
197        Self {
198            inputs: children,
199            metrics: ExecutionPlanMetricsSet::new(),
200            ..Self::clone(self)
201        }
202    }
203}
204
205impl DisplayAs for UnionExec {
206    fn fmt_as(
207        &self,
208        t: DisplayFormatType,
209        f: &mut std::fmt::Formatter,
210    ) -> std::fmt::Result {
211        match t {
212            DisplayFormatType::Default | DisplayFormatType::Verbose => {
213                write!(f, "UnionExec")
214            }
215            DisplayFormatType::TreeRender => Ok(()),
216        }
217    }
218}
219
220impl ExecutionPlan for UnionExec {
221    fn name(&self) -> &'static str {
222        "UnionExec"
223    }
224
225    /// Return a reference to Any that can be used for downcasting
226    fn properties(&self) -> &Arc<PlanProperties> {
227        &self.cache
228    }
229
230    fn check_invariants(&self, check: InvariantLevel) -> Result<()> {
231        check_default_invariants(self, check)?;
232
233        (self.inputs().len() >= 2).then_some(()).ok_or_else(|| {
234            internal_datafusion_err!("UnionExec should have at least 2 children")
235        })
236    }
237
238    fn maintains_input_order(&self) -> Vec<bool> {
239        // If the Union has an output ordering, it maintains at least one
240        // child's ordering (i.e. the meet).
241        // For instance, assume that the first child is SortExpr('a','b','c'),
242        // the second child is SortExpr('a','b') and the third child is
243        // SortExpr('a','b'). The output ordering would be SortExpr('a','b'),
244        // which is the "meet" of all input orderings. In this example, this
245        // function will return vec![false, true, true], indicating that we
246        // preserve the orderings for the 2nd and the 3rd children.
247        if let Some(output_ordering) = self.properties().output_ordering() {
248            self.inputs()
249                .iter()
250                .map(|child| {
251                    if let Some(child_ordering) = child.output_ordering() {
252                        output_ordering.len() == child_ordering.len()
253                    } else {
254                        false
255                    }
256                })
257                .collect()
258        } else {
259            vec![false; self.inputs().len()]
260        }
261    }
262
263    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
264        vec![false; self.children().len()]
265    }
266
267    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
268        self.inputs.iter().collect()
269    }
270
271    fn with_new_children(
272        self: Arc<Self>,
273        children: Vec<Arc<dyn ExecutionPlan>>,
274    ) -> Result<Arc<dyn ExecutionPlan>> {
275        check_if_same_properties!(self, children);
276        UnionExec::try_new(children)
277    }
278
279    fn execute(
280        &self,
281        mut partition: usize,
282        context: Arc<TaskContext>,
283    ) -> Result<SendableRecordBatchStream> {
284        trace!(
285            "Start UnionExec::execute for partition {} of context session_id {} and task_id {:?}",
286            partition,
287            context.session_id(),
288            context.task_id()
289        );
290        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
291        // record the tiny amount of work done in this function so
292        // elapsed_compute is reported as non zero
293        let elapsed_compute = baseline_metrics.elapsed_compute().clone();
294        let _timer = elapsed_compute.timer(); // record on drop
295
296        // find partition to execute
297        for input in self.inputs.iter() {
298            // Calculate whether partition belongs to the current partition
299            if partition < input.output_partitioning().partition_count() {
300                let stream = input.execute(partition, context)?;
301                debug!("Found a Union partition to execute");
302                return Ok(Box::pin(ObservedStream::new(
303                    stream,
304                    baseline_metrics,
305                    None,
306                )));
307            } else {
308                partition -= input.output_partitioning().partition_count();
309            }
310        }
311
312        warn!("Error in Union: Partition {partition} not found");
313
314        exec_err!("Partition {partition} not found in Union")
315    }
316
317    fn metrics(&self) -> Option<MetricsSet> {
318        Some(self.metrics.clone_inner())
319    }
320
321    fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
322        if let Some(partition_idx) = partition {
323            // For a specific partition, find which input it belongs to
324            let mut remaining_idx = partition_idx;
325            for input in &self.inputs {
326                let input_partition_count = input.output_partitioning().partition_count();
327                if remaining_idx < input_partition_count {
328                    // This partition belongs to this input
329                    return input.partition_statistics(Some(remaining_idx));
330                }
331                remaining_idx -= input_partition_count;
332            }
333            // If we get here, the partition index is out of bounds
334            Ok(Arc::new(Statistics::new_unknown(&self.schema())))
335        } else {
336            let schema = self.schema();
337            Ok(Arc::new(merge_input_statistics(
338                &self.inputs,
339                None,
340                schema.as_ref(),
341            )?))
342        }
343    }
344
345    fn cardinality_effect(&self) -> CardinalityEffect {
346        // Union combines rows from multiple inputs, so output rows are not tied
347        // to any single input and can only be constrained as greater-or-equal.
348        CardinalityEffect::GreaterEqual
349    }
350
351    fn supports_limit_pushdown(&self) -> bool {
352        true
353    }
354
355    /// Tries to push `projection` down through `union`. If possible, performs the
356    /// pushdown and returns a new [`UnionExec`] as the top plan which has projections
357    /// as its children. Otherwise, returns `None`.
358    fn try_swapping_with_projection(
359        &self,
360        projection: &ProjectionExec,
361    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
362        // If the projection doesn't narrow the schema, we shouldn't try to push it down.
363        if projection.expr().len() >= projection.input().schema().fields().len() {
364            return Ok(None);
365        }
366
367        let new_children = self
368            .children()
369            .into_iter()
370            .map(|child| make_with_child(projection, child))
371            .collect::<Result<Vec<_>>>()?;
372
373        Ok(Some(UnionExec::try_new(new_children.clone())?))
374    }
375
376    fn gather_filters_for_pushdown(
377        &self,
378        _phase: FilterPushdownPhase,
379        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
380        _config: &ConfigOptions,
381    ) -> Result<FilterDescription> {
382        FilterDescription::from_children(parent_filters, &self.children())
383    }
384
385    fn handle_child_pushdown_result(
386        &self,
387        phase: FilterPushdownPhase,
388        child_pushdown_result: ChildPushdownResult,
389        _config: &ConfigOptions,
390    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
391        // Pre phase: handle heterogeneous pushdown by wrapping individual
392        // children with FilterExec and reporting all filters as handled.
393        // Post phase: use default behavior to let the filter creator decide how to handle
394        // filters that weren't fully pushed down.
395        if phase != FilterPushdownPhase::Pre {
396            return Ok(FilterPushdownPropagation::if_all(child_pushdown_result));
397        }
398
399        // UnionExec needs specialized filter pushdown handling when children have
400        // heterogeneous pushdown support. Without this, when some children support
401        // pushdown and others don't, the default behavior would leave FilterExec
402        // above UnionExec, re-applying filters to outputs of all children—including
403        // those that already applied the filters via pushdown. This specialized
404        // implementation adds FilterExec only to children that don't support
405        // pushdown, avoiding redundant filtering and improving performance.
406        //
407        // Example: Given Child1 (no pushdown support) and Child2 (has pushdown support)
408        //   Default behavior:          This implementation:
409        //   FilterExec                 UnionExec
410        //     UnionExec                  FilterExec
411        //       Child1                     Child1
412        //       Child2(filter)           Child2(filter)
413
414        // Collect unsupported filters for each child
415        let mut unsupported_filters_per_child = vec![Vec::new(); self.inputs.len()];
416        for parent_filter_result in child_pushdown_result.parent_filters.iter() {
417            for (child_idx, &child_result) in
418                parent_filter_result.child_results.iter().enumerate()
419            {
420                if matches!(child_result, PushedDown::No) {
421                    unsupported_filters_per_child[child_idx]
422                        .push(Arc::clone(&parent_filter_result.filter));
423                }
424            }
425        }
426
427        // Wrap children that have unsupported filters with FilterExec
428        let mut new_children = self.inputs.clone();
429        for (child_idx, unsupported_filters) in
430            unsupported_filters_per_child.iter().enumerate()
431        {
432            if !unsupported_filters.is_empty() {
433                let combined_filter = conjunction(unsupported_filters.clone());
434                new_children[child_idx] = Arc::new(FilterExec::try_new(
435                    combined_filter,
436                    Arc::clone(&self.inputs[child_idx]),
437                )?);
438            }
439        }
440
441        // Check if any children were modified
442        let children_modified = new_children
443            .iter()
444            .zip(self.inputs.iter())
445            .any(|(new, old)| !Arc::ptr_eq(new, old));
446
447        let all_filters_pushed =
448            vec![PushedDown::Yes; child_pushdown_result.parent_filters.len()];
449        let propagation = if children_modified {
450            let updated_node = UnionExec::try_new(new_children)?;
451            FilterPushdownPropagation::with_parent_pushdown_result(all_filters_pushed)
452                .with_updated_node(updated_node)
453        } else {
454            FilterPushdownPropagation::with_parent_pushdown_result(all_filters_pushed)
455        };
456
457        // Report all parent filters as supported since we've ensured they're applied
458        // on all children (either pushed down or via FilterExec)
459        Ok(propagation)
460    }
461}
462
463/// Combines multiple input streams by interleaving them.
464///
465/// This only works if all inputs have the same hash-partitioning.
466///
467/// # Data Flow
468/// ```text
469/// +---------+
470/// |         |---+
471/// | Input 1 |   |
472/// |         |-------------+
473/// +---------+   |         |
474///               |         |         +---------+
475///               +------------------>|         |
476///                 +---------------->| Combine |-->
477///                 | +-------------->|         |
478///                 | |     |         +---------+
479/// +---------+     | |     |
480/// |         |-----+ |     |
481/// | Input 2 |       |     |
482/// |         |---------------+
483/// +---------+       |     | |
484///                   |     | |       +---------+
485///                   |     +-------->|         |
486///                   |       +------>| Combine |-->
487///                   |         +---->|         |
488///                   |         |     +---------+
489/// +---------+       |         |
490/// |         |-------+         |
491/// | Input 3 |                 |
492/// |         |-----------------+
493/// +---------+
494/// ```
495#[derive(Debug, Clone)]
496pub struct InterleaveExec {
497    /// Input execution plan
498    inputs: Vec<Arc<dyn ExecutionPlan>>,
499    /// Execution metrics
500    metrics: ExecutionPlanMetricsSet,
501    /// Cache holding plan properties like equivalences, output partitioning etc.
502    cache: Arc<PlanProperties>,
503}
504
505impl InterleaveExec {
506    /// Create a new InterleaveExec
507    pub fn try_new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Result<Self> {
508        assert_or_internal_err!(
509            can_interleave(inputs.iter()),
510            "Not all InterleaveExec children have a consistent hash partitioning"
511        );
512        let cache = Self::compute_properties(&inputs)?;
513        Ok(InterleaveExec {
514            inputs,
515            metrics: ExecutionPlanMetricsSet::new(),
516            cache: Arc::new(cache),
517        })
518    }
519
520    /// Get inputs of the execution plan
521    pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
522        &self.inputs
523    }
524
525    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
526    fn compute_properties(inputs: &[Arc<dyn ExecutionPlan>]) -> Result<PlanProperties> {
527        let schema = union_schema(inputs)?;
528        let eq_properties = EquivalenceProperties::new(schema);
529        // Get output partitioning:
530        let output_partitioning = inputs[0].output_partitioning().clone();
531        Ok(PlanProperties::new(
532            eq_properties,
533            output_partitioning,
534            emission_type_from_children(inputs),
535            boundedness_from_children(inputs),
536        ))
537    }
538
539    fn with_new_children_and_same_properties(
540        &self,
541        children: Vec<Arc<dyn ExecutionPlan>>,
542    ) -> Self {
543        Self {
544            inputs: children,
545            metrics: ExecutionPlanMetricsSet::new(),
546            ..Self::clone(self)
547        }
548    }
549}
550
551impl DisplayAs for InterleaveExec {
552    fn fmt_as(
553        &self,
554        t: DisplayFormatType,
555        f: &mut std::fmt::Formatter,
556    ) -> std::fmt::Result {
557        match t {
558            DisplayFormatType::Default | DisplayFormatType::Verbose => {
559                write!(f, "InterleaveExec")
560            }
561            DisplayFormatType::TreeRender => Ok(()),
562        }
563    }
564}
565
566impl ExecutionPlan for InterleaveExec {
567    fn name(&self) -> &'static str {
568        "InterleaveExec"
569    }
570
571    /// Return a reference to Any that can be used for downcasting
572    fn properties(&self) -> &Arc<PlanProperties> {
573        &self.cache
574    }
575
576    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
577        self.inputs.iter().collect()
578    }
579
580    fn maintains_input_order(&self) -> Vec<bool> {
581        vec![false; self.inputs().len()]
582    }
583
584    fn with_new_children(
585        self: Arc<Self>,
586        children: Vec<Arc<dyn ExecutionPlan>>,
587    ) -> Result<Arc<dyn ExecutionPlan>> {
588        // New children are no longer interleavable, which might be a bug of optimization rewrite.
589        assert_or_internal_err!(
590            can_interleave(children.iter()),
591            "Can not create InterleaveExec: new children can not be interleaved"
592        );
593        check_if_same_properties!(self, children);
594        Ok(Arc::new(InterleaveExec::try_new(children)?))
595    }
596
597    fn execute(
598        &self,
599        partition: usize,
600        context: Arc<TaskContext>,
601    ) -> Result<SendableRecordBatchStream> {
602        trace!(
603            "Start InterleaveExec::execute for partition {} of context session_id {} and task_id {:?}",
604            partition,
605            context.session_id(),
606            context.task_id()
607        );
608        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
609        // record the tiny amount of work done in this function so
610        // elapsed_compute is reported as non zero
611        let elapsed_compute = baseline_metrics.elapsed_compute().clone();
612        let _timer = elapsed_compute.timer(); // record on drop
613
614        let mut input_stream_vec = vec![];
615        for input in self.inputs.iter() {
616            if partition < input.output_partitioning().partition_count() {
617                input_stream_vec.push(input.execute(partition, Arc::clone(&context))?);
618            } else {
619                // Do not find a partition to execute
620                break;
621            }
622        }
623        if input_stream_vec.len() == self.inputs.len() {
624            let stream = Box::pin(CombinedRecordBatchStream::new(
625                self.schema(),
626                input_stream_vec,
627            ));
628            return Ok(Box::pin(ObservedStream::new(
629                stream,
630                baseline_metrics,
631                None,
632            )));
633        }
634
635        warn!("Error in InterleaveExec: Partition {partition} not found");
636
637        exec_err!("Partition {partition} not found in InterleaveExec")
638    }
639
640    fn metrics(&self) -> Option<MetricsSet> {
641        Some(self.metrics.clone_inner())
642    }
643
644    fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
645        let schema = self.schema();
646        Ok(Arc::new(merge_input_statistics(
647            &self.inputs,
648            partition,
649            schema.as_ref(),
650        )?))
651    }
652
653    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
654        vec![false; self.children().len()]
655    }
656}
657
658/// If all the input partitions have the same Hash partition spec with the first_input_partition
659/// The InterleaveExec is partition aware.
660///
661/// It might be too strict here in the case that the input partition specs are compatible but not exactly the same.
662/// For example one input partition has the partition spec Hash('a','b','c') and
663/// other has the partition spec Hash('a'), It is safe to derive the out partition with the spec Hash('a','b','c').
664pub fn can_interleave<T: Borrow<Arc<dyn ExecutionPlan>>>(
665    mut inputs: impl Iterator<Item = T>,
666) -> bool {
667    let Some(first) = inputs.next() else {
668        return false;
669    };
670
671    let reference = first.borrow().output_partitioning();
672    matches!(reference, Partitioning::Hash(_, _))
673        && inputs
674            .map(|plan| plan.borrow().output_partitioning().clone())
675            .all(|partition| partition == *reference)
676}
677
678fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> Result<SchemaRef> {
679    if inputs.is_empty() {
680        return exec_err!("Cannot create union schema from empty inputs");
681    }
682
683    let first_schema = inputs[0].schema();
684    let first_field_count = first_schema.fields().len();
685
686    // validate that all inputs have the same number of fields
687    for (idx, input) in inputs.iter().enumerate().skip(1) {
688        let field_count = input.schema().fields().len();
689        if field_count != first_field_count {
690            return exec_err!(
691                "UnionExec/InterleaveExec requires all inputs to have the same number of fields. \
692                 Input 0 has {first_field_count} fields, but input {idx} has {field_count} fields"
693            );
694        }
695    }
696
697    let fields = (0..first_field_count)
698        .map(|i| {
699            // We take the name from the left side of the union to match how names are coerced during logical planning,
700            // which also uses the left side names.
701            let base_field = first_schema.field(i).clone();
702
703            // Coerce metadata and nullability across all inputs
704
705            inputs
706                .iter()
707                .enumerate()
708                .map(|(input_idx, input)| {
709                    let field = input.schema().field(i).clone();
710                    let mut metadata = field.metadata().clone();
711
712                    let other_metadatas = inputs
713                        .iter()
714                        .enumerate()
715                        .filter(|(other_idx, _)| *other_idx != input_idx)
716                        .flat_map(|(_, other_input)| {
717                            other_input.schema().field(i).metadata().clone().into_iter()
718                        });
719
720                    metadata.extend(other_metadatas);
721                    field.with_metadata(metadata)
722                })
723                .find_or_first(Field::is_nullable)
724                // We can unwrap this because if inputs was empty, this would've already panic'ed when we
725                // indexed into inputs[0].
726                .unwrap()
727                .with_name(base_field.name())
728        })
729        .collect::<Vec<_>>();
730
731    let all_metadata_merged = inputs
732        .iter()
733        .flat_map(|i| i.schema().metadata().clone().into_iter())
734        .collect();
735
736    Ok(Arc::new(Schema::new_with_metadata(
737        fields,
738        all_metadata_merged,
739    )))
740}
741
742/// CombinedRecordBatchStream can be used to combine a Vec of SendableRecordBatchStreams into one
743struct CombinedRecordBatchStream {
744    /// Schema wrapped by Arc
745    schema: SchemaRef,
746    /// Stream entries
747    entries: Vec<SendableRecordBatchStream>,
748}
749
750impl CombinedRecordBatchStream {
751    /// Create an CombinedRecordBatchStream
752    pub fn new(schema: SchemaRef, entries: Vec<SendableRecordBatchStream>) -> Self {
753        Self { schema, entries }
754    }
755}
756
757impl RecordBatchStream for CombinedRecordBatchStream {
758    fn schema(&self) -> SchemaRef {
759        Arc::clone(&self.schema)
760    }
761}
762
763impl Stream for CombinedRecordBatchStream {
764    type Item = Result<RecordBatch>;
765
766    fn poll_next(
767        mut self: Pin<&mut Self>,
768        cx: &mut Context<'_>,
769    ) -> Poll<Option<Self::Item>> {
770        use Poll::*;
771
772        let start = thread_rng_n(self.entries.len() as u32) as usize;
773        let mut idx = start;
774
775        for _ in 0..self.entries.len() {
776            let stream = self.entries.get_mut(idx).unwrap();
777
778            match Pin::new(stream).poll_next(cx) {
779                Ready(Some(val)) => return Ready(Some(val)),
780                Ready(None) => {
781                    // Remove the entry
782                    self.entries.swap_remove(idx);
783
784                    // Check if this was the last entry, if so the cursor needs
785                    // to wrap
786                    if idx == self.entries.len() {
787                        idx = 0;
788                    } else if idx < start && start <= self.entries.len() {
789                        // The stream being swapped into the current index has
790                        // already been polled, so skip it.
791                        idx = idx.wrapping_add(1) % self.entries.len();
792                    }
793                }
794                Pending => {
795                    idx = idx.wrapping_add(1) % self.entries.len();
796                }
797            }
798        }
799
800        // If the map is empty, then the stream is complete.
801        if self.entries.is_empty() {
802            Ready(None)
803        } else {
804            Pending
805        }
806    }
807}
808
809fn merge_input_statistics(
810    inputs: &[Arc<dyn ExecutionPlan>],
811    partition: Option<usize>,
812    schema: &Schema,
813) -> Result<Statistics> {
814    let stats = inputs
815        .iter()
816        .map(|input| {
817            input
818                .partition_statistics(partition)
819                .map(Arc::unwrap_or_clone)
820        })
821        .collect::<Result<Vec<_>>>()?;
822
823    Statistics::try_merge_iter_with_ndv_fallback(stats.iter(), schema, NdvFallback::Sum)
824}
825
826#[cfg(test)]
827mod tests {
828    use super::*;
829    use crate::collect;
830    use crate::repartition::RepartitionExec;
831    use crate::test::exec::StatisticsExec;
832    use crate::test::{self, TestMemoryExec};
833
834    use arrow::compute::SortOptions;
835    use arrow::datatypes::DataType;
836    use datafusion_common::stats::Precision;
837    use datafusion_common::{ColumnStatistics, ScalarValue};
838    use datafusion_physical_expr::equivalence::convert_to_orderings;
839    use datafusion_physical_expr::expressions::col;
840
841    // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g)
842    fn create_test_schema() -> Result<SchemaRef> {
843        let a = Field::new("a", DataType::Int32, true);
844        let b = Field::new("b", DataType::Int32, true);
845        let c = Field::new("c", DataType::Int32, true);
846        let d = Field::new("d", DataType::Int32, true);
847        let e = Field::new("e", DataType::Int32, true);
848        let f = Field::new("f", DataType::Int32, true);
849        let g = Field::new("g", DataType::Int32, true);
850        let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g]));
851
852        Ok(schema)
853    }
854
855    fn create_test_schema2() -> Result<SchemaRef> {
856        let a = Field::new("a", DataType::Int32, true);
857        let b = Field::new("b", DataType::Int32, true);
858        let c = Field::new("c", DataType::Int32, true);
859        let d = Field::new("d", DataType::Int32, true);
860        let e = Field::new("e", DataType::Int32, true);
861        let f = Field::new("f", DataType::Int32, true);
862        let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f]));
863
864        Ok(schema)
865    }
866
867    #[tokio::test]
868    async fn test_union_partitions() -> Result<()> {
869        let task_ctx = Arc::new(TaskContext::default());
870
871        // Create inputs with different partitioning
872        let csv = test::scan_partitioned(4);
873        let csv2 = test::scan_partitioned(5);
874
875        let union_exec: Arc<dyn ExecutionPlan> = UnionExec::try_new(vec![csv, csv2])?;
876
877        // Should have 9 partitions and 9 output batches
878        assert_eq!(
879            union_exec
880                .properties()
881                .output_partitioning()
882                .partition_count(),
883            9
884        );
885
886        let result: Vec<RecordBatch> = collect(union_exec, task_ctx).await?;
887        assert_eq!(result.len(), 9);
888
889        Ok(())
890    }
891
892    fn stats_merge_inputs() -> (SchemaRef, Statistics, Statistics, Statistics) {
893        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, true)]));
894
895        let left = Statistics::default()
896            .with_num_rows(Precision::Exact(5))
897            .with_total_byte_size(Precision::Exact(23))
898            .add_column_statistics(
899                ColumnStatistics::new_unknown()
900                    .with_distinct_count(Precision::Exact(5))
901                    .with_min_value(Precision::Exact(ScalarValue::UInt32(Some(1))))
902                    .with_max_value(Precision::Exact(ScalarValue::UInt32(Some(21))))
903                    .with_sum_value(Precision::Exact(ScalarValue::UInt32(Some(42))))
904                    .with_null_count(Precision::Exact(0))
905                    .with_byte_size(Precision::Exact(40)),
906            );
907
908        let right = Statistics::default()
909            .with_num_rows(Precision::Exact(7))
910            .with_total_byte_size(Precision::Exact(29))
911            .add_column_statistics(
912                ColumnStatistics::new_unknown()
913                    .with_distinct_count(Precision::Exact(3))
914                    .with_min_value(Precision::Exact(ScalarValue::UInt32(Some(22))))
915                    .with_max_value(Precision::Exact(ScalarValue::UInt32(Some(34))))
916                    .with_sum_value(Precision::Exact(ScalarValue::UInt32(Some(8))))
917                    .with_null_count(Precision::Exact(1))
918                    .with_byte_size(Precision::Exact(60)),
919            );
920
921        let expected = Statistics::default()
922            .with_num_rows(Precision::Exact(12))
923            .with_total_byte_size(Precision::Exact(52))
924            .add_column_statistics(
925                ColumnStatistics::new_unknown()
926                    .with_distinct_count(Precision::Inexact(8))
927                    .with_min_value(Precision::Exact(ScalarValue::UInt32(Some(1))))
928                    .with_max_value(Precision::Exact(ScalarValue::UInt32(Some(34))))
929                    .with_sum_value(Precision::Exact(ScalarValue::UInt64(Some(50))))
930                    .with_null_count(Precision::Exact(1))
931                    .with_byte_size(Precision::Exact(100)),
932            );
933
934        (schema, left, right, expected)
935    }
936
937    fn stats_merge_multicolumn_inputs() -> (SchemaRef, Statistics, Statistics, Statistics)
938    {
939        let schema = Arc::new(Schema::new(vec![
940            Field::new("a", DataType::Int64, true),
941            Field::new("b", DataType::Utf8, true),
942            Field::new("c", DataType::Float32, true),
943        ]));
944
945        let left = Statistics::default()
946            .with_num_rows(Precision::Exact(5))
947            .with_total_byte_size(Precision::Exact(23))
948            .add_column_statistics(
949                ColumnStatistics::new_unknown()
950                    .with_distinct_count(Precision::Exact(5))
951                    .with_min_value(Precision::Exact(ScalarValue::Int64(Some(-4))))
952                    .with_max_value(Precision::Exact(ScalarValue::Int64(Some(21))))
953                    .with_sum_value(Precision::Exact(ScalarValue::Int64(Some(42))))
954                    .with_null_count(Precision::Exact(0)),
955            )
956            .add_column_statistics(
957                ColumnStatistics::new_unknown()
958                    .with_distinct_count(Precision::Exact(2))
959                    .with_min_value(Precision::Exact(ScalarValue::from("a")))
960                    .with_max_value(Precision::Exact(ScalarValue::from("x")))
961                    .with_null_count(Precision::Exact(3)),
962            )
963            .add_column_statistics(
964                ColumnStatistics::new_unknown()
965                    .with_max_value(Precision::Exact(ScalarValue::Float32(Some(1.1))))
966                    .with_min_value(Precision::Exact(ScalarValue::Float32(Some(0.1))))
967                    .with_sum_value(Precision::Exact(ScalarValue::Float32(Some(42.0)))),
968            );
969
970        let right = Statistics::default()
971            .with_num_rows(Precision::Exact(7))
972            .with_total_byte_size(Precision::Exact(29))
973            .add_column_statistics(
974                ColumnStatistics::new_unknown()
975                    .with_distinct_count(Precision::Exact(3))
976                    .with_min_value(Precision::Exact(ScalarValue::Int64(Some(1))))
977                    .with_max_value(Precision::Exact(ScalarValue::Int64(Some(34))))
978                    .with_sum_value(Precision::Exact(ScalarValue::Int64(Some(42))))
979                    .with_null_count(Precision::Exact(1)),
980            )
981            .add_column_statistics(
982                ColumnStatistics::new_unknown()
983                    .with_distinct_count(Precision::Exact(3))
984                    .with_min_value(Precision::Exact(ScalarValue::from("b")))
985                    .with_max_value(Precision::Exact(ScalarValue::from("z"))),
986            )
987            .add_column_statistics(ColumnStatistics::new_unknown());
988
989        let expected = Statistics::default()
990            .with_num_rows(Precision::Exact(12))
991            .with_total_byte_size(Precision::Exact(52))
992            .add_column_statistics(
993                ColumnStatistics::new_unknown()
994                    .with_distinct_count(Precision::Inexact(6))
995                    .with_min_value(Precision::Exact(ScalarValue::Int64(Some(-4))))
996                    .with_max_value(Precision::Exact(ScalarValue::Int64(Some(34))))
997                    .with_sum_value(Precision::Exact(ScalarValue::Int64(Some(84))))
998                    .with_null_count(Precision::Exact(1)),
999            )
1000            .add_column_statistics(
1001                ColumnStatistics::new_unknown()
1002                    .with_distinct_count(Precision::Inexact(5))
1003                    .with_min_value(Precision::Exact(ScalarValue::from("a")))
1004                    .with_max_value(Precision::Exact(ScalarValue::from("z"))),
1005            )
1006            .add_column_statistics(ColumnStatistics::new_unknown());
1007
1008        (schema, left, right, expected)
1009    }
1010
1011    #[test]
1012    fn test_union_partition_statistics_uses_shared_statistics_merge() -> Result<()> {
1013        let (schema, left, right, expected) = stats_merge_inputs();
1014
1015        let left: Arc<dyn ExecutionPlan> =
1016            Arc::new(StatisticsExec::new(left, schema.as_ref().clone()));
1017        let right: Arc<dyn ExecutionPlan> =
1018            Arc::new(StatisticsExec::new(right, schema.as_ref().clone()));
1019
1020        let union = UnionExec::try_new(vec![left, right])?;
1021        let stats = union.partition_statistics(None)?;
1022
1023        assert_eq!(stats.as_ref(), &expected);
1024        Ok(())
1025    }
1026
1027    #[test]
1028    fn test_union_partition_statistics_uses_shared_statistics_merge_multicolumn()
1029    -> Result<()> {
1030        let (schema, left, right, expected) = stats_merge_multicolumn_inputs();
1031
1032        let left: Arc<dyn ExecutionPlan> =
1033            Arc::new(StatisticsExec::new(left, schema.as_ref().clone()));
1034        let right: Arc<dyn ExecutionPlan> =
1035            Arc::new(StatisticsExec::new(right, schema.as_ref().clone()));
1036
1037        let union = UnionExec::try_new(vec![left, right])?;
1038        let stats = union.partition_statistics(None)?;
1039
1040        assert_eq!(stats.as_ref(), &expected);
1041        Ok(())
1042    }
1043
1044    #[test]
1045    fn test_interleave_partition_statistics_uses_shared_statistics_merge() -> Result<()> {
1046        let (schema, left, right, expected) = stats_merge_inputs();
1047        let hash_expr = vec![col("a", schema.as_ref())?];
1048
1049        let left: Arc<dyn ExecutionPlan> = Arc::new(RepartitionExec::try_new(
1050            Arc::new(StatisticsExec::new(left, schema.as_ref().clone())),
1051            Partitioning::Hash(hash_expr.clone(), 2),
1052        )?);
1053        let right: Arc<dyn ExecutionPlan> = Arc::new(RepartitionExec::try_new(
1054            Arc::new(StatisticsExec::new(right, schema.as_ref().clone())),
1055            Partitioning::Hash(hash_expr, 2),
1056        )?);
1057
1058        let interleave = InterleaveExec::try_new(vec![left, right])?;
1059        let stats = interleave.partition_statistics(None)?;
1060
1061        assert_eq!(stats.as_ref(), &expected);
1062        Ok(())
1063    }
1064
1065    #[test]
1066    fn test_interleave_partition_statistics_for_partition_uses_shared_statistics_merge()
1067    -> Result<()> {
1068        let (schema, left, right, _) = stats_merge_inputs();
1069        let hash_expr = vec![col("a", schema.as_ref())?];
1070
1071        let left: Arc<dyn ExecutionPlan> = Arc::new(RepartitionExec::try_new(
1072            Arc::new(StatisticsExec::new(left, schema.as_ref().clone())),
1073            Partitioning::Hash(hash_expr.clone(), 2),
1074        )?);
1075        let right: Arc<dyn ExecutionPlan> = Arc::new(RepartitionExec::try_new(
1076            Arc::new(StatisticsExec::new(right, schema.as_ref().clone())),
1077            Partitioning::Hash(hash_expr, 2),
1078        )?);
1079
1080        let interleave = InterleaveExec::try_new(vec![left, right])?;
1081        let stats = interleave.partition_statistics(Some(0))?;
1082
1083        let expected = Statistics::default()
1084            .with_num_rows(Precision::Inexact(5))
1085            .with_total_byte_size(Precision::Inexact(25))
1086            .add_column_statistics(ColumnStatistics::new_unknown());
1087
1088        assert_eq!(stats.as_ref(), &expected);
1089        Ok(())
1090    }
1091
1092    #[tokio::test]
1093    async fn test_union_equivalence_properties() -> Result<()> {
1094        let schema = create_test_schema()?;
1095        let col_a = &col("a", &schema)?;
1096        let col_b = &col("b", &schema)?;
1097        let col_c = &col("c", &schema)?;
1098        let col_d = &col("d", &schema)?;
1099        let col_e = &col("e", &schema)?;
1100        let col_f = &col("f", &schema)?;
1101        let options = SortOptions::default();
1102        let test_cases = [
1103            //-----------TEST CASE 1----------//
1104            (
1105                // First child orderings
1106                vec![
1107                    // [a ASC, b ASC, f ASC]
1108                    vec![(col_a, options), (col_b, options), (col_f, options)],
1109                ],
1110                // Second child orderings
1111                vec![
1112                    // [a ASC, b ASC, c ASC]
1113                    vec![(col_a, options), (col_b, options), (col_c, options)],
1114                    // [a ASC, b ASC, f ASC]
1115                    vec![(col_a, options), (col_b, options), (col_f, options)],
1116                ],
1117                // Union output orderings
1118                vec![
1119                    // [a ASC, b ASC, f ASC]
1120                    vec![(col_a, options), (col_b, options), (col_f, options)],
1121                ],
1122            ),
1123            //-----------TEST CASE 2----------//
1124            (
1125                // First child orderings
1126                vec![
1127                    // [a ASC, b ASC, f ASC]
1128                    vec![(col_a, options), (col_b, options), (col_f, options)],
1129                    // d ASC
1130                    vec![(col_d, options)],
1131                ],
1132                // Second child orderings
1133                vec![
1134                    // [a ASC, b ASC, c ASC]
1135                    vec![(col_a, options), (col_b, options), (col_c, options)],
1136                    // [e ASC]
1137                    vec![(col_e, options)],
1138                ],
1139                // Union output orderings
1140                vec![
1141                    // [a ASC, b ASC]
1142                    vec![(col_a, options), (col_b, options)],
1143                ],
1144            ),
1145        ];
1146
1147        for (
1148            test_idx,
1149            (first_child_orderings, second_child_orderings, union_orderings),
1150        ) in test_cases.iter().enumerate()
1151        {
1152            let first_orderings = convert_to_orderings(first_child_orderings);
1153            let second_orderings = convert_to_orderings(second_child_orderings);
1154            let union_expected_orderings = convert_to_orderings(union_orderings);
1155            let child1_exec = TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
1156                .try_with_sort_information(first_orderings)?;
1157            let child1 = Arc::new(child1_exec);
1158            let child1 = Arc::new(TestMemoryExec::update_cache(&child1));
1159            let child2_exec = TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
1160                .try_with_sort_information(second_orderings)?;
1161            let child2 = Arc::new(child2_exec);
1162            let child2 = Arc::new(TestMemoryExec::update_cache(&child2));
1163
1164            let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema));
1165            union_expected_eq.add_orderings(union_expected_orderings);
1166
1167            let union: Arc<dyn ExecutionPlan> = UnionExec::try_new(vec![child1, child2])?;
1168            let union_eq_properties = union.properties().equivalence_properties();
1169            let err_msg = format!(
1170                "Error in test id: {:?}, test case: {:?}",
1171                test_idx, test_cases[test_idx]
1172            );
1173            assert_eq_properties_same(union_eq_properties, &union_expected_eq, err_msg);
1174        }
1175        Ok(())
1176    }
1177
1178    fn assert_eq_properties_same(
1179        lhs: &EquivalenceProperties,
1180        rhs: &EquivalenceProperties,
1181        err_msg: String,
1182    ) {
1183        // Check whether orderings are same.
1184        let lhs_orderings = lhs.oeq_class();
1185        let rhs_orderings = rhs.oeq_class();
1186        assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{err_msg}");
1187        for rhs_ordering in rhs_orderings.iter() {
1188            assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg);
1189        }
1190    }
1191
1192    #[test]
1193    fn test_union_empty_inputs() {
1194        // Test that UnionExec::try_new fails with empty inputs
1195        let result = UnionExec::try_new(vec![]);
1196        assert!(
1197            result
1198                .unwrap_err()
1199                .to_string()
1200                .contains("UnionExec requires at least one input")
1201        );
1202    }
1203
1204    #[test]
1205    fn test_union_schema_empty_inputs() {
1206        // Test that union_schema fails with empty inputs
1207        let result = union_schema(&[]);
1208        assert!(
1209            result
1210                .unwrap_err()
1211                .to_string()
1212                .contains("Cannot create union schema from empty inputs")
1213        );
1214    }
1215
1216    #[test]
1217    fn test_union_single_input() -> Result<()> {
1218        // Test that UnionExec::try_new returns the single input directly
1219        let schema = create_test_schema()?;
1220        let memory_exec: Arc<dyn ExecutionPlan> =
1221            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
1222        let memory_exec_clone = Arc::clone(&memory_exec);
1223        let result = UnionExec::try_new(vec![memory_exec])?;
1224
1225        // Check that the result is the same as the input (no UnionExec wrapper)
1226        assert_eq!(result.schema(), schema);
1227        // Verify it's the same execution plan
1228        assert!(Arc::ptr_eq(&result, &memory_exec_clone));
1229
1230        Ok(())
1231    }
1232
1233    #[test]
1234    fn test_union_schema_multiple_inputs() -> Result<()> {
1235        // Test that existing functionality with multiple inputs still works
1236        let schema = create_test_schema()?;
1237        let memory_exec1 =
1238            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
1239        let memory_exec2 =
1240            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
1241
1242        let union_plan = UnionExec::try_new(vec![memory_exec1, memory_exec2])?;
1243
1244        // Downcast to verify it's a UnionExec
1245        let union = union_plan
1246            .downcast_ref::<UnionExec>()
1247            .expect("Expected UnionExec");
1248
1249        // Check that schema is correct
1250        assert_eq!(union.schema(), schema);
1251        // Check that we have 2 inputs
1252        assert_eq!(union.inputs().len(), 2);
1253
1254        Ok(())
1255    }
1256
1257    #[test]
1258    fn test_union_schema_mismatch() {
1259        // Test that UnionExec properly rejects inputs with different field counts
1260        let schema = create_test_schema().unwrap();
1261        let schema2 = create_test_schema2().unwrap();
1262        let memory_exec1 =
1263            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None).unwrap());
1264        let memory_exec2 =
1265            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema2), None).unwrap());
1266
1267        let result = UnionExec::try_new(vec![memory_exec1, memory_exec2]);
1268        assert!(result.is_err());
1269        assert!(
1270            result.unwrap_err().to_string().contains(
1271                "UnionExec/InterleaveExec requires all inputs to have the same number of fields"
1272            )
1273        );
1274    }
1275
1276    #[test]
1277    fn test_union_cardinality_effect() -> Result<()> {
1278        let schema = create_test_schema()?;
1279        let input1: Arc<dyn ExecutionPlan> =
1280            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
1281        let input2: Arc<dyn ExecutionPlan> =
1282            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
1283
1284        let union = UnionExec::try_new(vec![input1, input2])?;
1285        let union = union
1286            .downcast_ref::<UnionExec>()
1287            .expect("expected UnionExec for multiple inputs");
1288
1289        assert!(matches!(
1290            union.cardinality_effect(),
1291            CardinalityEffect::GreaterEqual
1292        ));
1293        Ok(())
1294    }
1295}