1use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::Result;
22use datafusion_common::tree_node::Transformed;
23use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
24use datafusion_expr::{Distinct, LogicalPlan, Projection, Union};
25use itertools::Itertools;
26use std::sync::Arc;
27
28#[derive(Default, Debug)]
29pub struct OptimizeUnions;
33
34impl OptimizeUnions {
35 #[expect(missing_docs)]
36 pub fn new() -> Self {
37 Self {}
38 }
39}
40
41impl OptimizerRule for OptimizeUnions {
42 fn name(&self) -> &str {
43 "optimize_unions"
44 }
45
46 fn apply_order(&self) -> Option<ApplyOrder> {
47 Some(ApplyOrder::BottomUp)
48 }
49
50 fn supports_rewrite(&self) -> bool {
51 true
52 }
53
54 fn rewrite(
55 &self,
56 plan: LogicalPlan,
57 _config: &dyn OptimizerConfig,
58 ) -> Result<Transformed<LogicalPlan>> {
59 match plan {
60 LogicalPlan::Union(Union { mut inputs, .. }) if inputs.len() == 1 => Ok(
61 Transformed::yes(Arc::unwrap_or_clone(inputs.pop().unwrap())),
62 ),
63 LogicalPlan::Union(Union { inputs, schema }) => {
64 let inputs = inputs
65 .into_iter()
66 .flat_map(extract_plans_from_union)
67 .map(|plan| coerce_plan_expr_for_schema(plan, &schema))
68 .collect::<Result<Vec<_>>>()?;
69
70 Ok(Transformed::yes(LogicalPlan::Union(Union {
71 inputs: inputs.into_iter().map(Arc::new).collect_vec(),
72 schema,
73 })))
74 }
75 LogicalPlan::Distinct(Distinct::All(nested_plan)) => {
76 match Arc::unwrap_or_clone(nested_plan) {
77 LogicalPlan::Union(Union { inputs, schema }) => {
78 let inputs = inputs
79 .into_iter()
80 .map(extract_plan_from_distinct)
81 .flat_map(extract_plans_from_union)
82 .map(|plan| coerce_plan_expr_for_schema(plan, &schema))
83 .collect::<Result<Vec<_>>>()?;
84
85 Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All(
86 Arc::new(LogicalPlan::Union(Union {
87 inputs: inputs.into_iter().map(Arc::new).collect_vec(),
88 schema: Arc::clone(&schema),
89 })),
90 ))))
91 }
92 nested_plan => Ok(Transformed::no(LogicalPlan::Distinct(
93 Distinct::All(Arc::new(nested_plan)),
94 ))),
95 }
96 }
97 _ => Ok(Transformed::no(plan)),
98 }
99 }
100}
101
102fn extract_plans_from_union(plan: Arc<LogicalPlan>) -> Vec<LogicalPlan> {
103 match Arc::unwrap_or_clone(plan) {
104 LogicalPlan::Union(Union { inputs, .. }) => inputs
105 .into_iter()
106 .map(Arc::unwrap_or_clone)
107 .collect::<Vec<_>>(),
108 LogicalPlan::Projection(Projection {
116 expr,
117 input,
118 schema,
119 ..
120 }) => match Arc::unwrap_or_clone(input) {
121 LogicalPlan::Union(Union { inputs, .. }) => inputs
122 .into_iter()
123 .map(Arc::unwrap_or_clone)
124 .map(|plan| {
125 LogicalPlan::Projection(
126 Projection::try_new_with_schema(
127 expr.clone(),
128 Arc::new(plan),
129 Arc::clone(&schema),
130 )
131 .unwrap(),
132 )
133 })
134 .collect::<Vec<_>>(),
135
136 plan => vec![LogicalPlan::Projection(
137 Projection::try_new_with_schema(expr, Arc::new(plan), schema).unwrap(),
138 )],
139 },
140 plan => vec![plan],
141 }
142}
143
144fn extract_plan_from_distinct(plan: Arc<LogicalPlan>) -> Arc<LogicalPlan> {
145 match Arc::unwrap_or_clone(plan) {
146 LogicalPlan::Distinct(Distinct::All(plan)) => plan,
147 plan => Arc::new(plan),
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::OptimizerContext;
155 use crate::analyzer::Analyzer;
156 use crate::analyzer::type_coercion::TypeCoercion;
157 use crate::assert_optimized_plan_eq_snapshot;
158 use arrow::datatypes::{DataType, Field, Schema};
159 use datafusion_common::config::ConfigOptions;
160 use datafusion_expr::{col, logical_plan::table_scan};
161
162 fn schema() -> Schema {
163 Schema::new(vec![
164 Field::new("id", DataType::Int32, false),
165 Field::new("key", DataType::Utf8, false),
166 Field::new("value", DataType::Float64, false),
167 ])
168 }
169
170 macro_rules! assert_optimized_plan_equal {
171 (
172 $plan:expr,
173 @ $expected:literal $(,)?
174 ) => {{
175 let options = ConfigOptions::default();
176 let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
177 .execute_and_check($plan, &options, |_, _| {})?;
178 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
179 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(OptimizeUnions::new())];
180 assert_optimized_plan_eq_snapshot!(
181 optimizer_ctx,
182 rules,
183 analyzed_plan,
184 @ $expected,
185 )
186 }};
187 }
188
189 #[test]
190 fn eliminate_nothing() -> Result<()> {
191 let plan_builder = table_scan(Some("table"), &schema(), None)?;
192
193 let plan = plan_builder.clone().union(plan_builder.build()?)?.build()?;
194
195 assert_optimized_plan_equal!(plan, @r"
196 Union
197 TableScan: table
198 TableScan: table
199 ")
200 }
201
202 #[test]
203 fn eliminate_distinct_nothing() -> Result<()> {
204 let plan_builder = table_scan(Some("table"), &schema(), None)?;
205
206 let plan = plan_builder
207 .clone()
208 .union_distinct(plan_builder.build()?)?
209 .build()?;
210
211 assert_optimized_plan_equal!(plan, @r"
212 Distinct:
213 Union
214 TableScan: table
215 TableScan: table
216 ")
217 }
218
219 #[test]
220 fn eliminate_nested_union() -> Result<()> {
221 let plan_builder = table_scan(Some("table"), &schema(), None)?;
222
223 let plan = plan_builder
224 .clone()
225 .union(plan_builder.clone().build()?)?
226 .union(plan_builder.clone().build()?)?
227 .union(plan_builder.build()?)?
228 .build()?;
229
230 assert_optimized_plan_equal!(plan, @r"
231 Union
232 TableScan: table
233 TableScan: table
234 TableScan: table
235 TableScan: table
236 ")
237 }
238
239 #[test]
240 fn eliminate_nested_union_with_distinct_union() -> Result<()> {
241 let plan_builder = table_scan(Some("table"), &schema(), None)?;
242
243 let plan = plan_builder
244 .clone()
245 .union_distinct(plan_builder.clone().build()?)?
246 .union(plan_builder.clone().build()?)?
247 .union(plan_builder.build()?)?
248 .build()?;
249
250 assert_optimized_plan_equal!(plan, @r"
251 Union
252 Distinct:
253 Union
254 TableScan: table
255 TableScan: table
256 TableScan: table
257 TableScan: table
258 ")
259 }
260
261 #[test]
262 fn eliminate_nested_distinct_union() -> Result<()> {
263 let plan_builder = table_scan(Some("table"), &schema(), None)?;
264
265 let plan = plan_builder
266 .clone()
267 .union(plan_builder.clone().build()?)?
268 .union_distinct(plan_builder.clone().build()?)?
269 .union(plan_builder.clone().build()?)?
270 .union_distinct(plan_builder.build()?)?
271 .build()?;
272
273 assert_optimized_plan_equal!(plan, @r"
274 Distinct:
275 Union
276 TableScan: table
277 TableScan: table
278 TableScan: table
279 TableScan: table
280 TableScan: table
281 ")
282 }
283
284 #[test]
285 fn eliminate_nested_distinct_union_with_distinct_table() -> Result<()> {
286 let plan_builder = table_scan(Some("table"), &schema(), None)?;
287
288 let plan = plan_builder
289 .clone()
290 .union_distinct(plan_builder.clone().distinct()?.build()?)?
291 .union(plan_builder.clone().distinct()?.build()?)?
292 .union_distinct(plan_builder.build()?)?
293 .build()?;
294
295 assert_optimized_plan_equal!(plan, @r"
296 Distinct:
297 Union
298 TableScan: table
299 TableScan: table
300 TableScan: table
301 TableScan: table
302 ")
303 }
304
305 #[test]
308 fn eliminate_nested_union_with_projection() -> Result<()> {
309 let plan_builder = table_scan(Some("table"), &schema(), None)?;
310
311 let plan = plan_builder
312 .clone()
313 .union(
314 plan_builder
315 .clone()
316 .project(vec![col("id").alias("table_id"), col("key"), col("value")])?
317 .build()?,
318 )?
319 .union(
320 plan_builder
321 .project(vec![col("id").alias("_id"), col("key"), col("value")])?
322 .build()?,
323 )?
324 .build()?;
325
326 assert_optimized_plan_equal!(plan, @r"
327 Union
328 TableScan: table
329 Projection: table.id AS id, table.key, table.value
330 TableScan: table
331 Projection: table.id AS id, table.key, table.value
332 TableScan: table
333 ")
334 }
335
336 #[test]
337 fn eliminate_nested_distinct_union_with_projection() -> Result<()> {
338 let plan_builder = table_scan(Some("table"), &schema(), None)?;
339
340 let plan = plan_builder
341 .clone()
342 .union_distinct(
343 plan_builder
344 .clone()
345 .project(vec![col("id").alias("table_id"), col("key"), col("value")])?
346 .build()?,
347 )?
348 .union_distinct(
349 plan_builder
350 .project(vec![col("id").alias("_id"), col("key"), col("value")])?
351 .build()?,
352 )?
353 .build()?;
354
355 assert_optimized_plan_equal!(plan, @r"
356 Distinct:
357 Union
358 TableScan: table
359 Projection: table.id AS id, table.key, table.value
360 TableScan: table
361 Projection: table.id AS id, table.key, table.value
362 TableScan: table
363 ")
364 }
365
366 #[test]
367 fn eliminate_nested_union_in_projection() -> Result<()> {
368 let plan_builder = table_scan(Some("table"), &schema(), None)?;
369
370 let plan = plan_builder
371 .clone()
372 .union(plan_builder.clone().build()?)?
373 .project(vec![col("id").alias("table_id"), col("key"), col("value")])?
374 .union(plan_builder.build()?)?
375 .build()?;
376
377 assert_optimized_plan_equal!(plan, @r"
378 Union
379 Projection: id AS table_id, key, value
380 TableScan: table
381 Projection: id AS table_id, key, value
382 TableScan: table
383 TableScan: table
384 ")
385 }
386
387 #[test]
388 fn eliminate_nested_union_with_type_cast_projection() -> Result<()> {
389 let table_1 = table_scan(
390 Some("table_1"),
391 &Schema::new(vec![
392 Field::new("id", DataType::Int64, false),
393 Field::new("key", DataType::Utf8, false),
394 Field::new("value", DataType::Float64, false),
395 ]),
396 None,
397 )?;
398
399 let table_2 = table_scan(
400 Some("table_1"),
401 &Schema::new(vec![
402 Field::new("id", DataType::Int32, false),
403 Field::new("key", DataType::Utf8, false),
404 Field::new("value", DataType::Float32, false),
405 ]),
406 None,
407 )?;
408
409 let table_3 = table_scan(
410 Some("table_1"),
411 &Schema::new(vec![
412 Field::new("id", DataType::Int16, false),
413 Field::new("key", DataType::Utf8, false),
414 Field::new("value", DataType::Float32, false),
415 ]),
416 None,
417 )?;
418
419 let plan = table_1
420 .union(table_2.build()?)?
421 .union(table_3.build()?)?
422 .build()?;
423
424 assert_optimized_plan_equal!(plan, @r"
425 Union
426 TableScan: table_1
427 Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
428 TableScan: table_1
429 Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
430 TableScan: table_1
431 ")
432 }
433
434 #[test]
435 fn eliminate_nested_distinct_union_with_type_cast_projection() -> Result<()> {
436 let table_1 = table_scan(
437 Some("table_1"),
438 &Schema::new(vec![
439 Field::new("id", DataType::Int64, false),
440 Field::new("key", DataType::Utf8, false),
441 Field::new("value", DataType::Float64, false),
442 ]),
443 None,
444 )?;
445
446 let table_2 = table_scan(
447 Some("table_1"),
448 &Schema::new(vec![
449 Field::new("id", DataType::Int32, false),
450 Field::new("key", DataType::Utf8, false),
451 Field::new("value", DataType::Float32, false),
452 ]),
453 None,
454 )?;
455
456 let table_3 = table_scan(
457 Some("table_1"),
458 &Schema::new(vec![
459 Field::new("id", DataType::Int16, false),
460 Field::new("key", DataType::Utf8, false),
461 Field::new("value", DataType::Float32, false),
462 ]),
463 None,
464 )?;
465
466 let plan = table_1
467 .union_distinct(table_2.build()?)?
468 .union_distinct(table_3.build()?)?
469 .build()?;
470
471 assert_optimized_plan_equal!(plan, @r"
472 Distinct:
473 Union
474 TableScan: table_1
475 Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
476 TableScan: table_1
477 Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value
478 TableScan: table_1
479 ")
480 }
481
482 #[test]
483 fn eliminate_one_union() -> Result<()> {
484 let plan = table_scan(Some("table"), &schema(), None)?.build()?;
485 let schema = Arc::clone(plan.schema());
486 let plan = LogicalPlan::Union(Union {
489 inputs: vec![Arc::new(plan)],
490 schema,
491 });
492
493 assert_optimized_plan_eq_snapshot!(
497 OptimizerContext::new().with_max_passes(1),
498 vec![Arc::new(OptimizeUnions::new())],
499 plan,
500 @"TableScan: table"
501 )
502 }
503}