datafusion_optimizer/
eliminate_nested_union.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateNestedUnion`]: flattens nested `Union` to a single `Union`
19use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::Result;
23use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
24use datafusion_expr::{Distinct, LogicalPlan, Union};
25use itertools::Itertools;
26use std::sync::Arc;
27
28#[derive(Default, Debug)]
29/// An optimization rule that replaces nested unions with a single union.
30pub struct EliminateNestedUnion;
31
32impl EliminateNestedUnion {
33    #[allow(missing_docs)]
34    pub fn new() -> Self {
35        Self {}
36    }
37}
38
39impl OptimizerRule for EliminateNestedUnion {
40    fn name(&self) -> &str {
41        "eliminate_nested_union"
42    }
43
44    fn apply_order(&self) -> Option<ApplyOrder> {
45        Some(ApplyOrder::BottomUp)
46    }
47
48    fn supports_rewrite(&self) -> bool {
49        true
50    }
51
52    fn rewrite(
53        &self,
54        plan: LogicalPlan,
55        _config: &dyn OptimizerConfig,
56    ) -> Result<Transformed<LogicalPlan>> {
57        match plan {
58            LogicalPlan::Union(Union { inputs, schema }) => {
59                let inputs = inputs
60                    .into_iter()
61                    .flat_map(extract_plans_from_union)
62                    .map(|plan| coerce_plan_expr_for_schema(plan, &schema))
63                    .collect::<Result<Vec<_>>>()?;
64
65                Ok(Transformed::yes(LogicalPlan::Union(Union {
66                    inputs: inputs.into_iter().map(Arc::new).collect_vec(),
67                    schema,
68                })))
69            }
70            LogicalPlan::Distinct(Distinct::All(nested_plan)) => {
71                match Arc::unwrap_or_clone(nested_plan) {
72                    LogicalPlan::Union(Union { inputs, schema }) => {
73                        let inputs = inputs
74                            .into_iter()
75                            .map(extract_plan_from_distinct)
76                            .flat_map(extract_plans_from_union)
77                            .map(|plan| coerce_plan_expr_for_schema(plan, &schema))
78                            .collect::<Result<Vec<_>>>()?;
79
80                        Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All(
81                            Arc::new(LogicalPlan::Union(Union {
82                                inputs: inputs.into_iter().map(Arc::new).collect_vec(),
83                                schema: Arc::clone(&schema),
84                            })),
85                        ))))
86                    }
87                    nested_plan => Ok(Transformed::no(LogicalPlan::Distinct(
88                        Distinct::All(Arc::new(nested_plan)),
89                    ))),
90                }
91            }
92            _ => Ok(Transformed::no(plan)),
93        }
94    }
95}
96
97fn extract_plans_from_union(plan: Arc<LogicalPlan>) -> Vec<LogicalPlan> {
98    match Arc::unwrap_or_clone(plan) {
99        LogicalPlan::Union(Union { inputs, .. }) => inputs
100            .into_iter()
101            .map(Arc::unwrap_or_clone)
102            .collect::<Vec<_>>(),
103        plan => vec![plan],
104    }
105}
106
107fn extract_plan_from_distinct(plan: Arc<LogicalPlan>) -> Arc<LogicalPlan> {
108    match Arc::unwrap_or_clone(plan) {
109        LogicalPlan::Distinct(Distinct::All(plan)) => plan,
110        plan => Arc::new(plan),
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::analyzer::type_coercion::TypeCoercion;
118    use crate::analyzer::Analyzer;
119    use crate::assert_optimized_plan_eq_snapshot;
120    use crate::OptimizerContext;
121    use arrow::datatypes::{DataType, Field, Schema};
122    use datafusion_common::config::ConfigOptions;
123    use datafusion_expr::{col, logical_plan::table_scan};
124
125    fn schema() -> Schema {
126        Schema::new(vec![
127            Field::new("id", DataType::Int32, false),
128            Field::new("key", DataType::Utf8, false),
129            Field::new("value", DataType::Float64, false),
130        ])
131    }
132
133    macro_rules! assert_optimized_plan_equal {
134        (
135            $plan:expr,
136            @ $expected:literal $(,)?
137        ) => {{
138            let options = ConfigOptions::default();
139            let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
140                .execute_and_check($plan, &options, |_, _| {})?;
141            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
142            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(EliminateNestedUnion::new())];
143            assert_optimized_plan_eq_snapshot!(
144                optimizer_ctx,
145                rules,
146                analyzed_plan,
147                @ $expected,
148            )
149        }};
150    }
151
152    #[test]
153    fn eliminate_nothing() -> Result<()> {
154        let plan_builder = table_scan(Some("table"), &schema(), None)?;
155
156        let plan = plan_builder.clone().union(plan_builder.build()?)?.build()?;
157
158        assert_optimized_plan_equal!(plan, @r"
159        Union
160          TableScan: table
161          TableScan: table
162        ")
163    }
164
165    #[test]
166    fn eliminate_distinct_nothing() -> Result<()> {
167        let plan_builder = table_scan(Some("table"), &schema(), None)?;
168
169        let plan = plan_builder
170            .clone()
171            .union_distinct(plan_builder.build()?)?
172            .build()?;
173
174        assert_optimized_plan_equal!(plan, @r"
175        Distinct:
176          Union
177            TableScan: table
178            TableScan: table
179        ")
180    }
181
182    #[test]
183    fn eliminate_nested_union() -> Result<()> {
184        let plan_builder = table_scan(Some("table"), &schema(), None)?;
185
186        let plan = plan_builder
187            .clone()
188            .union(plan_builder.clone().build()?)?
189            .union(plan_builder.clone().build()?)?
190            .union(plan_builder.build()?)?
191            .build()?;
192
193        assert_optimized_plan_equal!(plan, @r"
194        Union
195          TableScan: table
196          TableScan: table
197          TableScan: table
198          TableScan: table
199        ")
200    }
201
202    #[test]
203    fn eliminate_nested_union_with_distinct_union() -> Result<()> {
204        let plan_builder = table_scan(Some("table"), &schema(), None)?;
205
206        let plan = plan_builder
207            .clone()
208            .union_distinct(plan_builder.clone().build()?)?
209            .union(plan_builder.clone().build()?)?
210            .union(plan_builder.build()?)?
211            .build()?;
212
213        assert_optimized_plan_equal!(plan, @r"
214        Union
215          Distinct:
216            Union
217              TableScan: table
218              TableScan: table
219          TableScan: table
220          TableScan: table
221        ")
222    }
223
224    #[test]
225    fn eliminate_nested_distinct_union() -> Result<()> {
226        let plan_builder = table_scan(Some("table"), &schema(), None)?;
227
228        let plan = plan_builder
229            .clone()
230            .union(plan_builder.clone().build()?)?
231            .union_distinct(plan_builder.clone().build()?)?
232            .union(plan_builder.clone().build()?)?
233            .union_distinct(plan_builder.build()?)?
234            .build()?;
235
236        assert_optimized_plan_equal!(plan, @r"
237        Distinct:
238          Union
239            TableScan: table
240            TableScan: table
241            TableScan: table
242            TableScan: table
243            TableScan: table
244        ")
245    }
246
247    #[test]
248    fn eliminate_nested_distinct_union_with_distinct_table() -> Result<()> {
249        let plan_builder = table_scan(Some("table"), &schema(), None)?;
250
251        let plan = plan_builder
252            .clone()
253            .union_distinct(plan_builder.clone().distinct()?.build()?)?
254            .union(plan_builder.clone().distinct()?.build()?)?
255            .union_distinct(plan_builder.build()?)?
256            .build()?;
257
258        assert_optimized_plan_equal!(plan, @r"
259        Distinct:
260          Union
261            TableScan: table
262            TableScan: table
263            TableScan: table
264            TableScan: table
265        ")
266    }
267
268    // We don't need to use project_with_column_index in logical optimizer,
269    // after LogicalPlanBuilder::union, we already have all equal expression aliases
270    #[test]
271    fn eliminate_nested_union_with_projection() -> Result<()> {
272        let plan_builder = table_scan(Some("table"), &schema(), None)?;
273
274        let plan = plan_builder
275            .clone()
276            .union(
277                plan_builder
278                    .clone()
279                    .project(vec![col("id").alias("table_id"), col("key"), col("value")])?
280                    .build()?,
281            )?
282            .union(
283                plan_builder
284                    .project(vec![col("id").alias("_id"), col("key"), col("value")])?
285                    .build()?,
286            )?
287            .build()?;
288
289        assert_optimized_plan_equal!(plan, @r"
290        Union
291          TableScan: table
292          Projection: table.id AS id, table.key, table.value
293            TableScan: table
294          Projection: table.id AS id, table.key, table.value
295            TableScan: table
296        ")
297    }
298
299    #[test]
300    fn eliminate_nested_distinct_union_with_projection() -> Result<()> {
301        let plan_builder = table_scan(Some("table"), &schema(), None)?;
302
303        let plan = plan_builder
304            .clone()
305            .union_distinct(
306                plan_builder
307                    .clone()
308                    .project(vec![col("id").alias("table_id"), col("key"), col("value")])?
309                    .build()?,
310            )?
311            .union_distinct(
312                plan_builder
313                    .project(vec![col("id").alias("_id"), col("key"), col("value")])?
314                    .build()?,
315            )?
316            .build()?;
317
318        assert_optimized_plan_equal!(plan, @r"
319        Distinct:
320          Union
321            TableScan: table
322            Projection: table.id AS id, table.key, table.value
323              TableScan: table
324            Projection: table.id AS id, table.key, table.value
325              TableScan: table
326        ")
327    }
328
329    #[test]
330    fn eliminate_nested_union_with_type_cast_projection() -> Result<()> {
331        let table_1 = table_scan(
332            Some("table_1"),
333            &Schema::new(vec![
334                Field::new("id", DataType::Int64, false),
335                Field::new("key", DataType::Utf8, false),
336                Field::new("value", DataType::Float64, false),
337            ]),
338            None,
339        )?;
340
341        let table_2 = table_scan(
342            Some("table_1"),
343            &Schema::new(vec![
344                Field::new("id", DataType::Int32, false),
345                Field::new("key", DataType::Utf8, false),
346                Field::new("value", DataType::Float32, false),
347            ]),
348            None,
349        )?;
350
351        let table_3 = table_scan(
352            Some("table_1"),
353            &Schema::new(vec![
354                Field::new("id", DataType::Int16, false),
355                Field::new("key", DataType::Utf8, false),
356                Field::new("value", DataType::Float32, false),
357            ]),
358            None,
359        )?;
360
361        let plan = table_1
362            .union(table_2.build()?)?
363            .union(table_3.build()?)?
364            .build()?;
365
366        assert_optimized_plan_equal!(plan, @r"
367        Union
368          TableScan: table_1
369          Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
370            TableScan: table_1
371          Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
372            TableScan: table_1
373        ")
374    }
375
376    #[test]
377    fn eliminate_nested_distinct_union_with_type_cast_projection() -> Result<()> {
378        let table_1 = table_scan(
379            Some("table_1"),
380            &Schema::new(vec![
381                Field::new("id", DataType::Int64, false),
382                Field::new("key", DataType::Utf8, false),
383                Field::new("value", DataType::Float64, false),
384            ]),
385            None,
386        )?;
387
388        let table_2 = table_scan(
389            Some("table_1"),
390            &Schema::new(vec![
391                Field::new("id", DataType::Int32, false),
392                Field::new("key", DataType::Utf8, false),
393                Field::new("value", DataType::Float32, false),
394            ]),
395            None,
396        )?;
397
398        let table_3 = table_scan(
399            Some("table_1"),
400            &Schema::new(vec![
401                Field::new("id", DataType::Int16, false),
402                Field::new("key", DataType::Utf8, false),
403                Field::new("value", DataType::Float32, false),
404            ]),
405            None,
406        )?;
407
408        let plan = table_1
409            .union_distinct(table_2.build()?)?
410            .union_distinct(table_3.build()?)?
411            .build()?;
412
413        assert_optimized_plan_equal!(plan, @r"
414        Distinct:
415          Union
416            TableScan: table_1
417            Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
418              TableScan: table_1
419            Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
420              TableScan: table_1
421        ")
422    }
423}