Skip to main content

datafusion_physical_plan/
union.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18// Some of these functions reference the Postgres documentation
19// or implementation to ensure compatibility and are subject to
20// the Postgres license.
21
22//! The Union operator combines multiple inputs with the same schema
23
24use std::borrow::Borrow;
25use std::pin::Pin;
26use std::task::{Context, Poll};
27use std::{any::Any, sync::Arc};
28
29use super::{
30    ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan,
31    ExecutionPlanProperties, Partitioning, PlanProperties, RecordBatchStream,
32    SendableRecordBatchStream, Statistics,
33    metrics::{ExecutionPlanMetricsSet, MetricsSet},
34};
35use crate::execution_plan::{
36    InvariantLevel, boundedness_from_children, check_default_invariants,
37    emission_type_from_children,
38};
39use crate::filter_pushdown::{FilterDescription, FilterPushdownPhase};
40use crate::metrics::BaselineMetrics;
41use crate::projection::{ProjectionExec, make_with_child};
42use crate::stream::ObservedStream;
43
44use arrow::datatypes::{Field, Schema, SchemaRef};
45use arrow::record_batch::RecordBatch;
46use datafusion_common::config::ConfigOptions;
47use datafusion_common::stats::Precision;
48use datafusion_common::{
49    Result, assert_or_internal_err, exec_err, internal_datafusion_err,
50};
51use datafusion_execution::TaskContext;
52use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, calculate_union};
53
54use futures::Stream;
55use itertools::Itertools;
56use log::{debug, trace, warn};
57use tokio::macros::support::thread_rng_n;
58
59/// `UnionExec`: `UNION ALL` execution plan.
60///
61/// `UnionExec` combines multiple inputs with the same schema by
62/// concatenating the partitions.  It does not mix or copy data within
63/// or across partitions. Thus if the input partitions are sorted, the
64/// output partitions of the union are also sorted.
65///
66/// For example, given a `UnionExec` of two inputs, with `N`
67/// partitions, and `M` partitions, there will be `N+M` output
68/// partitions. The first `N` output partitions are from Input 1
69/// partitions, and then next `M` output partitions are from Input 2.
70///
71/// ```text
72///                        ▲       ▲           ▲         ▲
73///                        │       │           │         │
74///      Output            │  ...  │           │         │
75///    Partitions          │0      │N-1        │ N       │N+M-1
76/// (passes through   ┌────┴───────┴───────────┴─────────┴───┐
77///  the N+M input    │              UnionExec               │
78///   partitions)     │                                      │
79///                   └──────────────────────────────────────┘
80///                                      ▲
81///                                      │
82///                                      │
83///       Input           ┌────────┬─────┴────┬──────────┐
84///     Partitions        │ ...    │          │     ...  │
85///                    0  │        │ N-1      │ 0        │  M-1
86///                  ┌────┴────────┴───┐  ┌───┴──────────┴───┐
87///                  │                 │  │                  │
88///                  │                 │  │                  │
89///                  │                 │  │                  │
90///                  │                 │  │                  │
91///                  │                 │  │                  │
92///                  │                 │  │                  │
93///                  │Input 1          │  │Input 2           │
94///                  └─────────────────┘  └──────────────────┘
95/// ```
96#[derive(Debug, Clone)]
97pub struct UnionExec {
98    /// Input execution plan
99    inputs: Vec<Arc<dyn ExecutionPlan>>,
100    /// Execution metrics
101    metrics: ExecutionPlanMetricsSet,
102    /// Cache holding plan properties like equivalences, output partitioning etc.
103    cache: PlanProperties,
104}
105
106impl UnionExec {
107    /// Create a new UnionExec
108    #[deprecated(since = "44.0.0", note = "Use UnionExec::try_new instead")]
109    pub fn new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Self {
110        let schema =
111            union_schema(&inputs).expect("UnionExec::new called with empty inputs");
112        // The schema of the inputs and the union schema is consistent when:
113        // - They have the same number of fields, and
114        // - Their fields have same types at the same indices.
115        // Here, we know that schemas are consistent and the call below can
116        // not return an error.
117        let cache = Self::compute_properties(&inputs, schema).unwrap();
118        UnionExec {
119            inputs,
120            metrics: ExecutionPlanMetricsSet::new(),
121            cache,
122        }
123    }
124
125    /// Try to create a new UnionExec.
126    ///
127    /// # Errors
128    /// Returns an error if:
129    /// - `inputs` is empty
130    ///
131    /// # Optimization
132    /// If there is only one input, returns that input directly rather than wrapping it in a UnionExec
133    pub fn try_new(
134        inputs: Vec<Arc<dyn ExecutionPlan>>,
135    ) -> Result<Arc<dyn ExecutionPlan>> {
136        match inputs.len() {
137            0 => exec_err!("UnionExec requires at least one input"),
138            1 => Ok(inputs.into_iter().next().unwrap()),
139            _ => {
140                let schema = union_schema(&inputs)?;
141                // The schema of the inputs and the union schema is consistent when:
142                // - They have the same number of fields, and
143                // - Their fields have same types at the same indices.
144                // Here, we know that schemas are consistent and the call below can
145                // not return an error.
146                let cache = Self::compute_properties(&inputs, schema).unwrap();
147                Ok(Arc::new(UnionExec {
148                    inputs,
149                    metrics: ExecutionPlanMetricsSet::new(),
150                    cache,
151                }))
152            }
153        }
154    }
155
156    /// Get inputs of the execution plan
157    pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
158        &self.inputs
159    }
160
161    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
162    fn compute_properties(
163        inputs: &[Arc<dyn ExecutionPlan>],
164        schema: SchemaRef,
165    ) -> Result<PlanProperties> {
166        // Calculate equivalence properties:
167        let children_eqps = inputs
168            .iter()
169            .map(|child| child.equivalence_properties().clone())
170            .collect::<Vec<_>>();
171        let eq_properties = calculate_union(children_eqps, schema)?;
172
173        // Calculate output partitioning; i.e. sum output partitions of the inputs.
174        let num_partitions = inputs
175            .iter()
176            .map(|plan| plan.output_partitioning().partition_count())
177            .sum();
178        let output_partitioning = Partitioning::UnknownPartitioning(num_partitions);
179        Ok(PlanProperties::new(
180            eq_properties,
181            output_partitioning,
182            emission_type_from_children(inputs),
183            boundedness_from_children(inputs),
184        ))
185    }
186}
187
188impl DisplayAs for UnionExec {
189    fn fmt_as(
190        &self,
191        t: DisplayFormatType,
192        f: &mut std::fmt::Formatter,
193    ) -> std::fmt::Result {
194        match t {
195            DisplayFormatType::Default | DisplayFormatType::Verbose => {
196                write!(f, "UnionExec")
197            }
198            DisplayFormatType::TreeRender => Ok(()),
199        }
200    }
201}
202
203impl ExecutionPlan for UnionExec {
204    fn name(&self) -> &'static str {
205        "UnionExec"
206    }
207
208    /// Return a reference to Any that can be used for downcasting
209    fn as_any(&self) -> &dyn Any {
210        self
211    }
212
213    fn properties(&self) -> &PlanProperties {
214        &self.cache
215    }
216
217    fn check_invariants(&self, check: InvariantLevel) -> Result<()> {
218        check_default_invariants(self, check)?;
219
220        (self.inputs().len() >= 2).then_some(()).ok_or_else(|| {
221            internal_datafusion_err!("UnionExec should have at least 2 children")
222        })
223    }
224
225    fn maintains_input_order(&self) -> Vec<bool> {
226        // If the Union has an output ordering, it maintains at least one
227        // child's ordering (i.e. the meet).
228        // For instance, assume that the first child is SortExpr('a','b','c'),
229        // the second child is SortExpr('a','b') and the third child is
230        // SortExpr('a','b'). The output ordering would be SortExpr('a','b'),
231        // which is the "meet" of all input orderings. In this example, this
232        // function will return vec![false, true, true], indicating that we
233        // preserve the orderings for the 2nd and the 3rd children.
234        if let Some(output_ordering) = self.properties().output_ordering() {
235            self.inputs()
236                .iter()
237                .map(|child| {
238                    if let Some(child_ordering) = child.output_ordering() {
239                        output_ordering.len() == child_ordering.len()
240                    } else {
241                        false
242                    }
243                })
244                .collect()
245        } else {
246            vec![false; self.inputs().len()]
247        }
248    }
249
250    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
251        vec![false; self.children().len()]
252    }
253
254    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
255        self.inputs.iter().collect()
256    }
257
258    fn with_new_children(
259        self: Arc<Self>,
260        children: Vec<Arc<dyn ExecutionPlan>>,
261    ) -> Result<Arc<dyn ExecutionPlan>> {
262        UnionExec::try_new(children)
263    }
264
265    fn execute(
266        &self,
267        mut partition: usize,
268        context: Arc<TaskContext>,
269    ) -> Result<SendableRecordBatchStream> {
270        trace!(
271            "Start UnionExec::execute for partition {} of context session_id {} and task_id {:?}",
272            partition,
273            context.session_id(),
274            context.task_id()
275        );
276        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
277        // record the tiny amount of work done in this function so
278        // elapsed_compute is reported as non zero
279        let elapsed_compute = baseline_metrics.elapsed_compute().clone();
280        let _timer = elapsed_compute.timer(); // record on drop
281
282        // find partition to execute
283        for input in self.inputs.iter() {
284            // Calculate whether partition belongs to the current partition
285            if partition < input.output_partitioning().partition_count() {
286                let stream = input.execute(partition, context)?;
287                debug!("Found a Union partition to execute");
288                return Ok(Box::pin(ObservedStream::new(
289                    stream,
290                    baseline_metrics,
291                    None,
292                )));
293            } else {
294                partition -= input.output_partitioning().partition_count();
295            }
296        }
297
298        warn!("Error in Union: Partition {partition} not found");
299
300        exec_err!("Partition {partition} not found in Union")
301    }
302
303    fn metrics(&self) -> Option<MetricsSet> {
304        Some(self.metrics.clone_inner())
305    }
306
307    fn statistics(&self) -> Result<Statistics> {
308        self.partition_statistics(None)
309    }
310
311    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
312        if let Some(partition_idx) = partition {
313            // For a specific partition, find which input it belongs to
314            let mut remaining_idx = partition_idx;
315            for input in &self.inputs {
316                let input_partition_count = input.output_partitioning().partition_count();
317                if remaining_idx < input_partition_count {
318                    // This partition belongs to this input
319                    return input.partition_statistics(Some(remaining_idx));
320                }
321                remaining_idx -= input_partition_count;
322            }
323            // If we get here, the partition index is out of bounds
324            Ok(Statistics::new_unknown(&self.schema()))
325        } else {
326            // Collect statistics from all inputs
327            let stats = self
328                .inputs
329                .iter()
330                .map(|input_exec| input_exec.partition_statistics(None))
331                .collect::<Result<Vec<_>>>()?;
332
333            Ok(stats
334                .into_iter()
335                .reduce(stats_union)
336                .unwrap_or_else(|| Statistics::new_unknown(&self.schema())))
337        }
338    }
339
340    fn supports_limit_pushdown(&self) -> bool {
341        true
342    }
343
344    /// Tries to push `projection` down through `union`. If possible, performs the
345    /// pushdown and returns a new [`UnionExec`] as the top plan which has projections
346    /// as its children. Otherwise, returns `None`.
347    fn try_swapping_with_projection(
348        &self,
349        projection: &ProjectionExec,
350    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
351        // If the projection doesn't narrow the schema, we shouldn't try to push it down.
352        if projection.expr().len() >= projection.input().schema().fields().len() {
353            return Ok(None);
354        }
355
356        let new_children = self
357            .children()
358            .into_iter()
359            .map(|child| make_with_child(projection, child))
360            .collect::<Result<Vec<_>>>()?;
361
362        Ok(Some(UnionExec::try_new(new_children.clone())?))
363    }
364
365    fn gather_filters_for_pushdown(
366        &self,
367        _phase: FilterPushdownPhase,
368        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
369        _config: &ConfigOptions,
370    ) -> Result<FilterDescription> {
371        FilterDescription::from_children(parent_filters, &self.children())
372    }
373}
374
375/// Combines multiple input streams by interleaving them.
376///
377/// This only works if all inputs have the same hash-partitioning.
378///
379/// # Data Flow
380/// ```text
381/// +---------+
382/// |         |---+
383/// | Input 1 |   |
384/// |         |-------------+
385/// +---------+   |         |
386///               |         |         +---------+
387///               +------------------>|         |
388///                 +---------------->| Combine |-->
389///                 | +-------------->|         |
390///                 | |     |         +---------+
391/// +---------+     | |     |
392/// |         |-----+ |     |
393/// | Input 2 |       |     |
394/// |         |---------------+
395/// +---------+       |     | |
396///                   |     | |       +---------+
397///                   |     +-------->|         |
398///                   |       +------>| Combine |-->
399///                   |         +---->|         |
400///                   |         |     +---------+
401/// +---------+       |         |
402/// |         |-------+         |
403/// | Input 3 |                 |
404/// |         |-----------------+
405/// +---------+
406/// ```
407#[derive(Debug, Clone)]
408pub struct InterleaveExec {
409    /// Input execution plan
410    inputs: Vec<Arc<dyn ExecutionPlan>>,
411    /// Execution metrics
412    metrics: ExecutionPlanMetricsSet,
413    /// Cache holding plan properties like equivalences, output partitioning etc.
414    cache: PlanProperties,
415}
416
417impl InterleaveExec {
418    /// Create a new InterleaveExec
419    pub fn try_new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Result<Self> {
420        assert_or_internal_err!(
421            can_interleave(inputs.iter()),
422            "Not all InterleaveExec children have a consistent hash partitioning"
423        );
424        let cache = Self::compute_properties(&inputs)?;
425        Ok(InterleaveExec {
426            inputs,
427            metrics: ExecutionPlanMetricsSet::new(),
428            cache,
429        })
430    }
431
432    /// Get inputs of the execution plan
433    pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
434        &self.inputs
435    }
436
437    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
438    fn compute_properties(inputs: &[Arc<dyn ExecutionPlan>]) -> Result<PlanProperties> {
439        let schema = union_schema(inputs)?;
440        let eq_properties = EquivalenceProperties::new(schema);
441        // Get output partitioning:
442        let output_partitioning = inputs[0].output_partitioning().clone();
443        Ok(PlanProperties::new(
444            eq_properties,
445            output_partitioning,
446            emission_type_from_children(inputs),
447            boundedness_from_children(inputs),
448        ))
449    }
450}
451
452impl DisplayAs for InterleaveExec {
453    fn fmt_as(
454        &self,
455        t: DisplayFormatType,
456        f: &mut std::fmt::Formatter,
457    ) -> std::fmt::Result {
458        match t {
459            DisplayFormatType::Default | DisplayFormatType::Verbose => {
460                write!(f, "InterleaveExec")
461            }
462            DisplayFormatType::TreeRender => Ok(()),
463        }
464    }
465}
466
467impl ExecutionPlan for InterleaveExec {
468    fn name(&self) -> &'static str {
469        "InterleaveExec"
470    }
471
472    /// Return a reference to Any that can be used for downcasting
473    fn as_any(&self) -> &dyn Any {
474        self
475    }
476
477    fn properties(&self) -> &PlanProperties {
478        &self.cache
479    }
480
481    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
482        self.inputs.iter().collect()
483    }
484
485    fn maintains_input_order(&self) -> Vec<bool> {
486        vec![false; self.inputs().len()]
487    }
488
489    fn with_new_children(
490        self: Arc<Self>,
491        children: Vec<Arc<dyn ExecutionPlan>>,
492    ) -> Result<Arc<dyn ExecutionPlan>> {
493        // New children are no longer interleavable, which might be a bug of optimization rewrite.
494        assert_or_internal_err!(
495            can_interleave(children.iter()),
496            "Can not create InterleaveExec: new children can not be interleaved"
497        );
498        Ok(Arc::new(InterleaveExec::try_new(children)?))
499    }
500
501    fn execute(
502        &self,
503        partition: usize,
504        context: Arc<TaskContext>,
505    ) -> Result<SendableRecordBatchStream> {
506        trace!(
507            "Start InterleaveExec::execute for partition {} of context session_id {} and task_id {:?}",
508            partition,
509            context.session_id(),
510            context.task_id()
511        );
512        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
513        // record the tiny amount of work done in this function so
514        // elapsed_compute is reported as non zero
515        let elapsed_compute = baseline_metrics.elapsed_compute().clone();
516        let _timer = elapsed_compute.timer(); // record on drop
517
518        let mut input_stream_vec = vec![];
519        for input in self.inputs.iter() {
520            if partition < input.output_partitioning().partition_count() {
521                input_stream_vec.push(input.execute(partition, Arc::clone(&context))?);
522            } else {
523                // Do not find a partition to execute
524                break;
525            }
526        }
527        if input_stream_vec.len() == self.inputs.len() {
528            let stream = Box::pin(CombinedRecordBatchStream::new(
529                self.schema(),
530                input_stream_vec,
531            ));
532            return Ok(Box::pin(ObservedStream::new(
533                stream,
534                baseline_metrics,
535                None,
536            )));
537        }
538
539        warn!("Error in InterleaveExec: Partition {partition} not found");
540
541        exec_err!("Partition {partition} not found in InterleaveExec")
542    }
543
544    fn metrics(&self) -> Option<MetricsSet> {
545        Some(self.metrics.clone_inner())
546    }
547
548    fn statistics(&self) -> Result<Statistics> {
549        self.partition_statistics(None)
550    }
551
552    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
553        let stats = self
554            .inputs
555            .iter()
556            .map(|stat| stat.partition_statistics(partition))
557            .collect::<Result<Vec<_>>>()?;
558
559        Ok(stats
560            .into_iter()
561            .reduce(stats_union)
562            .unwrap_or_else(|| Statistics::new_unknown(&self.schema())))
563    }
564
565    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
566        vec![false; self.children().len()]
567    }
568}
569
570/// If all the input partitions have the same Hash partition spec with the first_input_partition
571/// The InterleaveExec is partition aware.
572///
573/// It might be too strict here in the case that the input partition specs are compatible but not exactly the same.
574/// For example one input partition has the partition spec Hash('a','b','c') and
575/// other has the partition spec Hash('a'), It is safe to derive the out partition with the spec Hash('a','b','c').
576pub fn can_interleave<T: Borrow<Arc<dyn ExecutionPlan>>>(
577    mut inputs: impl Iterator<Item = T>,
578) -> bool {
579    let Some(first) = inputs.next() else {
580        return false;
581    };
582
583    let reference = first.borrow().output_partitioning();
584    matches!(reference, Partitioning::Hash(_, _))
585        && inputs
586            .map(|plan| plan.borrow().output_partitioning().clone())
587            .all(|partition| partition == *reference)
588}
589
590fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> Result<SchemaRef> {
591    if inputs.is_empty() {
592        return exec_err!("Cannot create union schema from empty inputs");
593    }
594
595    let first_schema = inputs[0].schema();
596
597    let fields = (0..first_schema.fields().len())
598        .map(|i| {
599            // We take the name from the left side of the union to match how names are coerced during logical planning,
600            // which also uses the left side names.
601            let base_field = first_schema.field(i).clone();
602
603            // Coerce metadata and nullability across all inputs
604
605            inputs
606                .iter()
607                .enumerate()
608                .map(|(input_idx, input)| {
609                    let field = input.schema().field(i).clone();
610                    let mut metadata = field.metadata().clone();
611
612                    let other_metadatas = inputs
613                        .iter()
614                        .enumerate()
615                        .filter(|(other_idx, _)| *other_idx != input_idx)
616                        .flat_map(|(_, other_input)| {
617                            other_input.schema().field(i).metadata().clone().into_iter()
618                        });
619
620                    metadata.extend(other_metadatas);
621                    field.with_metadata(metadata)
622                })
623                .find_or_first(Field::is_nullable)
624                // We can unwrap this because if inputs was empty, this would've already panic'ed when we
625                // indexed into inputs[0].
626                .unwrap()
627                .with_name(base_field.name())
628        })
629        .collect::<Vec<_>>();
630
631    let all_metadata_merged = inputs
632        .iter()
633        .flat_map(|i| i.schema().metadata().clone().into_iter())
634        .collect();
635
636    Ok(Arc::new(Schema::new_with_metadata(
637        fields,
638        all_metadata_merged,
639    )))
640}
641
642/// CombinedRecordBatchStream can be used to combine a Vec of SendableRecordBatchStreams into one
643struct CombinedRecordBatchStream {
644    /// Schema wrapped by Arc
645    schema: SchemaRef,
646    /// Stream entries
647    entries: Vec<SendableRecordBatchStream>,
648}
649
650impl CombinedRecordBatchStream {
651    /// Create an CombinedRecordBatchStream
652    pub fn new(schema: SchemaRef, entries: Vec<SendableRecordBatchStream>) -> Self {
653        Self { schema, entries }
654    }
655}
656
657impl RecordBatchStream for CombinedRecordBatchStream {
658    fn schema(&self) -> SchemaRef {
659        Arc::clone(&self.schema)
660    }
661}
662
663impl Stream for CombinedRecordBatchStream {
664    type Item = Result<RecordBatch>;
665
666    fn poll_next(
667        mut self: Pin<&mut Self>,
668        cx: &mut Context<'_>,
669    ) -> Poll<Option<Self::Item>> {
670        use Poll::*;
671
672        let start = thread_rng_n(self.entries.len() as u32) as usize;
673        let mut idx = start;
674
675        for _ in 0..self.entries.len() {
676            let stream = self.entries.get_mut(idx).unwrap();
677
678            match Pin::new(stream).poll_next(cx) {
679                Ready(Some(val)) => return Ready(Some(val)),
680                Ready(None) => {
681                    // Remove the entry
682                    self.entries.swap_remove(idx);
683
684                    // Check if this was the last entry, if so the cursor needs
685                    // to wrap
686                    if idx == self.entries.len() {
687                        idx = 0;
688                    } else if idx < start && start <= self.entries.len() {
689                        // The stream being swapped into the current index has
690                        // already been polled, so skip it.
691                        idx = idx.wrapping_add(1) % self.entries.len();
692                    }
693                }
694                Pending => {
695                    idx = idx.wrapping_add(1) % self.entries.len();
696                }
697            }
698        }
699
700        // If the map is empty, then the stream is complete.
701        if self.entries.is_empty() {
702            Ready(None)
703        } else {
704            Pending
705        }
706    }
707}
708
709fn col_stats_union(
710    mut left: ColumnStatistics,
711    right: &ColumnStatistics,
712) -> ColumnStatistics {
713    left.distinct_count = Precision::Absent;
714    left.min_value = left.min_value.min(&right.min_value);
715    left.max_value = left.max_value.max(&right.max_value);
716    left.sum_value = left.sum_value.add(&right.sum_value);
717    left.null_count = left.null_count.add(&right.null_count);
718
719    left
720}
721
722fn stats_union(mut left: Statistics, right: Statistics) -> Statistics {
723    let Statistics {
724        num_rows: right_num_rows,
725        total_byte_size: right_total_bytes,
726        column_statistics: right_column_statistics,
727        ..
728    } = right;
729    left.num_rows = left.num_rows.add(&right_num_rows);
730    left.total_byte_size = left.total_byte_size.add(&right_total_bytes);
731    left.column_statistics = left
732        .column_statistics
733        .into_iter()
734        .zip(right_column_statistics.iter())
735        .map(|(a, b)| col_stats_union(a, b))
736        .collect::<Vec<_>>();
737    left
738}
739
740#[cfg(test)]
741mod tests {
742    use super::*;
743    use crate::collect;
744    use crate::test::{self, TestMemoryExec};
745
746    use arrow::compute::SortOptions;
747    use arrow::datatypes::DataType;
748    use datafusion_common::ScalarValue;
749    use datafusion_physical_expr::equivalence::convert_to_orderings;
750    use datafusion_physical_expr::expressions::col;
751
752    // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g)
753    fn create_test_schema() -> Result<SchemaRef> {
754        let a = Field::new("a", DataType::Int32, true);
755        let b = Field::new("b", DataType::Int32, true);
756        let c = Field::new("c", DataType::Int32, true);
757        let d = Field::new("d", DataType::Int32, true);
758        let e = Field::new("e", DataType::Int32, true);
759        let f = Field::new("f", DataType::Int32, true);
760        let g = Field::new("g", DataType::Int32, true);
761        let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g]));
762
763        Ok(schema)
764    }
765
766    #[tokio::test]
767    async fn test_union_partitions() -> Result<()> {
768        let task_ctx = Arc::new(TaskContext::default());
769
770        // Create inputs with different partitioning
771        let csv = test::scan_partitioned(4);
772        let csv2 = test::scan_partitioned(5);
773
774        let union_exec: Arc<dyn ExecutionPlan> = UnionExec::try_new(vec![csv, csv2])?;
775
776        // Should have 9 partitions and 9 output batches
777        assert_eq!(
778            union_exec
779                .properties()
780                .output_partitioning()
781                .partition_count(),
782            9
783        );
784
785        let result: Vec<RecordBatch> = collect(union_exec, task_ctx).await?;
786        assert_eq!(result.len(), 9);
787
788        Ok(())
789    }
790
791    #[tokio::test]
792    async fn test_stats_union() {
793        let left = Statistics {
794            num_rows: Precision::Exact(5),
795            total_byte_size: Precision::Exact(23),
796            column_statistics: vec![
797                ColumnStatistics {
798                    distinct_count: Precision::Exact(5),
799                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
800                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
801                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
802                    null_count: Precision::Exact(0),
803                    byte_size: Precision::Absent,
804                },
805                ColumnStatistics {
806                    distinct_count: Precision::Exact(1),
807                    max_value: Precision::Exact(ScalarValue::from("x")),
808                    min_value: Precision::Exact(ScalarValue::from("a")),
809                    sum_value: Precision::Absent,
810                    null_count: Precision::Exact(3),
811                    byte_size: Precision::Absent,
812                },
813                ColumnStatistics {
814                    distinct_count: Precision::Absent,
815                    max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
816                    min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
817                    sum_value: Precision::Exact(ScalarValue::Float32(Some(42.0))),
818                    null_count: Precision::Absent,
819                    byte_size: Precision::Absent,
820                },
821            ],
822        };
823
824        let right = Statistics {
825            num_rows: Precision::Exact(7),
826            total_byte_size: Precision::Exact(29),
827            column_statistics: vec![
828                ColumnStatistics {
829                    distinct_count: Precision::Exact(3),
830                    max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
831                    min_value: Precision::Exact(ScalarValue::Int64(Some(1))),
832                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
833                    null_count: Precision::Exact(1),
834                    byte_size: Precision::Absent,
835                },
836                ColumnStatistics {
837                    distinct_count: Precision::Absent,
838                    max_value: Precision::Exact(ScalarValue::from("c")),
839                    min_value: Precision::Exact(ScalarValue::from("b")),
840                    sum_value: Precision::Absent,
841                    null_count: Precision::Absent,
842                    byte_size: Precision::Absent,
843                },
844                ColumnStatistics {
845                    distinct_count: Precision::Absent,
846                    max_value: Precision::Absent,
847                    min_value: Precision::Absent,
848                    sum_value: Precision::Absent,
849                    null_count: Precision::Absent,
850                    byte_size: Precision::Absent,
851                },
852            ],
853        };
854
855        let result = stats_union(left, right);
856        let expected = Statistics {
857            num_rows: Precision::Exact(12),
858            total_byte_size: Precision::Exact(52),
859            column_statistics: vec![
860                ColumnStatistics {
861                    distinct_count: Precision::Absent,
862                    max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
863                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
864                    sum_value: Precision::Exact(ScalarValue::Int64(Some(84))),
865                    null_count: Precision::Exact(1),
866                    byte_size: Precision::Absent,
867                },
868                ColumnStatistics {
869                    distinct_count: Precision::Absent,
870                    max_value: Precision::Exact(ScalarValue::from("x")),
871                    min_value: Precision::Exact(ScalarValue::from("a")),
872                    sum_value: Precision::Absent,
873                    null_count: Precision::Absent,
874                    byte_size: Precision::Absent,
875                },
876                ColumnStatistics {
877                    distinct_count: Precision::Absent,
878                    max_value: Precision::Absent,
879                    min_value: Precision::Absent,
880                    sum_value: Precision::Absent,
881                    null_count: Precision::Absent,
882                    byte_size: Precision::Absent,
883                },
884            ],
885        };
886
887        assert_eq!(result, expected);
888    }
889
890    #[tokio::test]
891    async fn test_union_equivalence_properties() -> Result<()> {
892        let schema = create_test_schema()?;
893        let col_a = &col("a", &schema)?;
894        let col_b = &col("b", &schema)?;
895        let col_c = &col("c", &schema)?;
896        let col_d = &col("d", &schema)?;
897        let col_e = &col("e", &schema)?;
898        let col_f = &col("f", &schema)?;
899        let options = SortOptions::default();
900        let test_cases = [
901            //-----------TEST CASE 1----------//
902            (
903                // First child orderings
904                vec![
905                    // [a ASC, b ASC, f ASC]
906                    vec![(col_a, options), (col_b, options), (col_f, options)],
907                ],
908                // Second child orderings
909                vec![
910                    // [a ASC, b ASC, c ASC]
911                    vec![(col_a, options), (col_b, options), (col_c, options)],
912                    // [a ASC, b ASC, f ASC]
913                    vec![(col_a, options), (col_b, options), (col_f, options)],
914                ],
915                // Union output orderings
916                vec![
917                    // [a ASC, b ASC, f ASC]
918                    vec![(col_a, options), (col_b, options), (col_f, options)],
919                ],
920            ),
921            //-----------TEST CASE 2----------//
922            (
923                // First child orderings
924                vec![
925                    // [a ASC, b ASC, f ASC]
926                    vec![(col_a, options), (col_b, options), (col_f, options)],
927                    // d ASC
928                    vec![(col_d, options)],
929                ],
930                // Second child orderings
931                vec![
932                    // [a ASC, b ASC, c ASC]
933                    vec![(col_a, options), (col_b, options), (col_c, options)],
934                    // [e ASC]
935                    vec![(col_e, options)],
936                ],
937                // Union output orderings
938                vec![
939                    // [a ASC, b ASC]
940                    vec![(col_a, options), (col_b, options)],
941                ],
942            ),
943        ];
944
945        for (
946            test_idx,
947            (first_child_orderings, second_child_orderings, union_orderings),
948        ) in test_cases.iter().enumerate()
949        {
950            let first_orderings = convert_to_orderings(first_child_orderings);
951            let second_orderings = convert_to_orderings(second_child_orderings);
952            let union_expected_orderings = convert_to_orderings(union_orderings);
953            let child1_exec = TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
954                .try_with_sort_information(first_orderings)?;
955            let child1 = Arc::new(child1_exec);
956            let child1 = Arc::new(TestMemoryExec::update_cache(&child1));
957            let child2_exec = TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
958                .try_with_sort_information(second_orderings)?;
959            let child2 = Arc::new(child2_exec);
960            let child2 = Arc::new(TestMemoryExec::update_cache(&child2));
961
962            let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema));
963            union_expected_eq.add_orderings(union_expected_orderings);
964
965            let union: Arc<dyn ExecutionPlan> = UnionExec::try_new(vec![child1, child2])?;
966            let union_eq_properties = union.properties().equivalence_properties();
967            let err_msg = format!(
968                "Error in test id: {:?}, test case: {:?}",
969                test_idx, test_cases[test_idx]
970            );
971            assert_eq_properties_same(union_eq_properties, &union_expected_eq, err_msg);
972        }
973        Ok(())
974    }
975
976    fn assert_eq_properties_same(
977        lhs: &EquivalenceProperties,
978        rhs: &EquivalenceProperties,
979        err_msg: String,
980    ) {
981        // Check whether orderings are same.
982        let lhs_orderings = lhs.oeq_class();
983        let rhs_orderings = rhs.oeq_class();
984        assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{err_msg}");
985        for rhs_ordering in rhs_orderings.iter() {
986            assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg);
987        }
988    }
989
990    #[test]
991    fn test_union_empty_inputs() {
992        // Test that UnionExec::try_new fails with empty inputs
993        let result = UnionExec::try_new(vec![]);
994        assert!(
995            result
996                .unwrap_err()
997                .to_string()
998                .contains("UnionExec requires at least one input")
999        );
1000    }
1001
1002    #[test]
1003    fn test_union_schema_empty_inputs() {
1004        // Test that union_schema fails with empty inputs
1005        let result = union_schema(&[]);
1006        assert!(
1007            result
1008                .unwrap_err()
1009                .to_string()
1010                .contains("Cannot create union schema from empty inputs")
1011        );
1012    }
1013
1014    #[test]
1015    fn test_union_single_input() -> Result<()> {
1016        // Test that UnionExec::try_new returns the single input directly
1017        let schema = create_test_schema()?;
1018        let memory_exec: Arc<dyn ExecutionPlan> =
1019            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
1020        let memory_exec_clone = Arc::clone(&memory_exec);
1021        let result = UnionExec::try_new(vec![memory_exec])?;
1022
1023        // Check that the result is the same as the input (no UnionExec wrapper)
1024        assert_eq!(result.schema(), schema);
1025        // Verify it's the same execution plan
1026        assert!(Arc::ptr_eq(&result, &memory_exec_clone));
1027
1028        Ok(())
1029    }
1030
1031    #[test]
1032    fn test_union_schema_multiple_inputs() -> Result<()> {
1033        // Test that existing functionality with multiple inputs still works
1034        let schema = create_test_schema()?;
1035        let memory_exec1 =
1036            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
1037        let memory_exec2 =
1038            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
1039
1040        let union_plan = UnionExec::try_new(vec![memory_exec1, memory_exec2])?;
1041
1042        // Downcast to verify it's a UnionExec
1043        let union = union_plan
1044            .as_any()
1045            .downcast_ref::<UnionExec>()
1046            .expect("Expected UnionExec");
1047
1048        // Check that schema is correct
1049        assert_eq!(union.schema(), schema);
1050        // Check that we have 2 inputs
1051        assert_eq!(union.inputs().len(), 2);
1052
1053        Ok(())
1054    }
1055}