1use crate::{OptimizerConfig, OptimizerRule};
22use datafusion_common::Result;
23use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
24use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
25use datafusion_expr::logical_plan::builder::LogicalPlanBuilder;
26use datafusion_expr::utils::disjunction;
27use datafusion_expr::{
28 Distinct, Expr, Filter, LogicalPlan, Projection, SubqueryAlias, Union,
29};
30use log::debug;
31use std::sync::Arc;
32
33#[derive(Default, Debug)]
34pub struct UnionsToFilter;
35
36impl UnionsToFilter {
37 #[expect(missing_docs)]
38 pub fn new() -> Self {
39 Self
40 }
41}
42
43impl OptimizerRule for UnionsToFilter {
44 fn name(&self) -> &str {
45 "unions_to_filter"
46 }
47
48 fn supports_rewrite(&self) -> bool {
49 true
50 }
51
52 fn rewrite(
53 &self,
54 plan: LogicalPlan,
55 config: &dyn OptimizerConfig,
56 ) -> Result<Transformed<LogicalPlan>> {
57 if !config.options().optimizer.enable_unions_to_filter {
58 return Ok(Transformed::no(plan));
59 }
60
61 if !plan.exists(|p| Ok(matches!(p, LogicalPlan::Distinct(Distinct::All(_)))))? {
66 return Ok(Transformed::no(plan));
67 }
68
69 plan.rewrite_with_subqueries(&mut UnionsToFilterRewriter)
70 }
71}
72
73struct UnionsToFilterRewriter;
74
75impl TreeNodeRewriter for UnionsToFilterRewriter {
76 type Node = LogicalPlan;
77
78 fn f_up(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
79 match &plan {
80 LogicalPlan::Distinct(Distinct::All(input)) => {
81 match try_rewrite_distinct_union(input.as_ref().clone())? {
82 Some(rewritten) => Ok(Transformed::yes(rewritten)),
83 None => Ok(Transformed::no(plan)),
84 }
85 }
86 _ => Ok(Transformed::no(plan)),
87 }
88 }
89}
90
91fn try_rewrite_distinct_union(plan: LogicalPlan) -> Result<Option<LogicalPlan>> {
92 let LogicalPlan::Union(Union { inputs, schema }) = plan else {
93 debug!("unions_to_filter skipped: input is not a UNION");
94 return Ok(None);
95 };
96
97 if inputs.len() < 2 {
98 debug!(
99 "unions_to_filter skipped: UNION has {} input(s), need at least 2",
100 inputs.len()
101 );
102 return Ok(None);
103 }
104
105 let mut grouped: Vec<(GroupKey, Vec<Expr>)> = Vec::new();
109 let mut transformed = false;
110
111 for input in inputs {
112 let Some(branch) = extract_branch(Arc::unwrap_or_clone(input))? else {
113 return Ok(None);
114 };
115
116 let key = GroupKey {
117 source: branch.source,
118 wrappers: branch.wrappers,
119 };
120 if let Some((_, conds)) = grouped.iter_mut().find(|(k, _)| k == &key) {
121 conds.push(branch.predicate);
122 transformed = true;
123 } else {
124 grouped.push((key, vec![branch.predicate]));
125 }
126 }
127
128 if !transformed {
129 debug!("unions_to_filter skipped: no branch groups could be merged");
130 return Ok(None);
131 }
132
133 let mut builder: Option<LogicalPlanBuilder> = None;
134 for (key, predicates) in grouped {
135 let combined =
136 disjunction(predicates).expect("union branches always provide predicates");
137 let branch = LogicalPlanBuilder::from(key.source)
138 .filter(combined)?
139 .build()?;
140 let branch = wrap_branch(branch, &key.wrappers)?;
141 let branch = coerce_plan_expr_for_schema(branch, &schema)?;
142 let branch = align_plan_to_schema(branch, Arc::clone(&schema))?;
143 builder = Some(match builder {
144 None => LogicalPlanBuilder::from(branch),
145 Some(builder) => builder.union(branch)?,
146 });
147 }
148
149 let union = builder
150 .expect("at least one branch after rewrite")
151 .build()?;
152 Ok(Some(LogicalPlan::Distinct(Distinct::All(Arc::new(union)))))
153}
154
155struct UnionBranch {
156 source: LogicalPlan,
157 predicate: Expr,
158 wrappers: Vec<Wrapper>,
159}
160
161fn extract_branch(plan: LogicalPlan) -> Result<Option<UnionBranch>> {
162 let (wrappers, plan) = peel_wrappers(plan);
163
164 if !wrapper_projections_are_safe(&wrappers) {
169 debug!(
170 "unions_to_filter skipped: projection wrapper contains volatile expression or subquery"
171 );
172 return Ok(None);
173 }
174
175 match plan {
176 LogicalPlan::Filter(Filter {
177 predicate, input, ..
178 }) => {
179 if !is_mergeable_predicate(&predicate) {
180 debug!(
181 "unions_to_filter skipped: branch predicate contains volatility or a subquery"
182 );
183 return Ok(None);
184 }
185 Ok(Some(UnionBranch {
186 source: strip_passthrough_nodes(Arc::unwrap_or_clone(input)),
187 predicate,
188 wrappers,
189 }))
190 }
191 LogicalPlan::Limit(_) => {
196 debug!("unions_to_filter skipped: branch contains LIMIT");
197 Ok(None)
198 }
199 LogicalPlan::Sort(_) => {
200 debug!("unions_to_filter skipped: branch contains ORDER BY / SORT");
201 Ok(None)
202 }
203 other => Ok(Some(UnionBranch {
204 source: strip_passthrough_nodes(other),
205 predicate: Expr::Literal(
206 datafusion_common::ScalarValue::Boolean(Some(true)),
207 None,
208 ),
209 wrappers,
210 })),
211 }
212}
213
214#[derive(Debug, Clone, PartialEq, Eq)]
215struct GroupKey {
216 source: LogicalPlan,
217 wrappers: Vec<Wrapper>,
218}
219
220#[derive(Debug, Clone, PartialEq, Eq)]
221enum Wrapper {
222 Projection {
223 expr: Vec<Expr>,
224 schema: datafusion_common::DFSchemaRef,
225 },
226 SubqueryAlias {
227 alias: datafusion_common::TableReference,
228 schema: datafusion_common::DFSchemaRef,
229 },
230}
231
232fn peel_wrappers(mut plan: LogicalPlan) -> (Vec<Wrapper>, LogicalPlan) {
233 let mut wrappers = vec![];
234 loop {
235 match plan {
236 LogicalPlan::Projection(Projection {
237 expr,
238 input,
239 schema,
240 ..
241 }) => {
242 wrappers.push(Wrapper::Projection { expr, schema });
243 plan = Arc::unwrap_or_clone(input);
244 }
245 LogicalPlan::SubqueryAlias(SubqueryAlias {
246 input,
247 alias,
248 schema,
249 ..
250 }) => {
251 wrappers.push(Wrapper::SubqueryAlias { alias, schema });
252 plan = Arc::unwrap_or_clone(input);
253 }
254 other => return (wrappers, other),
255 }
256 }
257}
258
259fn wrap_branch(mut plan: LogicalPlan, wrappers: &[Wrapper]) -> Result<LogicalPlan> {
260 for wrapper in wrappers.iter().rev() {
261 plan = match wrapper {
262 Wrapper::Projection { expr, schema } => {
263 LogicalPlan::Projection(Projection::try_new_with_schema(
264 expr.clone(),
265 Arc::new(plan),
266 Arc::clone(schema),
267 )?)
268 }
269 Wrapper::SubqueryAlias { alias, .. } => LogicalPlan::SubqueryAlias(
274 SubqueryAlias::try_new(Arc::new(plan), alias.clone())?,
275 ),
276 };
277 }
278 Ok(plan)
279}
280
281fn strip_passthrough_nodes(mut plan: LogicalPlan) -> LogicalPlan {
282 loop {
283 plan = match plan {
284 LogicalPlan::Projection(Projection { input, .. }) => {
285 Arc::unwrap_or_clone(input)
286 }
287 LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => {
288 Arc::unwrap_or_clone(input)
289 }
290 other => return other,
291 };
292 }
293}
294
295fn align_plan_to_schema(
296 plan: LogicalPlan,
297 schema: datafusion_common::DFSchemaRef,
298) -> Result<LogicalPlan> {
299 if plan.schema() == &schema {
300 return Ok(plan);
301 }
302
303 let expr = plan
304 .schema()
305 .iter()
306 .enumerate()
307 .map(|(i, _)| {
308 Expr::Column(datafusion_common::Column::from(
309 plan.schema().qualified_field(i),
310 ))
311 })
312 .collect::<Vec<_>>();
313
314 Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
315 expr,
316 Arc::new(plan),
317 schema,
318 )?))
319}
320
321fn is_mergeable_predicate(expr: &Expr) -> bool {
322 !expr.is_volatile() && !expr_contains_subquery(expr)
323}
324
325fn wrapper_projections_are_safe(wrappers: &[Wrapper]) -> bool {
333 wrappers.iter().all(|w| match w {
334 Wrapper::Projection { expr, .. } => expr
335 .iter()
336 .all(|e| !e.is_volatile() && !expr_contains_subquery(e)),
337 Wrapper::SubqueryAlias { .. } => true,
338 })
339}
340
341fn expr_contains_subquery(expr: &Expr) -> bool {
342 expr.exists(|e| match e {
343 Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true),
344 _ => Ok(false),
345 })
346 .expect("boolean expression walk is infallible")
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use crate::OptimizerContext;
353 use crate::assert_optimized_plan_eq_snapshot;
354 use crate::test::test_table_scan_with_name;
355 use arrow::datatypes::DataType;
356 use datafusion_common::Result;
357 use datafusion_expr::{
358 ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
359 Volatility, col, lit,
360 };
361
362 macro_rules! assert_optimized_plan_equal {
363 (
364 $plan:expr,
365 @ $expected:literal $(,)?
366 ) => {{
367 let mut options = datafusion_common::config::ConfigOptions::default();
368 options.optimizer.enable_unions_to_filter = true;
369 let optimizer_ctx = OptimizerContext::new_with_config_options(Arc::new(options))
370 .with_max_passes(1);
371 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> =
372 vec![Arc::new(UnionsToFilter::new())];
373 assert_optimized_plan_eq_snapshot!(
374 optimizer_ctx,
375 rules,
376 $plan,
377 @ $expected,
378 )
379 }};
380 }
381
382 #[derive(Debug, PartialEq, Eq, Hash)]
383 struct VolatileTestUdf;
384
385 impl ScalarUDFImpl for VolatileTestUdf {
386 fn name(&self) -> &str {
387 "volatile_test"
388 }
389
390 fn signature(&self) -> &Signature {
391 static SIGNATURE: std::sync::LazyLock<Signature> =
392 std::sync::LazyLock::new(|| Signature::nullary(Volatility::Volatile));
393 &SIGNATURE
394 }
395
396 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
397 Ok(DataType::Float64)
398 }
399
400 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
401 panic!("VolatileTestUdf is not intended for execution")
402 }
403 }
404
405 fn volatile_expr() -> Expr {
406 ScalarUDF::new_from_impl(VolatileTestUdf).call(vec![])
407 }
408
409 #[test]
410 fn rewrite_union_distinct_same_source_filters() -> Result<()> {
411 let left = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
412 .filter(col("a").eq(lit(1)))?
413 .build()?;
414 let right = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
415 .filter(col("a").eq(lit(2)))?
416 .build()?;
417
418 let plan = LogicalPlanBuilder::from(left)
419 .union_distinct(right)?
420 .build()?;
421
422 assert_optimized_plan_equal!(plan, @r"
423 Distinct:
424 Projection: t.a, t.b, t.c
425 Filter: t.a = Int32(1) OR t.a = Int32(2)
426 TableScan: t
427 ")?;
428 Ok(())
429 }
430
431 #[test]
432 fn keep_union_distinct_different_sources() -> Result<()> {
433 let left = LogicalPlanBuilder::from(test_table_scan_with_name("t1")?)
434 .filter(col("a").eq(lit(1)))?
435 .build()?;
436 let right = LogicalPlanBuilder::from(test_table_scan_with_name("t2")?)
437 .filter(col("a").eq(lit(2)))?
438 .build()?;
439
440 let plan = LogicalPlanBuilder::from(left)
441 .union_distinct(right)?
442 .build()?;
443
444 assert_optimized_plan_equal!(plan, @r"
445 Distinct:
446 Union
447 Filter: t1.a = Int32(1)
448 TableScan: t1
449 Filter: t2.a = Int32(2)
450 TableScan: t2
451 ")?;
452 Ok(())
453 }
454
455 #[test]
456 fn keep_union_distinct_with_volatile_predicate() -> Result<()> {
457 let left = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
458 .filter(volatile_expr().gt(lit(0.5_f64)))?
459 .build()?;
460 let right = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
461 .filter(col("a").eq(lit(2)))?
462 .build()?;
463
464 let plan = LogicalPlanBuilder::from(left)
465 .union_distinct(right)?
466 .build()?;
467
468 assert_optimized_plan_equal!(plan, @r"
469 Distinct:
470 Union
471 Filter: volatile_test() > Float64(0.5)
472 TableScan: t
473 Filter: t.a = Int32(2)
474 TableScan: t
475 ")?;
476 Ok(())
477 }
478
479 #[test]
480 fn rewrite_union_distinct_with_matching_projection_prefix() -> Result<()> {
481 let left = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?)
482 .project(vec![col("a").alias("mgr"), col("b").alias("comm")])?
483 .build()?;
484 let right = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?)
485 .filter(col("b").eq(lit(5)))?
486 .project(vec![col("a").alias("mgr"), col("b").alias("comm")])?
487 .build()?;
488
489 let plan = LogicalPlanBuilder::from(left)
490 .union_distinct(right)?
491 .build()?;
492
493 assert_optimized_plan_equal!(plan, @r"
494 Distinct:
495 Projection: emp.a AS mgr, emp.b AS comm
496 Filter: Boolean(true) OR emp.b = Int32(5)
497 TableScan: emp
498 ")?;
499 Ok(())
500 }
501
502 #[test]
506 fn keep_union_distinct_with_volatile_projection() -> Result<()> {
507 let left = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
509 .filter(col("a").eq(lit(1)))?
510 .project(vec![volatile_expr().alias("v"), col("a")])?
511 .build()?;
512 let right = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
513 .filter(col("a").eq(lit(2)))?
514 .project(vec![volatile_expr().alias("v"), col("a")])?
515 .build()?;
516
517 let plan = LogicalPlanBuilder::from(left)
518 .union_distinct(right)?
519 .build()?;
520
521 assert_optimized_plan_equal!(plan, @r"
522 Distinct:
523 Union
524 Projection: volatile_test() AS v, t.a
525 Filter: t.a = Int32(1)
526 TableScan: t
527 Projection: volatile_test() AS v, t.a
528 Filter: t.a = Int32(2)
529 TableScan: t
530 ")?;
531 Ok(())
532 }
533
534 #[test]
536 fn keep_union_distinct_with_subquery_in_projection() -> Result<()> {
537 use datafusion_expr::scalar_subquery;
538
539 let t2 = test_table_scan_with_name("t2")?;
541 let subquery_plan = Arc::new(
542 LogicalPlanBuilder::from(t2)
543 .filter(col("t2.a").eq(col("t.a")))?
544 .project(vec![col("t2.b")])?
545 .build()?,
546 );
547 let sq = scalar_subquery(subquery_plan);
548
549 let left = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
550 .filter(col("a").eq(lit(1)))?
551 .project(vec![sq.clone().alias("sub"), col("a")])?
552 .build()?;
553 let right = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
554 .filter(col("a").eq(lit(2)))?
555 .project(vec![sq.alias("sub"), col("a")])?
556 .build()?;
557
558 let plan = LogicalPlanBuilder::from(left)
559 .union_distinct(right)?
560 .build()?;
561
562 let optimized = {
564 let mut options = datafusion_common::config::ConfigOptions::default();
565 options.optimizer.enable_unions_to_filter = true;
566 let optimizer_ctx =
567 OptimizerContext::new_with_config_options(Arc::new(options))
568 .with_max_passes(1);
569 let rules: Vec<Arc<dyn OptimizerRule + Send + Sync>> =
570 vec![Arc::new(UnionsToFilter::new())];
571 crate::Optimizer::with_rules(rules).optimize(
572 plan.clone(),
573 &optimizer_ctx,
574 |_, _| {},
575 )?
576 };
577 assert!(
579 matches!(
580 &optimized,
581 LogicalPlan::Distinct(Distinct::All(inner))
582 if matches!(inner.as_ref(), LogicalPlan::Union(_))
583 ),
584 "expected Distinct(Union(...)) to be preserved, got:\n{optimized:?}"
585 );
586 Ok(())
587 }
588
589 #[test]
593 fn keep_union_distinct_with_limit_branches() -> Result<()> {
594 let left = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?)
595 .project(vec![col("a").alias("mgr"), col("b").alias("comm")])?
596 .limit(0, Some(2))?
597 .build()?;
598 let right = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?)
599 .project(vec![col("a").alias("mgr"), col("b").alias("comm")])?
600 .limit(0, Some(2))?
601 .build()?;
602
603 let plan = LogicalPlanBuilder::from(left)
604 .union_distinct(right)?
605 .build()?;
606
607 assert_optimized_plan_equal!(plan, @r"
608 Distinct:
609 Union
610 Limit: skip=0, fetch=2
611 Projection: emp.a AS mgr, emp.b AS comm
612 TableScan: emp
613 Limit: skip=0, fetch=2
614 Projection: emp.a AS mgr, emp.b AS comm
615 TableScan: emp
616 ")?;
617 Ok(())
618 }
619
620 #[test]
624 fn keep_union_distinct_with_sort_branches() -> Result<()> {
625 let left = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?)
626 .project(vec![col("a").alias("mgr"), col("b").alias("comm")])?
627 .sort(vec![col("a").sort(true, true)])?
628 .build()?;
629 let right = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?)
630 .project(vec![col("a").alias("mgr"), col("b").alias("comm")])?
631 .sort(vec![col("a").sort(true, true)])?
632 .build()?;
633
634 let plan = LogicalPlanBuilder::from(left)
635 .union_distinct(right)?
636 .build()?;
637
638 assert_optimized_plan_equal!(plan, @r"
639 Distinct:
640 Union
641 Projection: mgr, comm
642 Sort: emp.a ASC NULLS FIRST
643 Projection: emp.a AS mgr, emp.b AS comm, emp.a
644 TableScan: emp
645 Projection: mgr, comm
646 Sort: emp.a ASC NULLS FIRST
647 Projection: emp.a AS mgr, emp.b AS comm, emp.a
648 TableScan: emp
649 ")?;
650 Ok(())
651 }
652}