Skip to main content

datafusion_optimizer/
decorrelate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`PullUpCorrelatedExpr`] converts correlated subqueries to `Joins`
19
20use std::collections::BTreeSet;
21use std::ops::Deref;
22use std::sync::Arc;
23
24use crate::simplify_expressions::ExprSimplifier;
25
26use datafusion_common::tree_node::{
27    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
28};
29use datafusion_common::{Column, DFSchemaRef, HashMap, Result, ScalarValue, plan_err};
30use datafusion_expr::expr::Alias;
31use datafusion_expr::simplify::SimplifyContext;
32use datafusion_expr::utils::{
33    collect_subquery_cols, conjunction, find_join_exprs, split_conjunction,
34};
35use datafusion_expr::{
36    BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan, LogicalPlanBuilder,
37    Operator, expr, lit,
38};
39
40/// This struct rewrite the sub query plan by pull up the correlated
41/// expressions(contains outer reference columns) from the inner subquery's
42/// 'Filter'. It adds the inner reference columns to the 'Projection' or
43/// 'Aggregate' of the subquery if they are missing, so that they can be
44/// evaluated by the parent operator as the join condition.
45#[derive(Debug)]
46pub struct PullUpCorrelatedExpr {
47    pub join_filters: Vec<Expr>,
48    /// mapping from the plan to its holding correlated columns
49    pub correlated_subquery_cols_map: HashMap<LogicalPlan, BTreeSet<Column>>,
50    pub in_predicate_opt: Option<Expr>,
51    /// Is this an Exists(Not Exists) SubQuery. Defaults to **FALSE**
52    pub exists_sub_query: bool,
53    /// Can the correlated expressions be pulled up. Defaults to **TRUE**
54    pub can_pull_up: bool,
55    /// Indicates if we encounter any correlated expression that can not be pulled up
56    /// above a aggregation without changing the meaning of the query.
57    can_pull_over_aggregation: bool,
58    /// Do we need to handle [the count bug] during the pull up process.
59    ///
60    /// The "count bug" was described in [Optimization of Nested SQL
61    /// Queries Revisited](https://dl.acm.org/doi/pdf/10.1145/38714.38723). This bug is
62    /// not specific to the COUNT function, and it can occur with any aggregate function,
63    /// such as SUM, AVG, etc. The anomaly arises because aggregates fail to distinguish
64    /// between an empty set and null values when optimizing a correlated query as a join.
65    /// Here, we use "the count bug" to refer to all such cases.
66    ///
67    /// [the count bug]: https://github.com/apache/datafusion/issues/10553
68    pub need_handle_count_bug: bool,
69    /// mapping from the plan to its expressions' evaluation result on empty batch
70    pub collected_count_expr_map: HashMap<LogicalPlan, ExprResultMap>,
71    /// pull up having expr, which must be evaluated after the Join
72    pub pull_up_having_expr: Option<Expr>,
73    /// whether we have converted a scalar aggregation into a group aggregation. When unnesting
74    /// lateral joins, we need to produce a left outer join in such cases.
75    pub pulled_up_scalar_agg: bool,
76}
77
78impl Default for PullUpCorrelatedExpr {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl PullUpCorrelatedExpr {
85    pub fn new() -> Self {
86        Self {
87            join_filters: vec![],
88            correlated_subquery_cols_map: HashMap::new(),
89            in_predicate_opt: None,
90            exists_sub_query: false,
91            can_pull_up: true,
92            can_pull_over_aggregation: true,
93            need_handle_count_bug: false,
94            collected_count_expr_map: HashMap::new(),
95            pull_up_having_expr: None,
96            pulled_up_scalar_agg: false,
97        }
98    }
99
100    /// Set if we need to handle [the count bug] during the pull up process
101    ///
102    /// [the count bug]: https://github.com/apache/datafusion/issues/10553
103    pub fn with_need_handle_count_bug(mut self, need_handle_count_bug: bool) -> Self {
104        self.need_handle_count_bug = need_handle_count_bug;
105        self
106    }
107
108    /// Set the in_predicate_opt
109    pub fn with_in_predicate_opt(mut self, in_predicate_opt: Option<Expr>) -> Self {
110        self.in_predicate_opt = in_predicate_opt;
111        self
112    }
113
114    /// Set if this is an Exists(Not Exists) SubQuery
115    pub fn with_exists_sub_query(mut self, exists_sub_query: bool) -> Self {
116        self.exists_sub_query = exists_sub_query;
117        self
118    }
119}
120
121/// Used to indicate the unmatched rows from the inner(subquery) table after the left out Join
122/// This is used to handle [the Count bug]
123///
124/// [the Count bug]: https://github.com/apache/datafusion/issues/10553
125pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true";
126
127/// Mapping from expr display name to its evaluation result on empty record
128/// batch (for example: 'count(*)' is 'ScalarValue(0)', 'count(*) + 2' is
129/// 'ScalarValue(2)')
130pub type ExprResultMap = HashMap<String, Expr>;
131
132impl TreeNodeRewriter for PullUpCorrelatedExpr {
133    type Node = LogicalPlan;
134
135    fn f_down(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
136        match plan {
137            LogicalPlan::Filter(_) => Ok(Transformed::no(plan)),
138            LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => {
139                let plan_hold_outer = !plan.all_out_ref_exprs().is_empty();
140                if plan_hold_outer {
141                    // the unsupported case
142                    self.can_pull_up = false;
143                    Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
144                } else {
145                    Ok(Transformed::no(plan))
146                }
147            }
148            LogicalPlan::Limit(_) => {
149                let plan_hold_outer = !plan.all_out_ref_exprs().is_empty();
150                match (self.exists_sub_query, plan_hold_outer) {
151                    (false, true) => {
152                        // the unsupported case
153                        self.can_pull_up = false;
154                        Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
155                    }
156                    _ => Ok(Transformed::no(plan)),
157                }
158            }
159            _ if plan.contains_outer_reference() => {
160                // the unsupported cases, the plan expressions contain out reference columns(like window expressions)
161                self.can_pull_up = false;
162                Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
163            }
164            _ => Ok(Transformed::no(plan)),
165        }
166    }
167
168    fn f_up(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
169        let subquery_schema = plan.schema();
170        match &plan {
171            LogicalPlan::Filter(plan_filter) => {
172                let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);
173                self.can_pull_over_aggregation = self.can_pull_over_aggregation
174                    && subquery_filter_exprs
175                        .iter()
176                        .filter(|e| e.contains_outer())
177                        .all(|&e| can_pullup_over_aggregation(e));
178                let (mut join_filters, subquery_filters) =
179                    find_join_exprs(subquery_filter_exprs)?;
180                if let Some(in_predicate) = &self.in_predicate_opt {
181                    // in_predicate may be already included in the join filters, remove it from the join filters first.
182                    join_filters = remove_duplicated_filter(join_filters, in_predicate);
183                }
184                let correlated_subquery_cols =
185                    collect_subquery_cols(&join_filters, subquery_schema)?;
186                for expr in join_filters {
187                    if !self.join_filters.contains(&expr) {
188                        self.join_filters.push(expr)
189                    }
190                }
191
192                let mut expr_result_map_for_count_bug = HashMap::new();
193                let pull_up_expr_opt = if let Some(expr_result_map) =
194                    self.collected_count_expr_map.get(plan_filter.input.deref())
195                {
196                    if let Some(expr) = conjunction(subquery_filters.clone()) {
197                        filter_exprs_evaluation_result_on_empty_batch(
198                            &expr,
199                            Arc::clone(plan_filter.input.schema()),
200                            expr_result_map,
201                            &mut expr_result_map_for_count_bug,
202                        )?
203                    } else {
204                        None
205                    }
206                } else {
207                    None
208                };
209
210                match (&pull_up_expr_opt, &self.pull_up_having_expr) {
211                    (Some(_), Some(_)) => {
212                        // Error path
213                        plan_err!("Unsupported Subquery plan")
214                    }
215                    (Some(_), None) => {
216                        self.pull_up_having_expr = pull_up_expr_opt;
217                        let new_plan =
218                            LogicalPlanBuilder::from((*plan_filter.input).clone())
219                                .build()?;
220                        self.correlated_subquery_cols_map
221                            .insert(new_plan.clone(), correlated_subquery_cols);
222                        Ok(Transformed::yes(new_plan))
223                    }
224                    (None, _) => {
225                        // if the subquery still has filter expressions, restore them.
226                        let mut plan =
227                            LogicalPlanBuilder::from((*plan_filter.input).clone());
228                        if let Some(expr) = conjunction(subquery_filters) {
229                            plan = plan.filter(expr)?
230                        }
231                        let new_plan = plan.build()?;
232                        self.correlated_subquery_cols_map
233                            .insert(new_plan.clone(), correlated_subquery_cols);
234                        Ok(Transformed::yes(new_plan))
235                    }
236                }
237            }
238            LogicalPlan::Projection(projection)
239                if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() =>
240            {
241                let mut local_correlated_cols = BTreeSet::new();
242                collect_local_correlated_cols(
243                    &plan,
244                    &self.correlated_subquery_cols_map,
245                    &mut local_correlated_cols,
246                );
247                // add missing columns to Projection
248                let mut missing_exprs =
249                    self.collect_missing_exprs(&projection.expr, &local_correlated_cols)?;
250
251                let mut expr_result_map_for_count_bug = HashMap::new();
252                if let Some(expr_result_map) =
253                    self.collected_count_expr_map.get(projection.input.deref())
254                {
255                    proj_exprs_evaluation_result_on_empty_batch(
256                        &projection.expr,
257                        projection.input.schema(),
258                        expr_result_map,
259                        &mut expr_result_map_for_count_bug,
260                    )?;
261                    if !expr_result_map_for_count_bug.is_empty() {
262                        // has count bug
263                        let un_matched_row = Expr::Column(Column::new_unqualified(
264                            UN_MATCHED_ROW_INDICATOR.to_string(),
265                        ));
266                        // add the unmatched rows indicator to the Projection expressions
267                        missing_exprs.push(un_matched_row);
268                    }
269                }
270
271                let new_plan = LogicalPlanBuilder::from((*projection.input).clone())
272                    .project(missing_exprs)?
273                    .build()?;
274                if !expr_result_map_for_count_bug.is_empty() {
275                    self.collected_count_expr_map
276                        .insert(new_plan.clone(), expr_result_map_for_count_bug);
277                }
278                Ok(Transformed::yes(new_plan))
279            }
280            LogicalPlan::Aggregate(aggregate)
281                if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() =>
282            {
283                // If the aggregation is from a distinct it will not change the result for
284                // exists/in subqueries so we can still pull up all predicates.
285                let is_distinct = aggregate.aggr_expr.is_empty();
286                if !is_distinct {
287                    self.can_pull_up = self.can_pull_up && self.can_pull_over_aggregation;
288                }
289                let mut local_correlated_cols = BTreeSet::new();
290                collect_local_correlated_cols(
291                    &plan,
292                    &self.correlated_subquery_cols_map,
293                    &mut local_correlated_cols,
294                );
295                // add missing columns to Aggregation's group expressions
296                let mut missing_exprs = self.collect_missing_exprs(
297                    &aggregate.group_expr,
298                    &local_correlated_cols,
299                )?;
300
301                // if the original group expressions are empty, need to handle the Count bug
302                let mut expr_result_map_for_count_bug = HashMap::new();
303                if self.need_handle_count_bug
304                    && aggregate.group_expr.is_empty()
305                    && !missing_exprs.is_empty()
306                {
307                    agg_exprs_evaluation_result_on_empty_batch(
308                        &aggregate.aggr_expr,
309                        aggregate.input.schema(),
310                        &mut expr_result_map_for_count_bug,
311                    )?;
312                    if !expr_result_map_for_count_bug.is_empty() {
313                        // has count bug
314                        let un_matched_row = lit(true).alias(UN_MATCHED_ROW_INDICATOR);
315                        // add the unmatched rows indicator to the Aggregation's group expressions
316                        missing_exprs.push(un_matched_row);
317                    }
318                }
319                if aggregate.group_expr.is_empty() {
320                    // TODO: how do we handle the case where we have pulled multiple aggregations? For example,
321                    // a group agg with a scalar agg as child.
322                    self.pulled_up_scalar_agg = true;
323                }
324                let new_plan = LogicalPlanBuilder::from((*aggregate.input).clone())
325                    .aggregate(missing_exprs, aggregate.aggr_expr.to_vec())?
326                    .build()?;
327                if !expr_result_map_for_count_bug.is_empty() {
328                    self.collected_count_expr_map
329                        .insert(new_plan.clone(), expr_result_map_for_count_bug);
330                }
331                Ok(Transformed::yes(new_plan))
332            }
333            LogicalPlan::SubqueryAlias(alias) => {
334                let mut local_correlated_cols = BTreeSet::new();
335                collect_local_correlated_cols(
336                    &plan,
337                    &self.correlated_subquery_cols_map,
338                    &mut local_correlated_cols,
339                );
340                let mut new_correlated_cols = BTreeSet::new();
341                for col in local_correlated_cols.iter() {
342                    new_correlated_cols
343                        .insert(Column::new(Some(alias.alias.clone()), col.name.clone()));
344                }
345                self.correlated_subquery_cols_map
346                    .insert(plan.clone(), new_correlated_cols);
347                if let Some(input_map) =
348                    self.collected_count_expr_map.get(alias.input.deref())
349                {
350                    self.collected_count_expr_map
351                        .insert(plan.clone(), input_map.clone());
352                }
353                Ok(Transformed::no(plan))
354            }
355            LogicalPlan::Limit(limit) => {
356                let input_expr_map = self
357                    .collected_count_expr_map
358                    .get(limit.input.deref())
359                    .cloned();
360                // handling the limit clause in the subquery
361                let new_plan = match (self.exists_sub_query, self.join_filters.is_empty())
362                {
363                    // Correlated exist subquery, remove the limit(so that correlated expressions can pull up)
364                    (true, false) => Transformed::yes(match limit.get_fetch_type()? {
365                        FetchType::Literal(Some(0)) => {
366                            LogicalPlan::EmptyRelation(EmptyRelation {
367                                produce_one_row: false,
368                                schema: Arc::clone(limit.input.schema()),
369                            })
370                        }
371                        _ => LogicalPlanBuilder::from((*limit.input).clone()).build()?,
372                    }),
373                    _ => Transformed::no(plan),
374                };
375                if let Some(input_map) = input_expr_map {
376                    self.collected_count_expr_map
377                        .insert(new_plan.data.clone(), input_map);
378                }
379                Ok(new_plan)
380            }
381            _ => Ok(Transformed::no(plan)),
382        }
383    }
384}
385
386impl PullUpCorrelatedExpr {
387    fn collect_missing_exprs(
388        &self,
389        exprs: &[Expr],
390        correlated_subquery_cols: &BTreeSet<Column>,
391    ) -> Result<Vec<Expr>> {
392        let mut missing_exprs = vec![];
393        for expr in exprs {
394            if !missing_exprs.contains(expr) {
395                missing_exprs.push(expr.clone())
396            }
397        }
398        for col in correlated_subquery_cols.iter() {
399            let col_expr = Expr::Column(col.clone());
400            if !missing_exprs.contains(&col_expr) {
401                missing_exprs.push(col_expr)
402            }
403        }
404        if let Some(pull_up_having) = &self.pull_up_having_expr {
405            let filter_apply_columns = pull_up_having.column_refs();
406            for col in filter_apply_columns {
407                // add to missing_exprs if not already there
408                let contains = missing_exprs
409                    .iter()
410                    .any(|expr| matches!(expr, Expr::Column(c) if c == col));
411                if !contains {
412                    missing_exprs.push(Expr::Column(col.clone()))
413                }
414            }
415        }
416        Ok(missing_exprs)
417    }
418}
419
420fn can_pullup_over_aggregation(expr: &Expr) -> bool {
421    if let Expr::BinaryExpr(BinaryExpr {
422        left,
423        op: Operator::Eq,
424        right,
425    }) = expr
426    {
427        match (left.deref(), right.deref()) {
428            (Expr::Column(_), right) => !right.any_column_refs(),
429            (left, Expr::Column(_)) => !left.any_column_refs(),
430            (Expr::Cast(Cast { expr, .. }), right)
431                if matches!(expr.deref(), Expr::Column(_)) =>
432            {
433                !right.any_column_refs()
434            }
435            (left, Expr::Cast(Cast { expr, .. }))
436                if matches!(expr.deref(), Expr::Column(_)) =>
437            {
438                !left.any_column_refs()
439            }
440            (_, _) => false,
441        }
442    } else {
443        false
444    }
445}
446
447fn collect_local_correlated_cols(
448    plan: &LogicalPlan,
449    all_cols_map: &HashMap<LogicalPlan, BTreeSet<Column>>,
450    local_cols: &mut BTreeSet<Column>,
451) {
452    for child in plan.inputs() {
453        if let Some(cols) = all_cols_map.get(child) {
454            local_cols.extend(cols.clone());
455        }
456        // SubqueryAlias is treated as the leaf node
457        if !matches!(child, LogicalPlan::SubqueryAlias(_)) {
458            collect_local_correlated_cols(child, all_cols_map, local_cols);
459        }
460    }
461}
462
463fn remove_duplicated_filter(filters: Vec<Expr>, in_predicate: &Expr) -> Vec<Expr> {
464    filters
465        .into_iter()
466        .filter(|filter| {
467            if filter == in_predicate {
468                return false;
469            }
470
471            // ignore the binary order
472            !match (filter, in_predicate) {
473                (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => {
474                    (a_expr.op == b_expr.op)
475                        && (a_expr.left == b_expr.left && a_expr.right == b_expr.right)
476                        || (a_expr.left == b_expr.right && a_expr.right == b_expr.left)
477                }
478                _ => false,
479            }
480        })
481        .collect::<Vec<_>>()
482}
483
484fn agg_exprs_evaluation_result_on_empty_batch(
485    agg_expr: &[Expr],
486    schema: &DFSchemaRef,
487    expr_result_map_for_count_bug: &mut ExprResultMap,
488) -> Result<()> {
489    for e in agg_expr.iter() {
490        let result_expr = e
491            .clone()
492            .transform_up(|expr| {
493                let new_expr = match expr {
494                    Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => {
495                        if func.name() == "count" {
496                            Transformed::yes(Expr::Literal(
497                                ScalarValue::Int64(Some(0)),
498                                None,
499                            ))
500                        } else {
501                            Transformed::yes(Expr::Literal(ScalarValue::Null, None))
502                        }
503                    }
504                    _ => Transformed::no(expr),
505                };
506                Ok(new_expr)
507            })
508            .data()?;
509
510        let result_expr = result_expr.unalias();
511        let info = SimplifyContext::default().with_schema(Arc::clone(schema));
512        let simplifier = ExprSimplifier::new(info);
513        let result_expr = simplifier.simplify(result_expr)?;
514        expr_result_map_for_count_bug.insert(e.schema_name().to_string(), result_expr);
515    }
516    Ok(())
517}
518
519fn proj_exprs_evaluation_result_on_empty_batch(
520    proj_expr: &[Expr],
521    schema: &DFSchemaRef,
522    input_expr_result_map_for_count_bug: &ExprResultMap,
523    expr_result_map_for_count_bug: &mut ExprResultMap,
524) -> Result<()> {
525    for expr in proj_expr.iter() {
526        let result_expr = expr
527            .clone()
528            .transform_up(|expr| {
529                if let Expr::Column(Column { name, .. }) = &expr {
530                    if let Some(result_expr) =
531                        input_expr_result_map_for_count_bug.get(name)
532                    {
533                        Ok(Transformed::yes(result_expr.clone()))
534                    } else {
535                        Ok(Transformed::no(expr))
536                    }
537                } else {
538                    Ok(Transformed::no(expr))
539                }
540            })
541            .data()?;
542
543        if result_expr.ne(expr) {
544            let info = SimplifyContext::default().with_schema(Arc::clone(schema));
545            let simplifier = ExprSimplifier::new(info);
546            let result_expr = simplifier.simplify(result_expr)?;
547            let expr_name = match expr {
548                Expr::Alias(Alias { name, .. }) => name.to_string(),
549                Expr::Column(Column {
550                    relation: _,
551                    name,
552                    spans: _,
553                }) => name.to_string(),
554                _ => expr.schema_name().to_string(),
555            };
556            expr_result_map_for_count_bug.insert(expr_name, result_expr);
557        }
558    }
559    Ok(())
560}
561
562fn filter_exprs_evaluation_result_on_empty_batch(
563    filter_expr: &Expr,
564    schema: DFSchemaRef,
565    input_expr_result_map_for_count_bug: &ExprResultMap,
566    expr_result_map_for_count_bug: &mut ExprResultMap,
567) -> Result<Option<Expr>> {
568    let result_expr = filter_expr
569        .clone()
570        .transform_up(|expr| {
571            if let Expr::Column(Column { name, .. }) = &expr {
572                if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) {
573                    Ok(Transformed::yes(result_expr.clone()))
574                } else {
575                    Ok(Transformed::no(expr))
576                }
577            } else {
578                Ok(Transformed::no(expr))
579            }
580        })
581        .data()?;
582
583    let pull_up_expr = if result_expr.ne(filter_expr) {
584        let info = SimplifyContext::default().with_schema(schema);
585        let simplifier = ExprSimplifier::new(info);
586        let result_expr = simplifier.simplify(result_expr)?;
587        match &result_expr {
588            // evaluate to false or null on empty batch, no need to pull up
589            Expr::Literal(ScalarValue::Null, _)
590            | Expr::Literal(ScalarValue::Boolean(Some(false)), _) => None,
591            // evaluate to true on empty batch, need to pull up the expr
592            Expr::Literal(ScalarValue::Boolean(Some(true)), _) => {
593                for (name, exprs) in input_expr_result_map_for_count_bug {
594                    expr_result_map_for_count_bug.insert(name.clone(), exprs.clone());
595                }
596                Some(filter_expr.clone())
597            }
598            // can not evaluate statically
599            _ => {
600                for input_expr in input_expr_result_map_for_count_bug.values() {
601                    let new_expr = Expr::Case(expr::Case {
602                        expr: None,
603                        when_then_expr: vec![(
604                            Box::new(result_expr.clone()),
605                            Box::new(input_expr.clone()),
606                        )],
607                        else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null, None))),
608                    });
609                    let expr_key = new_expr.schema_name().to_string();
610                    expr_result_map_for_count_bug.insert(expr_key, new_expr);
611                }
612                None
613            }
614        }
615    } else {
616        for (name, exprs) in input_expr_result_map_for_count_bug {
617            expr_result_map_for_count_bug.insert(name.clone(), exprs.clone());
618        }
619        None
620    };
621    Ok(pull_up_expr)
622}