Skip to main content

datafusion_optimizer/
scalar_subquery_to_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s
19
20use std::collections::{BTreeSet, HashMap};
21use std::sync::Arc;
22
23use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR};
24use crate::optimizer::ApplyOrder;
25use crate::utils::{evaluates_to_null, replace_qualified_name};
26use crate::{OptimizerConfig, OptimizerRule};
27
28use crate::analyzer::type_coercion::TypeCoercionRewriter;
29use datafusion_common::alias::AliasGenerator;
30use datafusion_common::tree_node::{
31    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
32};
33use datafusion_common::{Column, Result, ScalarValue, assert_or_internal_err, plan_err};
34use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
35use datafusion_expr::logical_plan::{JoinType, Subquery};
36use datafusion_expr::utils::conjunction;
37use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, lit, not, when};
38
39/// Optimizer rule that rewrites scalar subquery filters to joins and places an
40/// additional projection on top of the filter to preserve the original schema.
41///
42/// When [`datafusion_common::config::OptimizerOptions::enable_physical_uncorrelated_scalar_subquery`] is
43/// true (the default), only *correlated* scalar subqueries are rewritten here;
44/// uncorrelated ones are left for physical execution via `ScalarSubqueryExec`.
45/// When the option is false, all scalar subqueries — correlated and
46/// uncorrelated — are rewritten to left joins by this rule.
47#[derive(Default, Debug)]
48pub struct ScalarSubqueryToJoin {}
49
50impl ScalarSubqueryToJoin {
51    #[expect(missing_docs)]
52    pub fn new() -> Self {
53        Self::default()
54    }
55
56    /// Finds expressions that contain correlated scalar subqueries (and
57    /// recurses when found).
58    ///
59    /// # Arguments
60    /// * `predicate` - A conjunction to split and search.
61    /// * `alias_gen` - Generator used to produce unique aliases for each
62    ///   extracted scalar subquery (e.g. `__scalar_sq_1`, `__scalar_sq_2`).
63    ///   Each subquery is replaced by a column reference using the generated
64    ///   alias, and the same alias is later used to construct the join.
65    ///
66    /// Returns a tuple (subqueries, alias)
67    fn extract_subquery_exprs(
68        &self,
69        predicate: &Expr,
70        alias_gen: &Arc<AliasGenerator>,
71        physical_uncorrelated: bool,
72    ) -> Result<(Vec<(Subquery, String)>, Expr)> {
73        let mut extract = ExtractScalarSubQuery {
74            sub_query_info: vec![],
75            alias_gen,
76            physical_uncorrelated,
77        };
78        predicate
79            .clone()
80            .rewrite(&mut extract)
81            .data()
82            .map(|new_expr| (extract.sub_query_info, new_expr))
83    }
84}
85
86impl OptimizerRule for ScalarSubqueryToJoin {
87    fn supports_rewrite(&self) -> bool {
88        true
89    }
90
91    fn rewrite(
92        &self,
93        plan: LogicalPlan,
94        config: &dyn OptimizerConfig,
95    ) -> Result<Transformed<LogicalPlan>> {
96        match plan {
97            LogicalPlan::Filter(filter) => {
98                let physical_uncorrelated = config
99                    .options()
100                    .optimizer
101                    .enable_physical_uncorrelated_scalar_subquery;
102                // Optimization: skip the rest of the rule and its copies if
103                // there are no scalar subqueries this rule should rewrite
104                if !contains_scalar_subquery_to_rewrite(
105                    &filter.predicate,
106                    physical_uncorrelated,
107                ) {
108                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
109                }
110
111                let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs(
112                    &filter.predicate,
113                    config.alias_generator(),
114                    physical_uncorrelated,
115                )?;
116
117                assert_or_internal_err!(
118                    !subqueries.is_empty(),
119                    "Expected subqueries not found in filter"
120                );
121
122                // iterate through all subqueries in predicate, turning each into a left join
123                let mut cur_input = filter.input.as_ref().clone();
124                for (subquery, alias) in subqueries {
125                    if let Some((optimized_subquery, compensation_exprs)) =
126                        build_join(&subquery, &cur_input, &alias)?
127                    {
128                        if !compensation_exprs.is_empty() {
129                            rewrite_expr = rewrite_expr
130                                .transform_up(|expr| {
131                                    if let Some(compensation_expr) = expr
132                                        .try_as_col()
133                                        .and_then(|col| compensation_exprs.get(col))
134                                    {
135                                        Ok(Transformed::yes(compensation_expr.clone()))
136                                    } else {
137                                        Ok(Transformed::no(expr))
138                                    }
139                                })
140                                .data()?;
141                        }
142                        cur_input = optimized_subquery;
143                    } else {
144                        // if we can't handle all of the subqueries then bail for now
145                        return Ok(Transformed::no(LogicalPlan::Filter(filter)));
146                    }
147                }
148
149                // Preserve original schema as new Join might have more fields than what Filter & parents expect.
150                let projection =
151                    filter.input.schema().columns().into_iter().map(Expr::from);
152                let new_plan = LogicalPlanBuilder::from(cur_input)
153                    .filter(rewrite_expr)?
154                    .project(projection)?
155                    .build()?;
156                Ok(Transformed::yes(new_plan))
157            }
158            LogicalPlan::Projection(projection) => {
159                let physical_uncorrelated = config
160                    .options()
161                    .optimizer
162                    .enable_physical_uncorrelated_scalar_subquery;
163                // Optimization: skip the rest of the rule and its copies if there
164                // are no scalar subqueries this rule should rewrite
165                if !projection.expr.iter().any(|expr| {
166                    contains_scalar_subquery_to_rewrite(expr, physical_uncorrelated)
167                }) {
168                    return Ok(Transformed::no(LogicalPlan::Projection(projection)));
169                }
170
171                let mut all_subqueries = vec![];
172                let mut alias_to_index: HashMap<String, usize> = HashMap::new();
173                let mut rewrite_exprs: Vec<Expr> =
174                    Vec::with_capacity(projection.expr.len());
175                for (idx, expr) in projection.expr.iter().enumerate() {
176                    let (subqueries, rewrite_expr) = self.extract_subquery_exprs(
177                        expr,
178                        config.alias_generator(),
179                        physical_uncorrelated,
180                    )?;
181                    for (_, alias) in &subqueries {
182                        alias_to_index.insert(alias.clone(), idx);
183                    }
184                    all_subqueries.extend(subqueries);
185                    rewrite_exprs.push(rewrite_expr);
186                }
187                assert_or_internal_err!(
188                    !all_subqueries.is_empty(),
189                    "Expected subqueries not found in projection"
190                );
191                // iterate through all subqueries in predicate, turning each into a left join
192                let mut cur_input = projection.input.as_ref().clone();
193                for (subquery, alias) in all_subqueries {
194                    if let Some((optimized_subquery, compensation_exprs)) =
195                        build_join(&subquery, &cur_input, &alias)?
196                    {
197                        cur_input = optimized_subquery;
198                        if !compensation_exprs.is_empty()
199                            && let Some(&idx) = alias_to_index.get(&alias)
200                        {
201                            let new_expr = rewrite_exprs[idx]
202                                .clone()
203                                .transform_up(|expr| {
204                                    if let Some(compensation_expr) = expr
205                                        .try_as_col()
206                                        .and_then(|col| compensation_exprs.get(col))
207                                    {
208                                        Ok(Transformed::yes(compensation_expr.clone()))
209                                    } else {
210                                        Ok(Transformed::no(expr))
211                                    }
212                                })
213                                .data()?;
214                            rewrite_exprs[idx] = new_expr;
215                        }
216                    } else {
217                        // if we can't handle all of the subqueries then bail for now
218                        return Ok(Transformed::no(LogicalPlan::Projection(projection)));
219                    }
220                }
221
222                let mut proj_exprs = vec![];
223                for (expr, new_expr) in projection.expr.iter().zip(rewrite_exprs) {
224                    let old_expr_name = expr.schema_name().to_string();
225                    let new_expr_name = new_expr.schema_name().to_string();
226                    if new_expr_name != old_expr_name {
227                        proj_exprs.push(new_expr.alias(old_expr_name))
228                    } else {
229                        proj_exprs.push(new_expr);
230                    }
231                }
232                let new_plan = LogicalPlanBuilder::from(cur_input)
233                    .project(proj_exprs)?
234                    .build()?;
235                Ok(Transformed::yes(new_plan))
236            }
237
238            plan => Ok(Transformed::no(plan)),
239        }
240    }
241
242    fn name(&self) -> &str {
243        "scalar_subquery_to_join"
244    }
245
246    fn apply_order(&self) -> Option<ApplyOrder> {
247        Some(ApplyOrder::TopDown)
248    }
249}
250
251/// Returns true if the expression contains a scalar subquery that this rule
252/// should rewrite to a join.
253///
254/// When `enable_physical_uncorrelated_scalar_subquery` is true (the default) only
255/// correlated scalar subqueries are rewritten — uncorrelated ones are handled
256/// by the physical planner via `ScalarSubqueryExec`. When it is false, all
257/// scalar subqueries (correlated and uncorrelated) are rewritten.
258fn contains_scalar_subquery_to_rewrite(expr: &Expr, physical_uncorrelated: bool) -> bool {
259    expr.exists(|expr| {
260        Ok(matches!(
261            expr,
262            Expr::ScalarSubquery(sq)
263                if !physical_uncorrelated || !sq.outer_ref_columns.is_empty()
264        ))
265    })
266    .expect("Inner is always Ok")
267}
268
269struct ExtractScalarSubQuery<'a> {
270    sub_query_info: Vec<(Subquery, String)>,
271    alias_gen: &'a Arc<AliasGenerator>,
272    physical_uncorrelated: bool,
273}
274
275impl TreeNodeRewriter for ExtractScalarSubQuery<'_> {
276    type Node = Expr;
277
278    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
279        match expr {
280            // Match scalar subqueries this rule should rewrite to a join. When
281            // `physical_uncorrelated` is true, only correlated subqueries are
282            // rewritten — uncorrelated ones are handled later by the physical
283            // planner. When false, both are rewritten.
284            Expr::ScalarSubquery(ref subquery)
285                if !self.physical_uncorrelated
286                    || !subquery.outer_ref_columns.is_empty() =>
287            {
288                let subquery = subquery.clone();
289                let scalar_expr = subquery
290                    .subquery
291                    .head_output_expr()?
292                    .map_or(plan_err!("single expression required."), Ok)?;
293                let subqry_alias = self.alias_gen.next("__scalar_sq");
294                let col =
295                    create_col_from_scalar_expr(&scalar_expr, subqry_alias.clone())?;
296                self.sub_query_info.push((subquery, subqry_alias));
297                Ok(Transformed::new(
298                    Expr::Column(col),
299                    true,
300                    TreeNodeRecursion::Jump,
301                ))
302            }
303            _ => Ok(Transformed::no(expr)),
304        }
305    }
306}
307
308/// Takes a query like:
309///
310/// ```text
311/// select id from customers where balance >
312///     (select avg(total) from orders where orders.c_id = customers.id)
313/// ```
314///
315/// and optimizes it into:
316///
317/// ```text
318/// select c.id from customers c
319/// left join (select c_id, avg(total) from orders group by c_id) o
320/// on o.c_id = c.id
321/// where c.balance > o."avg(total)"
322/// ```
323///
324/// When [`datafusion_common::config::OptimizerOptions::enable_physical_uncorrelated_scalar_subquery`] is
325/// false, this function also handles uncorrelated scalar subqueries, rewriting
326/// them as a `Left Join: Filter: Boolean(true)` instead of leaving them for
327/// `ScalarSubqueryExec`.
328///
329/// # Arguments
330///
331/// * `subquery` - The scalar subquery to rewrite (correlated, or uncorrelated
332///   when `enable_physical_uncorrelated_scalar_subquery` is false).
333/// * `outer_input` - The outer plan that the decorrelated subquery is
334///   left-joined onto — the input of the `Filter` or `Projection` node
335///   that contained the subquery.
336/// * `subquery_alias` - The unique alias assigned to the decorrelated
337///   subquery; used both to qualify the join condition and to produce
338///   column references for the caller to substitute.
339///
340/// Returns `Ok(None)` if the subquery cannot be decorrelated. On success,
341/// returns the rewritten outer plan and a map from each count-bug-affected
342/// column to its `CASE WHEN __always_true IS NULL THEN ... END` compensation
343/// expression, which the caller must substitute into any expression that
344/// references those columns.
345fn build_join(
346    subquery: &Subquery,
347    outer_input: &LogicalPlan,
348    subquery_alias: &str,
349) -> Result<Option<(LogicalPlan, HashMap<Column, Expr>)>> {
350    // `build_join` also handles uncorrelated scalar subqueries (as a left
351    // join with `Boolean(true)`) when the
352    // `enable_physical_uncorrelated_scalar_subquery` option is disabled.
353    let subquery_plan = subquery.subquery.as_ref();
354    let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
355    let decorrelated_subquery = subquery_plan.clone().rewrite(&mut pull_up).data()?;
356    if !pull_up.can_pull_up {
357        return Ok(None);
358    }
359
360    let collected_count_expr_map = pull_up
361        .collected_count_expr_map
362        .get(&decorrelated_subquery)
363        .cloned();
364    let aliased_subquery = LogicalPlanBuilder::from(decorrelated_subquery)
365        .alias(subquery_alias.to_string())?
366        .build()?;
367
368    let all_correlated_cols: BTreeSet<Column> = pull_up
369        .correlated_subquery_cols_map
370        .values()
371        .flatten()
372        .cloned()
373        .collect();
374
375    // Correlated columns now live in the decorrelated subquery's output,
376    // so re-qualify them with the subquery alias.
377    let join_filter_opt =
378        conjunction(pull_up.join_filters).map_or(Ok(None), |filter| {
379            replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some)
380        })?;
381
382    // When pull-up did not extract any usable join keys (a correlated subquery
383    // whose predicate references only outer columns), fall back to `ON true`:
384    // the decorrelated subquery still yields at most one row per outer row
385    // because its aggregate is grouped by the (empty) set of correlated inner
386    // columns.
387    let join_filter = join_filter_opt.or_else(|| Some(lit(true)));
388
389    let new_plan = LogicalPlanBuilder::from(outer_input.clone())
390        .join_on(aliased_subquery, JoinType::Left, join_filter)?
391        .build()?;
392
393    // Add count-bug compensation for each of the subquery's projected
394    // expressions that yield non-NULL values on empty input. We wrap each
395    // such expression in a CASE that substitutes the empty-input value
396    // when the LEFT JOIN produced synthetic right-side NULLs (no inner
397    // row matched), and uses the actual right-side value (which may
398    // itself be NULL) otherwise.
399    let mut compensation_exprs = HashMap::new();
400    if let Some(expr_map) = collected_count_expr_map {
401        let mut expr_rewrite = TypeCoercionRewriter {
402            schema: new_plan.schema(),
403        };
404        let having_arm = pull_up
405            .pull_up_having_expr
406            .as_ref()
407            .map(|f| (not(f.clone()), lit(ScalarValue::Null)));
408        for (name, result) in expr_map {
409            if evaluates_to_null(result.clone(), result.column_refs())? {
410                // Aggregates whose empty-input value is NULL (max/min/sum/…)
411                // need no compensation: the LEFT JOIN already produces NULL
412                // for unmatched outer rows.
413                continue;
414            }
415
416            let indicator_col =
417                Column::new(Some(subquery_alias), UN_MATCHED_ROW_INDICATOR);
418            // Qualify with the subquery alias to avoid ambiguity when the
419            // outer table has a column with the same name as the aggregate.
420            let value_col = Column::new(Some(subquery_alias), name);
421
422            let mut builder = when(Expr::Column(indicator_col).is_null(), result);
423            if let Some((when_expr, then_expr)) = &having_arm {
424                builder = builder.when(when_expr.clone(), then_expr.clone());
425            }
426            let compensation_expr = builder.otherwise(Expr::Column(value_col.clone()))?;
427            compensation_exprs.insert(
428                value_col,
429                compensation_expr.rewrite(&mut expr_rewrite).data()?,
430            );
431        }
432    }
433
434    Ok(Some((new_plan, compensation_exprs)))
435}
436
437#[cfg(test)]
438mod tests {
439    use std::ops::Add;
440
441    use super::*;
442    use crate::test::*;
443
444    use arrow::datatypes::DataType;
445    use datafusion_expr::test::function_stub::sum;
446
447    use crate::assert_optimized_plan_eq_display_indent_snapshot;
448    use datafusion_expr::{Between, col, expr, out_ref_col, scalar_subquery};
449    use datafusion_functions_aggregate::min_max::{max, min};
450
451    macro_rules! assert_optimized_plan_equal {
452        (
453            $plan:expr,
454            @ $expected:literal $(,)?
455        ) => {{
456            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ScalarSubqueryToJoin::new());
457            assert_optimized_plan_eq_display_indent_snapshot!(
458                rule,
459                $plan,
460                @ $expected,
461            )
462        }};
463    }
464
465    /// Test multiple correlated subqueries
466    #[test]
467    fn multiple_subqueries() -> Result<()> {
468        let orders = Arc::new(
469            LogicalPlanBuilder::from(scan_tpch_table("orders"))
470                .filter(
471                    col("orders.o_custkey")
472                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
473                )?
474                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
475                .project(vec![max(col("orders.o_custkey"))])?
476                .build()?,
477        );
478
479        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
480            .filter(
481                lit(1)
482                    .lt(scalar_subquery(Arc::clone(&orders)))
483                    .and(lit(1).lt(scalar_subquery(orders))),
484            )?
485            .project(vec![col("customer.c_custkey")])?
486            .build()?;
487
488        assert_optimized_plan_equal!(
489            plan,
490            @r"
491        Projection: customer.c_custkey [c_custkey:Int64]
492          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
493            Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
494              Left Join:  Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
495                Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
496                  TableScan: customer [c_custkey:Int64, c_name:Utf8]
497                  SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
498                    Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
499                      Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
500                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
501                SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
502                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
503                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
504                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
505        "
506        )
507    }
508
509    /// Test recursive correlated subqueries
510    #[test]
511    fn recursive_subqueries() -> Result<()> {
512        let lineitem = Arc::new(
513            LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
514                .filter(
515                    col("lineitem.l_orderkey")
516                        .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
517                )?
518                .aggregate(
519                    Vec::<Expr>::new(),
520                    vec![sum(col("lineitem.l_extendedprice"))],
521                )?
522                .project(vec![sum(col("lineitem.l_extendedprice"))])?
523                .build()?,
524        );
525
526        let orders = Arc::new(
527            LogicalPlanBuilder::from(scan_tpch_table("orders"))
528                .filter(
529                    col("orders.o_custkey")
530                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey"))
531                        .and(col("orders.o_totalprice").lt(scalar_subquery(lineitem))),
532                )?
533                .aggregate(Vec::<Expr>::new(), vec![sum(col("orders.o_totalprice"))])?
534                .project(vec![sum(col("orders.o_totalprice"))])?
535                .build()?,
536        );
537
538        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
539            .filter(col("customer.c_acctbal").lt(scalar_subquery(orders)))?
540            .project(vec![col("customer.c_custkey")])?
541            .build()?;
542
543        assert_optimized_plan_equal!(
544            plan,
545            @r"
546        Projection: customer.c_custkey [c_custkey:Int64]
547          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
548            Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]
549              Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]
550                TableScan: customer [c_custkey:Int64, c_name:Utf8]
551                SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]
552                  Projection: sum(orders.o_totalprice), orders.o_custkey, __always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]
553                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, sum(orders.o_totalprice):Float64;N]
554                      Projection: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
555                        Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]
556                          Left Join:  Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]
557                            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
558                            SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]
559                              Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]
560                                Aggregate: groupBy=[[lineitem.l_orderkey, Boolean(true) AS __always_true]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, __always_true:Boolean, sum(lineitem.l_extendedprice):Float64;N]
561                                  TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]
562        "
563        )
564    }
565
566    /// Test for correlated scalar subquery filter with additional subquery filters
567    #[test]
568    fn scalar_subquery_with_subquery_filters() -> Result<()> {
569        let sq = Arc::new(
570            LogicalPlanBuilder::from(scan_tpch_table("orders"))
571                .filter(
572                    out_ref_col(DataType::Int64, "customer.c_custkey")
573                        .eq(col("orders.o_custkey"))
574                        .and(col("o_orderkey").eq(lit(1))),
575                )?
576                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
577                .project(vec![max(col("orders.o_custkey"))])?
578                .build()?,
579        );
580
581        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
582            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
583            .project(vec![col("customer.c_custkey")])?
584            .build()?;
585
586        assert_optimized_plan_equal!(
587            plan,
588            @r"
589        Projection: customer.c_custkey [c_custkey:Int64]
590          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
591            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
592              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
593                TableScan: customer [c_custkey:Int64, c_name:Utf8]
594                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
595                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
596                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
597                      Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
598                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
599        "
600        )
601    }
602
603    /// Test for correlated scalar subquery with no columns in schema
604    #[test]
605    fn scalar_subquery_no_cols() -> Result<()> {
606        let sq = Arc::new(
607            LogicalPlanBuilder::from(scan_tpch_table("orders"))
608                .filter(
609                    out_ref_col(DataType::Int64, "customer.c_custkey")
610                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
611                )?
612                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
613                .project(vec![max(col("orders.o_custkey"))])?
614                .build()?,
615        );
616
617        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
618            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
619            .project(vec![col("customer.c_custkey")])?
620            .build()?;
621
622        // it will optimize, but fail for the same reason the unoptimized query would
623        assert_optimized_plan_equal!(
624            plan,
625            @r"
626        Projection: customer.c_custkey [c_custkey:Int64]
627          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
628            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
629              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
630                TableScan: customer [c_custkey:Int64, c_name:Utf8]
631                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
632                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
633                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
634                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
635        "
636        )
637    }
638
639    /// Test for scalar subquery with both columns in schema
640    #[test]
641    fn scalar_subquery_with_no_correlated_cols() -> Result<()> {
642        let sq = Arc::new(
643            LogicalPlanBuilder::from(scan_tpch_table("orders"))
644                .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
645                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
646                .project(vec![max(col("orders.o_custkey"))])?
647                .build()?,
648        );
649
650        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
651            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
652            .project(vec![col("customer.c_custkey")])?
653            .build()?;
654
655        assert_optimized_plan_equal!(
656            plan,
657            @r"
658        Projection: customer.c_custkey [c_custkey:Int64]
659          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
660            Subquery: [max(orders.o_custkey):Int64;N]
661              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
662                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
663                  Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
664                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
665            TableScan: customer [c_custkey:Int64, c_name:Utf8]
666        "
667        )
668    }
669
670    /// Test for correlated scalar subquery not equal
671    #[test]
672    fn scalar_subquery_where_not_eq() -> Result<()> {
673        let sq = Arc::new(
674            LogicalPlanBuilder::from(scan_tpch_table("orders"))
675                .filter(
676                    out_ref_col(DataType::Int64, "customer.c_custkey")
677                        .not_eq(col("orders.o_custkey")),
678                )?
679                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
680                .project(vec![max(col("orders.o_custkey"))])?
681                .build()?,
682        );
683
684        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
685            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
686            .project(vec![col("customer.c_custkey")])?
687            .build()?;
688
689        // Unsupported predicate, subquery should not be decorrelated
690        assert_optimized_plan_equal!(
691            plan,
692            @r"
693        Projection: customer.c_custkey [c_custkey:Int64]
694          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
695            Subquery: [max(orders.o_custkey):Int64;N]
696              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
697                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
698                  Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
699                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
700            TableScan: customer [c_custkey:Int64, c_name:Utf8]
701        "
702        )
703    }
704
705    /// Test for correlated scalar subquery less than
706    #[test]
707    fn scalar_subquery_where_less_than() -> Result<()> {
708        let sq = Arc::new(
709            LogicalPlanBuilder::from(scan_tpch_table("orders"))
710                .filter(
711                    out_ref_col(DataType::Int64, "customer.c_custkey")
712                        .lt(col("orders.o_custkey")),
713                )?
714                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
715                .project(vec![max(col("orders.o_custkey"))])?
716                .build()?,
717        );
718
719        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
720            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
721            .project(vec![col("customer.c_custkey")])?
722            .build()?;
723
724        // Unsupported predicate, subquery should not be decorrelated
725        assert_optimized_plan_equal!(
726            plan,
727            @r"
728        Projection: customer.c_custkey [c_custkey:Int64]
729          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
730            Subquery: [max(orders.o_custkey):Int64;N]
731              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
732                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
733                  Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
734                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
735            TableScan: customer [c_custkey:Int64, c_name:Utf8]
736        "
737        )
738    }
739
740    /// Test for correlated scalar subquery filter with subquery disjunction
741    #[test]
742    fn scalar_subquery_with_subquery_disjunction() -> Result<()> {
743        let sq = Arc::new(
744            LogicalPlanBuilder::from(scan_tpch_table("orders"))
745                .filter(
746                    out_ref_col(DataType::Int64, "customer.c_custkey")
747                        .eq(col("orders.o_custkey"))
748                        .or(col("o_orderkey").eq(lit(1))),
749                )?
750                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
751                .project(vec![max(col("orders.o_custkey"))])?
752                .build()?,
753        );
754
755        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
756            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
757            .project(vec![col("customer.c_custkey")])?
758            .build()?;
759
760        // Unsupported predicate, subquery should not be decorrelated
761        assert_optimized_plan_equal!(
762            plan,
763            @r"
764        Projection: customer.c_custkey [c_custkey:Int64]
765          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
766            Subquery: [max(orders.o_custkey):Int64;N]
767              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
768                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
769                  Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
770                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
771            TableScan: customer [c_custkey:Int64, c_name:Utf8]
772        "
773        )
774    }
775
776    /// Test for correlated scalar without projection
777    #[test]
778    fn scalar_subquery_no_projection() -> Result<()> {
779        let sq = Arc::new(
780            LogicalPlanBuilder::from(scan_tpch_table("orders"))
781                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
782                .build()?,
783        );
784
785        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
786            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
787            .project(vec![col("customer.c_custkey")])?
788            .build()?;
789
790        let expected = "Error during planning: Scalar subquery should only return one column, but found 4: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice";
791        assert_analyzer_check_err(vec![], plan, expected);
792        Ok(())
793    }
794
795    /// Test for correlated scalar expressions
796    #[test]
797    fn scalar_subquery_project_expr() -> Result<()> {
798        let sq = Arc::new(
799            LogicalPlanBuilder::from(scan_tpch_table("orders"))
800                .filter(
801                    out_ref_col(DataType::Int64, "customer.c_custkey")
802                        .eq(col("orders.o_custkey")),
803                )?
804                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
805                .project(vec![col("max(orders.o_custkey)").add(lit(1))])?
806                .build()?,
807        );
808
809        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
810            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
811            .project(vec![col("customer.c_custkey")])?
812            .build()?;
813
814        assert_optimized_plan_equal!(
815            plan,
816            @r"
817        Projection: customer.c_custkey [c_custkey:Int64]
818          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
819            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
820              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
821                TableScan: customer [c_custkey:Int64, c_name:Utf8]
822                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]
823                  Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]
824                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
825                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
826        "
827        )
828    }
829
830    /// Test for correlated scalar subquery with non-strong project
831    #[test]
832    fn scalar_subquery_with_non_strong_project() -> Result<()> {
833        let case = Expr::Case(expr::Case {
834            expr: None,
835            when_then_expr: vec![(
836                Box::new(col("max(orders.o_totalprice)")),
837                Box::new(lit("a")),
838            )],
839            else_expr: Some(Box::new(lit("b"))),
840        });
841
842        let sq = Arc::new(
843            LogicalPlanBuilder::from(scan_tpch_table("orders"))
844                .filter(
845                    out_ref_col(DataType::Int64, "customer.c_custkey")
846                        .eq(col("orders.o_custkey")),
847                )?
848                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_totalprice"))])?
849                .project(vec![case])?
850                .build()?,
851        );
852
853        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
854            .project(vec![col("customer.c_custkey"), scalar_subquery(sq)])?
855            .build()?;
856
857        assert_optimized_plan_equal!(
858            plan,
859            @r#"
860        Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8("a") ELSE Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N]
861          Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N]
862            TableScan: customer [c_custkey:Int64, c_name:Utf8]
863            SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean]
864              Projection: CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END, orders.o_custkey, __always_true [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean]
865                Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_totalprice):Float64;N]
866                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
867        "#
868        )
869    }
870
871    /// Test for correlated scalar subquery multiple projected columns
872    #[test]
873    fn scalar_subquery_multi_col() -> Result<()> {
874        let sq = Arc::new(
875            LogicalPlanBuilder::from(scan_tpch_table("orders"))
876                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
877                .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])?
878                .build()?,
879        );
880
881        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
882            .filter(
883                col("customer.c_custkey")
884                    .eq(scalar_subquery(sq))
885                    .and(col("c_custkey").eq(lit(1))),
886            )?
887            .project(vec![col("customer.c_custkey")])?
888            .build()?;
889
890        let expected = "Error during planning: Scalar subquery should only return one column, but found 2: orders.o_custkey, orders.o_orderkey";
891        assert_analyzer_check_err(vec![], plan, expected);
892        Ok(())
893    }
894
895    /// Test for correlated scalar subquery filter with additional filters
896    #[test]
897    fn scalar_subquery_additional_filters_with_non_equal_clause() -> Result<()> {
898        let sq = Arc::new(
899            LogicalPlanBuilder::from(scan_tpch_table("orders"))
900                .filter(
901                    out_ref_col(DataType::Int64, "customer.c_custkey")
902                        .eq(col("orders.o_custkey")),
903                )?
904                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
905                .project(vec![max(col("orders.o_custkey"))])?
906                .build()?,
907        );
908
909        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
910            .filter(
911                col("customer.c_custkey")
912                    .gt_eq(scalar_subquery(sq))
913                    .and(col("c_custkey").eq(lit(1))),
914            )?
915            .project(vec![col("customer.c_custkey")])?
916            .build()?;
917
918        assert_optimized_plan_equal!(
919            plan,
920            @r"
921        Projection: customer.c_custkey [c_custkey:Int64]
922          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
923            Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
924              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
925                TableScan: customer [c_custkey:Int64, c_name:Utf8]
926                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
927                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
928                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
929                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
930        "
931        )
932    }
933
934    #[test]
935    fn scalar_subquery_additional_filters_with_equal_clause() -> Result<()> {
936        let sq = Arc::new(
937            LogicalPlanBuilder::from(scan_tpch_table("orders"))
938                .filter(
939                    out_ref_col(DataType::Int64, "customer.c_custkey")
940                        .eq(col("orders.o_custkey")),
941                )?
942                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
943                .project(vec![max(col("orders.o_custkey"))])?
944                .build()?,
945        );
946
947        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
948            .filter(
949                col("customer.c_custkey")
950                    .eq(scalar_subquery(sq))
951                    .and(col("c_custkey").eq(lit(1))),
952            )?
953            .project(vec![col("customer.c_custkey")])?
954            .build()?;
955
956        assert_optimized_plan_equal!(
957            plan,
958            @r"
959        Projection: customer.c_custkey [c_custkey:Int64]
960          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
961            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
962              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
963                TableScan: customer [c_custkey:Int64, c_name:Utf8]
964                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
965                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
966                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
967                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
968        "
969        )
970    }
971
972    /// Test for correlated scalar subquery filter with disjunctions
973    #[test]
974    fn scalar_subquery_disjunction() -> Result<()> {
975        let sq = Arc::new(
976            LogicalPlanBuilder::from(scan_tpch_table("orders"))
977                .filter(
978                    out_ref_col(DataType::Int64, "customer.c_custkey")
979                        .eq(col("orders.o_custkey")),
980                )?
981                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
982                .project(vec![max(col("orders.o_custkey"))])?
983                .build()?,
984        );
985
986        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
987            .filter(
988                col("customer.c_custkey")
989                    .eq(scalar_subquery(sq))
990                    .or(col("customer.c_custkey").eq(lit(1))),
991            )?
992            .project(vec![col("customer.c_custkey")])?
993            .build()?;
994
995        assert_optimized_plan_equal!(
996            plan,
997            @r"
998        Projection: customer.c_custkey [c_custkey:Int64]
999          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1000            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1001              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1002                TableScan: customer [c_custkey:Int64, c_name:Utf8]
1003                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1004                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1005                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
1006                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1007        "
1008        )
1009    }
1010
1011    /// Test for correlated scalar subquery filter
1012    #[test]
1013    fn exists_subquery_correlated() -> Result<()> {
1014        let sq = Arc::new(
1015            LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
1016                .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
1017                .aggregate(Vec::<Expr>::new(), vec![min(col("c"))])?
1018                .project(vec![min(col("c"))])?
1019                .build()?,
1020        );
1021
1022        let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
1023            .filter(col("test.c").lt(scalar_subquery(sq)))?
1024            .project(vec![col("test.c")])?
1025            .build()?;
1026
1027        assert_optimized_plan_equal!(
1028            plan,
1029            @r"
1030        Projection: test.c [c:UInt32]
1031          Projection: test.a, test.b, test.c [a:UInt32, b:UInt32, c:UInt32]
1032            Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]
1033              Left Join:  Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]
1034                TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1035                SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]
1036                  Projection: min(sq.c), sq.a, __always_true [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]
1037                    Aggregate: groupBy=[[sq.a, Boolean(true) AS __always_true]], aggr=[[min(sq.c)]] [a:UInt32, __always_true:Boolean, min(sq.c):UInt32;N]
1038                      TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1039        "
1040        )
1041    }
1042
1043    /// Test for non-correlated scalar subquery with no filters
1044    #[test]
1045    fn scalar_subquery_non_correlated_no_filters_with_non_equal_clause() -> Result<()> {
1046        let sq = Arc::new(
1047            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1048                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1049                .project(vec![max(col("orders.o_custkey"))])?
1050                .build()?,
1051        );
1052
1053        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1054            .filter(col("customer.c_custkey").lt(scalar_subquery(sq)))?
1055            .project(vec![col("customer.c_custkey")])?
1056            .build()?;
1057
1058        assert_optimized_plan_equal!(
1059            plan,
1060            @r"
1061        Projection: customer.c_custkey [c_custkey:Int64]
1062          Filter: customer.c_custkey < (<subquery>) [c_custkey:Int64, c_name:Utf8]
1063            Subquery: [max(orders.o_custkey):Int64;N]
1064              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1065                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1066                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1067            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1068        "
1069        )
1070    }
1071
1072    #[test]
1073    fn scalar_subquery_non_correlated_no_filters_with_equal_clause() -> Result<()> {
1074        let sq = Arc::new(
1075            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1076                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1077                .project(vec![max(col("orders.o_custkey"))])?
1078                .build()?,
1079        );
1080
1081        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1082            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
1083            .project(vec![col("customer.c_custkey")])?
1084            .build()?;
1085
1086        assert_optimized_plan_equal!(
1087            plan,
1088            @r"
1089        Projection: customer.c_custkey [c_custkey:Int64]
1090          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
1091            Subquery: [max(orders.o_custkey):Int64;N]
1092              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1093                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1094                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1095            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1096        "
1097        )
1098    }
1099
1100    #[test]
1101    fn correlated_scalar_subquery_in_between_clause() -> Result<()> {
1102        let sq1 = Arc::new(
1103            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1104                .filter(
1105                    out_ref_col(DataType::Int64, "customer.c_custkey")
1106                        .eq(col("orders.o_custkey")),
1107                )?
1108                .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1109                .project(vec![min(col("orders.o_custkey"))])?
1110                .build()?,
1111        );
1112        let sq2 = Arc::new(
1113            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1114                .filter(
1115                    out_ref_col(DataType::Int64, "customer.c_custkey")
1116                        .eq(col("orders.o_custkey")),
1117                )?
1118                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1119                .project(vec![max(col("orders.o_custkey"))])?
1120                .build()?,
1121        );
1122
1123        let between_expr = Expr::Between(Between {
1124            expr: Box::new(col("customer.c_custkey")),
1125            negated: false,
1126            low: Box::new(scalar_subquery(sq1)),
1127            high: Box::new(scalar_subquery(sq2)),
1128        });
1129
1130        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1131            .filter(between_expr)?
1132            .project(vec![col("customer.c_custkey")])?
1133            .build()?;
1134
1135        assert_optimized_plan_equal!(
1136            plan,
1137            @r"
1138        Projection: customer.c_custkey [c_custkey:Int64]
1139          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1140            Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1141              Left Join:  Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1142                Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1143                  TableScan: customer [c_custkey:Int64, c_name:Utf8]
1144                  SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1145                    Projection: min(orders.o_custkey), orders.o_custkey, __always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1146                      Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, min(orders.o_custkey):Int64;N]
1147                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1148                SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1149                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1150                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
1151                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1152        "
1153        )
1154    }
1155
1156    #[test]
1157    fn uncorrelated_scalar_subquery_in_between_clause() -> Result<()> {
1158        let sq1 = Arc::new(
1159            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1160                .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1161                .project(vec![min(col("orders.o_custkey"))])?
1162                .build()?,
1163        );
1164        let sq2 = Arc::new(
1165            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1166                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1167                .project(vec![max(col("orders.o_custkey"))])?
1168                .build()?,
1169        );
1170
1171        let between_expr = Expr::Between(Between {
1172            expr: Box::new(col("customer.c_custkey")),
1173            negated: false,
1174            low: Box::new(scalar_subquery(sq1)),
1175            high: Box::new(scalar_subquery(sq2)),
1176        });
1177
1178        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1179            .filter(between_expr)?
1180            .project(vec![col("customer.c_custkey")])?
1181            .build()?;
1182
1183        assert_optimized_plan_equal!(
1184            plan,
1185            @r"
1186        Projection: customer.c_custkey [c_custkey:Int64]
1187          Filter: customer.c_custkey BETWEEN (<subquery>) AND (<subquery>) [c_custkey:Int64, c_name:Utf8]
1188            Subquery: [min(orders.o_custkey):Int64;N]
1189              Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N]
1190                Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N]
1191                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1192            Subquery: [max(orders.o_custkey):Int64;N]
1193              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1194                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1195                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1196            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1197        "
1198        )
1199    }
1200
1201    #[test]
1202    fn uncorrelated_scalar_subquery_rewritten_when_flag_off() -> Result<()> {
1203        use datafusion_common::config::ConfigOptions;
1204
1205        let sq = Arc::new(
1206            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1207                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1208                .project(vec![max(col("orders.o_custkey"))])?
1209                .build()?,
1210        );
1211
1212        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1213            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
1214            .project(vec![col("customer.c_custkey")])?
1215            .build()?;
1216
1217        let mut options = ConfigOptions::default();
1218        options
1219            .optimizer
1220            .enable_physical_uncorrelated_scalar_subquery = false;
1221        let context = crate::OptimizerContext::new_with_config_options(Arc::new(options));
1222
1223        let rule: Arc<dyn OptimizerRule + Send + Sync> =
1224            Arc::new(ScalarSubqueryToJoin::new());
1225        let optimizer = crate::Optimizer::with_rules(vec![rule]);
1226        let optimized_plan = optimizer
1227            .optimize(plan, &context, |_, _| {})
1228            .expect("failed to optimize plan");
1229        let formatted_plan = optimized_plan.display_indent_schema();
1230
1231        insta::assert_snapshot!(
1232            formatted_plan,
1233            @r"
1234        Projection: customer.c_custkey [c_custkey:Int64]
1235          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1236            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1237              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1238                TableScan: customer [c_custkey:Int64, c_name:Utf8]
1239                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
1240                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1241                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1242                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1243        "
1244        );
1245
1246        Ok(())
1247    }
1248}