Skip to main content

datafusion_physical_plan/
union.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18// Some of these functions reference the Postgres documentation
19// or implementation to ensure compatibility and are subject to
20// the Postgres license.
21
22//! The Union operator combines multiple inputs with the same schema
23
24use std::borrow::Borrow;
25use std::pin::Pin;
26use std::task::{Context, Poll};
27use std::{any::Any, sync::Arc};
28
29use super::{
30    ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan,
31    ExecutionPlanProperties, Partitioning, PlanProperties, RecordBatchStream,
32    SendableRecordBatchStream, Statistics,
33    metrics::{ExecutionPlanMetricsSet, MetricsSet},
34};
35use crate::check_if_same_properties;
36use crate::execution_plan::{
37    InvariantLevel, boundedness_from_children, check_default_invariants,
38    emission_type_from_children,
39};
40use crate::filter::FilterExec;
41use crate::filter_pushdown::{
42    ChildPushdownResult, FilterDescription, FilterPushdownPhase,
43    FilterPushdownPropagation, PushedDown,
44};
45use crate::metrics::BaselineMetrics;
46use crate::projection::{ProjectionExec, make_with_child};
47use crate::stream::ObservedStream;
48
49use arrow::datatypes::{Field, Schema, SchemaRef};
50use arrow::record_batch::RecordBatch;
51use datafusion_common::config::ConfigOptions;
52use datafusion_common::stats::Precision;
53use datafusion_common::{
54    Result, assert_or_internal_err, exec_err, internal_datafusion_err,
55};
56use datafusion_execution::TaskContext;
57use datafusion_physical_expr::{
58    EquivalenceProperties, PhysicalExpr, calculate_union, conjunction,
59};
60
61use futures::Stream;
62use itertools::Itertools;
63use log::{debug, trace, warn};
64use tokio::macros::support::thread_rng_n;
65
66/// `UnionExec`: `UNION ALL` execution plan.
67///
68/// `UnionExec` combines multiple inputs with the same schema by
69/// concatenating the partitions.  It does not mix or copy data within
70/// or across partitions. Thus if the input partitions are sorted, the
71/// output partitions of the union are also sorted.
72///
73/// For example, given a `UnionExec` of two inputs, with `N`
74/// partitions, and `M` partitions, there will be `N+M` output
75/// partitions. The first `N` output partitions are from Input 1
76/// partitions, and then next `M` output partitions are from Input 2.
77///
78/// ```text
79///                        ▲       ▲           ▲         ▲
80///                        │       │           │         │
81///      Output            │  ...  │           │         │
82///    Partitions          │0      │N-1        │ N       │N+M-1
83/// (passes through   ┌────┴───────┴───────────┴─────────┴───┐
84///  the N+M input    │              UnionExec               │
85///   partitions)     │                                      │
86///                   └──────────────────────────────────────┘
87///                                      ▲
88///                                      │
89///                                      │
90///       Input           ┌────────┬─────┴────┬──────────┐
91///     Partitions        │ ...    │          │     ...  │
92///                    0  │        │ N-1      │ 0        │  M-1
93///                  ┌────┴────────┴───┐  ┌───┴──────────┴───┐
94///                  │                 │  │                  │
95///                  │                 │  │                  │
96///                  │                 │  │                  │
97///                  │                 │  │                  │
98///                  │                 │  │                  │
99///                  │                 │  │                  │
100///                  │Input 1          │  │Input 2           │
101///                  └─────────────────┘  └──────────────────┘
102/// ```
103#[derive(Debug, Clone)]
104pub struct UnionExec {
105    /// Input execution plan
106    inputs: Vec<Arc<dyn ExecutionPlan>>,
107    /// Execution metrics
108    metrics: ExecutionPlanMetricsSet,
109    /// Cache holding plan properties like equivalences, output partitioning etc.
110    cache: Arc<PlanProperties>,
111}
112
113impl UnionExec {
114    /// Create a new UnionExec
115    #[deprecated(since = "44.0.0", note = "Use UnionExec::try_new instead")]
116    pub fn new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Self {
117        let schema =
118            union_schema(&inputs).expect("UnionExec::new called with empty inputs");
119        // The schema of the inputs and the union schema is consistent when:
120        // - They have the same number of fields, and
121        // - Their fields have same types at the same indices.
122        // Here, we know that schemas are consistent and the call below can
123        // not return an error.
124        let cache = Self::compute_properties(&inputs, schema).unwrap();
125        UnionExec {
126            inputs,
127            metrics: ExecutionPlanMetricsSet::new(),
128            cache: Arc::new(cache),
129        }
130    }
131
132    /// Try to create a new UnionExec.
133    ///
134    /// # Errors
135    /// Returns an error if:
136    /// - `inputs` is empty
137    ///
138    /// # Optimization
139    /// If there is only one input, returns that input directly rather than wrapping it in a UnionExec
140    pub fn try_new(
141        inputs: Vec<Arc<dyn ExecutionPlan>>,
142    ) -> Result<Arc<dyn ExecutionPlan>> {
143        match inputs.len() {
144            0 => exec_err!("UnionExec requires at least one input"),
145            1 => Ok(inputs.into_iter().next().unwrap()),
146            _ => {
147                let schema = union_schema(&inputs)?;
148                // The schema of the inputs and the union schema is consistent when:
149                // - They have the same number of fields, and
150                // - Their fields have same types at the same indices.
151                // Here, we know that schemas are consistent and the call below can
152                // not return an error.
153                let cache = Self::compute_properties(&inputs, schema).unwrap();
154                Ok(Arc::new(UnionExec {
155                    inputs,
156                    metrics: ExecutionPlanMetricsSet::new(),
157                    cache: Arc::new(cache),
158                }))
159            }
160        }
161    }
162
163    /// Get inputs of the execution plan
164    pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
165        &self.inputs
166    }
167
168    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
169    fn compute_properties(
170        inputs: &[Arc<dyn ExecutionPlan>],
171        schema: SchemaRef,
172    ) -> Result<PlanProperties> {
173        // Calculate equivalence properties:
174        let children_eqps = inputs
175            .iter()
176            .map(|child| child.equivalence_properties().clone())
177            .collect::<Vec<_>>();
178        let eq_properties = calculate_union(children_eqps, schema)?;
179
180        // Calculate output partitioning; i.e. sum output partitions of the inputs.
181        let num_partitions = inputs
182            .iter()
183            .map(|plan| plan.output_partitioning().partition_count())
184            .sum();
185        let output_partitioning = Partitioning::UnknownPartitioning(num_partitions);
186        Ok(PlanProperties::new(
187            eq_properties,
188            output_partitioning,
189            emission_type_from_children(inputs),
190            boundedness_from_children(inputs),
191        ))
192    }
193
194    fn with_new_children_and_same_properties(
195        &self,
196        children: Vec<Arc<dyn ExecutionPlan>>,
197    ) -> Self {
198        Self {
199            inputs: children,
200            metrics: ExecutionPlanMetricsSet::new(),
201            ..Self::clone(self)
202        }
203    }
204}
205
206impl DisplayAs for UnionExec {
207    fn fmt_as(
208        &self,
209        t: DisplayFormatType,
210        f: &mut std::fmt::Formatter,
211    ) -> std::fmt::Result {
212        match t {
213            DisplayFormatType::Default | DisplayFormatType::Verbose => {
214                write!(f, "UnionExec")
215            }
216            DisplayFormatType::TreeRender => Ok(()),
217        }
218    }
219}
220
221impl ExecutionPlan for UnionExec {
222    fn name(&self) -> &'static str {
223        "UnionExec"
224    }
225
226    /// Return a reference to Any that can be used for downcasting
227    fn as_any(&self) -> &dyn Any {
228        self
229    }
230
231    fn properties(&self) -> &Arc<PlanProperties> {
232        &self.cache
233    }
234
235    fn check_invariants(&self, check: InvariantLevel) -> Result<()> {
236        check_default_invariants(self, check)?;
237
238        (self.inputs().len() >= 2).then_some(()).ok_or_else(|| {
239            internal_datafusion_err!("UnionExec should have at least 2 children")
240        })
241    }
242
243    fn maintains_input_order(&self) -> Vec<bool> {
244        // If the Union has an output ordering, it maintains at least one
245        // child's ordering (i.e. the meet).
246        // For instance, assume that the first child is SortExpr('a','b','c'),
247        // the second child is SortExpr('a','b') and the third child is
248        // SortExpr('a','b'). The output ordering would be SortExpr('a','b'),
249        // which is the "meet" of all input orderings. In this example, this
250        // function will return vec![false, true, true], indicating that we
251        // preserve the orderings for the 2nd and the 3rd children.
252        if let Some(output_ordering) = self.properties().output_ordering() {
253            self.inputs()
254                .iter()
255                .map(|child| {
256                    if let Some(child_ordering) = child.output_ordering() {
257                        output_ordering.len() == child_ordering.len()
258                    } else {
259                        false
260                    }
261                })
262                .collect()
263        } else {
264            vec![false; self.inputs().len()]
265        }
266    }
267
268    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
269        vec![false; self.children().len()]
270    }
271
272    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
273        self.inputs.iter().collect()
274    }
275
276    fn with_new_children(
277        self: Arc<Self>,
278        children: Vec<Arc<dyn ExecutionPlan>>,
279    ) -> Result<Arc<dyn ExecutionPlan>> {
280        check_if_same_properties!(self, children);
281        UnionExec::try_new(children)
282    }
283
284    fn execute(
285        &self,
286        mut partition: usize,
287        context: Arc<TaskContext>,
288    ) -> Result<SendableRecordBatchStream> {
289        trace!(
290            "Start UnionExec::execute for partition {} of context session_id {} and task_id {:?}",
291            partition,
292            context.session_id(),
293            context.task_id()
294        );
295        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
296        // record the tiny amount of work done in this function so
297        // elapsed_compute is reported as non zero
298        let elapsed_compute = baseline_metrics.elapsed_compute().clone();
299        let _timer = elapsed_compute.timer(); // record on drop
300
301        // find partition to execute
302        for input in self.inputs.iter() {
303            // Calculate whether partition belongs to the current partition
304            if partition < input.output_partitioning().partition_count() {
305                let stream = input.execute(partition, context)?;
306                debug!("Found a Union partition to execute");
307                return Ok(Box::pin(ObservedStream::new(
308                    stream,
309                    baseline_metrics,
310                    None,
311                )));
312            } else {
313                partition -= input.output_partitioning().partition_count();
314            }
315        }
316
317        warn!("Error in Union: Partition {partition} not found");
318
319        exec_err!("Partition {partition} not found in Union")
320    }
321
322    fn metrics(&self) -> Option<MetricsSet> {
323        Some(self.metrics.clone_inner())
324    }
325
326    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
327        if let Some(partition_idx) = partition {
328            // For a specific partition, find which input it belongs to
329            let mut remaining_idx = partition_idx;
330            for input in &self.inputs {
331                let input_partition_count = input.output_partitioning().partition_count();
332                if remaining_idx < input_partition_count {
333                    // This partition belongs to this input
334                    return input.partition_statistics(Some(remaining_idx));
335                }
336                remaining_idx -= input_partition_count;
337            }
338            // If we get here, the partition index is out of bounds
339            Ok(Statistics::new_unknown(&self.schema()))
340        } else {
341            // Collect statistics from all inputs
342            let stats = self
343                .inputs
344                .iter()
345                .map(|input_exec| input_exec.partition_statistics(None))
346                .collect::<Result<Vec<_>>>()?;
347
348            Ok(stats
349                .into_iter()
350                .reduce(stats_union)
351                .unwrap_or_else(|| Statistics::new_unknown(&self.schema())))
352        }
353    }
354
355    fn supports_limit_pushdown(&self) -> bool {
356        true
357    }
358
359    /// Tries to push `projection` down through `union`. If possible, performs the
360    /// pushdown and returns a new [`UnionExec`] as the top plan which has projections
361    /// as its children. Otherwise, returns `None`.
362    fn try_swapping_with_projection(
363        &self,
364        projection: &ProjectionExec,
365    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
366        // If the projection doesn't narrow the schema, we shouldn't try to push it down.
367        if projection.expr().len() >= projection.input().schema().fields().len() {
368            return Ok(None);
369        }
370
371        let new_children = self
372            .children()
373            .into_iter()
374            .map(|child| make_with_child(projection, child))
375            .collect::<Result<Vec<_>>>()?;
376
377        Ok(Some(UnionExec::try_new(new_children.clone())?))
378    }
379
380    fn gather_filters_for_pushdown(
381        &self,
382        _phase: FilterPushdownPhase,
383        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
384        _config: &ConfigOptions,
385    ) -> Result<FilterDescription> {
386        FilterDescription::from_children(parent_filters, &self.children())
387    }
388
389    fn handle_child_pushdown_result(
390        &self,
391        phase: FilterPushdownPhase,
392        child_pushdown_result: ChildPushdownResult,
393        _config: &ConfigOptions,
394    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
395        // Pre phase: handle heterogeneous pushdown by wrapping individual
396        // children with FilterExec and reporting all filters as handled.
397        // Post phase: use default behavior to let the filter creator decide how to handle
398        // filters that weren't fully pushed down.
399        if phase != FilterPushdownPhase::Pre {
400            return Ok(FilterPushdownPropagation::if_all(child_pushdown_result));
401        }
402
403        // UnionExec needs specialized filter pushdown handling when children have
404        // heterogeneous pushdown support. Without this, when some children support
405        // pushdown and others don't, the default behavior would leave FilterExec
406        // above UnionExec, re-applying filters to outputs of all children—including
407        // those that already applied the filters via pushdown. This specialized
408        // implementation adds FilterExec only to children that don't support
409        // pushdown, avoiding redundant filtering and improving performance.
410        //
411        // Example: Given Child1 (no pushdown support) and Child2 (has pushdown support)
412        //   Default behavior:          This implementation:
413        //   FilterExec                 UnionExec
414        //     UnionExec                  FilterExec
415        //       Child1                     Child1
416        //       Child2(filter)           Child2(filter)
417
418        // Collect unsupported filters for each child
419        let mut unsupported_filters_per_child = vec![Vec::new(); self.inputs.len()];
420        for parent_filter_result in child_pushdown_result.parent_filters.iter() {
421            for (child_idx, &child_result) in
422                parent_filter_result.child_results.iter().enumerate()
423            {
424                if matches!(child_result, PushedDown::No) {
425                    unsupported_filters_per_child[child_idx]
426                        .push(Arc::clone(&parent_filter_result.filter));
427                }
428            }
429        }
430
431        // Wrap children that have unsupported filters with FilterExec
432        let mut new_children = self.inputs.clone();
433        for (child_idx, unsupported_filters) in
434            unsupported_filters_per_child.iter().enumerate()
435        {
436            if !unsupported_filters.is_empty() {
437                let combined_filter = conjunction(unsupported_filters.clone());
438                new_children[child_idx] = Arc::new(FilterExec::try_new(
439                    combined_filter,
440                    Arc::clone(&self.inputs[child_idx]),
441                )?);
442            }
443        }
444
445        // Check if any children were modified
446        let children_modified = new_children
447            .iter()
448            .zip(self.inputs.iter())
449            .any(|(new, old)| !Arc::ptr_eq(new, old));
450
451        let all_filters_pushed =
452            vec![PushedDown::Yes; child_pushdown_result.parent_filters.len()];
453        let propagation = if children_modified {
454            let updated_node = UnionExec::try_new(new_children)?;
455            FilterPushdownPropagation::with_parent_pushdown_result(all_filters_pushed)
456                .with_updated_node(updated_node)
457        } else {
458            FilterPushdownPropagation::with_parent_pushdown_result(all_filters_pushed)
459        };
460
461        // Report all parent filters as supported since we've ensured they're applied
462        // on all children (either pushed down or via FilterExec)
463        Ok(propagation)
464    }
465}
466
467/// Combines multiple input streams by interleaving them.
468///
469/// This only works if all inputs have the same hash-partitioning.
470///
471/// # Data Flow
472/// ```text
473/// +---------+
474/// |         |---+
475/// | Input 1 |   |
476/// |         |-------------+
477/// +---------+   |         |
478///               |         |         +---------+
479///               +------------------>|         |
480///                 +---------------->| Combine |-->
481///                 | +-------------->|         |
482///                 | |     |         +---------+
483/// +---------+     | |     |
484/// |         |-----+ |     |
485/// | Input 2 |       |     |
486/// |         |---------------+
487/// +---------+       |     | |
488///                   |     | |       +---------+
489///                   |     +-------->|         |
490///                   |       +------>| Combine |-->
491///                   |         +---->|         |
492///                   |         |     +---------+
493/// +---------+       |         |
494/// |         |-------+         |
495/// | Input 3 |                 |
496/// |         |-----------------+
497/// +---------+
498/// ```
499#[derive(Debug, Clone)]
500pub struct InterleaveExec {
501    /// Input execution plan
502    inputs: Vec<Arc<dyn ExecutionPlan>>,
503    /// Execution metrics
504    metrics: ExecutionPlanMetricsSet,
505    /// Cache holding plan properties like equivalences, output partitioning etc.
506    cache: Arc<PlanProperties>,
507}
508
509impl InterleaveExec {
510    /// Create a new InterleaveExec
511    pub fn try_new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Result<Self> {
512        assert_or_internal_err!(
513            can_interleave(inputs.iter()),
514            "Not all InterleaveExec children have a consistent hash partitioning"
515        );
516        let cache = Self::compute_properties(&inputs)?;
517        Ok(InterleaveExec {
518            inputs,
519            metrics: ExecutionPlanMetricsSet::new(),
520            cache: Arc::new(cache),
521        })
522    }
523
524    /// Get inputs of the execution plan
525    pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
526        &self.inputs
527    }
528
529    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
530    fn compute_properties(inputs: &[Arc<dyn ExecutionPlan>]) -> Result<PlanProperties> {
531        let schema = union_schema(inputs)?;
532        let eq_properties = EquivalenceProperties::new(schema);
533        // Get output partitioning:
534        let output_partitioning = inputs[0].output_partitioning().clone();
535        Ok(PlanProperties::new(
536            eq_properties,
537            output_partitioning,
538            emission_type_from_children(inputs),
539            boundedness_from_children(inputs),
540        ))
541    }
542
543    fn with_new_children_and_same_properties(
544        &self,
545        children: Vec<Arc<dyn ExecutionPlan>>,
546    ) -> Self {
547        Self {
548            inputs: children,
549            metrics: ExecutionPlanMetricsSet::new(),
550            ..Self::clone(self)
551        }
552    }
553}
554
555impl DisplayAs for InterleaveExec {
556    fn fmt_as(
557        &self,
558        t: DisplayFormatType,
559        f: &mut std::fmt::Formatter,
560    ) -> std::fmt::Result {
561        match t {
562            DisplayFormatType::Default | DisplayFormatType::Verbose => {
563                write!(f, "InterleaveExec")
564            }
565            DisplayFormatType::TreeRender => Ok(()),
566        }
567    }
568}
569
570impl ExecutionPlan for InterleaveExec {
571    fn name(&self) -> &'static str {
572        "InterleaveExec"
573    }
574
575    /// Return a reference to Any that can be used for downcasting
576    fn as_any(&self) -> &dyn Any {
577        self
578    }
579
580    fn properties(&self) -> &Arc<PlanProperties> {
581        &self.cache
582    }
583
584    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
585        self.inputs.iter().collect()
586    }
587
588    fn maintains_input_order(&self) -> Vec<bool> {
589        vec![false; self.inputs().len()]
590    }
591
592    fn with_new_children(
593        self: Arc<Self>,
594        children: Vec<Arc<dyn ExecutionPlan>>,
595    ) -> Result<Arc<dyn ExecutionPlan>> {
596        // New children are no longer interleavable, which might be a bug of optimization rewrite.
597        assert_or_internal_err!(
598            can_interleave(children.iter()),
599            "Can not create InterleaveExec: new children can not be interleaved"
600        );
601        check_if_same_properties!(self, children);
602        Ok(Arc::new(InterleaveExec::try_new(children)?))
603    }
604
605    fn execute(
606        &self,
607        partition: usize,
608        context: Arc<TaskContext>,
609    ) -> Result<SendableRecordBatchStream> {
610        trace!(
611            "Start InterleaveExec::execute for partition {} of context session_id {} and task_id {:?}",
612            partition,
613            context.session_id(),
614            context.task_id()
615        );
616        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
617        // record the tiny amount of work done in this function so
618        // elapsed_compute is reported as non zero
619        let elapsed_compute = baseline_metrics.elapsed_compute().clone();
620        let _timer = elapsed_compute.timer(); // record on drop
621
622        let mut input_stream_vec = vec![];
623        for input in self.inputs.iter() {
624            if partition < input.output_partitioning().partition_count() {
625                input_stream_vec.push(input.execute(partition, Arc::clone(&context))?);
626            } else {
627                // Do not find a partition to execute
628                break;
629            }
630        }
631        if input_stream_vec.len() == self.inputs.len() {
632            let stream = Box::pin(CombinedRecordBatchStream::new(
633                self.schema(),
634                input_stream_vec,
635            ));
636            return Ok(Box::pin(ObservedStream::new(
637                stream,
638                baseline_metrics,
639                None,
640            )));
641        }
642
643        warn!("Error in InterleaveExec: Partition {partition} not found");
644
645        exec_err!("Partition {partition} not found in InterleaveExec")
646    }
647
648    fn metrics(&self) -> Option<MetricsSet> {
649        Some(self.metrics.clone_inner())
650    }
651
652    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
653        let stats = self
654            .inputs
655            .iter()
656            .map(|stat| stat.partition_statistics(partition))
657            .collect::<Result<Vec<_>>>()?;
658
659        Ok(stats
660            .into_iter()
661            .reduce(stats_union)
662            .unwrap_or_else(|| Statistics::new_unknown(&self.schema())))
663    }
664
665    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
666        vec![false; self.children().len()]
667    }
668}
669
670/// If all the input partitions have the same Hash partition spec with the first_input_partition
671/// The InterleaveExec is partition aware.
672///
673/// It might be too strict here in the case that the input partition specs are compatible but not exactly the same.
674/// For example one input partition has the partition spec Hash('a','b','c') and
675/// other has the partition spec Hash('a'), It is safe to derive the out partition with the spec Hash('a','b','c').
676pub fn can_interleave<T: Borrow<Arc<dyn ExecutionPlan>>>(
677    mut inputs: impl Iterator<Item = T>,
678) -> bool {
679    let Some(first) = inputs.next() else {
680        return false;
681    };
682
683    let reference = first.borrow().output_partitioning();
684    matches!(reference, Partitioning::Hash(_, _))
685        && inputs
686            .map(|plan| plan.borrow().output_partitioning().clone())
687            .all(|partition| partition == *reference)
688}
689
690fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> Result<SchemaRef> {
691    if inputs.is_empty() {
692        return exec_err!("Cannot create union schema from empty inputs");
693    }
694
695    let first_schema = inputs[0].schema();
696    let first_field_count = first_schema.fields().len();
697
698    // validate that all inputs have the same number of fields
699    for (idx, input) in inputs.iter().enumerate().skip(1) {
700        let field_count = input.schema().fields().len();
701        if field_count != first_field_count {
702            return exec_err!(
703                "UnionExec/InterleaveExec requires all inputs to have the same number of fields. \
704                 Input 0 has {first_field_count} fields, but input {idx} has {field_count} fields"
705            );
706        }
707    }
708
709    let fields = (0..first_field_count)
710        .map(|i| {
711            // We take the name from the left side of the union to match how names are coerced during logical planning,
712            // which also uses the left side names.
713            let base_field = first_schema.field(i).clone();
714
715            // Coerce metadata and nullability across all inputs
716
717            inputs
718                .iter()
719                .enumerate()
720                .map(|(input_idx, input)| {
721                    let field = input.schema().field(i).clone();
722                    let mut metadata = field.metadata().clone();
723
724                    let other_metadatas = inputs
725                        .iter()
726                        .enumerate()
727                        .filter(|(other_idx, _)| *other_idx != input_idx)
728                        .flat_map(|(_, other_input)| {
729                            other_input.schema().field(i).metadata().clone().into_iter()
730                        });
731
732                    metadata.extend(other_metadatas);
733                    field.with_metadata(metadata)
734                })
735                .find_or_first(Field::is_nullable)
736                // We can unwrap this because if inputs was empty, this would've already panic'ed when we
737                // indexed into inputs[0].
738                .unwrap()
739                .with_name(base_field.name())
740        })
741        .collect::<Vec<_>>();
742
743    let all_metadata_merged = inputs
744        .iter()
745        .flat_map(|i| i.schema().metadata().clone().into_iter())
746        .collect();
747
748    Ok(Arc::new(Schema::new_with_metadata(
749        fields,
750        all_metadata_merged,
751    )))
752}
753
754/// CombinedRecordBatchStream can be used to combine a Vec of SendableRecordBatchStreams into one
755struct CombinedRecordBatchStream {
756    /// Schema wrapped by Arc
757    schema: SchemaRef,
758    /// Stream entries
759    entries: Vec<SendableRecordBatchStream>,
760}
761
762impl CombinedRecordBatchStream {
763    /// Create an CombinedRecordBatchStream
764    pub fn new(schema: SchemaRef, entries: Vec<SendableRecordBatchStream>) -> Self {
765        Self { schema, entries }
766    }
767}
768
769impl RecordBatchStream for CombinedRecordBatchStream {
770    fn schema(&self) -> SchemaRef {
771        Arc::clone(&self.schema)
772    }
773}
774
775impl Stream for CombinedRecordBatchStream {
776    type Item = Result<RecordBatch>;
777
778    fn poll_next(
779        mut self: Pin<&mut Self>,
780        cx: &mut Context<'_>,
781    ) -> Poll<Option<Self::Item>> {
782        use Poll::*;
783
784        let start = thread_rng_n(self.entries.len() as u32) as usize;
785        let mut idx = start;
786
787        for _ in 0..self.entries.len() {
788            let stream = self.entries.get_mut(idx).unwrap();
789
790            match Pin::new(stream).poll_next(cx) {
791                Ready(Some(val)) => return Ready(Some(val)),
792                Ready(None) => {
793                    // Remove the entry
794                    self.entries.swap_remove(idx);
795
796                    // Check if this was the last entry, if so the cursor needs
797                    // to wrap
798                    if idx == self.entries.len() {
799                        idx = 0;
800                    } else if idx < start && start <= self.entries.len() {
801                        // The stream being swapped into the current index has
802                        // already been polled, so skip it.
803                        idx = idx.wrapping_add(1) % self.entries.len();
804                    }
805                }
806                Pending => {
807                    idx = idx.wrapping_add(1) % self.entries.len();
808                }
809            }
810        }
811
812        // If the map is empty, then the stream is complete.
813        if self.entries.is_empty() {
814            Ready(None)
815        } else {
816            Pending
817        }
818    }
819}
820
821fn col_stats_union(
822    mut left: ColumnStatistics,
823    right: &ColumnStatistics,
824) -> ColumnStatistics {
825    left.distinct_count = Precision::Absent;
826    left.min_value = left.min_value.min(&right.min_value);
827    left.max_value = left.max_value.max(&right.max_value);
828    left.sum_value = left.sum_value.add(&right.sum_value);
829    left.null_count = left.null_count.add(&right.null_count);
830
831    left
832}
833
834fn stats_union(mut left: Statistics, right: Statistics) -> Statistics {
835    let Statistics {
836        num_rows: right_num_rows,
837        total_byte_size: right_total_bytes,
838        column_statistics: right_column_statistics,
839        ..
840    } = right;
841    left.num_rows = left.num_rows.add(&right_num_rows);
842    left.total_byte_size = left.total_byte_size.add(&right_total_bytes);
843    left.column_statistics = left
844        .column_statistics
845        .into_iter()
846        .zip(right_column_statistics.iter())
847        .map(|(a, b)| col_stats_union(a, b))
848        .collect::<Vec<_>>();
849    left
850}
851
852#[cfg(test)]
853mod tests {
854    use super::*;
855    use crate::collect;
856    use crate::test::{self, TestMemoryExec};
857
858    use arrow::compute::SortOptions;
859    use arrow::datatypes::DataType;
860    use datafusion_common::ScalarValue;
861    use datafusion_physical_expr::equivalence::convert_to_orderings;
862    use datafusion_physical_expr::expressions::col;
863
864    // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g)
865    fn create_test_schema() -> Result<SchemaRef> {
866        let a = Field::new("a", DataType::Int32, true);
867        let b = Field::new("b", DataType::Int32, true);
868        let c = Field::new("c", DataType::Int32, true);
869        let d = Field::new("d", DataType::Int32, true);
870        let e = Field::new("e", DataType::Int32, true);
871        let f = Field::new("f", DataType::Int32, true);
872        let g = Field::new("g", DataType::Int32, true);
873        let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g]));
874
875        Ok(schema)
876    }
877
878    fn create_test_schema2() -> Result<SchemaRef> {
879        let a = Field::new("a", DataType::Int32, true);
880        let b = Field::new("b", DataType::Int32, true);
881        let c = Field::new("c", DataType::Int32, true);
882        let d = Field::new("d", DataType::Int32, true);
883        let e = Field::new("e", DataType::Int32, true);
884        let f = Field::new("f", DataType::Int32, true);
885        let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f]));
886
887        Ok(schema)
888    }
889
890    #[tokio::test]
891    async fn test_union_partitions() -> Result<()> {
892        let task_ctx = Arc::new(TaskContext::default());
893
894        // Create inputs with different partitioning
895        let csv = test::scan_partitioned(4);
896        let csv2 = test::scan_partitioned(5);
897
898        let union_exec: Arc<dyn ExecutionPlan> = UnionExec::try_new(vec![csv, csv2])?;
899
900        // Should have 9 partitions and 9 output batches
901        assert_eq!(
902            union_exec
903                .properties()
904                .output_partitioning()
905                .partition_count(),
906            9
907        );
908
909        let result: Vec<RecordBatch> = collect(union_exec, task_ctx).await?;
910        assert_eq!(result.len(), 9);
911
912        Ok(())
913    }
914
915    #[tokio::test]
916    async fn test_stats_union() {
917        let left = Statistics {
918            num_rows: Precision::Exact(5),
919            total_byte_size: Precision::Exact(23),
920            column_statistics: vec![
921                ColumnStatistics {
922                    distinct_count: Precision::Exact(5),
923                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
924                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
925                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
926                    null_count: Precision::Exact(0),
927                    byte_size: Precision::Absent,
928                },
929                ColumnStatistics {
930                    distinct_count: Precision::Exact(1),
931                    max_value: Precision::Exact(ScalarValue::from("x")),
932                    min_value: Precision::Exact(ScalarValue::from("a")),
933                    sum_value: Precision::Absent,
934                    null_count: Precision::Exact(3),
935                    byte_size: Precision::Absent,
936                },
937                ColumnStatistics {
938                    distinct_count: Precision::Absent,
939                    max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
940                    min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
941                    sum_value: Precision::Exact(ScalarValue::Float32(Some(42.0))),
942                    null_count: Precision::Absent,
943                    byte_size: Precision::Absent,
944                },
945            ],
946        };
947
948        let right = Statistics {
949            num_rows: Precision::Exact(7),
950            total_byte_size: Precision::Exact(29),
951            column_statistics: vec![
952                ColumnStatistics {
953                    distinct_count: Precision::Exact(3),
954                    max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
955                    min_value: Precision::Exact(ScalarValue::Int64(Some(1))),
956                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
957                    null_count: Precision::Exact(1),
958                    byte_size: Precision::Absent,
959                },
960                ColumnStatistics {
961                    distinct_count: Precision::Absent,
962                    max_value: Precision::Exact(ScalarValue::from("c")),
963                    min_value: Precision::Exact(ScalarValue::from("b")),
964                    sum_value: Precision::Absent,
965                    null_count: Precision::Absent,
966                    byte_size: Precision::Absent,
967                },
968                ColumnStatistics {
969                    distinct_count: Precision::Absent,
970                    max_value: Precision::Absent,
971                    min_value: Precision::Absent,
972                    sum_value: Precision::Absent,
973                    null_count: Precision::Absent,
974                    byte_size: Precision::Absent,
975                },
976            ],
977        };
978
979        let result = stats_union(left, right);
980        let expected = Statistics {
981            num_rows: Precision::Exact(12),
982            total_byte_size: Precision::Exact(52),
983            column_statistics: vec![
984                ColumnStatistics {
985                    distinct_count: Precision::Absent,
986                    max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
987                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
988                    sum_value: Precision::Exact(ScalarValue::Int64(Some(84))),
989                    null_count: Precision::Exact(1),
990                    byte_size: Precision::Absent,
991                },
992                ColumnStatistics {
993                    distinct_count: Precision::Absent,
994                    max_value: Precision::Exact(ScalarValue::from("x")),
995                    min_value: Precision::Exact(ScalarValue::from("a")),
996                    sum_value: Precision::Absent,
997                    null_count: Precision::Absent,
998                    byte_size: Precision::Absent,
999                },
1000                ColumnStatistics {
1001                    distinct_count: Precision::Absent,
1002                    max_value: Precision::Absent,
1003                    min_value: Precision::Absent,
1004                    sum_value: Precision::Absent,
1005                    null_count: Precision::Absent,
1006                    byte_size: Precision::Absent,
1007                },
1008            ],
1009        };
1010
1011        assert_eq!(result, expected);
1012    }
1013
1014    #[tokio::test]
1015    async fn test_union_equivalence_properties() -> Result<()> {
1016        let schema = create_test_schema()?;
1017        let col_a = &col("a", &schema)?;
1018        let col_b = &col("b", &schema)?;
1019        let col_c = &col("c", &schema)?;
1020        let col_d = &col("d", &schema)?;
1021        let col_e = &col("e", &schema)?;
1022        let col_f = &col("f", &schema)?;
1023        let options = SortOptions::default();
1024        let test_cases = [
1025            //-----------TEST CASE 1----------//
1026            (
1027                // First child orderings
1028                vec![
1029                    // [a ASC, b ASC, f ASC]
1030                    vec![(col_a, options), (col_b, options), (col_f, options)],
1031                ],
1032                // Second child orderings
1033                vec![
1034                    // [a ASC, b ASC, c ASC]
1035                    vec![(col_a, options), (col_b, options), (col_c, options)],
1036                    // [a ASC, b ASC, f ASC]
1037                    vec![(col_a, options), (col_b, options), (col_f, options)],
1038                ],
1039                // Union output orderings
1040                vec![
1041                    // [a ASC, b ASC, f ASC]
1042                    vec![(col_a, options), (col_b, options), (col_f, options)],
1043                ],
1044            ),
1045            //-----------TEST CASE 2----------//
1046            (
1047                // First child orderings
1048                vec![
1049                    // [a ASC, b ASC, f ASC]
1050                    vec![(col_a, options), (col_b, options), (col_f, options)],
1051                    // d ASC
1052                    vec![(col_d, options)],
1053                ],
1054                // Second child orderings
1055                vec![
1056                    // [a ASC, b ASC, c ASC]
1057                    vec![(col_a, options), (col_b, options), (col_c, options)],
1058                    // [e ASC]
1059                    vec![(col_e, options)],
1060                ],
1061                // Union output orderings
1062                vec![
1063                    // [a ASC, b ASC]
1064                    vec![(col_a, options), (col_b, options)],
1065                ],
1066            ),
1067        ];
1068
1069        for (
1070            test_idx,
1071            (first_child_orderings, second_child_orderings, union_orderings),
1072        ) in test_cases.iter().enumerate()
1073        {
1074            let first_orderings = convert_to_orderings(first_child_orderings);
1075            let second_orderings = convert_to_orderings(second_child_orderings);
1076            let union_expected_orderings = convert_to_orderings(union_orderings);
1077            let child1_exec = TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
1078                .try_with_sort_information(first_orderings)?;
1079            let child1 = Arc::new(child1_exec);
1080            let child1 = Arc::new(TestMemoryExec::update_cache(&child1));
1081            let child2_exec = TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
1082                .try_with_sort_information(second_orderings)?;
1083            let child2 = Arc::new(child2_exec);
1084            let child2 = Arc::new(TestMemoryExec::update_cache(&child2));
1085
1086            let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema));
1087            union_expected_eq.add_orderings(union_expected_orderings);
1088
1089            let union: Arc<dyn ExecutionPlan> = UnionExec::try_new(vec![child1, child2])?;
1090            let union_eq_properties = union.properties().equivalence_properties();
1091            let err_msg = format!(
1092                "Error in test id: {:?}, test case: {:?}",
1093                test_idx, test_cases[test_idx]
1094            );
1095            assert_eq_properties_same(union_eq_properties, &union_expected_eq, err_msg);
1096        }
1097        Ok(())
1098    }
1099
1100    fn assert_eq_properties_same(
1101        lhs: &EquivalenceProperties,
1102        rhs: &EquivalenceProperties,
1103        err_msg: String,
1104    ) {
1105        // Check whether orderings are same.
1106        let lhs_orderings = lhs.oeq_class();
1107        let rhs_orderings = rhs.oeq_class();
1108        assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{err_msg}");
1109        for rhs_ordering in rhs_orderings.iter() {
1110            assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg);
1111        }
1112    }
1113
1114    #[test]
1115    fn test_union_empty_inputs() {
1116        // Test that UnionExec::try_new fails with empty inputs
1117        let result = UnionExec::try_new(vec![]);
1118        assert!(
1119            result
1120                .unwrap_err()
1121                .to_string()
1122                .contains("UnionExec requires at least one input")
1123        );
1124    }
1125
1126    #[test]
1127    fn test_union_schema_empty_inputs() {
1128        // Test that union_schema fails with empty inputs
1129        let result = union_schema(&[]);
1130        assert!(
1131            result
1132                .unwrap_err()
1133                .to_string()
1134                .contains("Cannot create union schema from empty inputs")
1135        );
1136    }
1137
1138    #[test]
1139    fn test_union_single_input() -> Result<()> {
1140        // Test that UnionExec::try_new returns the single input directly
1141        let schema = create_test_schema()?;
1142        let memory_exec: Arc<dyn ExecutionPlan> =
1143            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
1144        let memory_exec_clone = Arc::clone(&memory_exec);
1145        let result = UnionExec::try_new(vec![memory_exec])?;
1146
1147        // Check that the result is the same as the input (no UnionExec wrapper)
1148        assert_eq!(result.schema(), schema);
1149        // Verify it's the same execution plan
1150        assert!(Arc::ptr_eq(&result, &memory_exec_clone));
1151
1152        Ok(())
1153    }
1154
1155    #[test]
1156    fn test_union_schema_multiple_inputs() -> Result<()> {
1157        // Test that existing functionality with multiple inputs still works
1158        let schema = create_test_schema()?;
1159        let memory_exec1 =
1160            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
1161        let memory_exec2 =
1162            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
1163
1164        let union_plan = UnionExec::try_new(vec![memory_exec1, memory_exec2])?;
1165
1166        // Downcast to verify it's a UnionExec
1167        let union = union_plan
1168            .as_any()
1169            .downcast_ref::<UnionExec>()
1170            .expect("Expected UnionExec");
1171
1172        // Check that schema is correct
1173        assert_eq!(union.schema(), schema);
1174        // Check that we have 2 inputs
1175        assert_eq!(union.inputs().len(), 2);
1176
1177        Ok(())
1178    }
1179
1180    #[test]
1181    fn test_union_schema_mismatch() {
1182        // Test that UnionExec properly rejects inputs with different field counts
1183        let schema = create_test_schema().unwrap();
1184        let schema2 = create_test_schema2().unwrap();
1185        let memory_exec1 =
1186            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None).unwrap());
1187        let memory_exec2 =
1188            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema2), None).unwrap());
1189
1190        let result = UnionExec::try_new(vec![memory_exec1, memory_exec2]);
1191        assert!(result.is_err());
1192        assert!(
1193            result.unwrap_err().to_string().contains(
1194                "UnionExec/InterleaveExec requires all inputs to have the same number of fields"
1195            )
1196        );
1197    }
1198}