datafusion_optimizer/
decorrelate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`PullUpCorrelatedExpr`] converts correlated subqueries to `Joins`
19
20use std::collections::BTreeSet;
21use std::ops::Deref;
22use std::sync::Arc;
23
24use crate::simplify_expressions::ExprSimplifier;
25
26use datafusion_common::tree_node::{
27    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
28};
29use datafusion_common::{plan_err, Column, DFSchemaRef, HashMap, Result, ScalarValue};
30use datafusion_expr::expr::Alias;
31use datafusion_expr::simplify::SimplifyContext;
32use datafusion_expr::utils::{
33    collect_subquery_cols, conjunction, find_join_exprs, split_conjunction,
34};
35use datafusion_expr::{
36    expr, lit, BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan,
37    LogicalPlanBuilder, Operator,
38};
39use datafusion_physical_expr::execution_props::ExecutionProps;
40
41/// This struct rewrite the sub query plan by pull up the correlated
42/// expressions(contains outer reference columns) from the inner subquery's
43/// 'Filter'. It adds the inner reference columns to the 'Projection' or
44/// 'Aggregate' of the subquery if they are missing, so that they can be
45/// evaluated by the parent operator as the join condition.
46#[derive(Debug)]
47pub struct PullUpCorrelatedExpr {
48    pub join_filters: Vec<Expr>,
49    /// mapping from the plan to its holding correlated columns
50    pub correlated_subquery_cols_map: HashMap<LogicalPlan, BTreeSet<Column>>,
51    pub in_predicate_opt: Option<Expr>,
52    /// Is this an Exists(Not Exists) SubQuery. Defaults to **FALSE**
53    pub exists_sub_query: bool,
54    /// Can the correlated expressions be pulled up. Defaults to **TRUE**
55    pub can_pull_up: bool,
56    /// Indicates if we encounter any correlated expression that can not be pulled up
57    /// above a aggregation without changing the meaning of the query.
58    can_pull_over_aggregation: bool,
59    /// Do we need to handle [the count bug] during the pull up process.
60    ///
61    /// The "count bug" was described in [Optimization of Nested SQL
62    /// Queries Revisited](https://dl.acm.org/doi/pdf/10.1145/38714.38723). This bug is
63    /// not specific to the COUNT function, and it can occur with any aggregate function,
64    /// such as SUM, AVG, etc. The anomaly arises because aggregates fail to distinguish
65    /// between an empty set and null values when optimizing a correlated query as a join.
66    /// Here, we use "the count bug" to refer to all such cases.
67    ///
68    /// [the count bug]: https://github.com/apache/datafusion/issues/10553
69    pub need_handle_count_bug: bool,
70    /// mapping from the plan to its expressions' evaluation result on empty batch
71    pub collected_count_expr_map: HashMap<LogicalPlan, ExprResultMap>,
72    /// pull up having expr, which must be evaluated after the Join
73    pub pull_up_having_expr: Option<Expr>,
74    /// whether we have converted a scalar aggregation into a group aggregation. When unnesting
75    /// lateral joins, we need to produce a left outer join in such cases.
76    pub pulled_up_scalar_agg: bool,
77}
78
79impl Default for PullUpCorrelatedExpr {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl PullUpCorrelatedExpr {
86    pub fn new() -> Self {
87        Self {
88            join_filters: vec![],
89            correlated_subquery_cols_map: HashMap::new(),
90            in_predicate_opt: None,
91            exists_sub_query: false,
92            can_pull_up: true,
93            can_pull_over_aggregation: true,
94            need_handle_count_bug: false,
95            collected_count_expr_map: HashMap::new(),
96            pull_up_having_expr: None,
97            pulled_up_scalar_agg: false,
98        }
99    }
100
101    /// Set if we need to handle [the count bug] during the pull up process
102    ///
103    /// [the count bug]: https://github.com/apache/datafusion/issues/10553
104    pub fn with_need_handle_count_bug(mut self, need_handle_count_bug: bool) -> Self {
105        self.need_handle_count_bug = need_handle_count_bug;
106        self
107    }
108
109    /// Set the in_predicate_opt
110    pub fn with_in_predicate_opt(mut self, in_predicate_opt: Option<Expr>) -> Self {
111        self.in_predicate_opt = in_predicate_opt;
112        self
113    }
114
115    /// Set if this is an Exists(Not Exists) SubQuery
116    pub fn with_exists_sub_query(mut self, exists_sub_query: bool) -> Self {
117        self.exists_sub_query = exists_sub_query;
118        self
119    }
120}
121
122/// Used to indicate the unmatched rows from the inner(subquery) table after the left out Join
123/// This is used to handle [the Count bug]
124///
125/// [the Count bug]: https://github.com/apache/datafusion/issues/10553
126pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true";
127
128/// Mapping from expr display name to its evaluation result on empty record
129/// batch (for example: 'count(*)' is 'ScalarValue(0)', 'count(*) + 2' is
130/// 'ScalarValue(2)')
131pub type ExprResultMap = HashMap<String, Expr>;
132
133impl TreeNodeRewriter for PullUpCorrelatedExpr {
134    type Node = LogicalPlan;
135
136    fn f_down(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
137        match plan {
138            LogicalPlan::Filter(_) => Ok(Transformed::no(plan)),
139            LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => {
140                let plan_hold_outer = !plan.all_out_ref_exprs().is_empty();
141                if plan_hold_outer {
142                    // the unsupported case
143                    self.can_pull_up = false;
144                    Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
145                } else {
146                    Ok(Transformed::no(plan))
147                }
148            }
149            LogicalPlan::Limit(_) => {
150                let plan_hold_outer = !plan.all_out_ref_exprs().is_empty();
151                match (self.exists_sub_query, plan_hold_outer) {
152                    (false, true) => {
153                        // the unsupported case
154                        self.can_pull_up = false;
155                        Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
156                    }
157                    _ => Ok(Transformed::no(plan)),
158                }
159            }
160            _ if plan.contains_outer_reference() => {
161                // the unsupported cases, the plan expressions contain out reference columns(like window expressions)
162                self.can_pull_up = false;
163                Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
164            }
165            _ => Ok(Transformed::no(plan)),
166        }
167    }
168
169    fn f_up(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
170        let subquery_schema = plan.schema();
171        match &plan {
172            LogicalPlan::Filter(plan_filter) => {
173                let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);
174                self.can_pull_over_aggregation = self.can_pull_over_aggregation
175                    && subquery_filter_exprs
176                        .iter()
177                        .filter(|e| e.contains_outer())
178                        .all(|&e| can_pullup_over_aggregation(e));
179                let (mut join_filters, subquery_filters) =
180                    find_join_exprs(subquery_filter_exprs)?;
181                if let Some(in_predicate) = &self.in_predicate_opt {
182                    // in_predicate may be already included in the join filters, remove it from the join filters first.
183                    join_filters = remove_duplicated_filter(join_filters, in_predicate);
184                }
185                let correlated_subquery_cols =
186                    collect_subquery_cols(&join_filters, subquery_schema)?;
187                for expr in join_filters {
188                    if !self.join_filters.contains(&expr) {
189                        self.join_filters.push(expr)
190                    }
191                }
192
193                let mut expr_result_map_for_count_bug = HashMap::new();
194                let pull_up_expr_opt = if let Some(expr_result_map) =
195                    self.collected_count_expr_map.get(plan_filter.input.deref())
196                {
197                    if let Some(expr) = conjunction(subquery_filters.clone()) {
198                        filter_exprs_evaluation_result_on_empty_batch(
199                            &expr,
200                            Arc::clone(plan_filter.input.schema()),
201                            expr_result_map,
202                            &mut expr_result_map_for_count_bug,
203                        )?
204                    } else {
205                        None
206                    }
207                } else {
208                    None
209                };
210
211                match (&pull_up_expr_opt, &self.pull_up_having_expr) {
212                    (Some(_), Some(_)) => {
213                        // Error path
214                        plan_err!("Unsupported Subquery plan")
215                    }
216                    (Some(_), None) => {
217                        self.pull_up_having_expr = pull_up_expr_opt;
218                        let new_plan =
219                            LogicalPlanBuilder::from((*plan_filter.input).clone())
220                                .build()?;
221                        self.correlated_subquery_cols_map
222                            .insert(new_plan.clone(), correlated_subquery_cols);
223                        Ok(Transformed::yes(new_plan))
224                    }
225                    (None, _) => {
226                        // if the subquery still has filter expressions, restore them.
227                        let mut plan =
228                            LogicalPlanBuilder::from((*plan_filter.input).clone());
229                        if let Some(expr) = conjunction(subquery_filters) {
230                            plan = plan.filter(expr)?
231                        }
232                        let new_plan = plan.build()?;
233                        self.correlated_subquery_cols_map
234                            .insert(new_plan.clone(), correlated_subquery_cols);
235                        Ok(Transformed::yes(new_plan))
236                    }
237                }
238            }
239            LogicalPlan::Projection(projection)
240                if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() =>
241            {
242                let mut local_correlated_cols = BTreeSet::new();
243                collect_local_correlated_cols(
244                    &plan,
245                    &self.correlated_subquery_cols_map,
246                    &mut local_correlated_cols,
247                );
248                // add missing columns to Projection
249                let mut missing_exprs =
250                    self.collect_missing_exprs(&projection.expr, &local_correlated_cols)?;
251
252                let mut expr_result_map_for_count_bug = HashMap::new();
253                if let Some(expr_result_map) =
254                    self.collected_count_expr_map.get(projection.input.deref())
255                {
256                    proj_exprs_evaluation_result_on_empty_batch(
257                        &projection.expr,
258                        projection.input.schema(),
259                        expr_result_map,
260                        &mut expr_result_map_for_count_bug,
261                    )?;
262                    if !expr_result_map_for_count_bug.is_empty() {
263                        // has count bug
264                        let un_matched_row = Expr::Column(Column::new_unqualified(
265                            UN_MATCHED_ROW_INDICATOR.to_string(),
266                        ));
267                        // add the unmatched rows indicator to the Projection expressions
268                        missing_exprs.push(un_matched_row);
269                    }
270                }
271
272                let new_plan = LogicalPlanBuilder::from((*projection.input).clone())
273                    .project(missing_exprs)?
274                    .build()?;
275                if !expr_result_map_for_count_bug.is_empty() {
276                    self.collected_count_expr_map
277                        .insert(new_plan.clone(), expr_result_map_for_count_bug);
278                }
279                Ok(Transformed::yes(new_plan))
280            }
281            LogicalPlan::Aggregate(aggregate)
282                if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() =>
283            {
284                // If the aggregation is from a distinct it will not change the result for
285                // exists/in subqueries so we can still pull up all predicates.
286                let is_distinct = aggregate.aggr_expr.is_empty();
287                if !is_distinct {
288                    self.can_pull_up = self.can_pull_up && self.can_pull_over_aggregation;
289                }
290                let mut local_correlated_cols = BTreeSet::new();
291                collect_local_correlated_cols(
292                    &plan,
293                    &self.correlated_subquery_cols_map,
294                    &mut local_correlated_cols,
295                );
296                // add missing columns to Aggregation's group expressions
297                let mut missing_exprs = self.collect_missing_exprs(
298                    &aggregate.group_expr,
299                    &local_correlated_cols,
300                )?;
301
302                // if the original group expressions are empty, need to handle the Count bug
303                let mut expr_result_map_for_count_bug = HashMap::new();
304                if self.need_handle_count_bug
305                    && aggregate.group_expr.is_empty()
306                    && !missing_exprs.is_empty()
307                {
308                    agg_exprs_evaluation_result_on_empty_batch(
309                        &aggregate.aggr_expr,
310                        aggregate.input.schema(),
311                        &mut expr_result_map_for_count_bug,
312                    )?;
313                    if !expr_result_map_for_count_bug.is_empty() {
314                        // has count bug
315                        let un_matched_row = lit(true).alias(UN_MATCHED_ROW_INDICATOR);
316                        // add the unmatched rows indicator to the Aggregation's group expressions
317                        missing_exprs.push(un_matched_row);
318                    }
319                }
320                if aggregate.group_expr.is_empty() {
321                    // TODO: how do we handle the case where we have pulled multiple aggregations? For example,
322                    // a group agg with a scalar agg as child.
323                    self.pulled_up_scalar_agg = true;
324                }
325                let new_plan = LogicalPlanBuilder::from((*aggregate.input).clone())
326                    .aggregate(missing_exprs, aggregate.aggr_expr.to_vec())?
327                    .build()?;
328                if !expr_result_map_for_count_bug.is_empty() {
329                    self.collected_count_expr_map
330                        .insert(new_plan.clone(), expr_result_map_for_count_bug);
331                }
332                Ok(Transformed::yes(new_plan))
333            }
334            LogicalPlan::SubqueryAlias(alias) => {
335                let mut local_correlated_cols = BTreeSet::new();
336                collect_local_correlated_cols(
337                    &plan,
338                    &self.correlated_subquery_cols_map,
339                    &mut local_correlated_cols,
340                );
341                let mut new_correlated_cols = BTreeSet::new();
342                for col in local_correlated_cols.iter() {
343                    new_correlated_cols
344                        .insert(Column::new(Some(alias.alias.clone()), col.name.clone()));
345                }
346                self.correlated_subquery_cols_map
347                    .insert(plan.clone(), new_correlated_cols);
348                if let Some(input_map) =
349                    self.collected_count_expr_map.get(alias.input.deref())
350                {
351                    self.collected_count_expr_map
352                        .insert(plan.clone(), input_map.clone());
353                }
354                Ok(Transformed::no(plan))
355            }
356            LogicalPlan::Limit(limit) => {
357                let input_expr_map = self
358                    .collected_count_expr_map
359                    .get(limit.input.deref())
360                    .cloned();
361                // handling the limit clause in the subquery
362                let new_plan = match (self.exists_sub_query, self.join_filters.is_empty())
363                {
364                    // Correlated exist subquery, remove the limit(so that correlated expressions can pull up)
365                    (true, false) => Transformed::yes(match limit.get_fetch_type()? {
366                        FetchType::Literal(Some(0)) => {
367                            LogicalPlan::EmptyRelation(EmptyRelation {
368                                produce_one_row: false,
369                                schema: Arc::clone(limit.input.schema()),
370                            })
371                        }
372                        _ => LogicalPlanBuilder::from((*limit.input).clone()).build()?,
373                    }),
374                    _ => Transformed::no(plan),
375                };
376                if let Some(input_map) = input_expr_map {
377                    self.collected_count_expr_map
378                        .insert(new_plan.data.clone(), input_map);
379                }
380                Ok(new_plan)
381            }
382            _ => Ok(Transformed::no(plan)),
383        }
384    }
385}
386
387impl PullUpCorrelatedExpr {
388    fn collect_missing_exprs(
389        &self,
390        exprs: &[Expr],
391        correlated_subquery_cols: &BTreeSet<Column>,
392    ) -> Result<Vec<Expr>> {
393        let mut missing_exprs = vec![];
394        for expr in exprs {
395            if !missing_exprs.contains(expr) {
396                missing_exprs.push(expr.clone())
397            }
398        }
399        for col in correlated_subquery_cols.iter() {
400            let col_expr = Expr::Column(col.clone());
401            if !missing_exprs.contains(&col_expr) {
402                missing_exprs.push(col_expr)
403            }
404        }
405        if let Some(pull_up_having) = &self.pull_up_having_expr {
406            let filter_apply_columns = pull_up_having.column_refs();
407            for col in filter_apply_columns {
408                // add to missing_exprs if not already there
409                let contains = missing_exprs
410                    .iter()
411                    .any(|expr| matches!(expr, Expr::Column(c) if c == col));
412                if !contains {
413                    missing_exprs.push(Expr::Column(col.clone()))
414                }
415            }
416        }
417        Ok(missing_exprs)
418    }
419}
420
421fn can_pullup_over_aggregation(expr: &Expr) -> bool {
422    if let Expr::BinaryExpr(BinaryExpr {
423        left,
424        op: Operator::Eq,
425        right,
426    }) = expr
427    {
428        match (left.deref(), right.deref()) {
429            (Expr::Column(_), right) => !right.any_column_refs(),
430            (left, Expr::Column(_)) => !left.any_column_refs(),
431            (Expr::Cast(Cast { expr, .. }), right)
432                if matches!(expr.deref(), Expr::Column(_)) =>
433            {
434                !right.any_column_refs()
435            }
436            (left, Expr::Cast(Cast { expr, .. }))
437                if matches!(expr.deref(), Expr::Column(_)) =>
438            {
439                !left.any_column_refs()
440            }
441            (_, _) => false,
442        }
443    } else {
444        false
445    }
446}
447
448fn collect_local_correlated_cols(
449    plan: &LogicalPlan,
450    all_cols_map: &HashMap<LogicalPlan, BTreeSet<Column>>,
451    local_cols: &mut BTreeSet<Column>,
452) {
453    for child in plan.inputs() {
454        if let Some(cols) = all_cols_map.get(child) {
455            local_cols.extend(cols.clone());
456        }
457        // SubqueryAlias is treated as the leaf node
458        if !matches!(child, LogicalPlan::SubqueryAlias(_)) {
459            collect_local_correlated_cols(child, all_cols_map, local_cols);
460        }
461    }
462}
463
464fn remove_duplicated_filter(filters: Vec<Expr>, in_predicate: &Expr) -> Vec<Expr> {
465    filters
466        .into_iter()
467        .filter(|filter| {
468            if filter == in_predicate {
469                return false;
470            }
471
472            // ignore the binary order
473            !match (filter, in_predicate) {
474                (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => {
475                    (a_expr.op == b_expr.op)
476                        && (a_expr.left == b_expr.left && a_expr.right == b_expr.right)
477                        || (a_expr.left == b_expr.right && a_expr.right == b_expr.left)
478                }
479                _ => false,
480            }
481        })
482        .collect::<Vec<_>>()
483}
484
485fn agg_exprs_evaluation_result_on_empty_batch(
486    agg_expr: &[Expr],
487    schema: &DFSchemaRef,
488    expr_result_map_for_count_bug: &mut ExprResultMap,
489) -> Result<()> {
490    for e in agg_expr.iter() {
491        let result_expr = e
492            .clone()
493            .transform_up(|expr| {
494                let new_expr = match expr {
495                    Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => {
496                        if func.name() == "count" {
497                            Transformed::yes(Expr::Literal(
498                                ScalarValue::Int64(Some(0)),
499                                None,
500                            ))
501                        } else {
502                            Transformed::yes(Expr::Literal(ScalarValue::Null, None))
503                        }
504                    }
505                    _ => Transformed::no(expr),
506                };
507                Ok(new_expr)
508            })
509            .data()?;
510
511        let result_expr = result_expr.unalias();
512        let props = ExecutionProps::new();
513        let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema));
514        let simplifier = ExprSimplifier::new(info);
515        let result_expr = simplifier.simplify(result_expr)?;
516        expr_result_map_for_count_bug.insert(e.schema_name().to_string(), result_expr);
517    }
518    Ok(())
519}
520
521fn proj_exprs_evaluation_result_on_empty_batch(
522    proj_expr: &[Expr],
523    schema: &DFSchemaRef,
524    input_expr_result_map_for_count_bug: &ExprResultMap,
525    expr_result_map_for_count_bug: &mut ExprResultMap,
526) -> Result<()> {
527    for expr in proj_expr.iter() {
528        let result_expr = expr
529            .clone()
530            .transform_up(|expr| {
531                if let Expr::Column(Column { name, .. }) = &expr {
532                    if let Some(result_expr) =
533                        input_expr_result_map_for_count_bug.get(name)
534                    {
535                        Ok(Transformed::yes(result_expr.clone()))
536                    } else {
537                        Ok(Transformed::no(expr))
538                    }
539                } else {
540                    Ok(Transformed::no(expr))
541                }
542            })
543            .data()?;
544
545        if result_expr.ne(expr) {
546            let props = ExecutionProps::new();
547            let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema));
548            let simplifier = ExprSimplifier::new(info);
549            let result_expr = simplifier.simplify(result_expr)?;
550            let expr_name = match expr {
551                Expr::Alias(Alias { name, .. }) => name.to_string(),
552                Expr::Column(Column {
553                    relation: _,
554                    name,
555                    spans: _,
556                }) => name.to_string(),
557                _ => expr.schema_name().to_string(),
558            };
559            expr_result_map_for_count_bug.insert(expr_name, result_expr);
560        }
561    }
562    Ok(())
563}
564
565fn filter_exprs_evaluation_result_on_empty_batch(
566    filter_expr: &Expr,
567    schema: DFSchemaRef,
568    input_expr_result_map_for_count_bug: &ExprResultMap,
569    expr_result_map_for_count_bug: &mut ExprResultMap,
570) -> Result<Option<Expr>> {
571    let result_expr = filter_expr
572        .clone()
573        .transform_up(|expr| {
574            if let Expr::Column(Column { name, .. }) = &expr {
575                if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) {
576                    Ok(Transformed::yes(result_expr.clone()))
577                } else {
578                    Ok(Transformed::no(expr))
579                }
580            } else {
581                Ok(Transformed::no(expr))
582            }
583        })
584        .data()?;
585
586    let pull_up_expr = if result_expr.ne(filter_expr) {
587        let props = ExecutionProps::new();
588        let info = SimplifyContext::new(&props).with_schema(schema);
589        let simplifier = ExprSimplifier::new(info);
590        let result_expr = simplifier.simplify(result_expr)?;
591        match &result_expr {
592            // evaluate to false or null on empty batch, no need to pull up
593            Expr::Literal(ScalarValue::Null, _)
594            | Expr::Literal(ScalarValue::Boolean(Some(false)), _) => None,
595            // evaluate to true on empty batch, need to pull up the expr
596            Expr::Literal(ScalarValue::Boolean(Some(true)), _) => {
597                for (name, exprs) in input_expr_result_map_for_count_bug {
598                    expr_result_map_for_count_bug.insert(name.clone(), exprs.clone());
599                }
600                Some(filter_expr.clone())
601            }
602            // can not evaluate statically
603            _ => {
604                for input_expr in input_expr_result_map_for_count_bug.values() {
605                    let new_expr = Expr::Case(expr::Case {
606                        expr: None,
607                        when_then_expr: vec![(
608                            Box::new(result_expr.clone()),
609                            Box::new(input_expr.clone()),
610                        )],
611                        else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null, None))),
612                    });
613                    let expr_key = new_expr.schema_name().to_string();
614                    expr_result_map_for_count_bug.insert(expr_key, new_expr);
615                }
616                None
617            }
618        }
619    } else {
620        for (name, exprs) in input_expr_result_map_for_count_bug {
621            expr_result_map_for_count_bug.insert(name.clone(), exprs.clone());
622        }
623        None
624    };
625    Ok(pull_up_expr)
626}