Skip to main content

datafusion_optimizer/
unions_to_filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Rewrites `UNION DISTINCT` branches that differ only by filter predicates
19//! into a single filtered branch plus `DISTINCT`.
20
21use crate::{OptimizerConfig, OptimizerRule};
22use datafusion_common::Result;
23use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
24use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
25use datafusion_expr::logical_plan::builder::LogicalPlanBuilder;
26use datafusion_expr::utils::disjunction;
27use datafusion_expr::{
28    Distinct, Expr, Filter, LogicalPlan, Projection, SubqueryAlias, Union,
29};
30use log::debug;
31use std::sync::Arc;
32
33#[derive(Default, Debug)]
34pub struct UnionsToFilter;
35
36impl UnionsToFilter {
37    #[expect(missing_docs)]
38    pub fn new() -> Self {
39        Self
40    }
41}
42
43impl OptimizerRule for UnionsToFilter {
44    fn name(&self) -> &str {
45        "unions_to_filter"
46    }
47
48    fn supports_rewrite(&self) -> bool {
49        true
50    }
51
52    fn rewrite(
53        &self,
54        plan: LogicalPlan,
55        config: &dyn OptimizerConfig,
56    ) -> Result<Transformed<LogicalPlan>> {
57        if !config.options().optimizer.enable_unions_to_filter {
58            return Ok(Transformed::no(plan));
59        }
60
61        // Fast pre-check: if the plan tree has no Distinct::All node at all we can
62        // skip the expensive bottom-up rewrite_with_subqueries traversal entirely.
63        // This matters for large UNION ALL plans (e.g. TPC-DS Q4) where the rule
64        // can never fire and the traversal overhead is otherwise measurable.
65        if !plan.exists(|p| Ok(matches!(p, LogicalPlan::Distinct(Distinct::All(_)))))? {
66            return Ok(Transformed::no(plan));
67        }
68
69        plan.rewrite_with_subqueries(&mut UnionsToFilterRewriter)
70    }
71}
72
73struct UnionsToFilterRewriter;
74
75impl TreeNodeRewriter for UnionsToFilterRewriter {
76    type Node = LogicalPlan;
77
78    fn f_up(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
79        match &plan {
80            LogicalPlan::Distinct(Distinct::All(input)) => {
81                match try_rewrite_distinct_union(input.as_ref().clone())? {
82                    Some(rewritten) => Ok(Transformed::yes(rewritten)),
83                    None => Ok(Transformed::no(plan)),
84                }
85            }
86            _ => Ok(Transformed::no(plan)),
87        }
88    }
89}
90
91fn try_rewrite_distinct_union(plan: LogicalPlan) -> Result<Option<LogicalPlan>> {
92    let LogicalPlan::Union(Union { inputs, schema }) = plan else {
93        debug!("unions_to_filter skipped: input is not a UNION");
94        return Ok(None);
95    };
96
97    if inputs.len() < 2 {
98        debug!(
99            "unions_to_filter skipped: UNION has {} input(s), need at least 2",
100            inputs.len()
101        );
102        return Ok(None);
103    }
104
105    // Use a Vec instead of HashMap: union branches are typically 2-10 entries,
106    // so a linear scan with PartialEq is faster than recursively hashing entire
107    // LogicalPlan subtrees (O(N * tree_size) hashing for every insert/lookup).
108    let mut grouped: Vec<(GroupKey, Vec<Expr>)> = Vec::new();
109    let mut transformed = false;
110
111    for input in inputs {
112        let Some(branch) = extract_branch(Arc::unwrap_or_clone(input))? else {
113            return Ok(None);
114        };
115
116        let key = GroupKey {
117            source: branch.source,
118            wrappers: branch.wrappers,
119        };
120        if let Some((_, conds)) = grouped.iter_mut().find(|(k, _)| k == &key) {
121            conds.push(branch.predicate);
122            transformed = true;
123        } else {
124            grouped.push((key, vec![branch.predicate]));
125        }
126    }
127
128    if !transformed {
129        debug!("unions_to_filter skipped: no branch groups could be merged");
130        return Ok(None);
131    }
132
133    let mut builder: Option<LogicalPlanBuilder> = None;
134    for (key, predicates) in grouped {
135        let combined =
136            disjunction(predicates).expect("union branches always provide predicates");
137        let branch = LogicalPlanBuilder::from(key.source)
138            .filter(combined)?
139            .build()?;
140        let branch = wrap_branch(branch, &key.wrappers)?;
141        let branch = coerce_plan_expr_for_schema(branch, &schema)?;
142        let branch = align_plan_to_schema(branch, Arc::clone(&schema))?;
143        builder = Some(match builder {
144            None => LogicalPlanBuilder::from(branch),
145            Some(builder) => builder.union(branch)?,
146        });
147    }
148
149    let union = builder
150        .expect("at least one branch after rewrite")
151        .build()?;
152    Ok(Some(LogicalPlan::Distinct(Distinct::All(Arc::new(union)))))
153}
154
155struct UnionBranch {
156    source: LogicalPlan,
157    predicate: Expr,
158    wrappers: Vec<Wrapper>,
159}
160
161fn extract_branch(plan: LogicalPlan) -> Result<Option<UnionBranch>> {
162    let (wrappers, plan) = peel_wrappers(plan);
163
164    // Volatile or subquery expressions in the projection must not be merged:
165    // they are evaluated once per branch in the original plan but would be
166    // evaluated once per combined row after the rewrite, which can change the
167    // output row set.
168    if !wrapper_projections_are_safe(&wrappers) {
169        debug!(
170            "unions_to_filter skipped: projection wrapper contains volatile expression or subquery"
171        );
172        return Ok(None);
173    }
174
175    match plan {
176        LogicalPlan::Filter(Filter {
177            predicate, input, ..
178        }) => {
179            if !is_mergeable_predicate(&predicate) {
180                debug!(
181                    "unions_to_filter skipped: branch predicate contains volatility or a subquery"
182                );
183                return Ok(None);
184            }
185            Ok(Some(UnionBranch {
186                source: strip_passthrough_nodes(Arc::unwrap_or_clone(input)),
187                predicate,
188                wrappers,
189            }))
190        }
191        // A Limit or Sort node changes the row-set semantics of the branch.
192        // Merging two such branches into one would silently drop the per-branch
193        // row restriction (LIMIT) or rely on an order guarantee that UNION does
194        // not preserve (ORDER BY).  Bail out to leave the UNION unchanged.
195        LogicalPlan::Limit(_) => {
196            debug!("unions_to_filter skipped: branch contains LIMIT");
197            Ok(None)
198        }
199        LogicalPlan::Sort(_) => {
200            debug!("unions_to_filter skipped: branch contains ORDER BY / SORT");
201            Ok(None)
202        }
203        other => Ok(Some(UnionBranch {
204            source: strip_passthrough_nodes(other),
205            predicate: Expr::Literal(
206                datafusion_common::ScalarValue::Boolean(Some(true)),
207                None,
208            ),
209            wrappers,
210        })),
211    }
212}
213
214#[derive(Debug, Clone, PartialEq, Eq)]
215struct GroupKey {
216    source: LogicalPlan,
217    wrappers: Vec<Wrapper>,
218}
219
220#[derive(Debug, Clone, PartialEq, Eq)]
221enum Wrapper {
222    Projection {
223        expr: Vec<Expr>,
224        schema: datafusion_common::DFSchemaRef,
225    },
226    SubqueryAlias {
227        alias: datafusion_common::TableReference,
228        schema: datafusion_common::DFSchemaRef,
229    },
230}
231
232fn peel_wrappers(mut plan: LogicalPlan) -> (Vec<Wrapper>, LogicalPlan) {
233    let mut wrappers = vec![];
234    loop {
235        match plan {
236            LogicalPlan::Projection(Projection {
237                expr,
238                input,
239                schema,
240                ..
241            }) => {
242                wrappers.push(Wrapper::Projection { expr, schema });
243                plan = Arc::unwrap_or_clone(input);
244            }
245            LogicalPlan::SubqueryAlias(SubqueryAlias {
246                input,
247                alias,
248                schema,
249                ..
250            }) => {
251                wrappers.push(Wrapper::SubqueryAlias { alias, schema });
252                plan = Arc::unwrap_or_clone(input);
253            }
254            other => return (wrappers, other),
255        }
256    }
257}
258
259fn wrap_branch(mut plan: LogicalPlan, wrappers: &[Wrapper]) -> Result<LogicalPlan> {
260    for wrapper in wrappers.iter().rev() {
261        plan = match wrapper {
262            Wrapper::Projection { expr, schema } => {
263                LogicalPlan::Projection(Projection::try_new_with_schema(
264                    expr.clone(),
265                    Arc::new(plan),
266                    Arc::clone(schema),
267                )?)
268            }
269            // SubqueryAlias::try_new recomputes the schema from the new input.
270            // This is safe because the source table is unchanged; only the
271            // filter predicate differs, so the recomputed schema matches the
272            // original one stored in peel_wrappers.
273            Wrapper::SubqueryAlias { alias, .. } => LogicalPlan::SubqueryAlias(
274                SubqueryAlias::try_new(Arc::new(plan), alias.clone())?,
275            ),
276        };
277    }
278    Ok(plan)
279}
280
281fn strip_passthrough_nodes(mut plan: LogicalPlan) -> LogicalPlan {
282    loop {
283        plan = match plan {
284            LogicalPlan::Projection(Projection { input, .. }) => {
285                Arc::unwrap_or_clone(input)
286            }
287            LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => {
288                Arc::unwrap_or_clone(input)
289            }
290            other => return other,
291        };
292    }
293}
294
295fn align_plan_to_schema(
296    plan: LogicalPlan,
297    schema: datafusion_common::DFSchemaRef,
298) -> Result<LogicalPlan> {
299    if plan.schema() == &schema {
300        return Ok(plan);
301    }
302
303    let expr = plan
304        .schema()
305        .iter()
306        .enumerate()
307        .map(|(i, _)| {
308            Expr::Column(datafusion_common::Column::from(
309                plan.schema().qualified_field(i),
310            ))
311        })
312        .collect::<Vec<_>>();
313
314    Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
315        expr,
316        Arc::new(plan),
317        schema,
318    )?))
319}
320
321fn is_mergeable_predicate(expr: &Expr) -> bool {
322    !expr.is_volatile() && !expr_contains_subquery(expr)
323}
324
325/// Returns `true` when every projection expression in `wrappers` is both
326/// non-volatile and subquery-free.
327///
328/// Volatile expressions (e.g. `random()`, `now()`) or correlated subqueries
329/// in the SELECT list cannot be safely merged: the original plan evaluates
330/// them once per branch execution, while the rewritten plan evaluates them
331/// once per combined row, which can change the set of output rows.
332fn wrapper_projections_are_safe(wrappers: &[Wrapper]) -> bool {
333    wrappers.iter().all(|w| match w {
334        Wrapper::Projection { expr, .. } => expr
335            .iter()
336            .all(|e| !e.is_volatile() && !expr_contains_subquery(e)),
337        Wrapper::SubqueryAlias { .. } => true,
338    })
339}
340
341fn expr_contains_subquery(expr: &Expr) -> bool {
342    expr.exists(|e| match e {
343        Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true),
344        _ => Ok(false),
345    })
346    .expect("boolean expression walk is infallible")
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use crate::OptimizerContext;
353    use crate::assert_optimized_plan_eq_snapshot;
354    use crate::test::test_table_scan_with_name;
355    use arrow::datatypes::DataType;
356    use datafusion_common::Result;
357    use datafusion_expr::{
358        ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
359        Volatility, col, lit,
360    };
361
362    macro_rules! assert_optimized_plan_equal {
363        (
364            $plan:expr,
365            @ $expected:literal $(,)?
366        ) => {{
367            let mut options = datafusion_common::config::ConfigOptions::default();
368            options.optimizer.enable_unions_to_filter = true;
369            let optimizer_ctx = OptimizerContext::new_with_config_options(Arc::new(options))
370                .with_max_passes(1);
371            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> =
372                vec![Arc::new(UnionsToFilter::new())];
373            assert_optimized_plan_eq_snapshot!(
374                optimizer_ctx,
375                rules,
376                $plan,
377                @ $expected,
378            )
379        }};
380    }
381
382    #[derive(Debug, PartialEq, Eq, Hash)]
383    struct VolatileTestUdf;
384
385    impl ScalarUDFImpl for VolatileTestUdf {
386        fn name(&self) -> &str {
387            "volatile_test"
388        }
389
390        fn signature(&self) -> &Signature {
391            static SIGNATURE: std::sync::LazyLock<Signature> =
392                std::sync::LazyLock::new(|| Signature::nullary(Volatility::Volatile));
393            &SIGNATURE
394        }
395
396        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
397            Ok(DataType::Float64)
398        }
399
400        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
401            panic!("VolatileTestUdf is not intended for execution")
402        }
403    }
404
405    fn volatile_expr() -> Expr {
406        ScalarUDF::new_from_impl(VolatileTestUdf).call(vec![])
407    }
408
409    #[test]
410    fn rewrite_union_distinct_same_source_filters() -> Result<()> {
411        let left = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
412            .filter(col("a").eq(lit(1)))?
413            .build()?;
414        let right = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
415            .filter(col("a").eq(lit(2)))?
416            .build()?;
417
418        let plan = LogicalPlanBuilder::from(left)
419            .union_distinct(right)?
420            .build()?;
421
422        assert_optimized_plan_equal!(plan, @r"
423        Distinct:
424          Projection: t.a, t.b, t.c
425            Filter: t.a = Int32(1) OR t.a = Int32(2)
426              TableScan: t
427        ")?;
428        Ok(())
429    }
430
431    #[test]
432    fn keep_union_distinct_different_sources() -> Result<()> {
433        let left = LogicalPlanBuilder::from(test_table_scan_with_name("t1")?)
434            .filter(col("a").eq(lit(1)))?
435            .build()?;
436        let right = LogicalPlanBuilder::from(test_table_scan_with_name("t2")?)
437            .filter(col("a").eq(lit(2)))?
438            .build()?;
439
440        let plan = LogicalPlanBuilder::from(left)
441            .union_distinct(right)?
442            .build()?;
443
444        assert_optimized_plan_equal!(plan, @r"
445        Distinct:
446          Union
447            Filter: t1.a = Int32(1)
448              TableScan: t1
449            Filter: t2.a = Int32(2)
450              TableScan: t2
451        ")?;
452        Ok(())
453    }
454
455    #[test]
456    fn keep_union_distinct_with_volatile_predicate() -> Result<()> {
457        let left = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
458            .filter(volatile_expr().gt(lit(0.5_f64)))?
459            .build()?;
460        let right = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
461            .filter(col("a").eq(lit(2)))?
462            .build()?;
463
464        let plan = LogicalPlanBuilder::from(left)
465            .union_distinct(right)?
466            .build()?;
467
468        assert_optimized_plan_equal!(plan, @r"
469        Distinct:
470          Union
471            Filter: volatile_test() > Float64(0.5)
472              TableScan: t
473            Filter: t.a = Int32(2)
474              TableScan: t
475        ")?;
476        Ok(())
477    }
478
479    #[test]
480    fn rewrite_union_distinct_with_matching_projection_prefix() -> Result<()> {
481        let left = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?)
482            .project(vec![col("a").alias("mgr"), col("b").alias("comm")])?
483            .build()?;
484        let right = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?)
485            .filter(col("b").eq(lit(5)))?
486            .project(vec![col("a").alias("mgr"), col("b").alias("comm")])?
487            .build()?;
488
489        let plan = LogicalPlanBuilder::from(left)
490            .union_distinct(right)?
491            .build()?;
492
493        assert_optimized_plan_equal!(plan, @r"
494        Distinct:
495          Projection: emp.a AS mgr, emp.b AS comm
496            Filter: Boolean(true) OR emp.b = Int32(5)
497              TableScan: emp
498        ")?;
499        Ok(())
500    }
501
502    /// A volatile expression in the **projection** (SELECT list) must block the
503    /// rewrite.  Each original branch evaluates it independently; merging them
504    /// would evaluate it once per combined row, changing the row set.
505    #[test]
506    fn keep_union_distinct_with_volatile_projection() -> Result<()> {
507        // Both branches project volatile_test() AS v over the same source.
508        let left = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
509            .filter(col("a").eq(lit(1)))?
510            .project(vec![volatile_expr().alias("v"), col("a")])?
511            .build()?;
512        let right = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
513            .filter(col("a").eq(lit(2)))?
514            .project(vec![volatile_expr().alias("v"), col("a")])?
515            .build()?;
516
517        let plan = LogicalPlanBuilder::from(left)
518            .union_distinct(right)?
519            .build()?;
520
521        assert_optimized_plan_equal!(plan, @r"
522        Distinct:
523          Union
524            Projection: volatile_test() AS v, t.a
525              Filter: t.a = Int32(1)
526                TableScan: t
527            Projection: volatile_test() AS v, t.a
528              Filter: t.a = Int32(2)
529                TableScan: t
530        ")?;
531        Ok(())
532    }
533
534    /// A scalar subquery in the **projection** must also block the rewrite.
535    #[test]
536    fn keep_union_distinct_with_subquery_in_projection() -> Result<()> {
537        use datafusion_expr::scalar_subquery;
538
539        // Build a simple scalar subquery: (SELECT t2.b FROM t2 WHERE t2.a = t.a)
540        let t2 = test_table_scan_with_name("t2")?;
541        let subquery_plan = Arc::new(
542            LogicalPlanBuilder::from(t2)
543                .filter(col("t2.a").eq(col("t.a")))?
544                .project(vec![col("t2.b")])?
545                .build()?,
546        );
547        let sq = scalar_subquery(subquery_plan);
548
549        let left = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
550            .filter(col("a").eq(lit(1)))?
551            .project(vec![sq.clone().alias("sub"), col("a")])?
552            .build()?;
553        let right = LogicalPlanBuilder::from(test_table_scan_with_name("t")?)
554            .filter(col("a").eq(lit(2)))?
555            .project(vec![sq.alias("sub"), col("a")])?
556            .build()?;
557
558        let plan = LogicalPlanBuilder::from(left)
559            .union_distinct(right)?
560            .build()?;
561
562        // Plan should be left untouched because the projection contains a subquery.
563        let optimized = {
564            let mut options = datafusion_common::config::ConfigOptions::default();
565            options.optimizer.enable_unions_to_filter = true;
566            let optimizer_ctx =
567                OptimizerContext::new_with_config_options(Arc::new(options))
568                    .with_max_passes(1);
569            let rules: Vec<Arc<dyn OptimizerRule + Send + Sync>> =
570                vec![Arc::new(UnionsToFilter::new())];
571            crate::Optimizer::with_rules(rules).optimize(
572                plan.clone(),
573                &optimizer_ctx,
574                |_, _| {},
575            )?
576        };
577        // The Distinct(Union(...)) structure must be preserved.
578        assert!(
579            matches!(
580                &optimized,
581                LogicalPlan::Distinct(Distinct::All(inner))
582                if matches!(inner.as_ref(), LogicalPlan::Union(_))
583            ),
584            "expected Distinct(Union(...)) to be preserved, got:\n{optimized:?}"
585        );
586        Ok(())
587    }
588
589    /// A UNION where both branches have a LIMIT must **not** be rewritten.
590    /// Each branch independently restricts the row-set; collapsing them into a
591    /// single branch would lose the per-branch LIMIT semantics.
592    #[test]
593    fn keep_union_distinct_with_limit_branches() -> Result<()> {
594        let left = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?)
595            .project(vec![col("a").alias("mgr"), col("b").alias("comm")])?
596            .limit(0, Some(2))?
597            .build()?;
598        let right = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?)
599            .project(vec![col("a").alias("mgr"), col("b").alias("comm")])?
600            .limit(0, Some(2))?
601            .build()?;
602
603        let plan = LogicalPlanBuilder::from(left)
604            .union_distinct(right)?
605            .build()?;
606
607        assert_optimized_plan_equal!(plan, @r"
608        Distinct:
609          Union
610            Limit: skip=0, fetch=2
611              Projection: emp.a AS mgr, emp.b AS comm
612                TableScan: emp
613            Limit: skip=0, fetch=2
614              Projection: emp.a AS mgr, emp.b AS comm
615                TableScan: emp
616        ")?;
617        Ok(())
618    }
619
620    /// A UNION where both branches have an ORDER BY (Sort) must **not** be
621    /// rewritten.  ORDER BY inside a UNION subquery does not guarantee ordering
622    /// in the result; merging the branches would silently discard the Sort.
623    #[test]
624    fn keep_union_distinct_with_sort_branches() -> Result<()> {
625        let left = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?)
626            .project(vec![col("a").alias("mgr"), col("b").alias("comm")])?
627            .sort(vec![col("a").sort(true, true)])?
628            .build()?;
629        let right = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?)
630            .project(vec![col("a").alias("mgr"), col("b").alias("comm")])?
631            .sort(vec![col("a").sort(true, true)])?
632            .build()?;
633
634        let plan = LogicalPlanBuilder::from(left)
635            .union_distinct(right)?
636            .build()?;
637
638        assert_optimized_plan_equal!(plan, @r"
639        Distinct:
640          Union
641            Projection: mgr, comm
642              Sort: emp.a ASC NULLS FIRST
643                Projection: emp.a AS mgr, emp.b AS comm, emp.a
644                  TableScan: emp
645            Projection: mgr, comm
646              Sort: emp.a ASC NULLS FIRST
647                Projection: emp.a AS mgr, emp.b AS comm, emp.a
648                  TableScan: emp
649        ")?;
650        Ok(())
651    }
652}