datafusion_optimizer/
decorrelate_lateral_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`DecorrelateLateralJoin`] decorrelates logical plans produced by lateral joins.
19
20use std::collections::BTreeSet;
21
22use crate::decorrelate::PullUpCorrelatedExpr;
23use crate::optimizer::ApplyOrder;
24use crate::{OptimizerConfig, OptimizerRule};
25use datafusion_expr::{lit, Join};
26
27use datafusion_common::tree_node::{
28    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29};
30use datafusion_common::Result;
31use datafusion_expr::logical_plan::JoinType;
32use datafusion_expr::utils::conjunction;
33use datafusion_expr::{LogicalPlan, LogicalPlanBuilder};
34
35/// Optimizer rule for rewriting lateral joins to joins
36#[derive(Default, Debug)]
37pub struct DecorrelateLateralJoin {}
38
39impl DecorrelateLateralJoin {
40    #[allow(missing_docs)]
41    pub fn new() -> Self {
42        Self::default()
43    }
44}
45
46impl OptimizerRule for DecorrelateLateralJoin {
47    fn supports_rewrite(&self) -> bool {
48        true
49    }
50
51    fn rewrite(
52        &self,
53        plan: LogicalPlan,
54        _config: &dyn OptimizerConfig,
55    ) -> Result<Transformed<LogicalPlan>> {
56        // Find cross joins with outer column references on the right side (i.e., the apply operator).
57        let LogicalPlan::Join(join) = plan else {
58            return Ok(Transformed::no(plan));
59        };
60
61        rewrite_internal(join)
62    }
63
64    fn name(&self) -> &str {
65        "decorrelate_lateral_join"
66    }
67
68    fn apply_order(&self) -> Option<ApplyOrder> {
69        Some(ApplyOrder::TopDown)
70    }
71}
72
73// Build the decorrelated join based on the original lateral join query. For now, we only support cross/inner
74// lateral joins.
75fn rewrite_internal(join: Join) -> Result<Transformed<LogicalPlan>> {
76    if join.join_type != JoinType::Inner {
77        return Ok(Transformed::no(LogicalPlan::Join(join)));
78    }
79
80    match join.right.apply_with_subqueries(|p| {
81        // TODO: support outer joins
82        if p.contains_outer_reference() {
83            Ok(TreeNodeRecursion::Stop)
84        } else {
85            Ok(TreeNodeRecursion::Continue)
86        }
87    })? {
88        TreeNodeRecursion::Stop => {}
89        TreeNodeRecursion::Continue => {
90            // The left side contains outer references, we need to decorrelate it.
91            return Ok(Transformed::new(
92                LogicalPlan::Join(join),
93                false,
94                TreeNodeRecursion::Jump,
95            ));
96        }
97        TreeNodeRecursion::Jump => {
98            unreachable!("")
99        }
100    }
101
102    let LogicalPlan::Subquery(subquery) = join.right.as_ref() else {
103        return Ok(Transformed::no(LogicalPlan::Join(join)));
104    };
105
106    if join.join_type != JoinType::Inner {
107        return Ok(Transformed::no(LogicalPlan::Join(join)));
108    }
109    let subquery_plan = subquery.subquery.as_ref();
110    let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
111    let rewritten_subquery = subquery_plan.clone().rewrite(&mut pull_up).data()?;
112    if !pull_up.can_pull_up {
113        return Ok(Transformed::no(LogicalPlan::Join(join)));
114    }
115
116    let mut all_correlated_cols = BTreeSet::new();
117    pull_up
118        .correlated_subquery_cols_map
119        .values()
120        .for_each(|cols| all_correlated_cols.extend(cols.clone()));
121    let join_filter_opt = conjunction(pull_up.join_filters);
122    let join_filter = match join_filter_opt {
123        Some(join_filter) => join_filter,
124        None => lit(true),
125    };
126    // -- inner join but the right side always has one row, we need to rewrite it to a left join
127    // SELECT * FROM t0, LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0);
128    // -- inner join but the right side number of rows is related to the filter (join) condition, so keep inner join.
129    // SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0);
130    let new_plan = LogicalPlanBuilder::from(join.left)
131        .join_on(
132            rewritten_subquery,
133            if pull_up.pulled_up_scalar_agg {
134                JoinType::Left
135            } else {
136                JoinType::Inner
137            },
138            Some(join_filter),
139        )?
140        .build()?;
141    // TODO: handle count(*) bug
142    Ok(Transformed::new(new_plan, true, TreeNodeRecursion::Jump))
143}