datafusion_physical_plan/
union.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18// Some of these functions reference the Postgres documentation
19// or implementation to ensure compatibility and are subject to
20// the Postgres license.
21
22//! The Union operator combines multiple inputs with the same schema
23
24use std::borrow::Borrow;
25use std::pin::Pin;
26use std::task::{Context, Poll};
27use std::{any::Any, sync::Arc};
28
29use super::{
30    metrics::{ExecutionPlanMetricsSet, MetricsSet},
31    ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan,
32    ExecutionPlanProperties, Partitioning, PlanProperties, RecordBatchStream,
33    SendableRecordBatchStream, Statistics,
34};
35use crate::execution_plan::{
36    boundedness_from_children, check_default_invariants, emission_type_from_children,
37    InvariantLevel,
38};
39use crate::filter_pushdown::{FilterDescription, FilterPushdownPhase};
40use crate::metrics::BaselineMetrics;
41use crate::projection::{make_with_child, ProjectionExec};
42use crate::stream::ObservedStream;
43
44use arrow::datatypes::{Field, Schema, SchemaRef};
45use arrow::record_batch::RecordBatch;
46use datafusion_common::config::ConfigOptions;
47use datafusion_common::stats::Precision;
48use datafusion_common::{exec_err, internal_datafusion_err, internal_err, Result};
49use datafusion_execution::TaskContext;
50use datafusion_physical_expr::{calculate_union, EquivalenceProperties, PhysicalExpr};
51
52use futures::Stream;
53use itertools::Itertools;
54use log::{debug, trace, warn};
55use tokio::macros::support::thread_rng_n;
56
57/// `UnionExec`: `UNION ALL` execution plan.
58///
59/// `UnionExec` combines multiple inputs with the same schema by
60/// concatenating the partitions.  It does not mix or copy data within
61/// or across partitions. Thus if the input partitions are sorted, the
62/// output partitions of the union are also sorted.
63///
64/// For example, given a `UnionExec` of two inputs, with `N`
65/// partitions, and `M` partitions, there will be `N+M` output
66/// partitions. The first `N` output partitions are from Input 1
67/// partitions, and then next `M` output partitions are from Input 2.
68///
69/// ```text
70///                        ▲       ▲           ▲         ▲
71///                        │       │           │         │
72///      Output            │  ...  │           │         │
73///    Partitions          │0      │N-1        │ N       │N+M-1
74/// (passes through   ┌────┴───────┴───────────┴─────────┴───┐
75///  the N+M input    │              UnionExec               │
76///   partitions)     │                                      │
77///                   └──────────────────────────────────────┘
78///                                      ▲
79///                                      │
80///                                      │
81///       Input           ┌────────┬─────┴────┬──────────┐
82///     Partitions        │ ...    │          │     ...  │
83///                    0  │        │ N-1      │ 0        │  M-1
84///                  ┌────┴────────┴───┐  ┌───┴──────────┴───┐
85///                  │                 │  │                  │
86///                  │                 │  │                  │
87///                  │                 │  │                  │
88///                  │                 │  │                  │
89///                  │                 │  │                  │
90///                  │                 │  │                  │
91///                  │Input 1          │  │Input 2           │
92///                  └─────────────────┘  └──────────────────┘
93/// ```
94#[derive(Debug, Clone)]
95pub struct UnionExec {
96    /// Input execution plan
97    inputs: Vec<Arc<dyn ExecutionPlan>>,
98    /// Execution metrics
99    metrics: ExecutionPlanMetricsSet,
100    /// Cache holding plan properties like equivalences, output partitioning etc.
101    cache: PlanProperties,
102}
103
104impl UnionExec {
105    /// Create a new UnionExec
106    #[deprecated(since = "44.0.0", note = "Use UnionExec::try_new instead")]
107    pub fn new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Self {
108        let schema =
109            union_schema(&inputs).expect("UnionExec::new called with empty inputs");
110        // The schema of the inputs and the union schema is consistent when:
111        // - They have the same number of fields, and
112        // - Their fields have same types at the same indices.
113        // Here, we know that schemas are consistent and the call below can
114        // not return an error.
115        let cache = Self::compute_properties(&inputs, schema).unwrap();
116        UnionExec {
117            inputs,
118            metrics: ExecutionPlanMetricsSet::new(),
119            cache,
120        }
121    }
122
123    /// Try to create a new UnionExec.
124    ///
125    /// # Errors
126    /// Returns an error if:
127    /// - `inputs` is empty
128    ///
129    /// # Optimization
130    /// If there is only one input, returns that input directly rather than wrapping it in a UnionExec
131    pub fn try_new(
132        inputs: Vec<Arc<dyn ExecutionPlan>>,
133    ) -> Result<Arc<dyn ExecutionPlan>> {
134        match inputs.len() {
135            0 => exec_err!("UnionExec requires at least one input"),
136            1 => Ok(inputs.into_iter().next().unwrap()),
137            _ => {
138                let schema = union_schema(&inputs)?;
139                // The schema of the inputs and the union schema is consistent when:
140                // - They have the same number of fields, and
141                // - Their fields have same types at the same indices.
142                // Here, we know that schemas are consistent and the call below can
143                // not return an error.
144                let cache = Self::compute_properties(&inputs, schema).unwrap();
145                Ok(Arc::new(UnionExec {
146                    inputs,
147                    metrics: ExecutionPlanMetricsSet::new(),
148                    cache,
149                }))
150            }
151        }
152    }
153
154    /// Get inputs of the execution plan
155    pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
156        &self.inputs
157    }
158
159    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
160    fn compute_properties(
161        inputs: &[Arc<dyn ExecutionPlan>],
162        schema: SchemaRef,
163    ) -> Result<PlanProperties> {
164        // Calculate equivalence properties:
165        let children_eqps = inputs
166            .iter()
167            .map(|child| child.equivalence_properties().clone())
168            .collect::<Vec<_>>();
169        let eq_properties = calculate_union(children_eqps, schema)?;
170
171        // Calculate output partitioning; i.e. sum output partitions of the inputs.
172        let num_partitions = inputs
173            .iter()
174            .map(|plan| plan.output_partitioning().partition_count())
175            .sum();
176        let output_partitioning = Partitioning::UnknownPartitioning(num_partitions);
177        Ok(PlanProperties::new(
178            eq_properties,
179            output_partitioning,
180            emission_type_from_children(inputs),
181            boundedness_from_children(inputs),
182        ))
183    }
184}
185
186impl DisplayAs for UnionExec {
187    fn fmt_as(
188        &self,
189        t: DisplayFormatType,
190        f: &mut std::fmt::Formatter,
191    ) -> std::fmt::Result {
192        match t {
193            DisplayFormatType::Default | DisplayFormatType::Verbose => {
194                write!(f, "UnionExec")
195            }
196            DisplayFormatType::TreeRender => Ok(()),
197        }
198    }
199}
200
201impl ExecutionPlan for UnionExec {
202    fn name(&self) -> &'static str {
203        "UnionExec"
204    }
205
206    /// Return a reference to Any that can be used for downcasting
207    fn as_any(&self) -> &dyn Any {
208        self
209    }
210
211    fn properties(&self) -> &PlanProperties {
212        &self.cache
213    }
214
215    fn check_invariants(&self, check: InvariantLevel) -> Result<()> {
216        check_default_invariants(self, check)?;
217
218        (self.inputs().len() >= 2).then_some(()).ok_or_else(|| {
219            internal_datafusion_err!("UnionExec should have at least 2 children")
220        })
221    }
222
223    fn maintains_input_order(&self) -> Vec<bool> {
224        // If the Union has an output ordering, it maintains at least one
225        // child's ordering (i.e. the meet).
226        // For instance, assume that the first child is SortExpr('a','b','c'),
227        // the second child is SortExpr('a','b') and the third child is
228        // SortExpr('a','b'). The output ordering would be SortExpr('a','b'),
229        // which is the "meet" of all input orderings. In this example, this
230        // function will return vec![false, true, true], indicating that we
231        // preserve the orderings for the 2nd and the 3rd children.
232        if let Some(output_ordering) = self.properties().output_ordering() {
233            self.inputs()
234                .iter()
235                .map(|child| {
236                    if let Some(child_ordering) = child.output_ordering() {
237                        output_ordering.len() == child_ordering.len()
238                    } else {
239                        false
240                    }
241                })
242                .collect()
243        } else {
244            vec![false; self.inputs().len()]
245        }
246    }
247
248    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
249        vec![false; self.children().len()]
250    }
251
252    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
253        self.inputs.iter().collect()
254    }
255
256    fn with_new_children(
257        self: Arc<Self>,
258        children: Vec<Arc<dyn ExecutionPlan>>,
259    ) -> Result<Arc<dyn ExecutionPlan>> {
260        UnionExec::try_new(children)
261    }
262
263    fn execute(
264        &self,
265        mut partition: usize,
266        context: Arc<TaskContext>,
267    ) -> Result<SendableRecordBatchStream> {
268        trace!("Start UnionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
269        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
270        // record the tiny amount of work done in this function so
271        // elapsed_compute is reported as non zero
272        let elapsed_compute = baseline_metrics.elapsed_compute().clone();
273        let _timer = elapsed_compute.timer(); // record on drop
274
275        // find partition to execute
276        for input in self.inputs.iter() {
277            // Calculate whether partition belongs to the current partition
278            if partition < input.output_partitioning().partition_count() {
279                let stream = input.execute(partition, context)?;
280                debug!("Found a Union partition to execute");
281                return Ok(Box::pin(ObservedStream::new(
282                    stream,
283                    baseline_metrics,
284                    None,
285                )));
286            } else {
287                partition -= input.output_partitioning().partition_count();
288            }
289        }
290
291        warn!("Error in Union: Partition {partition} not found");
292
293        exec_err!("Partition {partition} not found in Union")
294    }
295
296    fn metrics(&self) -> Option<MetricsSet> {
297        Some(self.metrics.clone_inner())
298    }
299
300    fn statistics(&self) -> Result<Statistics> {
301        self.partition_statistics(None)
302    }
303
304    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
305        if let Some(partition_idx) = partition {
306            // For a specific partition, find which input it belongs to
307            let mut remaining_idx = partition_idx;
308            for input in &self.inputs {
309                let input_partition_count = input.output_partitioning().partition_count();
310                if remaining_idx < input_partition_count {
311                    // This partition belongs to this input
312                    return input.partition_statistics(Some(remaining_idx));
313                }
314                remaining_idx -= input_partition_count;
315            }
316            // If we get here, the partition index is out of bounds
317            Ok(Statistics::new_unknown(&self.schema()))
318        } else {
319            // Collect statistics from all inputs
320            let stats = self
321                .inputs
322                .iter()
323                .map(|input_exec| input_exec.partition_statistics(None))
324                .collect::<Result<Vec<_>>>()?;
325
326            Ok(stats
327                .into_iter()
328                .reduce(stats_union)
329                .unwrap_or_else(|| Statistics::new_unknown(&self.schema())))
330        }
331    }
332
333    fn supports_limit_pushdown(&self) -> bool {
334        true
335    }
336
337    /// Tries to push `projection` down through `union`. If possible, performs the
338    /// pushdown and returns a new [`UnionExec`] as the top plan which has projections
339    /// as its children. Otherwise, returns `None`.
340    fn try_swapping_with_projection(
341        &self,
342        projection: &ProjectionExec,
343    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
344        // If the projection doesn't narrow the schema, we shouldn't try to push it down.
345        if projection.expr().len() >= projection.input().schema().fields().len() {
346            return Ok(None);
347        }
348
349        let new_children = self
350            .children()
351            .into_iter()
352            .map(|child| make_with_child(projection, child))
353            .collect::<Result<Vec<_>>>()?;
354
355        Ok(Some(UnionExec::try_new(new_children.clone())?))
356    }
357
358    fn gather_filters_for_pushdown(
359        &self,
360        _phase: FilterPushdownPhase,
361        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
362        _config: &ConfigOptions,
363    ) -> Result<FilterDescription> {
364        FilterDescription::from_children(parent_filters, &self.children())
365    }
366}
367
368/// Combines multiple input streams by interleaving them.
369///
370/// This only works if all inputs have the same hash-partitioning.
371///
372/// # Data Flow
373/// ```text
374/// +---------+
375/// |         |---+
376/// | Input 1 |   |
377/// |         |-------------+
378/// +---------+   |         |
379///               |         |         +---------+
380///               +------------------>|         |
381///                 +---------------->| Combine |-->
382///                 | +-------------->|         |
383///                 | |     |         +---------+
384/// +---------+     | |     |
385/// |         |-----+ |     |
386/// | Input 2 |       |     |
387/// |         |---------------+
388/// +---------+       |     | |
389///                   |     | |       +---------+
390///                   |     +-------->|         |
391///                   |       +------>| Combine |-->
392///                   |         +---->|         |
393///                   |         |     +---------+
394/// +---------+       |         |
395/// |         |-------+         |
396/// | Input 3 |                 |
397/// |         |-----------------+
398/// +---------+
399/// ```
400#[derive(Debug, Clone)]
401pub struct InterleaveExec {
402    /// Input execution plan
403    inputs: Vec<Arc<dyn ExecutionPlan>>,
404    /// Execution metrics
405    metrics: ExecutionPlanMetricsSet,
406    /// Cache holding plan properties like equivalences, output partitioning etc.
407    cache: PlanProperties,
408}
409
410impl InterleaveExec {
411    /// Create a new InterleaveExec
412    pub fn try_new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Result<Self> {
413        if !can_interleave(inputs.iter()) {
414            return internal_err!(
415                "Not all InterleaveExec children have a consistent hash partitioning"
416            );
417        }
418        let cache = Self::compute_properties(&inputs)?;
419        Ok(InterleaveExec {
420            inputs,
421            metrics: ExecutionPlanMetricsSet::new(),
422            cache,
423        })
424    }
425
426    /// Get inputs of the execution plan
427    pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
428        &self.inputs
429    }
430
431    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
432    fn compute_properties(inputs: &[Arc<dyn ExecutionPlan>]) -> Result<PlanProperties> {
433        let schema = union_schema(inputs)?;
434        let eq_properties = EquivalenceProperties::new(schema);
435        // Get output partitioning:
436        let output_partitioning = inputs[0].output_partitioning().clone();
437        Ok(PlanProperties::new(
438            eq_properties,
439            output_partitioning,
440            emission_type_from_children(inputs),
441            boundedness_from_children(inputs),
442        ))
443    }
444}
445
446impl DisplayAs for InterleaveExec {
447    fn fmt_as(
448        &self,
449        t: DisplayFormatType,
450        f: &mut std::fmt::Formatter,
451    ) -> std::fmt::Result {
452        match t {
453            DisplayFormatType::Default | DisplayFormatType::Verbose => {
454                write!(f, "InterleaveExec")
455            }
456            DisplayFormatType::TreeRender => Ok(()),
457        }
458    }
459}
460
461impl ExecutionPlan for InterleaveExec {
462    fn name(&self) -> &'static str {
463        "InterleaveExec"
464    }
465
466    /// Return a reference to Any that can be used for downcasting
467    fn as_any(&self) -> &dyn Any {
468        self
469    }
470
471    fn properties(&self) -> &PlanProperties {
472        &self.cache
473    }
474
475    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
476        self.inputs.iter().collect()
477    }
478
479    fn maintains_input_order(&self) -> Vec<bool> {
480        vec![false; self.inputs().len()]
481    }
482
483    fn with_new_children(
484        self: Arc<Self>,
485        children: Vec<Arc<dyn ExecutionPlan>>,
486    ) -> Result<Arc<dyn ExecutionPlan>> {
487        // New children are no longer interleavable, which might be a bug of optimization rewrite.
488        if !can_interleave(children.iter()) {
489            return internal_err!(
490                "Can not create InterleaveExec: new children can not be interleaved"
491            );
492        }
493        Ok(Arc::new(InterleaveExec::try_new(children)?))
494    }
495
496    fn execute(
497        &self,
498        partition: usize,
499        context: Arc<TaskContext>,
500    ) -> Result<SendableRecordBatchStream> {
501        trace!("Start InterleaveExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
502        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
503        // record the tiny amount of work done in this function so
504        // elapsed_compute is reported as non zero
505        let elapsed_compute = baseline_metrics.elapsed_compute().clone();
506        let _timer = elapsed_compute.timer(); // record on drop
507
508        let mut input_stream_vec = vec![];
509        for input in self.inputs.iter() {
510            if partition < input.output_partitioning().partition_count() {
511                input_stream_vec.push(input.execute(partition, Arc::clone(&context))?);
512            } else {
513                // Do not find a partition to execute
514                break;
515            }
516        }
517        if input_stream_vec.len() == self.inputs.len() {
518            let stream = Box::pin(CombinedRecordBatchStream::new(
519                self.schema(),
520                input_stream_vec,
521            ));
522            return Ok(Box::pin(ObservedStream::new(
523                stream,
524                baseline_metrics,
525                None,
526            )));
527        }
528
529        warn!("Error in InterleaveExec: Partition {partition} not found");
530
531        exec_err!("Partition {partition} not found in InterleaveExec")
532    }
533
534    fn metrics(&self) -> Option<MetricsSet> {
535        Some(self.metrics.clone_inner())
536    }
537
538    fn statistics(&self) -> Result<Statistics> {
539        self.partition_statistics(None)
540    }
541
542    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
543        let stats = self
544            .inputs
545            .iter()
546            .map(|stat| stat.partition_statistics(partition))
547            .collect::<Result<Vec<_>>>()?;
548
549        Ok(stats
550            .into_iter()
551            .reduce(stats_union)
552            .unwrap_or_else(|| Statistics::new_unknown(&self.schema())))
553    }
554
555    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
556        vec![false; self.children().len()]
557    }
558}
559
560/// If all the input partitions have the same Hash partition spec with the first_input_partition
561/// The InterleaveExec is partition aware.
562///
563/// It might be too strict here in the case that the input partition specs are compatible but not exactly the same.
564/// For example one input partition has the partition spec Hash('a','b','c') and
565/// other has the partition spec Hash('a'), It is safe to derive the out partition with the spec Hash('a','b','c').
566pub fn can_interleave<T: Borrow<Arc<dyn ExecutionPlan>>>(
567    mut inputs: impl Iterator<Item = T>,
568) -> bool {
569    let Some(first) = inputs.next() else {
570        return false;
571    };
572
573    let reference = first.borrow().output_partitioning();
574    matches!(reference, Partitioning::Hash(_, _))
575        && inputs
576            .map(|plan| plan.borrow().output_partitioning().clone())
577            .all(|partition| partition == *reference)
578}
579
580fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> Result<SchemaRef> {
581    if inputs.is_empty() {
582        return exec_err!("Cannot create union schema from empty inputs");
583    }
584
585    let first_schema = inputs[0].schema();
586
587    let fields = (0..first_schema.fields().len())
588        .map(|i| {
589            // We take the name from the left side of the union to match how names are coerced during logical planning,
590            // which also uses the left side names.
591            let base_field = first_schema.field(i).clone();
592
593            // Coerce metadata and nullability across all inputs
594            let merged_field = inputs
595                .iter()
596                .enumerate()
597                .map(|(input_idx, input)| {
598                    let field = input.schema().field(i).clone();
599                    let mut metadata = field.metadata().clone();
600
601                    let other_metadatas = inputs
602                        .iter()
603                        .enumerate()
604                        .filter(|(other_idx, _)| *other_idx != input_idx)
605                        .flat_map(|(_, other_input)| {
606                            other_input.schema().field(i).metadata().clone().into_iter()
607                        });
608
609                    metadata.extend(other_metadatas);
610                    field.with_metadata(metadata)
611                })
612                .find_or_first(Field::is_nullable)
613                // We can unwrap this because if inputs was empty, this would've already panic'ed when we
614                // indexed into inputs[0].
615                .unwrap()
616                .with_name(base_field.name());
617
618            merged_field
619        })
620        .collect::<Vec<_>>();
621
622    let all_metadata_merged = inputs
623        .iter()
624        .flat_map(|i| i.schema().metadata().clone().into_iter())
625        .collect();
626
627    Ok(Arc::new(Schema::new_with_metadata(
628        fields,
629        all_metadata_merged,
630    )))
631}
632
633/// CombinedRecordBatchStream can be used to combine a Vec of SendableRecordBatchStreams into one
634struct CombinedRecordBatchStream {
635    /// Schema wrapped by Arc
636    schema: SchemaRef,
637    /// Stream entries
638    entries: Vec<SendableRecordBatchStream>,
639}
640
641impl CombinedRecordBatchStream {
642    /// Create an CombinedRecordBatchStream
643    pub fn new(schema: SchemaRef, entries: Vec<SendableRecordBatchStream>) -> Self {
644        Self { schema, entries }
645    }
646}
647
648impl RecordBatchStream for CombinedRecordBatchStream {
649    fn schema(&self) -> SchemaRef {
650        Arc::clone(&self.schema)
651    }
652}
653
654impl Stream for CombinedRecordBatchStream {
655    type Item = Result<RecordBatch>;
656
657    fn poll_next(
658        mut self: Pin<&mut Self>,
659        cx: &mut Context<'_>,
660    ) -> Poll<Option<Self::Item>> {
661        use Poll::*;
662
663        let start = thread_rng_n(self.entries.len() as u32) as usize;
664        let mut idx = start;
665
666        for _ in 0..self.entries.len() {
667            let stream = self.entries.get_mut(idx).unwrap();
668
669            match Pin::new(stream).poll_next(cx) {
670                Ready(Some(val)) => return Ready(Some(val)),
671                Ready(None) => {
672                    // Remove the entry
673                    self.entries.swap_remove(idx);
674
675                    // Check if this was the last entry, if so the cursor needs
676                    // to wrap
677                    if idx == self.entries.len() {
678                        idx = 0;
679                    } else if idx < start && start <= self.entries.len() {
680                        // The stream being swapped into the current index has
681                        // already been polled, so skip it.
682                        idx = idx.wrapping_add(1) % self.entries.len();
683                    }
684                }
685                Pending => {
686                    idx = idx.wrapping_add(1) % self.entries.len();
687                }
688            }
689        }
690
691        // If the map is empty, then the stream is complete.
692        if self.entries.is_empty() {
693            Ready(None)
694        } else {
695            Pending
696        }
697    }
698}
699
700fn col_stats_union(
701    mut left: ColumnStatistics,
702    right: ColumnStatistics,
703) -> ColumnStatistics {
704    left.distinct_count = Precision::Absent;
705    left.min_value = left.min_value.min(&right.min_value);
706    left.max_value = left.max_value.max(&right.max_value);
707    left.sum_value = left.sum_value.add(&right.sum_value);
708    left.null_count = left.null_count.add(&right.null_count);
709
710    left
711}
712
713fn stats_union(mut left: Statistics, right: Statistics) -> Statistics {
714    left.num_rows = left.num_rows.add(&right.num_rows);
715    left.total_byte_size = left.total_byte_size.add(&right.total_byte_size);
716    left.column_statistics = left
717        .column_statistics
718        .into_iter()
719        .zip(right.column_statistics)
720        .map(|(a, b)| col_stats_union(a, b))
721        .collect::<Vec<_>>();
722    left
723}
724
725#[cfg(test)]
726mod tests {
727    use super::*;
728    use crate::collect;
729    use crate::test::{self, TestMemoryExec};
730
731    use arrow::compute::SortOptions;
732    use arrow::datatypes::DataType;
733    use datafusion_common::ScalarValue;
734    use datafusion_physical_expr::equivalence::convert_to_orderings;
735    use datafusion_physical_expr::expressions::col;
736
737    // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g)
738    fn create_test_schema() -> Result<SchemaRef> {
739        let a = Field::new("a", DataType::Int32, true);
740        let b = Field::new("b", DataType::Int32, true);
741        let c = Field::new("c", DataType::Int32, true);
742        let d = Field::new("d", DataType::Int32, true);
743        let e = Field::new("e", DataType::Int32, true);
744        let f = Field::new("f", DataType::Int32, true);
745        let g = Field::new("g", DataType::Int32, true);
746        let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g]));
747
748        Ok(schema)
749    }
750
751    #[tokio::test]
752    async fn test_union_partitions() -> Result<()> {
753        let task_ctx = Arc::new(TaskContext::default());
754
755        // Create inputs with different partitioning
756        let csv = test::scan_partitioned(4);
757        let csv2 = test::scan_partitioned(5);
758
759        let union_exec: Arc<dyn ExecutionPlan> = UnionExec::try_new(vec![csv, csv2])?;
760
761        // Should have 9 partitions and 9 output batches
762        assert_eq!(
763            union_exec
764                .properties()
765                .output_partitioning()
766                .partition_count(),
767            9
768        );
769
770        let result: Vec<RecordBatch> = collect(union_exec, task_ctx).await?;
771        assert_eq!(result.len(), 9);
772
773        Ok(())
774    }
775
776    #[tokio::test]
777    async fn test_stats_union() {
778        let left = Statistics {
779            num_rows: Precision::Exact(5),
780            total_byte_size: Precision::Exact(23),
781            column_statistics: vec![
782                ColumnStatistics {
783                    distinct_count: Precision::Exact(5),
784                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
785                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
786                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
787                    null_count: Precision::Exact(0),
788                },
789                ColumnStatistics {
790                    distinct_count: Precision::Exact(1),
791                    max_value: Precision::Exact(ScalarValue::from("x")),
792                    min_value: Precision::Exact(ScalarValue::from("a")),
793                    sum_value: Precision::Absent,
794                    null_count: Precision::Exact(3),
795                },
796                ColumnStatistics {
797                    distinct_count: Precision::Absent,
798                    max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
799                    min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
800                    sum_value: Precision::Exact(ScalarValue::Float32(Some(42.0))),
801                    null_count: Precision::Absent,
802                },
803            ],
804        };
805
806        let right = Statistics {
807            num_rows: Precision::Exact(7),
808            total_byte_size: Precision::Exact(29),
809            column_statistics: vec![
810                ColumnStatistics {
811                    distinct_count: Precision::Exact(3),
812                    max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
813                    min_value: Precision::Exact(ScalarValue::Int64(Some(1))),
814                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
815                    null_count: Precision::Exact(1),
816                },
817                ColumnStatistics {
818                    distinct_count: Precision::Absent,
819                    max_value: Precision::Exact(ScalarValue::from("c")),
820                    min_value: Precision::Exact(ScalarValue::from("b")),
821                    sum_value: Precision::Absent,
822                    null_count: Precision::Absent,
823                },
824                ColumnStatistics {
825                    distinct_count: Precision::Absent,
826                    max_value: Precision::Absent,
827                    min_value: Precision::Absent,
828                    sum_value: Precision::Absent,
829                    null_count: Precision::Absent,
830                },
831            ],
832        };
833
834        let result = stats_union(left, right);
835        let expected = Statistics {
836            num_rows: Precision::Exact(12),
837            total_byte_size: Precision::Exact(52),
838            column_statistics: vec![
839                ColumnStatistics {
840                    distinct_count: Precision::Absent,
841                    max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
842                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
843                    sum_value: Precision::Exact(ScalarValue::Int64(Some(84))),
844                    null_count: Precision::Exact(1),
845                },
846                ColumnStatistics {
847                    distinct_count: Precision::Absent,
848                    max_value: Precision::Exact(ScalarValue::from("x")),
849                    min_value: Precision::Exact(ScalarValue::from("a")),
850                    sum_value: Precision::Absent,
851                    null_count: Precision::Absent,
852                },
853                ColumnStatistics {
854                    distinct_count: Precision::Absent,
855                    max_value: Precision::Absent,
856                    min_value: Precision::Absent,
857                    sum_value: Precision::Absent,
858                    null_count: Precision::Absent,
859                },
860            ],
861        };
862
863        assert_eq!(result, expected);
864    }
865
866    #[tokio::test]
867    async fn test_union_equivalence_properties() -> Result<()> {
868        let schema = create_test_schema()?;
869        let col_a = &col("a", &schema)?;
870        let col_b = &col("b", &schema)?;
871        let col_c = &col("c", &schema)?;
872        let col_d = &col("d", &schema)?;
873        let col_e = &col("e", &schema)?;
874        let col_f = &col("f", &schema)?;
875        let options = SortOptions::default();
876        let test_cases = [
877            //-----------TEST CASE 1----------//
878            (
879                // First child orderings
880                vec![
881                    // [a ASC, b ASC, f ASC]
882                    vec![(col_a, options), (col_b, options), (col_f, options)],
883                ],
884                // Second child orderings
885                vec![
886                    // [a ASC, b ASC, c ASC]
887                    vec![(col_a, options), (col_b, options), (col_c, options)],
888                    // [a ASC, b ASC, f ASC]
889                    vec![(col_a, options), (col_b, options), (col_f, options)],
890                ],
891                // Union output orderings
892                vec![
893                    // [a ASC, b ASC, f ASC]
894                    vec![(col_a, options), (col_b, options), (col_f, options)],
895                ],
896            ),
897            //-----------TEST CASE 2----------//
898            (
899                // First child orderings
900                vec![
901                    // [a ASC, b ASC, f ASC]
902                    vec![(col_a, options), (col_b, options), (col_f, options)],
903                    // d ASC
904                    vec![(col_d, options)],
905                ],
906                // Second child orderings
907                vec![
908                    // [a ASC, b ASC, c ASC]
909                    vec![(col_a, options), (col_b, options), (col_c, options)],
910                    // [e ASC]
911                    vec![(col_e, options)],
912                ],
913                // Union output orderings
914                vec![
915                    // [a ASC, b ASC]
916                    vec![(col_a, options), (col_b, options)],
917                ],
918            ),
919        ];
920
921        for (
922            test_idx,
923            (first_child_orderings, second_child_orderings, union_orderings),
924        ) in test_cases.iter().enumerate()
925        {
926            let first_orderings = convert_to_orderings(first_child_orderings);
927            let second_orderings = convert_to_orderings(second_child_orderings);
928            let union_expected_orderings = convert_to_orderings(union_orderings);
929            let child1 = Arc::new(TestMemoryExec::update_cache(Arc::new(
930                TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
931                    .try_with_sort_information(first_orderings)?,
932            )));
933            let child2 = Arc::new(TestMemoryExec::update_cache(Arc::new(
934                TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
935                    .try_with_sort_information(second_orderings)?,
936            )));
937
938            let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema));
939            union_expected_eq.add_orderings(union_expected_orderings);
940
941            let union: Arc<dyn ExecutionPlan> = UnionExec::try_new(vec![child1, child2])?;
942            let union_eq_properties = union.properties().equivalence_properties();
943            let err_msg = format!(
944                "Error in test id: {:?}, test case: {:?}",
945                test_idx, test_cases[test_idx]
946            );
947            assert_eq_properties_same(union_eq_properties, &union_expected_eq, err_msg);
948        }
949        Ok(())
950    }
951
952    fn assert_eq_properties_same(
953        lhs: &EquivalenceProperties,
954        rhs: &EquivalenceProperties,
955        err_msg: String,
956    ) {
957        // Check whether orderings are same.
958        let lhs_orderings = lhs.oeq_class();
959        let rhs_orderings = rhs.oeq_class();
960        assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{err_msg}");
961        for rhs_ordering in rhs_orderings.iter() {
962            assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg);
963        }
964    }
965
966    #[test]
967    fn test_union_empty_inputs() {
968        // Test that UnionExec::try_new fails with empty inputs
969        let result = UnionExec::try_new(vec![]);
970        assert!(result
971            .unwrap_err()
972            .to_string()
973            .contains("UnionExec requires at least one input"));
974    }
975
976    #[test]
977    fn test_union_schema_empty_inputs() {
978        // Test that union_schema fails with empty inputs
979        let result = union_schema(&[]);
980        assert!(result
981            .unwrap_err()
982            .to_string()
983            .contains("Cannot create union schema from empty inputs"));
984    }
985
986    #[test]
987    fn test_union_single_input() -> Result<()> {
988        // Test that UnionExec::try_new returns the single input directly
989        let schema = create_test_schema()?;
990        let memory_exec: Arc<dyn ExecutionPlan> =
991            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
992        let memory_exec_clone = Arc::clone(&memory_exec);
993        let result = UnionExec::try_new(vec![memory_exec])?;
994
995        // Check that the result is the same as the input (no UnionExec wrapper)
996        assert_eq!(result.schema(), schema);
997        // Verify it's the same execution plan
998        assert!(Arc::ptr_eq(&result, &memory_exec_clone));
999
1000        Ok(())
1001    }
1002
1003    #[test]
1004    fn test_union_schema_multiple_inputs() -> Result<()> {
1005        // Test that existing functionality with multiple inputs still works
1006        let schema = create_test_schema()?;
1007        let memory_exec1 =
1008            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
1009        let memory_exec2 =
1010            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
1011
1012        let union_plan = UnionExec::try_new(vec![memory_exec1, memory_exec2])?;
1013
1014        // Downcast to verify it's a UnionExec
1015        let union = union_plan
1016            .as_any()
1017            .downcast_ref::<UnionExec>()
1018            .expect("Expected UnionExec");
1019
1020        // Check that schema is correct
1021        assert_eq!(union.schema(), schema);
1022        // Check that we have 2 inputs
1023        assert_eq!(union.inputs().len(), 2);
1024
1025        Ok(())
1026    }
1027}