datafusion_optimizer/
optimize_unions.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`OptimizeUnions`]: removes `Union` nodes in the logical plan.
19use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::Result;
22use datafusion_common::tree_node::Transformed;
23use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
24use datafusion_expr::{Distinct, LogicalPlan, Projection, Union};
25use itertools::Itertools;
26use std::sync::Arc;
27
28#[derive(Default, Debug)]
29/// An optimization rule that
30/// 1. replaces nested unions with a single union.
31/// 2. removes unions with a single input.
32pub struct OptimizeUnions;
33
34impl OptimizeUnions {
35    #[expect(missing_docs)]
36    pub fn new() -> Self {
37        Self {}
38    }
39}
40
41impl OptimizerRule for OptimizeUnions {
42    fn name(&self) -> &str {
43        "optimize_unions"
44    }
45
46    fn apply_order(&self) -> Option<ApplyOrder> {
47        Some(ApplyOrder::BottomUp)
48    }
49
50    fn supports_rewrite(&self) -> bool {
51        true
52    }
53
54    fn rewrite(
55        &self,
56        plan: LogicalPlan,
57        _config: &dyn OptimizerConfig,
58    ) -> Result<Transformed<LogicalPlan>> {
59        match plan {
60            LogicalPlan::Union(Union { mut inputs, .. }) if inputs.len() == 1 => Ok(
61                Transformed::yes(Arc::unwrap_or_clone(inputs.pop().unwrap())),
62            ),
63            LogicalPlan::Union(Union { inputs, schema }) => {
64                let inputs = inputs
65                    .into_iter()
66                    .flat_map(extract_plans_from_union)
67                    .map(|plan| coerce_plan_expr_for_schema(plan, &schema))
68                    .collect::<Result<Vec<_>>>()?;
69
70                Ok(Transformed::yes(LogicalPlan::Union(Union {
71                    inputs: inputs.into_iter().map(Arc::new).collect_vec(),
72                    schema,
73                })))
74            }
75            LogicalPlan::Distinct(Distinct::All(nested_plan)) => {
76                match Arc::unwrap_or_clone(nested_plan) {
77                    LogicalPlan::Union(Union { inputs, schema }) => {
78                        let inputs = inputs
79                            .into_iter()
80                            .map(extract_plan_from_distinct)
81                            .flat_map(extract_plans_from_union)
82                            .map(|plan| coerce_plan_expr_for_schema(plan, &schema))
83                            .collect::<Result<Vec<_>>>()?;
84
85                        Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All(
86                            Arc::new(LogicalPlan::Union(Union {
87                                inputs: inputs.into_iter().map(Arc::new).collect_vec(),
88                                schema: Arc::clone(&schema),
89                            })),
90                        ))))
91                    }
92                    nested_plan => Ok(Transformed::no(LogicalPlan::Distinct(
93                        Distinct::All(Arc::new(nested_plan)),
94                    ))),
95                }
96            }
97            _ => Ok(Transformed::no(plan)),
98        }
99    }
100}
101
102fn extract_plans_from_union(plan: Arc<LogicalPlan>) -> Vec<LogicalPlan> {
103    match Arc::unwrap_or_clone(plan) {
104        LogicalPlan::Union(Union { inputs, .. }) => inputs
105            .into_iter()
106            .map(Arc::unwrap_or_clone)
107            .collect::<Vec<_>>(),
108        // While unnesting, unwrap a Projection whose input is a nested Union,
109        // flatten the inner Union, and push the same Projection down onto
110        // each of the nested Union’s children.
111        //
112        // Example:
113        //   Union { Projection { Union { plan1, plan2 } }, plan3 }
114        //     => Union { Projection { plan1 }, Projection { plan2 }, plan3 }
115        LogicalPlan::Projection(Projection {
116            expr,
117            input,
118            schema,
119            ..
120        }) => match Arc::unwrap_or_clone(input) {
121            LogicalPlan::Union(Union { inputs, .. }) => inputs
122                .into_iter()
123                .map(Arc::unwrap_or_clone)
124                .map(|plan| {
125                    LogicalPlan::Projection(
126                        Projection::try_new_with_schema(
127                            expr.clone(),
128                            Arc::new(plan),
129                            Arc::clone(&schema),
130                        )
131                        .unwrap(),
132                    )
133                })
134                .collect::<Vec<_>>(),
135
136            plan => vec![LogicalPlan::Projection(
137                Projection::try_new_with_schema(expr, Arc::new(plan), schema).unwrap(),
138            )],
139        },
140        plan => vec![plan],
141    }
142}
143
144fn extract_plan_from_distinct(plan: Arc<LogicalPlan>) -> Arc<LogicalPlan> {
145    match Arc::unwrap_or_clone(plan) {
146        LogicalPlan::Distinct(Distinct::All(plan)) => plan,
147        plan => Arc::new(plan),
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::OptimizerContext;
155    use crate::analyzer::Analyzer;
156    use crate::analyzer::type_coercion::TypeCoercion;
157    use crate::assert_optimized_plan_eq_snapshot;
158    use arrow::datatypes::{DataType, Field, Schema};
159    use datafusion_common::config::ConfigOptions;
160    use datafusion_expr::{col, logical_plan::table_scan};
161
162    fn schema() -> Schema {
163        Schema::new(vec![
164            Field::new("id", DataType::Int32, false),
165            Field::new("key", DataType::Utf8, false),
166            Field::new("value", DataType::Float64, false),
167        ])
168    }
169
170    macro_rules! assert_optimized_plan_equal {
171        (
172            $plan:expr,
173            @ $expected:literal $(,)?
174        ) => {{
175            let options = ConfigOptions::default();
176            let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
177                .execute_and_check($plan, &options, |_, _| {})?;
178            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
179            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(OptimizeUnions::new())];
180            assert_optimized_plan_eq_snapshot!(
181                optimizer_ctx,
182                rules,
183                analyzed_plan,
184                @ $expected,
185            )
186        }};
187    }
188
189    #[test]
190    fn eliminate_nothing() -> Result<()> {
191        let plan_builder = table_scan(Some("table"), &schema(), None)?;
192
193        let plan = plan_builder.clone().union(plan_builder.build()?)?.build()?;
194
195        assert_optimized_plan_equal!(plan, @r"
196        Union
197          TableScan: table
198          TableScan: table
199        ")
200    }
201
202    #[test]
203    fn eliminate_distinct_nothing() -> Result<()> {
204        let plan_builder = table_scan(Some("table"), &schema(), None)?;
205
206        let plan = plan_builder
207            .clone()
208            .union_distinct(plan_builder.build()?)?
209            .build()?;
210
211        assert_optimized_plan_equal!(plan, @r"
212        Distinct:
213          Union
214            TableScan: table
215            TableScan: table
216        ")
217    }
218
219    #[test]
220    fn eliminate_nested_union() -> Result<()> {
221        let plan_builder = table_scan(Some("table"), &schema(), None)?;
222
223        let plan = plan_builder
224            .clone()
225            .union(plan_builder.clone().build()?)?
226            .union(plan_builder.clone().build()?)?
227            .union(plan_builder.build()?)?
228            .build()?;
229
230        assert_optimized_plan_equal!(plan, @r"
231        Union
232          TableScan: table
233          TableScan: table
234          TableScan: table
235          TableScan: table
236        ")
237    }
238
239    #[test]
240    fn eliminate_nested_union_with_distinct_union() -> Result<()> {
241        let plan_builder = table_scan(Some("table"), &schema(), None)?;
242
243        let plan = plan_builder
244            .clone()
245            .union_distinct(plan_builder.clone().build()?)?
246            .union(plan_builder.clone().build()?)?
247            .union(plan_builder.build()?)?
248            .build()?;
249
250        assert_optimized_plan_equal!(plan, @r"
251        Union
252          Distinct:
253            Union
254              TableScan: table
255              TableScan: table
256          TableScan: table
257          TableScan: table
258        ")
259    }
260
261    #[test]
262    fn eliminate_nested_distinct_union() -> Result<()> {
263        let plan_builder = table_scan(Some("table"), &schema(), None)?;
264
265        let plan = plan_builder
266            .clone()
267            .union(plan_builder.clone().build()?)?
268            .union_distinct(plan_builder.clone().build()?)?
269            .union(plan_builder.clone().build()?)?
270            .union_distinct(plan_builder.build()?)?
271            .build()?;
272
273        assert_optimized_plan_equal!(plan, @r"
274        Distinct:
275          Union
276            TableScan: table
277            TableScan: table
278            TableScan: table
279            TableScan: table
280            TableScan: table
281        ")
282    }
283
284    #[test]
285    fn eliminate_nested_distinct_union_with_distinct_table() -> Result<()> {
286        let plan_builder = table_scan(Some("table"), &schema(), None)?;
287
288        let plan = plan_builder
289            .clone()
290            .union_distinct(plan_builder.clone().distinct()?.build()?)?
291            .union(plan_builder.clone().distinct()?.build()?)?
292            .union_distinct(plan_builder.build()?)?
293            .build()?;
294
295        assert_optimized_plan_equal!(plan, @r"
296        Distinct:
297          Union
298            TableScan: table
299            TableScan: table
300            TableScan: table
301            TableScan: table
302        ")
303    }
304
305    // We don't need to use project_with_column_index in logical optimizer,
306    // after LogicalPlanBuilder::union, we already have all equal expression aliases
307    #[test]
308    fn eliminate_nested_union_with_projection() -> Result<()> {
309        let plan_builder = table_scan(Some("table"), &schema(), None)?;
310
311        let plan = plan_builder
312            .clone()
313            .union(
314                plan_builder
315                    .clone()
316                    .project(vec![col("id").alias("table_id"), col("key"), col("value")])?
317                    .build()?,
318            )?
319            .union(
320                plan_builder
321                    .project(vec![col("id").alias("_id"), col("key"), col("value")])?
322                    .build()?,
323            )?
324            .build()?;
325
326        assert_optimized_plan_equal!(plan, @r"
327        Union
328          TableScan: table
329          Projection: table.id AS id, table.key, table.value
330            TableScan: table
331          Projection: table.id AS id, table.key, table.value
332            TableScan: table
333        ")
334    }
335
336    #[test]
337    fn eliminate_nested_distinct_union_with_projection() -> Result<()> {
338        let plan_builder = table_scan(Some("table"), &schema(), None)?;
339
340        let plan = plan_builder
341            .clone()
342            .union_distinct(
343                plan_builder
344                    .clone()
345                    .project(vec![col("id").alias("table_id"), col("key"), col("value")])?
346                    .build()?,
347            )?
348            .union_distinct(
349                plan_builder
350                    .project(vec![col("id").alias("_id"), col("key"), col("value")])?
351                    .build()?,
352            )?
353            .build()?;
354
355        assert_optimized_plan_equal!(plan, @r"
356        Distinct:
357          Union
358            TableScan: table
359            Projection: table.id AS id, table.key, table.value
360              TableScan: table
361            Projection: table.id AS id, table.key, table.value
362              TableScan: table
363        ")
364    }
365
366    #[test]
367    fn eliminate_nested_union_in_projection() -> Result<()> {
368        let plan_builder = table_scan(Some("table"), &schema(), None)?;
369
370        let plan = plan_builder
371            .clone()
372            .union(plan_builder.clone().build()?)?
373            .project(vec![col("id").alias("table_id"), col("key"), col("value")])?
374            .union(plan_builder.build()?)?
375            .build()?;
376
377        assert_optimized_plan_equal!(plan, @r"
378        Union
379          Projection: id AS table_id, key, value
380            TableScan: table
381          Projection: id AS table_id, key, value
382            TableScan: table
383          TableScan: table
384        ")
385    }
386
387    #[test]
388    fn eliminate_nested_union_with_type_cast_projection() -> Result<()> {
389        let table_1 = table_scan(
390            Some("table_1"),
391            &Schema::new(vec![
392                Field::new("id", DataType::Int64, false),
393                Field::new("key", DataType::Utf8, false),
394                Field::new("value", DataType::Float64, false),
395            ]),
396            None,
397        )?;
398
399        let table_2 = table_scan(
400            Some("table_1"),
401            &Schema::new(vec![
402                Field::new("id", DataType::Int32, false),
403                Field::new("key", DataType::Utf8, false),
404                Field::new("value", DataType::Float32, false),
405            ]),
406            None,
407        )?;
408
409        let table_3 = table_scan(
410            Some("table_1"),
411            &Schema::new(vec![
412                Field::new("id", DataType::Int16, false),
413                Field::new("key", DataType::Utf8, false),
414                Field::new("value", DataType::Float32, false),
415            ]),
416            None,
417        )?;
418
419        let plan = table_1
420            .union(table_2.build()?)?
421            .union(table_3.build()?)?
422            .build()?;
423
424        assert_optimized_plan_equal!(plan, @r"
425        Union
426          TableScan: table_1
427          Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
428            TableScan: table_1
429          Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
430            TableScan: table_1
431        ")
432    }
433
434    #[test]
435    fn eliminate_nested_distinct_union_with_type_cast_projection() -> Result<()> {
436        let table_1 = table_scan(
437            Some("table_1"),
438            &Schema::new(vec![
439                Field::new("id", DataType::Int64, false),
440                Field::new("key", DataType::Utf8, false),
441                Field::new("value", DataType::Float64, false),
442            ]),
443            None,
444        )?;
445
446        let table_2 = table_scan(
447            Some("table_1"),
448            &Schema::new(vec![
449                Field::new("id", DataType::Int32, false),
450                Field::new("key", DataType::Utf8, false),
451                Field::new("value", DataType::Float32, false),
452            ]),
453            None,
454        )?;
455
456        let table_3 = table_scan(
457            Some("table_1"),
458            &Schema::new(vec![
459                Field::new("id", DataType::Int16, false),
460                Field::new("key", DataType::Utf8, false),
461                Field::new("value", DataType::Float32, false),
462            ]),
463            None,
464        )?;
465
466        let plan = table_1
467            .union_distinct(table_2.build()?)?
468            .union_distinct(table_3.build()?)?
469            .build()?;
470
471        assert_optimized_plan_equal!(plan, @r"
472        Distinct:
473          Union
474            TableScan: table_1
475            Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
476              TableScan: table_1
477            Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
478              TableScan: table_1
479        ")
480    }
481
482    #[test]
483    fn eliminate_one_union() -> Result<()> {
484        let plan = table_scan(Some("table"), &schema(), None)?.build()?;
485        let schema = Arc::clone(plan.schema());
486        // note it is not possible to create a single input union via
487        // LogicalPlanBuilder so create it manually here
488        let plan = LogicalPlan::Union(Union {
489            inputs: vec![Arc::new(plan)],
490            schema,
491        });
492
493        // Note we can't use the same assert_optimized_plan_equal as creating a
494        // single input union is not possible via LogicalPlanBuilder and other passes
495        // throw errors / don't handle the schema correctly.
496        assert_optimized_plan_eq_snapshot!(
497            OptimizerContext::new().with_max_passes(1),
498            vec![Arc::new(OptimizeUnions::new())],
499            plan,
500            @"TableScan: table"
501        )
502    }
503}