1use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::Result;
23use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
24use datafusion_expr::{Distinct, LogicalPlan, Union};
25use itertools::Itertools;
26use std::sync::Arc;
27
28#[derive(Default, Debug)]
29pub struct EliminateNestedUnion;
31
32impl EliminateNestedUnion {
33 #[allow(missing_docs)]
34 pub fn new() -> Self {
35 Self {}
36 }
37}
38
39impl OptimizerRule for EliminateNestedUnion {
40 fn name(&self) -> &str {
41 "eliminate_nested_union"
42 }
43
44 fn apply_order(&self) -> Option<ApplyOrder> {
45 Some(ApplyOrder::BottomUp)
46 }
47
48 fn supports_rewrite(&self) -> bool {
49 true
50 }
51
52 fn rewrite(
53 &self,
54 plan: LogicalPlan,
55 _config: &dyn OptimizerConfig,
56 ) -> Result<Transformed<LogicalPlan>> {
57 match plan {
58 LogicalPlan::Union(Union { inputs, schema }) => {
59 let inputs = inputs
60 .into_iter()
61 .flat_map(extract_plans_from_union)
62 .map(|plan| coerce_plan_expr_for_schema(plan, &schema))
63 .collect::<Result<Vec<_>>>()?;
64
65 Ok(Transformed::yes(LogicalPlan::Union(Union {
66 inputs: inputs.into_iter().map(Arc::new).collect_vec(),
67 schema,
68 })))
69 }
70 LogicalPlan::Distinct(Distinct::All(nested_plan)) => {
71 match Arc::unwrap_or_clone(nested_plan) {
72 LogicalPlan::Union(Union { inputs, schema }) => {
73 let inputs = inputs
74 .into_iter()
75 .map(extract_plan_from_distinct)
76 .flat_map(extract_plans_from_union)
77 .map(|plan| coerce_plan_expr_for_schema(plan, &schema))
78 .collect::<Result<Vec<_>>>()?;
79
80 Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All(
81 Arc::new(LogicalPlan::Union(Union {
82 inputs: inputs.into_iter().map(Arc::new).collect_vec(),
83 schema: Arc::clone(&schema),
84 })),
85 ))))
86 }
87 nested_plan => Ok(Transformed::no(LogicalPlan::Distinct(
88 Distinct::All(Arc::new(nested_plan)),
89 ))),
90 }
91 }
92 _ => Ok(Transformed::no(plan)),
93 }
94 }
95}
96
97fn extract_plans_from_union(plan: Arc<LogicalPlan>) -> Vec<LogicalPlan> {
98 match Arc::unwrap_or_clone(plan) {
99 LogicalPlan::Union(Union { inputs, .. }) => inputs
100 .into_iter()
101 .map(Arc::unwrap_or_clone)
102 .collect::<Vec<_>>(),
103 plan => vec![plan],
104 }
105}
106
107fn extract_plan_from_distinct(plan: Arc<LogicalPlan>) -> Arc<LogicalPlan> {
108 match Arc::unwrap_or_clone(plan) {
109 LogicalPlan::Distinct(Distinct::All(plan)) => plan,
110 plan => Arc::new(plan),
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::analyzer::type_coercion::TypeCoercion;
118 use crate::analyzer::Analyzer;
119 use crate::assert_optimized_plan_eq_snapshot;
120 use crate::OptimizerContext;
121 use arrow::datatypes::{DataType, Field, Schema};
122 use datafusion_common::config::ConfigOptions;
123 use datafusion_expr::{col, logical_plan::table_scan};
124
125 fn schema() -> Schema {
126 Schema::new(vec![
127 Field::new("id", DataType::Int32, false),
128 Field::new("key", DataType::Utf8, false),
129 Field::new("value", DataType::Float64, false),
130 ])
131 }
132
133 macro_rules! assert_optimized_plan_equal {
134 (
135 $plan:expr,
136 @ $expected:literal $(,)?
137 ) => {{
138 let options = ConfigOptions::default();
139 let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
140 .execute_and_check($plan, &options, |_, _| {})?;
141 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
142 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(EliminateNestedUnion::new())];
143 assert_optimized_plan_eq_snapshot!(
144 optimizer_ctx,
145 rules,
146 analyzed_plan,
147 @ $expected,
148 )
149 }};
150 }
151
152 #[test]
153 fn eliminate_nothing() -> Result<()> {
154 let plan_builder = table_scan(Some("table"), &schema(), None)?;
155
156 let plan = plan_builder.clone().union(plan_builder.build()?)?.build()?;
157
158 assert_optimized_plan_equal!(plan, @r"
159 Union
160 TableScan: table
161 TableScan: table
162 ")
163 }
164
165 #[test]
166 fn eliminate_distinct_nothing() -> Result<()> {
167 let plan_builder = table_scan(Some("table"), &schema(), None)?;
168
169 let plan = plan_builder
170 .clone()
171 .union_distinct(plan_builder.build()?)?
172 .build()?;
173
174 assert_optimized_plan_equal!(plan, @r"
175 Distinct:
176 Union
177 TableScan: table
178 TableScan: table
179 ")
180 }
181
182 #[test]
183 fn eliminate_nested_union() -> Result<()> {
184 let plan_builder = table_scan(Some("table"), &schema(), None)?;
185
186 let plan = plan_builder
187 .clone()
188 .union(plan_builder.clone().build()?)?
189 .union(plan_builder.clone().build()?)?
190 .union(plan_builder.build()?)?
191 .build()?;
192
193 assert_optimized_plan_equal!(plan, @r"
194 Union
195 TableScan: table
196 TableScan: table
197 TableScan: table
198 TableScan: table
199 ")
200 }
201
202 #[test]
203 fn eliminate_nested_union_with_distinct_union() -> Result<()> {
204 let plan_builder = table_scan(Some("table"), &schema(), None)?;
205
206 let plan = plan_builder
207 .clone()
208 .union_distinct(plan_builder.clone().build()?)?
209 .union(plan_builder.clone().build()?)?
210 .union(plan_builder.build()?)?
211 .build()?;
212
213 assert_optimized_plan_equal!(plan, @r"
214 Union
215 Distinct:
216 Union
217 TableScan: table
218 TableScan: table
219 TableScan: table
220 TableScan: table
221 ")
222 }
223
224 #[test]
225 fn eliminate_nested_distinct_union() -> Result<()> {
226 let plan_builder = table_scan(Some("table"), &schema(), None)?;
227
228 let plan = plan_builder
229 .clone()
230 .union(plan_builder.clone().build()?)?
231 .union_distinct(plan_builder.clone().build()?)?
232 .union(plan_builder.clone().build()?)?
233 .union_distinct(plan_builder.build()?)?
234 .build()?;
235
236 assert_optimized_plan_equal!(plan, @r"
237 Distinct:
238 Union
239 TableScan: table
240 TableScan: table
241 TableScan: table
242 TableScan: table
243 TableScan: table
244 ")
245 }
246
247 #[test]
248 fn eliminate_nested_distinct_union_with_distinct_table() -> Result<()> {
249 let plan_builder = table_scan(Some("table"), &schema(), None)?;
250
251 let plan = plan_builder
252 .clone()
253 .union_distinct(plan_builder.clone().distinct()?.build()?)?
254 .union(plan_builder.clone().distinct()?.build()?)?
255 .union_distinct(plan_builder.build()?)?
256 .build()?;
257
258 assert_optimized_plan_equal!(plan, @r"
259 Distinct:
260 Union
261 TableScan: table
262 TableScan: table
263 TableScan: table
264 TableScan: table
265 ")
266 }
267
268 #[test]
271 fn eliminate_nested_union_with_projection() -> Result<()> {
272 let plan_builder = table_scan(Some("table"), &schema(), None)?;
273
274 let plan = plan_builder
275 .clone()
276 .union(
277 plan_builder
278 .clone()
279 .project(vec![col("id").alias("table_id"), col("key"), col("value")])?
280 .build()?,
281 )?
282 .union(
283 plan_builder
284 .project(vec![col("id").alias("_id"), col("key"), col("value")])?
285 .build()?,
286 )?
287 .build()?;
288
289 assert_optimized_plan_equal!(plan, @r"
290 Union
291 TableScan: table
292 Projection: table.id AS id, table.key, table.value
293 TableScan: table
294 Projection: table.id AS id, table.key, table.value
295 TableScan: table
296 ")
297 }
298
299 #[test]
300 fn eliminate_nested_distinct_union_with_projection() -> Result<()> {
301 let plan_builder = table_scan(Some("table"), &schema(), None)?;
302
303 let plan = plan_builder
304 .clone()
305 .union_distinct(
306 plan_builder
307 .clone()
308 .project(vec![col("id").alias("table_id"), col("key"), col("value")])?
309 .build()?,
310 )?
311 .union_distinct(
312 plan_builder
313 .project(vec![col("id").alias("_id"), col("key"), col("value")])?
314 .build()?,
315 )?
316 .build()?;
317
318 assert_optimized_plan_equal!(plan, @r"
319 Distinct:
320 Union
321 TableScan: table
322 Projection: table.id AS id, table.key, table.value
323 TableScan: table
324 Projection: table.id AS id, table.key, table.value
325 TableScan: table
326 ")
327 }
328
329 #[test]
330 fn eliminate_nested_union_with_type_cast_projection() -> Result<()> {
331 let table_1 = table_scan(
332 Some("table_1"),
333 &Schema::new(vec![
334 Field::new("id", DataType::Int64, false),
335 Field::new("key", DataType::Utf8, false),
336 Field::new("value", DataType::Float64, false),
337 ]),
338 None,
339 )?;
340
341 let table_2 = table_scan(
342 Some("table_1"),
343 &Schema::new(vec![
344 Field::new("id", DataType::Int32, false),
345 Field::new("key", DataType::Utf8, false),
346 Field::new("value", DataType::Float32, false),
347 ]),
348 None,
349 )?;
350
351 let table_3 = table_scan(
352 Some("table_1"),
353 &Schema::new(vec![
354 Field::new("id", DataType::Int16, false),
355 Field::new("key", DataType::Utf8, false),
356 Field::new("value", DataType::Float32, false),
357 ]),
358 None,
359 )?;
360
361 let plan = table_1
362 .union(table_2.build()?)?
363 .union(table_3.build()?)?
364 .build()?;
365
366 assert_optimized_plan_equal!(plan, @r"
367 Union
368 TableScan: table_1
369 Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
370 TableScan: table_1
371 Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
372 TableScan: table_1
373 ")
374 }
375
376 #[test]
377 fn eliminate_nested_distinct_union_with_type_cast_projection() -> Result<()> {
378 let table_1 = table_scan(
379 Some("table_1"),
380 &Schema::new(vec![
381 Field::new("id", DataType::Int64, false),
382 Field::new("key", DataType::Utf8, false),
383 Field::new("value", DataType::Float64, false),
384 ]),
385 None,
386 )?;
387
388 let table_2 = table_scan(
389 Some("table_1"),
390 &Schema::new(vec![
391 Field::new("id", DataType::Int32, false),
392 Field::new("key", DataType::Utf8, false),
393 Field::new("value", DataType::Float32, false),
394 ]),
395 None,
396 )?;
397
398 let table_3 = table_scan(
399 Some("table_1"),
400 &Schema::new(vec![
401 Field::new("id", DataType::Int16, false),
402 Field::new("key", DataType::Utf8, false),
403 Field::new("value", DataType::Float32, false),
404 ]),
405 None,
406 )?;
407
408 let plan = table_1
409 .union_distinct(table_2.build()?)?
410 .union_distinct(table_3.build()?)?
411 .build()?;
412
413 assert_optimized_plan_equal!(plan, @r"
414 Distinct:
415 Union
416 TableScan: table_1
417 Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
418 TableScan: table_1
419 Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
420 TableScan: table_1
421 ")
422 }
423}