Skip to main content

datafusion_optimizer/
rewrite_set_comparison.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Optimizer rule rewriting `SetComparison` subqueries (e.g. `= ANY`,
19//! `> ALL`) into boolean expressions built from `EXISTS` subqueries
20//! that capture SQL three-valued logic.
21
22use crate::{OptimizerConfig, OptimizerRule};
23use datafusion_common::tree_node::{Transformed, TreeNode};
24use datafusion_common::{Column, DFSchema, ExprSchema, Result, ScalarValue, plan_err};
25use datafusion_expr::expr::{self, Exists, SetComparison, SetQuantifier};
26use datafusion_expr::logical_plan::Subquery;
27use datafusion_expr::logical_plan::builder::LogicalPlanBuilder;
28use datafusion_expr::{Expr, LogicalPlan, lit};
29use std::sync::Arc;
30
31use datafusion_expr::utils::merge_schema;
32
33/// Rewrite `SetComparison` expressions to scalar subqueries that return the
34/// correct boolean value (including SQL NULL semantics). After this rule
35/// runs, later rules such as `ScalarSubqueryToJoin` can decorrelate and
36/// remove the remaining subquery.
37#[derive(Debug, Default)]
38pub struct RewriteSetComparison;
39
40impl RewriteSetComparison {
41    /// Create a new `RewriteSetComparison` optimizer rule.
42    pub fn new() -> Self {
43        Self
44    }
45
46    fn rewrite_plan(&self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
47        let schema = merge_schema(&plan.inputs());
48        plan.map_expressions(|expr| {
49            expr.transform_up(|expr| rewrite_set_comparison(expr, &schema))
50        })
51    }
52}
53
54impl OptimizerRule for RewriteSetComparison {
55    fn name(&self) -> &str {
56        "rewrite_set_comparison"
57    }
58
59    fn rewrite(
60        &self,
61        plan: LogicalPlan,
62        _config: &dyn OptimizerConfig,
63    ) -> Result<Transformed<LogicalPlan>> {
64        plan.transform_up_with_subqueries(|plan| self.rewrite_plan(plan))
65    }
66}
67
68fn rewrite_set_comparison(
69    expr: Expr,
70    outer_schema: &DFSchema,
71) -> Result<Transformed<Expr>> {
72    match expr {
73        Expr::SetComparison(set_comparison) => {
74            let rewritten = build_set_comparison_subquery(set_comparison, outer_schema)?;
75            Ok(Transformed::yes(rewritten))
76        }
77        _ => Ok(Transformed::no(expr)),
78    }
79}
80
81fn build_set_comparison_subquery(
82    set_comparison: SetComparison,
83    outer_schema: &DFSchema,
84) -> Result<Expr> {
85    let SetComparison {
86        expr,
87        subquery,
88        op,
89        quantifier,
90    } = set_comparison;
91
92    let left_expr = to_outer_reference(*expr, outer_schema)?;
93    let subquery_schema = subquery.subquery.schema();
94    if subquery_schema.fields().is_empty() {
95        return plan_err!("single expression required.");
96    }
97    // avoid `head_output_expr` for aggr/window plan, it will gives group-by expr if exists
98    let right_expr = Expr::Column(Column::from(subquery_schema.qualified_field(0)));
99
100    let comparison = Expr::BinaryExpr(expr::BinaryExpr::new(
101        Box::new(left_expr),
102        op,
103        Box::new(right_expr),
104    ));
105
106    let true_exists =
107        exists_subquery(&subquery, Expr::IsTrue(Box::new(comparison.clone())))?;
108    let null_exists =
109        exists_subquery(&subquery, Expr::IsNull(Box::new(comparison.clone())))?;
110
111    let result_expr = match quantifier {
112        SetQuantifier::Any => Expr::Case(expr::Case {
113            expr: None,
114            when_then_expr: vec![
115                (Box::new(true_exists), Box::new(lit(true))),
116                (
117                    Box::new(null_exists),
118                    Box::new(Expr::Literal(ScalarValue::Boolean(None), None)),
119                ),
120            ],
121            else_expr: Some(Box::new(lit(false))),
122        }),
123        SetQuantifier::All => {
124            let false_exists =
125                exists_subquery(&subquery, Expr::IsFalse(Box::new(comparison.clone())))?;
126            Expr::Case(expr::Case {
127                expr: None,
128                when_then_expr: vec![
129                    (Box::new(false_exists), Box::new(lit(false))),
130                    (
131                        Box::new(null_exists),
132                        Box::new(Expr::Literal(ScalarValue::Boolean(None), None)),
133                    ),
134                ],
135                else_expr: Some(Box::new(lit(true))),
136            })
137        }
138    };
139
140    Ok(result_expr)
141}
142
143fn exists_subquery(subquery: &Subquery, filter: Expr) -> Result<Expr> {
144    let plan = LogicalPlanBuilder::from(subquery.subquery.as_ref().clone())
145        .filter(filter)?
146        .build()?;
147    let outer_ref_columns = plan.all_out_ref_exprs();
148    Ok(Expr::Exists(Exists {
149        subquery: Subquery {
150            subquery: Arc::new(plan),
151            outer_ref_columns,
152            spans: subquery.spans.clone(),
153        },
154        negated: false,
155    }))
156}
157
158fn to_outer_reference(expr: Expr, outer_schema: &DFSchema) -> Result<Expr> {
159    expr.transform_up(|expr| match expr {
160        Expr::Column(col) => {
161            let field = outer_schema.field_from_column(&col)?;
162            Ok(Transformed::yes(Expr::OuterReferenceColumn(
163                Arc::clone(field),
164                col,
165            )))
166        }
167        Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)),
168        _ => Ok(Transformed::no(expr)),
169    })
170    .map(|t| t.data)
171}