Skip to main content

datafusion_optimizer/
decorrelate_lateral_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`DecorrelateLateralJoin`] decorrelates logical plans produced by lateral joins.
19
20use std::sync::Arc;
21
22use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR};
23use crate::optimizer::ApplyOrder;
24use crate::utils::evaluates_to_null;
25use crate::{OptimizerConfig, OptimizerRule};
26use datafusion_expr::{Expr, Join, expr};
27
28use datafusion_common::tree_node::{
29    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
30};
31use datafusion_common::{Column, DFSchema, Result, ScalarValue, TableReference};
32use datafusion_expr::logical_plan::{JoinType, Subquery};
33use datafusion_expr::utils::conjunction;
34use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, SubqueryAlias};
35
36/// Optimizer rule for rewriting lateral joins to joins
37#[derive(Default, Debug)]
38pub struct DecorrelateLateralJoin {}
39
40impl DecorrelateLateralJoin {
41    #[expect(missing_docs)]
42    pub fn new() -> Self {
43        Self::default()
44    }
45}
46
47impl OptimizerRule for DecorrelateLateralJoin {
48    fn supports_rewrite(&self) -> bool {
49        true
50    }
51
52    fn rewrite(
53        &self,
54        plan: LogicalPlan,
55        _config: &dyn OptimizerConfig,
56    ) -> Result<Transformed<LogicalPlan>> {
57        // Find cross joins with outer column references on the right side (i.e., the apply operator).
58        let LogicalPlan::Join(join) = plan else {
59            return Ok(Transformed::no(plan));
60        };
61
62        rewrite_internal(join)
63    }
64
65    fn name(&self) -> &str {
66        "decorrelate_lateral_join"
67    }
68
69    fn apply_order(&self) -> Option<ApplyOrder> {
70        Some(ApplyOrder::TopDown)
71    }
72}
73
74// Build the decorrelated join based on the original lateral join query.
75// Supports INNER and LEFT lateral joins.
76fn rewrite_internal(join: Join) -> Result<Transformed<LogicalPlan>> {
77    if !matches!(join.join_type, JoinType::Inner | JoinType::Left) {
78        return Ok(Transformed::no(LogicalPlan::Join(join)));
79    }
80    let original_join_type = join.join_type;
81
82    // The right side is wrapped in a Subquery node when it contains outer
83    // references. Quickly skip joins that don't have this structure.
84    let Some((subquery, alias)) = extract_lateral_subquery(join.right.as_ref()) else {
85        return Ok(Transformed::no(LogicalPlan::Join(join)));
86    };
87
88    // If the subquery has no outer references, there is nothing to decorrelate.
89    // A LATERAL with no outer references is just a cross join.
90    let has_outer_refs = matches!(
91        subquery.subquery.apply_with_subqueries(|p| {
92            if p.contains_outer_reference() {
93                Ok(TreeNodeRecursion::Stop)
94            } else {
95                Ok(TreeNodeRecursion::Continue)
96            }
97        })?,
98        TreeNodeRecursion::Stop
99    );
100    if !has_outer_refs {
101        return Ok(Transformed::no(LogicalPlan::Join(join)));
102    }
103
104    let subquery_plan = subquery.subquery.as_ref();
105    let original_join_filter = join.filter.clone();
106
107    // Walk the subquery plan bottom-up, extracting correlated filter
108    // predicates into join conditions and converting ungrouped aggregates
109    // into group-by aggregates keyed on the correlation columns.
110    let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
111    let rewritten_subquery = subquery_plan.clone().rewrite(&mut pull_up).data()?;
112    if !pull_up.can_pull_up {
113        return Ok(Transformed::no(LogicalPlan::Join(join)));
114    }
115
116    // TODO: support HAVING in lateral subqueries.
117    // <https://github.com/apache/datafusion/issues/21198>
118    if pull_up.pull_up_having_expr.is_some() {
119        return Ok(Transformed::no(LogicalPlan::Join(join)));
120    }
121
122    // The correlation predicates (extracted from the subquery's WHERE) become
123    // the rewritten join's ON clause. See below for discussion of how the
124    // user's original ON clause is handled.
125    let correlation_filter = conjunction(pull_up.join_filters);
126
127    // Look up each aggregate's default value on empty input (e.g., COUNT → 0,
128    // SUM → NULL). This must happen before wrapping in SubqueryAlias, because
129    // the map is keyed by LogicalPlan and wrapping changes the plan.
130    let collected_count_expr_map = pull_up
131        .collected_count_expr_map
132        .get(&rewritten_subquery)
133        .cloned();
134
135    // Re-wrap in SubqueryAlias if the original had one, preserving the alias name.
136    // The SubqueryAlias re-qualifies all columns with the alias, so we must also
137    // rewrite column references in both the correlation and ON-clause filters.
138    let (right_plan, correlation_filter, original_join_filter) =
139        if let Some(ref alias) = alias {
140            let inner_schema = Arc::clone(rewritten_subquery.schema());
141            let right = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new(
142                Arc::new(rewritten_subquery),
143                alias.clone(),
144            )?);
145            let corr = correlation_filter
146                .map(|f| requalify_filter(f, &inner_schema, alias))
147                .transpose()?;
148            let on = original_join_filter
149                .map(|f| requalify_filter(f, &inner_schema, alias))
150                .transpose()?;
151            (right, corr, on)
152        } else {
153            (rewritten_subquery, correlation_filter, original_join_filter)
154        };
155
156    // For LEFT lateral joins, verify that all column references in the
157    // correlation filter are resolvable within the join's left and right
158    // schemas. If the lateral subquery references columns from an outer scope,
159    // the extracted filter will contain unresolvable columns and we must skip
160    // decorrelation.
161    //
162    // INNER lateral joins do not need this check: later optimizer passes
163    // (filter pushdown, join reordering) can restructure the plan to resolve
164    // cross-scope references. LEFT joins cannot be freely reordered.
165    if original_join_type == JoinType::Left
166        && let Some(ref filter) = correlation_filter
167    {
168        let left_schema = join.left.schema();
169        let right_schema = right_plan.schema();
170        let has_outer_scope_refs = filter
171            .column_refs()
172            .iter()
173            .any(|col| !left_schema.has_column(col) && !right_schema.has_column(col));
174        if has_outer_scope_refs {
175            return Ok(Transformed::no(LogicalPlan::Join(join)));
176        }
177    }
178
179    // Use a left join when the user wrote LEFT LATERAL or when a scalar
180    // aggregation was pulled up (preserves outer rows with no matches).
181    let join_type =
182        if original_join_type == JoinType::Left || pull_up.pulled_up_scalar_agg {
183            JoinType::Left
184        } else {
185            JoinType::Inner
186        };
187
188    // The correlation predicates (extracted from the subquery's WHERE) are
189    // turned into the rewritten join's ON clause. There are three cases that
190    // determine how the user's original ON clause is handled:
191    //
192    // - INNER lateral: user ON clause becomes a post-join filter. This restores
193    //   inner-join semantics if the join is upgraded to LEFT for count-bug
194    //   handling.
195    //
196    // - LEFT lateral with grouped (or no) agg: user ON clause is merged into
197    //   the rewritten ON clause, alongside the correlation predicates. LEFT
198    //   join semantics correctly preserve unmatched rows with NULLs.
199    //
200    // - LEFT lateral with an ungrouped aggregate (which decorrelation converts
201    //   to a group-by keyed on the correlation columns): user ON clause cannot
202    //   be placed in the join condition (it would conflict with count-bug
203    //   compensation) or as a post-join filter (that would remove
204    //   left-preserved rows). Instead, a projection is added after count-bug
205    //   compensation that replaces each right-side column with NULL when the ON
206    //   condition is not satisfied:
207    //
208    //      CASE WHEN (on_cond) IS NOT TRUE THEN NULL ELSE <col> END
209    //
210    //   This simulates LEFT JOIN semantics for the user's ON clause without
211    //   interfering with count-bug compensation.
212    let (join_filter, post_join_filter, on_condition_for_projection) =
213        if original_join_type == JoinType::Left {
214            if pull_up.pulled_up_scalar_agg {
215                (correlation_filter, None, original_join_filter)
216            } else {
217                let combined = conjunction(
218                    correlation_filter.into_iter().chain(original_join_filter),
219                );
220                (combined, None, None)
221            }
222        } else {
223            (correlation_filter, original_join_filter, None)
224        };
225
226    let left_field_count = join.left.schema().fields().len();
227    let new_plan = LogicalPlanBuilder::from(join.left)
228        .join_on(right_plan, join_type, join_filter)?
229        .build()?;
230
231    // Handle the count bug: in the rewritten left join, unmatched outer
232    // rows get NULLs for all right-side columns. But some aggregates
233    // have non-NULL defaults on empty input (e.g., COUNT returns 0, not
234    // NULL). Add a projection that wraps those columns:
235    //   CASE WHEN __always_true IS NULL THEN <default> ELSE <column> END
236    let new_plan = if let Some(expr_map) = collected_count_expr_map {
237        let join_schema = new_plan.schema();
238        let alias_qualifier = alias.as_ref();
239        let mut proj_exprs: Vec<Expr> = vec![];
240
241        for (i, (qualifier, field)) in join_schema.iter().enumerate() {
242            let col = Expr::Column(Column::new(qualifier.cloned(), field.name()));
243
244            // Only compensate right-side (subquery) fields. Left-side fields
245            // may share a name with an aggregate alias but must not be wrapped.
246            let name = field.name();
247            if i >= left_field_count
248                && let Some(default_value) = expr_map.get(name.as_str())
249                && !evaluates_to_null(default_value.clone(), default_value.column_refs())?
250            {
251                // Column whose aggregate doesn't naturally return NULL
252                // on empty input (e.g., COUNT returns 0). Wrap it.
253                let indicator_col =
254                    Column::new(alias_qualifier.cloned(), UN_MATCHED_ROW_INDICATOR);
255                let case_expr = Expr::Case(expr::Case {
256                    expr: None,
257                    when_then_expr: vec![(
258                        Box::new(Expr::IsNull(Box::new(Expr::Column(indicator_col)))),
259                        Box::new(default_value.clone()),
260                    )],
261                    else_expr: Some(Box::new(col)),
262                });
263                proj_exprs.push(Expr::Alias(expr::Alias {
264                    expr: Box::new(case_expr),
265                    relation: qualifier.cloned(),
266                    name: name.to_string(),
267                    metadata: None,
268                }));
269                continue;
270            }
271            proj_exprs.push(col);
272        }
273
274        LogicalPlanBuilder::from(new_plan)
275            .project(proj_exprs)?
276            .build()?
277    } else {
278        new_plan
279    };
280
281    // For LEFT lateral joins with an ungrouped aggregate, simulate LEFT JOIN
282    // semantics for the user's ON clause by adding a projection that replaces
283    // right-side columns with NULL when the ON condition is false (see
284    // commentary above).
285    //
286    // Note: the ON condition expression is duplicated per column, so this
287    // assumes it is deterministic.
288    let new_plan = if let Some(on_cond) = on_condition_for_projection {
289        let schema = Arc::clone(new_plan.schema());
290        let mut proj_exprs: Vec<Expr> = vec![];
291
292        for (i, (qualifier, field)) in schema.iter().enumerate() {
293            let col = Expr::Column(Column::new(qualifier.cloned(), field.name()));
294
295            if i < left_field_count {
296                proj_exprs.push(col);
297                continue;
298            }
299
300            let typed_null =
301                Expr::Literal(ScalarValue::try_from(field.data_type())?, None);
302            let case_expr = Expr::Case(expr::Case {
303                expr: None,
304                when_then_expr: vec![(
305                    Box::new(Expr::IsNotTrue(Box::new(on_cond.clone()))),
306                    Box::new(typed_null),
307                )],
308                else_expr: Some(Box::new(col)),
309            });
310            proj_exprs.push(case_expr.alias_qualified(qualifier.cloned(), field.name()));
311        }
312
313        LogicalPlanBuilder::from(new_plan)
314            .project(proj_exprs)?
315            .build()?
316    } else {
317        new_plan
318    };
319
320    // Apply the original ON clause as a post-join filter (INNER lateral only).
321    let new_plan = if let Some(on_filter) = post_join_filter {
322        LogicalPlanBuilder::from(new_plan)
323            .filter(on_filter)?
324            .build()?
325    } else {
326        new_plan
327    };
328
329    Ok(Transformed::new(new_plan, true, TreeNodeRecursion::Jump))
330}
331
332/// Extract the Subquery and optional alias from a lateral join's right side.
333fn extract_lateral_subquery(
334    plan: &LogicalPlan,
335) -> Option<(Subquery, Option<TableReference>)> {
336    match plan {
337        LogicalPlan::Subquery(sq) => Some((sq.clone(), None)),
338        LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => {
339            if let LogicalPlan::Subquery(sq) = input.as_ref() {
340                Some((sq.clone(), Some(alias.clone())))
341            } else {
342                None
343            }
344        }
345        _ => None,
346    }
347}
348
349/// Rewrite column references in a join filter expression so that columns
350/// belonging to the inner (right) side use the SubqueryAlias qualifier.
351///
352/// The `PullUpCorrelatedExpr` pass extracts join filters with the inner
353/// columns qualified by their original table names (e.g., `t2.t1_id`).
354/// When the inner plan is wrapped in a `SubqueryAlias("sub")`, those
355/// columns are re-qualified as `sub.t1_id`. This function applies the
356/// same requalification to the filter so it matches the aliased schema.
357fn requalify_filter(
358    filter: Expr,
359    inner_schema: &DFSchema,
360    alias: &TableReference,
361) -> Result<Expr> {
362    filter
363        .transform(|expr| {
364            if let Expr::Column(col) = &expr
365                && inner_schema.has_column(col)
366            {
367                let new_col = Column::new(Some(alias.clone()), col.name.clone());
368                return Ok(Transformed::yes(Expr::Column(new_col)));
369            }
370            Ok(Transformed::no(expr))
371        })
372        .data()
373}