datafusion_optimizer/
rewrite_set_comparison.rs1use crate::{OptimizerConfig, OptimizerRule};
23use datafusion_common::tree_node::{Transformed, TreeNode};
24use datafusion_common::{Column, DFSchema, ExprSchema, Result, ScalarValue, plan_err};
25use datafusion_expr::expr::{self, Exists, SetComparison, SetQuantifier};
26use datafusion_expr::logical_plan::Subquery;
27use datafusion_expr::logical_plan::builder::LogicalPlanBuilder;
28use datafusion_expr::{Expr, LogicalPlan, lit};
29use std::sync::Arc;
30
31use datafusion_expr::utils::merge_schema;
32
33#[derive(Debug, Default)]
38pub struct RewriteSetComparison;
39
40impl RewriteSetComparison {
41 pub fn new() -> Self {
43 Self
44 }
45
46 fn rewrite_plan(&self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
47 let schema = merge_schema(&plan.inputs());
48 plan.map_expressions(|expr| {
49 expr.transform_up(|expr| rewrite_set_comparison(expr, &schema))
50 })
51 }
52}
53
54impl OptimizerRule for RewriteSetComparison {
55 fn name(&self) -> &str {
56 "rewrite_set_comparison"
57 }
58
59 fn rewrite(
60 &self,
61 plan: LogicalPlan,
62 _config: &dyn OptimizerConfig,
63 ) -> Result<Transformed<LogicalPlan>> {
64 plan.transform_up_with_subqueries(|plan| self.rewrite_plan(plan))
65 }
66}
67
68fn rewrite_set_comparison(
69 expr: Expr,
70 outer_schema: &DFSchema,
71) -> Result<Transformed<Expr>> {
72 match expr {
73 Expr::SetComparison(set_comparison) => {
74 let rewritten = build_set_comparison_subquery(set_comparison, outer_schema)?;
75 Ok(Transformed::yes(rewritten))
76 }
77 _ => Ok(Transformed::no(expr)),
78 }
79}
80
81fn build_set_comparison_subquery(
82 set_comparison: SetComparison,
83 outer_schema: &DFSchema,
84) -> Result<Expr> {
85 let SetComparison {
86 expr,
87 subquery,
88 op,
89 quantifier,
90 } = set_comparison;
91
92 let left_expr = to_outer_reference(*expr, outer_schema)?;
93 let subquery_schema = subquery.subquery.schema();
94 if subquery_schema.fields().is_empty() {
95 return plan_err!("single expression required.");
96 }
97 let right_expr = Expr::Column(Column::from(subquery_schema.qualified_field(0)));
99
100 let comparison = Expr::BinaryExpr(expr::BinaryExpr::new(
101 Box::new(left_expr),
102 op,
103 Box::new(right_expr),
104 ));
105
106 let true_exists =
107 exists_subquery(&subquery, Expr::IsTrue(Box::new(comparison.clone())))?;
108 let null_exists =
109 exists_subquery(&subquery, Expr::IsNull(Box::new(comparison.clone())))?;
110
111 let result_expr = match quantifier {
112 SetQuantifier::Any => Expr::Case(expr::Case {
113 expr: None,
114 when_then_expr: vec![
115 (Box::new(true_exists), Box::new(lit(true))),
116 (
117 Box::new(null_exists),
118 Box::new(Expr::Literal(ScalarValue::Boolean(None), None)),
119 ),
120 ],
121 else_expr: Some(Box::new(lit(false))),
122 }),
123 SetQuantifier::All => {
124 let false_exists =
125 exists_subquery(&subquery, Expr::IsFalse(Box::new(comparison.clone())))?;
126 Expr::Case(expr::Case {
127 expr: None,
128 when_then_expr: vec![
129 (Box::new(false_exists), Box::new(lit(false))),
130 (
131 Box::new(null_exists),
132 Box::new(Expr::Literal(ScalarValue::Boolean(None), None)),
133 ),
134 ],
135 else_expr: Some(Box::new(lit(true))),
136 })
137 }
138 };
139
140 Ok(result_expr)
141}
142
143fn exists_subquery(subquery: &Subquery, filter: Expr) -> Result<Expr> {
144 let plan = LogicalPlanBuilder::from(subquery.subquery.as_ref().clone())
145 .filter(filter)?
146 .build()?;
147 let outer_ref_columns = plan.all_out_ref_exprs();
148 Ok(Expr::Exists(Exists {
149 subquery: Subquery {
150 subquery: Arc::new(plan),
151 outer_ref_columns,
152 spans: subquery.spans.clone(),
153 },
154 negated: false,
155 }))
156}
157
158fn to_outer_reference(expr: Expr, outer_schema: &DFSchema) -> Result<Expr> {
159 expr.transform_up(|expr| match expr {
160 Expr::Column(col) => {
161 let field = outer_schema.field_from_column(&col)?;
162 Ok(Transformed::yes(Expr::OuterReferenceColumn(
163 Arc::clone(field),
164 col,
165 )))
166 }
167 Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)),
168 _ => Ok(Transformed::no(expr)),
169 })
170 .map(|t| t.data)
171}