datafusion_physical_plan/
union.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18// Some of these functions reference the Postgres documentation
19// or implementation to ensure compatibility and are subject to
20// the Postgres license.
21
22//! The Union operator combines multiple inputs with the same schema
23
24use std::borrow::Borrow;
25use std::pin::Pin;
26use std::task::{Context, Poll};
27use std::{any::Any, sync::Arc};
28
29use super::{
30    metrics::{ExecutionPlanMetricsSet, MetricsSet},
31    ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan,
32    ExecutionPlanProperties, Partitioning, PlanProperties, RecordBatchStream,
33    SendableRecordBatchStream, Statistics,
34};
35use crate::execution_plan::{
36    boundedness_from_children, check_default_invariants, emission_type_from_children,
37    InvariantLevel,
38};
39use crate::metrics::BaselineMetrics;
40use crate::projection::{make_with_child, ProjectionExec};
41use crate::stream::ObservedStream;
42
43use arrow::datatypes::{Field, Schema, SchemaRef};
44use arrow::record_batch::RecordBatch;
45use datafusion_common::stats::Precision;
46use datafusion_common::{exec_err, internal_err, DataFusionError, Result};
47use datafusion_execution::TaskContext;
48use datafusion_physical_expr::{calculate_union, EquivalenceProperties};
49
50use futures::Stream;
51use itertools::Itertools;
52use log::{debug, trace, warn};
53use tokio::macros::support::thread_rng_n;
54
55/// `UnionExec`: `UNION ALL` execution plan.
56///
57/// `UnionExec` combines multiple inputs with the same schema by
58/// concatenating the partitions.  It does not mix or copy data within
59/// or across partitions. Thus if the input partitions are sorted, the
60/// output partitions of the union are also sorted.
61///
62/// For example, given a `UnionExec` of two inputs, with `N`
63/// partitions, and `M` partitions, there will be `N+M` output
64/// partitions. The first `N` output partitions are from Input 1
65/// partitions, and then next `M` output partitions are from Input 2.
66///
67/// ```text
68///                       ▲       ▲           ▲         ▲
69///                       │       │           │         │
70///     Output            │  ...  │           │         │
71///   Partitions          │0      │N-1        │ N       │N+M-1
72///(passes through   ┌────┴───────┴───────────┴─────────┴───┐
73/// the N+M input    │              UnionExec               │
74///  partitions)     │                                      │
75///                  └──────────────────────────────────────┘
76///                                      ▲
77///                                      │
78///                                      │
79///       Input           ┌────────┬─────┴────┬──────────┐
80///     Partitions        │ ...    │          │     ...  │
81///                    0  │        │ N-1      │ 0        │  M-1
82///                  ┌────┴────────┴───┐  ┌───┴──────────┴───┐
83///                  │                 │  │                  │
84///                  │                 │  │                  │
85///                  │                 │  │                  │
86///                  │                 │  │                  │
87///                  │                 │  │                  │
88///                  │                 │  │                  │
89///                  │Input 1          │  │Input 2           │
90///                  └─────────────────┘  └──────────────────┘
91/// ```
92#[derive(Debug, Clone)]
93pub struct UnionExec {
94    /// Input execution plan
95    inputs: Vec<Arc<dyn ExecutionPlan>>,
96    /// Execution metrics
97    metrics: ExecutionPlanMetricsSet,
98    /// Cache holding plan properties like equivalences, output partitioning etc.
99    cache: PlanProperties,
100}
101
102impl UnionExec {
103    /// Create a new UnionExec
104    pub fn new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Self {
105        let schema = union_schema(&inputs);
106        // The schema of the inputs and the union schema is consistent when:
107        // - They have the same number of fields, and
108        // - Their fields have same types at the same indices.
109        // Here, we know that schemas are consistent and the call below can
110        // not return an error.
111        let cache = Self::compute_properties(&inputs, schema).unwrap();
112        UnionExec {
113            inputs,
114            metrics: ExecutionPlanMetricsSet::new(),
115            cache,
116        }
117    }
118
119    /// Get inputs of the execution plan
120    pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
121        &self.inputs
122    }
123
124    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
125    fn compute_properties(
126        inputs: &[Arc<dyn ExecutionPlan>],
127        schema: SchemaRef,
128    ) -> Result<PlanProperties> {
129        // Calculate equivalence properties:
130        let children_eqps = inputs
131            .iter()
132            .map(|child| child.equivalence_properties().clone())
133            .collect::<Vec<_>>();
134        let eq_properties = calculate_union(children_eqps, schema)?;
135
136        // Calculate output partitioning; i.e. sum output partitions of the inputs.
137        let num_partitions = inputs
138            .iter()
139            .map(|plan| plan.output_partitioning().partition_count())
140            .sum();
141        let output_partitioning = Partitioning::UnknownPartitioning(num_partitions);
142        Ok(PlanProperties::new(
143            eq_properties,
144            output_partitioning,
145            emission_type_from_children(inputs),
146            boundedness_from_children(inputs),
147        ))
148    }
149}
150
151impl DisplayAs for UnionExec {
152    fn fmt_as(
153        &self,
154        t: DisplayFormatType,
155        f: &mut std::fmt::Formatter,
156    ) -> std::fmt::Result {
157        match t {
158            DisplayFormatType::Default | DisplayFormatType::Verbose => {
159                write!(f, "UnionExec")
160            }
161            DisplayFormatType::TreeRender => Ok(()),
162        }
163    }
164}
165
166impl ExecutionPlan for UnionExec {
167    fn name(&self) -> &'static str {
168        "UnionExec"
169    }
170
171    /// Return a reference to Any that can be used for downcasting
172    fn as_any(&self) -> &dyn Any {
173        self
174    }
175
176    fn properties(&self) -> &PlanProperties {
177        &self.cache
178    }
179
180    fn check_invariants(&self, check: InvariantLevel) -> Result<()> {
181        check_default_invariants(self, check)?;
182
183        (self.inputs().len() >= 2)
184            .then_some(())
185            .ok_or(DataFusionError::Internal(
186                "UnionExec should have at least 2 children".into(),
187            ))
188    }
189
190    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
191        self.inputs.iter().collect()
192    }
193
194    fn maintains_input_order(&self) -> Vec<bool> {
195        // If the Union has an output ordering, it maintains at least one
196        // child's ordering (i.e. the meet).
197        // For instance, assume that the first child is SortExpr('a','b','c'),
198        // the second child is SortExpr('a','b') and the third child is
199        // SortExpr('a','b'). The output ordering would be SortExpr('a','b'),
200        // which is the "meet" of all input orderings. In this example, this
201        // function will return vec![false, true, true], indicating that we
202        // preserve the orderings for the 2nd and the 3rd children.
203        if let Some(output_ordering) = self.properties().output_ordering() {
204            self.inputs()
205                .iter()
206                .map(|child| {
207                    if let Some(child_ordering) = child.output_ordering() {
208                        output_ordering.len() == child_ordering.len()
209                    } else {
210                        false
211                    }
212                })
213                .collect()
214        } else {
215            vec![false; self.inputs().len()]
216        }
217    }
218
219    fn with_new_children(
220        self: Arc<Self>,
221        children: Vec<Arc<dyn ExecutionPlan>>,
222    ) -> Result<Arc<dyn ExecutionPlan>> {
223        Ok(Arc::new(UnionExec::new(children)))
224    }
225
226    fn execute(
227        &self,
228        mut partition: usize,
229        context: Arc<TaskContext>,
230    ) -> Result<SendableRecordBatchStream> {
231        trace!("Start UnionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
232        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
233        // record the tiny amount of work done in this function so
234        // elapsed_compute is reported as non zero
235        let elapsed_compute = baseline_metrics.elapsed_compute().clone();
236        let _timer = elapsed_compute.timer(); // record on drop
237
238        // find partition to execute
239        for input in self.inputs.iter() {
240            // Calculate whether partition belongs to the current partition
241            if partition < input.output_partitioning().partition_count() {
242                let stream = input.execute(partition, context)?;
243                debug!("Found a Union partition to execute");
244                return Ok(Box::pin(ObservedStream::new(
245                    stream,
246                    baseline_metrics,
247                    None,
248                )));
249            } else {
250                partition -= input.output_partitioning().partition_count();
251            }
252        }
253
254        warn!("Error in Union: Partition {partition} not found");
255
256        exec_err!("Partition {partition} not found in Union")
257    }
258
259    fn metrics(&self) -> Option<MetricsSet> {
260        Some(self.metrics.clone_inner())
261    }
262
263    fn statistics(&self) -> Result<Statistics> {
264        self.partition_statistics(None)
265    }
266
267    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
268        if let Some(partition_idx) = partition {
269            // For a specific partition, find which input it belongs to
270            let mut remaining_idx = partition_idx;
271            for input in &self.inputs {
272                let input_partition_count = input.output_partitioning().partition_count();
273                if remaining_idx < input_partition_count {
274                    // This partition belongs to this input
275                    return input.partition_statistics(Some(remaining_idx));
276                }
277                remaining_idx -= input_partition_count;
278            }
279            // If we get here, the partition index is out of bounds
280            Ok(Statistics::new_unknown(&self.schema()))
281        } else {
282            // Collect statistics from all inputs
283            let stats = self
284                .inputs
285                .iter()
286                .map(|input_exec| input_exec.partition_statistics(None))
287                .collect::<Result<Vec<_>>>()?;
288
289            Ok(stats
290                .into_iter()
291                .reduce(stats_union)
292                .unwrap_or_else(|| Statistics::new_unknown(&self.schema())))
293        }
294    }
295
296    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
297        vec![false; self.children().len()]
298    }
299
300    fn supports_limit_pushdown(&self) -> bool {
301        true
302    }
303
304    /// Tries to push `projection` down through `union`. If possible, performs the
305    /// pushdown and returns a new [`UnionExec`] as the top plan which has projections
306    /// as its children. Otherwise, returns `None`.
307    fn try_swapping_with_projection(
308        &self,
309        projection: &ProjectionExec,
310    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
311        // If the projection doesn't narrow the schema, we shouldn't try to push it down.
312        if projection.expr().len() >= projection.input().schema().fields().len() {
313            return Ok(None);
314        }
315
316        let new_children = self
317            .children()
318            .into_iter()
319            .map(|child| make_with_child(projection, child))
320            .collect::<Result<Vec<_>>>()?;
321
322        Ok(Some(Arc::new(UnionExec::new(new_children))))
323    }
324}
325
326/// Combines multiple input streams by interleaving them.
327///
328/// This only works if all inputs have the same hash-partitioning.
329///
330/// # Data Flow
331/// ```text
332/// +---------+
333/// |         |---+
334/// | Input 1 |   |
335/// |         |-------------+
336/// +---------+   |         |
337///               |         |         +---------+
338///               +------------------>|         |
339///                 +---------------->| Combine |-->
340///                 | +-------------->|         |
341///                 | |     |         +---------+
342/// +---------+     | |     |
343/// |         |-----+ |     |
344/// | Input 2 |       |     |
345/// |         |---------------+
346/// +---------+       |     | |
347///                   |     | |       +---------+
348///                   |     +-------->|         |
349///                   |       +------>| Combine |-->
350///                   |         +---->|         |
351///                   |         |     +---------+
352/// +---------+       |         |
353/// |         |-------+         |
354/// | Input 3 |                 |
355/// |         |-----------------+
356/// +---------+
357/// ```
358#[derive(Debug, Clone)]
359pub struct InterleaveExec {
360    /// Input execution plan
361    inputs: Vec<Arc<dyn ExecutionPlan>>,
362    /// Execution metrics
363    metrics: ExecutionPlanMetricsSet,
364    /// Cache holding plan properties like equivalences, output partitioning etc.
365    cache: PlanProperties,
366}
367
368impl InterleaveExec {
369    /// Create a new InterleaveExec
370    pub fn try_new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Result<Self> {
371        if !can_interleave(inputs.iter()) {
372            return internal_err!(
373                "Not all InterleaveExec children have a consistent hash partitioning"
374            );
375        }
376        let cache = Self::compute_properties(&inputs);
377        Ok(InterleaveExec {
378            inputs,
379            metrics: ExecutionPlanMetricsSet::new(),
380            cache,
381        })
382    }
383
384    /// Get inputs of the execution plan
385    pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
386        &self.inputs
387    }
388
389    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
390    fn compute_properties(inputs: &[Arc<dyn ExecutionPlan>]) -> PlanProperties {
391        let schema = union_schema(inputs);
392        let eq_properties = EquivalenceProperties::new(schema);
393        // Get output partitioning:
394        let output_partitioning = inputs[0].output_partitioning().clone();
395        PlanProperties::new(
396            eq_properties,
397            output_partitioning,
398            emission_type_from_children(inputs),
399            boundedness_from_children(inputs),
400        )
401    }
402}
403
404impl DisplayAs for InterleaveExec {
405    fn fmt_as(
406        &self,
407        t: DisplayFormatType,
408        f: &mut std::fmt::Formatter,
409    ) -> std::fmt::Result {
410        match t {
411            DisplayFormatType::Default | DisplayFormatType::Verbose => {
412                write!(f, "InterleaveExec")
413            }
414            DisplayFormatType::TreeRender => Ok(()),
415        }
416    }
417}
418
419impl ExecutionPlan for InterleaveExec {
420    fn name(&self) -> &'static str {
421        "InterleaveExec"
422    }
423
424    /// Return a reference to Any that can be used for downcasting
425    fn as_any(&self) -> &dyn Any {
426        self
427    }
428
429    fn properties(&self) -> &PlanProperties {
430        &self.cache
431    }
432
433    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
434        self.inputs.iter().collect()
435    }
436
437    fn maintains_input_order(&self) -> Vec<bool> {
438        vec![false; self.inputs().len()]
439    }
440
441    fn with_new_children(
442        self: Arc<Self>,
443        children: Vec<Arc<dyn ExecutionPlan>>,
444    ) -> Result<Arc<dyn ExecutionPlan>> {
445        // New children are no longer interleavable, which might be a bug of optimization rewrite.
446        if !can_interleave(children.iter()) {
447            return internal_err!(
448                "Can not create InterleaveExec: new children can not be interleaved"
449            );
450        }
451        Ok(Arc::new(InterleaveExec::try_new(children)?))
452    }
453
454    fn execute(
455        &self,
456        partition: usize,
457        context: Arc<TaskContext>,
458    ) -> Result<SendableRecordBatchStream> {
459        trace!("Start InterleaveExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
460        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
461        // record the tiny amount of work done in this function so
462        // elapsed_compute is reported as non zero
463        let elapsed_compute = baseline_metrics.elapsed_compute().clone();
464        let _timer = elapsed_compute.timer(); // record on drop
465
466        let mut input_stream_vec = vec![];
467        for input in self.inputs.iter() {
468            if partition < input.output_partitioning().partition_count() {
469                input_stream_vec.push(input.execute(partition, Arc::clone(&context))?);
470            } else {
471                // Do not find a partition to execute
472                break;
473            }
474        }
475        if input_stream_vec.len() == self.inputs.len() {
476            let stream = Box::pin(CombinedRecordBatchStream::new(
477                self.schema(),
478                input_stream_vec,
479            ));
480            return Ok(Box::pin(ObservedStream::new(
481                stream,
482                baseline_metrics,
483                None,
484            )));
485        }
486
487        warn!("Error in InterleaveExec: Partition {partition} not found");
488
489        exec_err!("Partition {partition} not found in InterleaveExec")
490    }
491
492    fn metrics(&self) -> Option<MetricsSet> {
493        Some(self.metrics.clone_inner())
494    }
495
496    fn statistics(&self) -> Result<Statistics> {
497        self.partition_statistics(None)
498    }
499
500    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
501        if partition.is_some() {
502            return Ok(Statistics::new_unknown(&self.schema()));
503        }
504        let stats = self
505            .inputs
506            .iter()
507            .map(|stat| stat.partition_statistics(None))
508            .collect::<Result<Vec<_>>>()?;
509
510        Ok(stats
511            .into_iter()
512            .reduce(stats_union)
513            .unwrap_or_else(|| Statistics::new_unknown(&self.schema())))
514    }
515
516    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
517        vec![false; self.children().len()]
518    }
519}
520
521/// If all the input partitions have the same Hash partition spec with the first_input_partition
522/// The InterleaveExec is partition aware.
523///
524/// It might be too strict here in the case that the input partition specs are compatible but not exactly the same.
525/// For example one input partition has the partition spec Hash('a','b','c') and
526/// other has the partition spec Hash('a'), It is safe to derive the out partition with the spec Hash('a','b','c').
527pub fn can_interleave<T: Borrow<Arc<dyn ExecutionPlan>>>(
528    mut inputs: impl Iterator<Item = T>,
529) -> bool {
530    let Some(first) = inputs.next() else {
531        return false;
532    };
533
534    let reference = first.borrow().output_partitioning();
535    matches!(reference, Partitioning::Hash(_, _))
536        && inputs
537            .map(|plan| plan.borrow().output_partitioning().clone())
538            .all(|partition| partition == *reference)
539}
540
541fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> SchemaRef {
542    let first_schema = inputs[0].schema();
543
544    let fields = (0..first_schema.fields().len())
545        .map(|i| {
546            // We take the name from the left side of the union to match how names are coerced during logical planning,
547            // which also uses the left side names.
548            let base_field = first_schema.field(i).clone();
549
550            // Coerce metadata and nullability across all inputs
551            let merged_field = inputs
552                .iter()
553                .enumerate()
554                .map(|(input_idx, input)| {
555                    let field = input.schema().field(i).clone();
556                    let mut metadata = field.metadata().clone();
557
558                    let other_metadatas = inputs
559                        .iter()
560                        .enumerate()
561                        .filter(|(other_idx, _)| *other_idx != input_idx)
562                        .flat_map(|(_, other_input)| {
563                            other_input.schema().field(i).metadata().clone().into_iter()
564                        });
565
566                    metadata.extend(other_metadatas);
567                    field.with_metadata(metadata)
568                })
569                .find_or_first(Field::is_nullable)
570                // We can unwrap this because if inputs was empty, this would've already panic'ed when we
571                // indexed into inputs[0].
572                .unwrap()
573                .with_name(base_field.name());
574
575            merged_field
576        })
577        .collect::<Vec<_>>();
578
579    let all_metadata_merged = inputs
580        .iter()
581        .flat_map(|i| i.schema().metadata().clone().into_iter())
582        .collect();
583
584    Arc::new(Schema::new_with_metadata(fields, all_metadata_merged))
585}
586
587/// CombinedRecordBatchStream can be used to combine a Vec of SendableRecordBatchStreams into one
588struct CombinedRecordBatchStream {
589    /// Schema wrapped by Arc
590    schema: SchemaRef,
591    /// Stream entries
592    entries: Vec<SendableRecordBatchStream>,
593}
594
595impl CombinedRecordBatchStream {
596    /// Create an CombinedRecordBatchStream
597    pub fn new(schema: SchemaRef, entries: Vec<SendableRecordBatchStream>) -> Self {
598        Self { schema, entries }
599    }
600}
601
602impl RecordBatchStream for CombinedRecordBatchStream {
603    fn schema(&self) -> SchemaRef {
604        Arc::clone(&self.schema)
605    }
606}
607
608impl Stream for CombinedRecordBatchStream {
609    type Item = Result<RecordBatch>;
610
611    fn poll_next(
612        mut self: Pin<&mut Self>,
613        cx: &mut Context<'_>,
614    ) -> Poll<Option<Self::Item>> {
615        use Poll::*;
616
617        let start = thread_rng_n(self.entries.len() as u32) as usize;
618        let mut idx = start;
619
620        for _ in 0..self.entries.len() {
621            let stream = self.entries.get_mut(idx).unwrap();
622
623            match Pin::new(stream).poll_next(cx) {
624                Ready(Some(val)) => return Ready(Some(val)),
625                Ready(None) => {
626                    // Remove the entry
627                    self.entries.swap_remove(idx);
628
629                    // Check if this was the last entry, if so the cursor needs
630                    // to wrap
631                    if idx == self.entries.len() {
632                        idx = 0;
633                    } else if idx < start && start <= self.entries.len() {
634                        // The stream being swapped into the current index has
635                        // already been polled, so skip it.
636                        idx = idx.wrapping_add(1) % self.entries.len();
637                    }
638                }
639                Pending => {
640                    idx = idx.wrapping_add(1) % self.entries.len();
641                }
642            }
643        }
644
645        // If the map is empty, then the stream is complete.
646        if self.entries.is_empty() {
647            Ready(None)
648        } else {
649            Pending
650        }
651    }
652}
653
654fn col_stats_union(
655    mut left: ColumnStatistics,
656    right: ColumnStatistics,
657) -> ColumnStatistics {
658    left.distinct_count = Precision::Absent;
659    left.min_value = left.min_value.min(&right.min_value);
660    left.max_value = left.max_value.max(&right.max_value);
661    left.sum_value = left.sum_value.add(&right.sum_value);
662    left.null_count = left.null_count.add(&right.null_count);
663
664    left
665}
666
667fn stats_union(mut left: Statistics, right: Statistics) -> Statistics {
668    left.num_rows = left.num_rows.add(&right.num_rows);
669    left.total_byte_size = left.total_byte_size.add(&right.total_byte_size);
670    left.column_statistics = left
671        .column_statistics
672        .into_iter()
673        .zip(right.column_statistics)
674        .map(|(a, b)| col_stats_union(a, b))
675        .collect::<Vec<_>>();
676    left
677}
678
679#[cfg(test)]
680mod tests {
681    use super::*;
682    use crate::collect;
683    use crate::test::{self, TestMemoryExec};
684
685    use arrow::compute::SortOptions;
686    use arrow::datatypes::DataType;
687    use datafusion_common::ScalarValue;
688    use datafusion_physical_expr::equivalence::convert_to_orderings;
689    use datafusion_physical_expr::expressions::col;
690
691    // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g)
692    fn create_test_schema() -> Result<SchemaRef> {
693        let a = Field::new("a", DataType::Int32, true);
694        let b = Field::new("b", DataType::Int32, true);
695        let c = Field::new("c", DataType::Int32, true);
696        let d = Field::new("d", DataType::Int32, true);
697        let e = Field::new("e", DataType::Int32, true);
698        let f = Field::new("f", DataType::Int32, true);
699        let g = Field::new("g", DataType::Int32, true);
700        let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g]));
701
702        Ok(schema)
703    }
704
705    #[tokio::test]
706    async fn test_union_partitions() -> Result<()> {
707        let task_ctx = Arc::new(TaskContext::default());
708
709        // Create inputs with different partitioning
710        let csv = test::scan_partitioned(4);
711        let csv2 = test::scan_partitioned(5);
712
713        let union_exec = Arc::new(UnionExec::new(vec![csv, csv2]));
714
715        // Should have 9 partitions and 9 output batches
716        assert_eq!(
717            union_exec
718                .properties()
719                .output_partitioning()
720                .partition_count(),
721            9
722        );
723
724        let result: Vec<RecordBatch> = collect(union_exec, task_ctx).await?;
725        assert_eq!(result.len(), 9);
726
727        Ok(())
728    }
729
730    #[tokio::test]
731    async fn test_stats_union() {
732        let left = Statistics {
733            num_rows: Precision::Exact(5),
734            total_byte_size: Precision::Exact(23),
735            column_statistics: vec![
736                ColumnStatistics {
737                    distinct_count: Precision::Exact(5),
738                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
739                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
740                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
741                    null_count: Precision::Exact(0),
742                },
743                ColumnStatistics {
744                    distinct_count: Precision::Exact(1),
745                    max_value: Precision::Exact(ScalarValue::from("x")),
746                    min_value: Precision::Exact(ScalarValue::from("a")),
747                    sum_value: Precision::Absent,
748                    null_count: Precision::Exact(3),
749                },
750                ColumnStatistics {
751                    distinct_count: Precision::Absent,
752                    max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
753                    min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
754                    sum_value: Precision::Exact(ScalarValue::Float32(Some(42.0))),
755                    null_count: Precision::Absent,
756                },
757            ],
758        };
759
760        let right = Statistics {
761            num_rows: Precision::Exact(7),
762            total_byte_size: Precision::Exact(29),
763            column_statistics: vec![
764                ColumnStatistics {
765                    distinct_count: Precision::Exact(3),
766                    max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
767                    min_value: Precision::Exact(ScalarValue::Int64(Some(1))),
768                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
769                    null_count: Precision::Exact(1),
770                },
771                ColumnStatistics {
772                    distinct_count: Precision::Absent,
773                    max_value: Precision::Exact(ScalarValue::from("c")),
774                    min_value: Precision::Exact(ScalarValue::from("b")),
775                    sum_value: Precision::Absent,
776                    null_count: Precision::Absent,
777                },
778                ColumnStatistics {
779                    distinct_count: Precision::Absent,
780                    max_value: Precision::Absent,
781                    min_value: Precision::Absent,
782                    sum_value: Precision::Absent,
783                    null_count: Precision::Absent,
784                },
785            ],
786        };
787
788        let result = stats_union(left, right);
789        let expected = Statistics {
790            num_rows: Precision::Exact(12),
791            total_byte_size: Precision::Exact(52),
792            column_statistics: vec![
793                ColumnStatistics {
794                    distinct_count: Precision::Absent,
795                    max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
796                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
797                    sum_value: Precision::Exact(ScalarValue::Int64(Some(84))),
798                    null_count: Precision::Exact(1),
799                },
800                ColumnStatistics {
801                    distinct_count: Precision::Absent,
802                    max_value: Precision::Exact(ScalarValue::from("x")),
803                    min_value: Precision::Exact(ScalarValue::from("a")),
804                    sum_value: Precision::Absent,
805                    null_count: Precision::Absent,
806                },
807                ColumnStatistics {
808                    distinct_count: Precision::Absent,
809                    max_value: Precision::Absent,
810                    min_value: Precision::Absent,
811                    sum_value: Precision::Absent,
812                    null_count: Precision::Absent,
813                },
814            ],
815        };
816
817        assert_eq!(result, expected);
818    }
819
820    #[tokio::test]
821    async fn test_union_equivalence_properties() -> Result<()> {
822        let schema = create_test_schema()?;
823        let col_a = &col("a", &schema)?;
824        let col_b = &col("b", &schema)?;
825        let col_c = &col("c", &schema)?;
826        let col_d = &col("d", &schema)?;
827        let col_e = &col("e", &schema)?;
828        let col_f = &col("f", &schema)?;
829        let options = SortOptions::default();
830        let test_cases = [
831            //-----------TEST CASE 1----------//
832            (
833                // First child orderings
834                vec![
835                    // [a ASC, b ASC, f ASC]
836                    vec![(col_a, options), (col_b, options), (col_f, options)],
837                ],
838                // Second child orderings
839                vec![
840                    // [a ASC, b ASC, c ASC]
841                    vec![(col_a, options), (col_b, options), (col_c, options)],
842                    // [a ASC, b ASC, f ASC]
843                    vec![(col_a, options), (col_b, options), (col_f, options)],
844                ],
845                // Union output orderings
846                vec![
847                    // [a ASC, b ASC, f ASC]
848                    vec![(col_a, options), (col_b, options), (col_f, options)],
849                ],
850            ),
851            //-----------TEST CASE 2----------//
852            (
853                // First child orderings
854                vec![
855                    // [a ASC, b ASC, f ASC]
856                    vec![(col_a, options), (col_b, options), (col_f, options)],
857                    // d ASC
858                    vec![(col_d, options)],
859                ],
860                // Second child orderings
861                vec![
862                    // [a ASC, b ASC, c ASC]
863                    vec![(col_a, options), (col_b, options), (col_c, options)],
864                    // [e ASC]
865                    vec![(col_e, options)],
866                ],
867                // Union output orderings
868                vec![
869                    // [a ASC, b ASC]
870                    vec![(col_a, options), (col_b, options)],
871                ],
872            ),
873        ];
874
875        for (
876            test_idx,
877            (first_child_orderings, second_child_orderings, union_orderings),
878        ) in test_cases.iter().enumerate()
879        {
880            let first_orderings = convert_to_orderings(first_child_orderings);
881            let second_orderings = convert_to_orderings(second_child_orderings);
882            let union_expected_orderings = convert_to_orderings(union_orderings);
883            let child1 = Arc::new(TestMemoryExec::update_cache(Arc::new(
884                TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
885                    .try_with_sort_information(first_orderings)?,
886            )));
887            let child2 = Arc::new(TestMemoryExec::update_cache(Arc::new(
888                TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
889                    .try_with_sort_information(second_orderings)?,
890            )));
891
892            let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema));
893            union_expected_eq.add_orderings(union_expected_orderings);
894
895            let union = UnionExec::new(vec![child1, child2]);
896            let union_eq_properties = union.properties().equivalence_properties();
897            let err_msg = format!(
898                "Error in test id: {:?}, test case: {:?}",
899                test_idx, test_cases[test_idx]
900            );
901            assert_eq_properties_same(union_eq_properties, &union_expected_eq, err_msg);
902        }
903        Ok(())
904    }
905
906    fn assert_eq_properties_same(
907        lhs: &EquivalenceProperties,
908        rhs: &EquivalenceProperties,
909        err_msg: String,
910    ) {
911        // Check whether orderings are same.
912        let lhs_orderings = lhs.oeq_class();
913        let rhs_orderings = rhs.oeq_class();
914        assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{err_msg}");
915        for rhs_ordering in rhs_orderings.iter() {
916            assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg);
917        }
918    }
919}