datafusion_optimizer/
scalar_subquery_to_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s
19
20use std::collections::{BTreeSet, HashMap};
21use std::sync::Arc;
22
23use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR};
24use crate::optimizer::ApplyOrder;
25use crate::utils::{evaluates_to_null, replace_qualified_name};
26use crate::{OptimizerConfig, OptimizerRule};
27
28use crate::analyzer::type_coercion::TypeCoercionRewriter;
29use datafusion_common::alias::AliasGenerator;
30use datafusion_common::tree_node::{
31    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
32};
33use datafusion_common::{Column, Result, ScalarValue, assert_or_internal_err, plan_err};
34use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
35use datafusion_expr::logical_plan::{JoinType, Subquery};
36use datafusion_expr::utils::conjunction;
37use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, expr};
38
39/// Optimizer rule for rewriting subquery filters to joins
40/// and places additional projection on top of the filter, to preserve
41/// original schema.
42#[derive(Default, Debug)]
43pub struct ScalarSubqueryToJoin {}
44
45impl ScalarSubqueryToJoin {
46    #[expect(missing_docs)]
47    pub fn new() -> Self {
48        Self::default()
49    }
50
51    /// Finds expressions that have a scalar subquery in them (and recurses when found)
52    ///
53    /// # Arguments
54    /// * `predicate` - A conjunction to split and search
55    ///
56    /// Returns a tuple (subqueries, alias)
57    fn extract_subquery_exprs(
58        &self,
59        predicate: &Expr,
60        alias_gen: &Arc<AliasGenerator>,
61    ) -> Result<(Vec<(Subquery, String)>, Expr)> {
62        let mut extract = ExtractScalarSubQuery {
63            sub_query_info: vec![],
64            alias_gen,
65        };
66        predicate
67            .clone()
68            .rewrite(&mut extract)
69            .data()
70            .map(|new_expr| (extract.sub_query_info, new_expr))
71    }
72}
73
74impl OptimizerRule for ScalarSubqueryToJoin {
75    fn supports_rewrite(&self) -> bool {
76        true
77    }
78
79    fn rewrite(
80        &self,
81        plan: LogicalPlan,
82        config: &dyn OptimizerConfig,
83    ) -> Result<Transformed<LogicalPlan>> {
84        match plan {
85            LogicalPlan::Filter(filter) => {
86                // Optimization: skip the rest of the rule and its copies if
87                // there are no scalar subqueries
88                if !contains_scalar_subquery(&filter.predicate) {
89                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
90                }
91
92                let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs(
93                    &filter.predicate,
94                    config.alias_generator(),
95                )?;
96
97                assert_or_internal_err!(
98                    !subqueries.is_empty(),
99                    "Expected subqueries not found in filter"
100                );
101
102                // iterate through all subqueries in predicate, turning each into a left join
103                let mut cur_input = filter.input.as_ref().clone();
104                for (subquery, alias) in subqueries {
105                    if let Some((optimized_subquery, expr_check_map)) =
106                        build_join(&subquery, &cur_input, &alias)?
107                    {
108                        if !expr_check_map.is_empty() {
109                            rewrite_expr = rewrite_expr
110                                .transform_up(|expr| {
111                                    // replace column references with entry in map, if it exists
112                                    if let Some(map_expr) = expr
113                                        .try_as_col()
114                                        .and_then(|col| expr_check_map.get(&col.name))
115                                    {
116                                        Ok(Transformed::yes(map_expr.clone()))
117                                    } else {
118                                        Ok(Transformed::no(expr))
119                                    }
120                                })
121                                .data()?;
122                        }
123                        cur_input = optimized_subquery;
124                    } else {
125                        // if we can't handle all of the subqueries then bail for now
126                        return Ok(Transformed::no(LogicalPlan::Filter(filter)));
127                    }
128                }
129
130                // Preserve original schema as new Join might have more fields than what Filter & parents expect.
131                let projection =
132                    filter.input.schema().columns().into_iter().map(Expr::from);
133                let new_plan = LogicalPlanBuilder::from(cur_input)
134                    .filter(rewrite_expr)?
135                    .project(projection)?
136                    .build()?;
137                Ok(Transformed::yes(new_plan))
138            }
139            LogicalPlan::Projection(projection) => {
140                // Optimization: skip the rest of the rule and its copies if
141                // there are no scalar subqueries
142                if !projection.expr.iter().any(contains_scalar_subquery) {
143                    return Ok(Transformed::no(LogicalPlan::Projection(projection)));
144                }
145
146                let mut all_subqueries = vec![];
147                let mut expr_to_rewrite_expr_map = HashMap::new();
148                let mut subquery_to_expr_map = HashMap::new();
149                for expr in projection.expr.iter() {
150                    let (subqueries, rewrite_exprs) =
151                        self.extract_subquery_exprs(expr, config.alias_generator())?;
152                    for (subquery, _) in &subqueries {
153                        subquery_to_expr_map.insert(subquery.clone(), expr.clone());
154                    }
155                    all_subqueries.extend(subqueries);
156                    expr_to_rewrite_expr_map.insert(expr, rewrite_exprs);
157                }
158                assert_or_internal_err!(
159                    !all_subqueries.is_empty(),
160                    "Expected subqueries not found in projection"
161                );
162                // iterate through all subqueries in predicate, turning each into a left join
163                let mut cur_input = projection.input.as_ref().clone();
164                for (subquery, alias) in all_subqueries {
165                    if let Some((optimized_subquery, expr_check_map)) =
166                        build_join(&subquery, &cur_input, &alias)?
167                    {
168                        cur_input = optimized_subquery;
169                        if !expr_check_map.is_empty()
170                            && let Some(expr) = subquery_to_expr_map.get(&subquery)
171                            && let Some(rewrite_expr) = expr_to_rewrite_expr_map.get(expr)
172                        {
173                            let new_expr = rewrite_expr
174                                .clone()
175                                .transform_up(|expr| {
176                                    // replace column references with entry in map, if it exists
177                                    if let Some(map_expr) = expr
178                                        .try_as_col()
179                                        .and_then(|col| expr_check_map.get(&col.name))
180                                    {
181                                        Ok(Transformed::yes(map_expr.clone()))
182                                    } else {
183                                        Ok(Transformed::no(expr))
184                                    }
185                                })
186                                .data()?;
187                            expr_to_rewrite_expr_map.insert(expr, new_expr);
188                        }
189                    } else {
190                        // if we can't handle all of the subqueries then bail for now
191                        return Ok(Transformed::no(LogicalPlan::Projection(projection)));
192                    }
193                }
194
195                let mut proj_exprs = vec![];
196                for expr in projection.expr.iter() {
197                    let old_expr_name = expr.schema_name().to_string();
198                    let new_expr = expr_to_rewrite_expr_map.get(expr).unwrap();
199                    let new_expr_name = new_expr.schema_name().to_string();
200                    if new_expr_name != old_expr_name {
201                        proj_exprs.push(new_expr.clone().alias(old_expr_name))
202                    } else {
203                        proj_exprs.push(new_expr.clone());
204                    }
205                }
206                let new_plan = LogicalPlanBuilder::from(cur_input)
207                    .project(proj_exprs)?
208                    .build()?;
209                Ok(Transformed::yes(new_plan))
210            }
211
212            plan => Ok(Transformed::no(plan)),
213        }
214    }
215
216    fn name(&self) -> &str {
217        "scalar_subquery_to_join"
218    }
219
220    fn apply_order(&self) -> Option<ApplyOrder> {
221        Some(ApplyOrder::TopDown)
222    }
223}
224
225/// Returns true if the expression has a scalar subquery somewhere in it
226/// false otherwise
227fn contains_scalar_subquery(expr: &Expr) -> bool {
228    expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_))))
229        .expect("Inner is always Ok")
230}
231
232struct ExtractScalarSubQuery<'a> {
233    sub_query_info: Vec<(Subquery, String)>,
234    alias_gen: &'a Arc<AliasGenerator>,
235}
236
237impl TreeNodeRewriter for ExtractScalarSubQuery<'_> {
238    type Node = Expr;
239
240    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
241        match expr {
242            Expr::ScalarSubquery(subquery) => {
243                let subqry_alias = self.alias_gen.next("__scalar_sq");
244                self.sub_query_info
245                    .push((subquery.clone(), subqry_alias.clone()));
246                let scalar_expr = subquery
247                    .subquery
248                    .head_output_expr()?
249                    .map_or(plan_err!("single expression required."), Ok)?;
250                Ok(Transformed::new(
251                    Expr::Column(create_col_from_scalar_expr(
252                        &scalar_expr,
253                        subqry_alias,
254                    )?),
255                    true,
256                    TreeNodeRecursion::Jump,
257                ))
258            }
259            _ => Ok(Transformed::no(expr)),
260        }
261    }
262}
263
264/// Takes a query like:
265///
266/// ```text
267/// select id from customers where balance >
268///     (select avg(total) from orders where orders.c_id = customers.id)
269/// ```
270///
271/// and optimizes it into:
272///
273/// ```text
274/// select c.id from customers c
275/// left join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id
276/// where c.balance > o.val
277/// ```
278///
279/// Or a query like:
280///
281/// ```text
282/// select id from customers where balance >
283///     (select avg(total) from orders)
284/// ```
285///
286/// and optimizes it into:
287///
288/// ```text
289/// select c.id from customers c
290/// left join (select avg(total) as val from orders) a
291/// where c.balance > a.val
292/// ```
293///
294/// # Arguments
295///
296/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders)
297/// * `filter_input` - The non-subquery portion (from customers)
298/// * `outer_others` - Any additional parts to the `where` expression (and c.x = y)
299/// * `subquery_alias` - Subquery aliases
300fn build_join(
301    subquery: &Subquery,
302    filter_input: &LogicalPlan,
303    subquery_alias: &str,
304) -> Result<Option<(LogicalPlan, HashMap<String, Expr>)>> {
305    let subquery_plan = subquery.subquery.as_ref();
306    let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
307    let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?;
308    if !pull_up.can_pull_up {
309        return Ok(None);
310    }
311
312    let collected_count_expr_map =
313        pull_up.collected_count_expr_map.get(&new_plan).cloned();
314    let sub_query_alias = LogicalPlanBuilder::from(new_plan)
315        .alias(subquery_alias.to_string())?
316        .build()?;
317
318    let mut all_correlated_cols = BTreeSet::new();
319    pull_up
320        .correlated_subquery_cols_map
321        .values()
322        .for_each(|cols| all_correlated_cols.extend(cols.clone()));
323
324    // alias the join filter
325    let join_filter_opt =
326        conjunction(pull_up.join_filters).map_or(Ok(None), |filter| {
327            replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some)
328        })?;
329
330    // join our sub query into the main plan
331    let new_plan = if join_filter_opt.is_none() {
332        match filter_input {
333            LogicalPlan::EmptyRelation(EmptyRelation {
334                produce_one_row: true,
335                schema: _,
336            }) => sub_query_alias,
337            _ => {
338                // if not correlated, group down to 1 row and left join on that (preserving row count)
339                LogicalPlanBuilder::from(filter_input.clone())
340                    .join_on(
341                        sub_query_alias,
342                        JoinType::Left,
343                        vec![Expr::Literal(ScalarValue::Boolean(Some(true)), None)],
344                    )?
345                    .build()?
346            }
347        }
348    } else {
349        // left join if correlated, grouping by the join keys so we don't change row count
350        LogicalPlanBuilder::from(filter_input.clone())
351            .join_on(sub_query_alias, JoinType::Left, join_filter_opt)?
352            .build()?
353    };
354    let mut computation_project_expr = HashMap::new();
355    if let Some(expr_map) = collected_count_expr_map {
356        for (name, result) in expr_map {
357            if evaluates_to_null(result.clone(), result.column_refs())? {
358                // If expr always returns null when column is null, skip processing
359                continue;
360            }
361            let computer_expr = if let Some(filter) = &pull_up.pull_up_having_expr {
362                Expr::Case(expr::Case {
363                    expr: None,
364                    when_then_expr: vec![
365                        (
366                            Box::new(Expr::IsNull(Box::new(Expr::Column(
367                                Column::new_unqualified(UN_MATCHED_ROW_INDICATOR),
368                            )))),
369                            Box::new(result),
370                        ),
371                        (
372                            Box::new(Expr::Not(Box::new(filter.clone()))),
373                            Box::new(Expr::Literal(ScalarValue::Null, None)),
374                        ),
375                    ],
376                    else_expr: Some(Box::new(Expr::Column(Column::new_unqualified(
377                        name.clone(),
378                    )))),
379                })
380            } else {
381                Expr::Case(expr::Case {
382                    expr: None,
383                    when_then_expr: vec![(
384                        Box::new(Expr::IsNull(Box::new(Expr::Column(
385                            Column::new_unqualified(UN_MATCHED_ROW_INDICATOR),
386                        )))),
387                        Box::new(result),
388                    )],
389                    else_expr: Some(Box::new(Expr::Column(Column::new_unqualified(
390                        name.clone(),
391                    )))),
392                })
393            };
394            let mut expr_rewrite = TypeCoercionRewriter {
395                schema: new_plan.schema(),
396            };
397            computation_project_expr
398                .insert(name, computer_expr.rewrite(&mut expr_rewrite).data()?);
399        }
400    }
401
402    Ok(Some((new_plan, computation_project_expr)))
403}
404
405#[cfg(test)]
406mod tests {
407    use std::ops::Add;
408
409    use super::*;
410    use crate::test::*;
411
412    use arrow::datatypes::DataType;
413    use datafusion_expr::test::function_stub::sum;
414
415    use crate::assert_optimized_plan_eq_display_indent_snapshot;
416    use datafusion_expr::{Between, col, lit, out_ref_col, scalar_subquery};
417    use datafusion_functions_aggregate::min_max::{max, min};
418
419    macro_rules! assert_optimized_plan_equal {
420        (
421            $plan:expr,
422            @ $expected:literal $(,)?
423        ) => {{
424            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ScalarSubqueryToJoin::new());
425            assert_optimized_plan_eq_display_indent_snapshot!(
426                rule,
427                $plan,
428                @ $expected,
429            )
430        }};
431    }
432
433    /// Test multiple correlated subqueries
434    #[test]
435    fn multiple_subqueries() -> Result<()> {
436        let orders = Arc::new(
437            LogicalPlanBuilder::from(scan_tpch_table("orders"))
438                .filter(
439                    col("orders.o_custkey")
440                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
441                )?
442                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
443                .project(vec![max(col("orders.o_custkey"))])?
444                .build()?,
445        );
446
447        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
448            .filter(
449                lit(1)
450                    .lt(scalar_subquery(Arc::clone(&orders)))
451                    .and(lit(1).lt(scalar_subquery(orders))),
452            )?
453            .project(vec![col("customer.c_custkey")])?
454            .build()?;
455
456        assert_optimized_plan_equal!(
457            plan,
458            @r"
459        Projection: customer.c_custkey [c_custkey:Int64]
460          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
461            Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
462              Left Join:  Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
463                Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
464                  TableScan: customer [c_custkey:Int64, c_name:Utf8]
465                  SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
466                    Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
467                      Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
468                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
469                SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
470                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
471                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
472                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
473        "
474        )
475    }
476
477    /// Test recursive correlated subqueries
478    #[test]
479    fn recursive_subqueries() -> Result<()> {
480        let lineitem = Arc::new(
481            LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
482                .filter(
483                    col("lineitem.l_orderkey")
484                        .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
485                )?
486                .aggregate(
487                    Vec::<Expr>::new(),
488                    vec![sum(col("lineitem.l_extendedprice"))],
489                )?
490                .project(vec![sum(col("lineitem.l_extendedprice"))])?
491                .build()?,
492        );
493
494        let orders = Arc::new(
495            LogicalPlanBuilder::from(scan_tpch_table("orders"))
496                .filter(
497                    col("orders.o_custkey")
498                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey"))
499                        .and(col("orders.o_totalprice").lt(scalar_subquery(lineitem))),
500                )?
501                .aggregate(Vec::<Expr>::new(), vec![sum(col("orders.o_totalprice"))])?
502                .project(vec![sum(col("orders.o_totalprice"))])?
503                .build()?,
504        );
505
506        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
507            .filter(col("customer.c_acctbal").lt(scalar_subquery(orders)))?
508            .project(vec![col("customer.c_custkey")])?
509            .build()?;
510
511        assert_optimized_plan_equal!(
512            plan,
513            @r"
514        Projection: customer.c_custkey [c_custkey:Int64]
515          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
516            Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]
517              Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]
518                TableScan: customer [c_custkey:Int64, c_name:Utf8]
519                SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]
520                  Projection: sum(orders.o_totalprice), orders.o_custkey, __always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]
521                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, sum(orders.o_totalprice):Float64;N]
522                      Projection: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
523                        Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]
524                          Left Join:  Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]
525                            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
526                            SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]
527                              Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]
528                                Aggregate: groupBy=[[lineitem.l_orderkey, Boolean(true) AS __always_true]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, __always_true:Boolean, sum(lineitem.l_extendedprice):Float64;N]
529                                  TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]
530        "
531        )
532    }
533
534    /// Test for correlated scalar subquery filter with additional subquery filters
535    #[test]
536    fn scalar_subquery_with_subquery_filters() -> Result<()> {
537        let sq = Arc::new(
538            LogicalPlanBuilder::from(scan_tpch_table("orders"))
539                .filter(
540                    out_ref_col(DataType::Int64, "customer.c_custkey")
541                        .eq(col("orders.o_custkey"))
542                        .and(col("o_orderkey").eq(lit(1))),
543                )?
544                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
545                .project(vec![max(col("orders.o_custkey"))])?
546                .build()?,
547        );
548
549        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
550            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
551            .project(vec![col("customer.c_custkey")])?
552            .build()?;
553
554        assert_optimized_plan_equal!(
555            plan,
556            @r"
557        Projection: customer.c_custkey [c_custkey:Int64]
558          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
559            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
560              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
561                TableScan: customer [c_custkey:Int64, c_name:Utf8]
562                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
563                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
564                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
565                      Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
566                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
567        "
568        )
569    }
570
571    /// Test for correlated scalar subquery with no columns in schema
572    #[test]
573    fn scalar_subquery_no_cols() -> Result<()> {
574        let sq = Arc::new(
575            LogicalPlanBuilder::from(scan_tpch_table("orders"))
576                .filter(
577                    out_ref_col(DataType::Int64, "customer.c_custkey")
578                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
579                )?
580                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
581                .project(vec![max(col("orders.o_custkey"))])?
582                .build()?,
583        );
584
585        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
586            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
587            .project(vec![col("customer.c_custkey")])?
588            .build()?;
589
590        // it will optimize, but fail for the same reason the unoptimized query would
591        assert_optimized_plan_equal!(
592            plan,
593            @r"
594        Projection: customer.c_custkey [c_custkey:Int64]
595          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
596            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
597              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
598                TableScan: customer [c_custkey:Int64, c_name:Utf8]
599                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
600                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
601                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
602                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
603        "
604        )
605    }
606
607    /// Test for scalar subquery with both columns in schema
608    #[test]
609    fn scalar_subquery_with_no_correlated_cols() -> Result<()> {
610        let sq = Arc::new(
611            LogicalPlanBuilder::from(scan_tpch_table("orders"))
612                .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
613                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
614                .project(vec![max(col("orders.o_custkey"))])?
615                .build()?,
616        );
617
618        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
619            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
620            .project(vec![col("customer.c_custkey")])?
621            .build()?;
622
623        assert_optimized_plan_equal!(
624            plan,
625            @r"
626        Projection: customer.c_custkey [c_custkey:Int64]
627          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
628            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
629              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
630                TableScan: customer [c_custkey:Int64, c_name:Utf8]
631                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
632                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
633                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
634                      Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
635                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
636        "
637        )
638    }
639
640    /// Test for correlated scalar subquery not equal
641    #[test]
642    fn scalar_subquery_where_not_eq() -> Result<()> {
643        let sq = Arc::new(
644            LogicalPlanBuilder::from(scan_tpch_table("orders"))
645                .filter(
646                    out_ref_col(DataType::Int64, "customer.c_custkey")
647                        .not_eq(col("orders.o_custkey")),
648                )?
649                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
650                .project(vec![max(col("orders.o_custkey"))])?
651                .build()?,
652        );
653
654        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
655            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
656            .project(vec![col("customer.c_custkey")])?
657            .build()?;
658
659        // Unsupported predicate, subquery should not be decorrelated
660        assert_optimized_plan_equal!(
661            plan,
662            @r"
663        Projection: customer.c_custkey [c_custkey:Int64]
664          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
665            Subquery: [max(orders.o_custkey):Int64;N]
666              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
667                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
668                  Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
669                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
670            TableScan: customer [c_custkey:Int64, c_name:Utf8]
671        "
672        )
673    }
674
675    /// Test for correlated scalar subquery less than
676    #[test]
677    fn scalar_subquery_where_less_than() -> Result<()> {
678        let sq = Arc::new(
679            LogicalPlanBuilder::from(scan_tpch_table("orders"))
680                .filter(
681                    out_ref_col(DataType::Int64, "customer.c_custkey")
682                        .lt(col("orders.o_custkey")),
683                )?
684                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
685                .project(vec![max(col("orders.o_custkey"))])?
686                .build()?,
687        );
688
689        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
690            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
691            .project(vec![col("customer.c_custkey")])?
692            .build()?;
693
694        // Unsupported predicate, subquery should not be decorrelated
695        assert_optimized_plan_equal!(
696            plan,
697            @r"
698        Projection: customer.c_custkey [c_custkey:Int64]
699          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
700            Subquery: [max(orders.o_custkey):Int64;N]
701              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
702                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
703                  Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
704                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
705            TableScan: customer [c_custkey:Int64, c_name:Utf8]
706        "
707        )
708    }
709
710    /// Test for correlated scalar subquery filter with subquery disjunction
711    #[test]
712    fn scalar_subquery_with_subquery_disjunction() -> Result<()> {
713        let sq = Arc::new(
714            LogicalPlanBuilder::from(scan_tpch_table("orders"))
715                .filter(
716                    out_ref_col(DataType::Int64, "customer.c_custkey")
717                        .eq(col("orders.o_custkey"))
718                        .or(col("o_orderkey").eq(lit(1))),
719                )?
720                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
721                .project(vec![max(col("orders.o_custkey"))])?
722                .build()?,
723        );
724
725        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
726            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
727            .project(vec![col("customer.c_custkey")])?
728            .build()?;
729
730        // Unsupported predicate, subquery should not be decorrelated
731        assert_optimized_plan_equal!(
732            plan,
733            @r"
734        Projection: customer.c_custkey [c_custkey:Int64]
735          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
736            Subquery: [max(orders.o_custkey):Int64;N]
737              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
738                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
739                  Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
740                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
741            TableScan: customer [c_custkey:Int64, c_name:Utf8]
742        "
743        )
744    }
745
746    /// Test for correlated scalar without projection
747    #[test]
748    fn scalar_subquery_no_projection() -> Result<()> {
749        let sq = Arc::new(
750            LogicalPlanBuilder::from(scan_tpch_table("orders"))
751                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
752                .build()?,
753        );
754
755        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
756            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
757            .project(vec![col("customer.c_custkey")])?
758            .build()?;
759
760        let expected = "Error during planning: Scalar subquery should only return one column, but found 4: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice";
761        assert_analyzer_check_err(vec![], plan, expected);
762        Ok(())
763    }
764
765    /// Test for correlated scalar expressions
766    #[test]
767    fn scalar_subquery_project_expr() -> Result<()> {
768        let sq = Arc::new(
769            LogicalPlanBuilder::from(scan_tpch_table("orders"))
770                .filter(
771                    out_ref_col(DataType::Int64, "customer.c_custkey")
772                        .eq(col("orders.o_custkey")),
773                )?
774                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
775                .project(vec![col("max(orders.o_custkey)").add(lit(1))])?
776                .build()?,
777        );
778
779        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
780            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
781            .project(vec![col("customer.c_custkey")])?
782            .build()?;
783
784        assert_optimized_plan_equal!(
785            plan,
786            @r"
787        Projection: customer.c_custkey [c_custkey:Int64]
788          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
789            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
790              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
791                TableScan: customer [c_custkey:Int64, c_name:Utf8]
792                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]
793                  Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]
794                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
795                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
796        "
797        )
798    }
799
800    /// Test for correlated scalar subquery with non-strong project
801    #[test]
802    fn scalar_subquery_with_non_strong_project() -> Result<()> {
803        let case = Expr::Case(expr::Case {
804            expr: None,
805            when_then_expr: vec![(
806                Box::new(col("max(orders.o_totalprice)")),
807                Box::new(lit("a")),
808            )],
809            else_expr: Some(Box::new(lit("b"))),
810        });
811
812        let sq = Arc::new(
813            LogicalPlanBuilder::from(scan_tpch_table("orders"))
814                .filter(
815                    out_ref_col(DataType::Int64, "customer.c_custkey")
816                        .eq(col("orders.o_custkey")),
817                )?
818                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_totalprice"))])?
819                .project(vec![case])?
820                .build()?,
821        );
822
823        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
824            .project(vec![col("customer.c_custkey"), scalar_subquery(sq)])?
825            .build()?;
826
827        assert_optimized_plan_equal!(
828            plan,
829            @r#"
830        Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8("a") ELSE Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N]
831          Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N]
832            TableScan: customer [c_custkey:Int64, c_name:Utf8]
833            SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean]
834              Projection: CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END, orders.o_custkey, __always_true [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean]
835                Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_totalprice):Float64;N]
836                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
837        "#
838        )
839    }
840
841    /// Test for correlated scalar subquery multiple projected columns
842    #[test]
843    fn scalar_subquery_multi_col() -> Result<()> {
844        let sq = Arc::new(
845            LogicalPlanBuilder::from(scan_tpch_table("orders"))
846                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
847                .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])?
848                .build()?,
849        );
850
851        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
852            .filter(
853                col("customer.c_custkey")
854                    .eq(scalar_subquery(sq))
855                    .and(col("c_custkey").eq(lit(1))),
856            )?
857            .project(vec![col("customer.c_custkey")])?
858            .build()?;
859
860        let expected = "Error during planning: Scalar subquery should only return one column, but found 2: orders.o_custkey, orders.o_orderkey";
861        assert_analyzer_check_err(vec![], plan, expected);
862        Ok(())
863    }
864
865    /// Test for correlated scalar subquery filter with additional filters
866    #[test]
867    fn scalar_subquery_additional_filters_with_non_equal_clause() -> Result<()> {
868        let sq = Arc::new(
869            LogicalPlanBuilder::from(scan_tpch_table("orders"))
870                .filter(
871                    out_ref_col(DataType::Int64, "customer.c_custkey")
872                        .eq(col("orders.o_custkey")),
873                )?
874                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
875                .project(vec![max(col("orders.o_custkey"))])?
876                .build()?,
877        );
878
879        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
880            .filter(
881                col("customer.c_custkey")
882                    .gt_eq(scalar_subquery(sq))
883                    .and(col("c_custkey").eq(lit(1))),
884            )?
885            .project(vec![col("customer.c_custkey")])?
886            .build()?;
887
888        assert_optimized_plan_equal!(
889            plan,
890            @r"
891        Projection: customer.c_custkey [c_custkey:Int64]
892          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
893            Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
894              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
895                TableScan: customer [c_custkey:Int64, c_name:Utf8]
896                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
897                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
898                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
899                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
900        "
901        )
902    }
903
904    #[test]
905    fn scalar_subquery_additional_filters_with_equal_clause() -> Result<()> {
906        let sq = Arc::new(
907            LogicalPlanBuilder::from(scan_tpch_table("orders"))
908                .filter(
909                    out_ref_col(DataType::Int64, "customer.c_custkey")
910                        .eq(col("orders.o_custkey")),
911                )?
912                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
913                .project(vec![max(col("orders.o_custkey"))])?
914                .build()?,
915        );
916
917        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
918            .filter(
919                col("customer.c_custkey")
920                    .eq(scalar_subquery(sq))
921                    .and(col("c_custkey").eq(lit(1))),
922            )?
923            .project(vec![col("customer.c_custkey")])?
924            .build()?;
925
926        assert_optimized_plan_equal!(
927            plan,
928            @r"
929        Projection: customer.c_custkey [c_custkey:Int64]
930          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
931            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
932              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
933                TableScan: customer [c_custkey:Int64, c_name:Utf8]
934                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
935                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
936                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
937                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
938        "
939        )
940    }
941
942    /// Test for correlated scalar subquery filter with disjunctions
943    #[test]
944    fn scalar_subquery_disjunction() -> Result<()> {
945        let sq = Arc::new(
946            LogicalPlanBuilder::from(scan_tpch_table("orders"))
947                .filter(
948                    out_ref_col(DataType::Int64, "customer.c_custkey")
949                        .eq(col("orders.o_custkey")),
950                )?
951                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
952                .project(vec![max(col("orders.o_custkey"))])?
953                .build()?,
954        );
955
956        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
957            .filter(
958                col("customer.c_custkey")
959                    .eq(scalar_subquery(sq))
960                    .or(col("customer.c_custkey").eq(lit(1))),
961            )?
962            .project(vec![col("customer.c_custkey")])?
963            .build()?;
964
965        assert_optimized_plan_equal!(
966            plan,
967            @r"
968        Projection: customer.c_custkey [c_custkey:Int64]
969          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
970            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
971              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
972                TableScan: customer [c_custkey:Int64, c_name:Utf8]
973                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
974                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
975                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
976                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
977        "
978        )
979    }
980
981    /// Test for correlated scalar subquery filter
982    #[test]
983    fn exists_subquery_correlated() -> Result<()> {
984        let sq = Arc::new(
985            LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
986                .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
987                .aggregate(Vec::<Expr>::new(), vec![min(col("c"))])?
988                .project(vec![min(col("c"))])?
989                .build()?,
990        );
991
992        let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
993            .filter(col("test.c").lt(scalar_subquery(sq)))?
994            .project(vec![col("test.c")])?
995            .build()?;
996
997        assert_optimized_plan_equal!(
998            plan,
999            @r"
1000        Projection: test.c [c:UInt32]
1001          Projection: test.a, test.b, test.c [a:UInt32, b:UInt32, c:UInt32]
1002            Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]
1003              Left Join:  Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]
1004                TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1005                SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]
1006                  Projection: min(sq.c), sq.a, __always_true [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]
1007                    Aggregate: groupBy=[[sq.a, Boolean(true) AS __always_true]], aggr=[[min(sq.c)]] [a:UInt32, __always_true:Boolean, min(sq.c):UInt32;N]
1008                      TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1009        "
1010        )
1011    }
1012
1013    /// Test for non-correlated scalar subquery with no filters
1014    #[test]
1015    fn scalar_subquery_non_correlated_no_filters_with_non_equal_clause() -> Result<()> {
1016        let sq = Arc::new(
1017            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1018                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1019                .project(vec![max(col("orders.o_custkey"))])?
1020                .build()?,
1021        );
1022
1023        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1024            .filter(col("customer.c_custkey").lt(scalar_subquery(sq)))?
1025            .project(vec![col("customer.c_custkey")])?
1026            .build()?;
1027
1028        assert_optimized_plan_equal!(
1029            plan,
1030            @r"
1031        Projection: customer.c_custkey [c_custkey:Int64]
1032          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1033            Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1034              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1035                TableScan: customer [c_custkey:Int64, c_name:Utf8]
1036                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
1037                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1038                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1039                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1040        "
1041        )
1042    }
1043
1044    #[test]
1045    fn scalar_subquery_non_correlated_no_filters_with_equal_clause() -> Result<()> {
1046        let sq = Arc::new(
1047            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1048                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1049                .project(vec![max(col("orders.o_custkey"))])?
1050                .build()?,
1051        );
1052
1053        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1054            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
1055            .project(vec![col("customer.c_custkey")])?
1056            .build()?;
1057
1058        assert_optimized_plan_equal!(
1059            plan,
1060            @r"
1061        Projection: customer.c_custkey [c_custkey:Int64]
1062          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1063            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1064              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1065                TableScan: customer [c_custkey:Int64, c_name:Utf8]
1066                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
1067                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1068                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1069                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1070        "
1071        )
1072    }
1073
1074    #[test]
1075    fn correlated_scalar_subquery_in_between_clause() -> Result<()> {
1076        let sq1 = Arc::new(
1077            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1078                .filter(
1079                    out_ref_col(DataType::Int64, "customer.c_custkey")
1080                        .eq(col("orders.o_custkey")),
1081                )?
1082                .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1083                .project(vec![min(col("orders.o_custkey"))])?
1084                .build()?,
1085        );
1086        let sq2 = Arc::new(
1087            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1088                .filter(
1089                    out_ref_col(DataType::Int64, "customer.c_custkey")
1090                        .eq(col("orders.o_custkey")),
1091                )?
1092                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1093                .project(vec![max(col("orders.o_custkey"))])?
1094                .build()?,
1095        );
1096
1097        let between_expr = Expr::Between(Between {
1098            expr: Box::new(col("customer.c_custkey")),
1099            negated: false,
1100            low: Box::new(scalar_subquery(sq1)),
1101            high: Box::new(scalar_subquery(sq2)),
1102        });
1103
1104        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1105            .filter(between_expr)?
1106            .project(vec![col("customer.c_custkey")])?
1107            .build()?;
1108
1109        assert_optimized_plan_equal!(
1110            plan,
1111            @r"
1112        Projection: customer.c_custkey [c_custkey:Int64]
1113          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1114            Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1115              Left Join:  Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1116                Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1117                  TableScan: customer [c_custkey:Int64, c_name:Utf8]
1118                  SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1119                    Projection: min(orders.o_custkey), orders.o_custkey, __always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1120                      Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, min(orders.o_custkey):Int64;N]
1121                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1122                SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1123                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1124                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
1125                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1126        "
1127        )
1128    }
1129
1130    #[test]
1131    fn uncorrelated_scalar_subquery_in_between_clause() -> Result<()> {
1132        let sq1 = Arc::new(
1133            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1134                .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1135                .project(vec![min(col("orders.o_custkey"))])?
1136                .build()?,
1137        );
1138        let sq2 = Arc::new(
1139            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1140                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1141                .project(vec![max(col("orders.o_custkey"))])?
1142                .build()?,
1143        );
1144
1145        let between_expr = Expr::Between(Between {
1146            expr: Box::new(col("customer.c_custkey")),
1147            negated: false,
1148            low: Box::new(scalar_subquery(sq1)),
1149            high: Box::new(scalar_subquery(sq2)),
1150        });
1151
1152        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1153            .filter(between_expr)?
1154            .project(vec![col("customer.c_custkey")])?
1155            .build()?;
1156
1157        assert_optimized_plan_equal!(
1158            plan,
1159            @r"
1160        Projection: customer.c_custkey [c_custkey:Int64]
1161          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1162            Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]
1163              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]
1164                Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N]
1165                  TableScan: customer [c_custkey:Int64, c_name:Utf8]
1166                  SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N]
1167                    Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N]
1168                      Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N]
1169                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1170                SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N]
1171                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1172                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1173                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1174        "
1175        )
1176    }
1177}