datafusion_optimizer/
scalar_subquery_to_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s
19
20use std::collections::{BTreeSet, HashMap};
21use std::sync::Arc;
22
23use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR};
24use crate::optimizer::ApplyOrder;
25use crate::utils::{evaluates_to_null, replace_qualified_name};
26use crate::{OptimizerConfig, OptimizerRule};
27
28use crate::analyzer::type_coercion::TypeCoercionRewriter;
29use datafusion_common::alias::AliasGenerator;
30use datafusion_common::tree_node::{
31    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
32};
33use datafusion_common::{internal_err, plan_err, Column, Result, ScalarValue};
34use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
35use datafusion_expr::logical_plan::{JoinType, Subquery};
36use datafusion_expr::utils::conjunction;
37use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder};
38
39/// Optimizer rule for rewriting subquery filters to joins
40/// and places additional projection on top of the filter, to preserve
41/// original schema.
42#[derive(Default, Debug)]
43pub struct ScalarSubqueryToJoin {}
44
45impl ScalarSubqueryToJoin {
46    #[allow(missing_docs)]
47    pub fn new() -> Self {
48        Self::default()
49    }
50
51    /// Finds expressions that have a scalar subquery in them (and recurses when found)
52    ///
53    /// # Arguments
54    /// * `predicate` - A conjunction to split and search
55    ///
56    /// Returns a tuple (subqueries, alias)
57    fn extract_subquery_exprs(
58        &self,
59        predicate: &Expr,
60        alias_gen: &Arc<AliasGenerator>,
61    ) -> Result<(Vec<(Subquery, String)>, Expr)> {
62        let mut extract = ExtractScalarSubQuery {
63            sub_query_info: vec![],
64            alias_gen,
65        };
66        predicate
67            .clone()
68            .rewrite(&mut extract)
69            .data()
70            .map(|new_expr| (extract.sub_query_info, new_expr))
71    }
72}
73
74impl OptimizerRule for ScalarSubqueryToJoin {
75    fn supports_rewrite(&self) -> bool {
76        true
77    }
78
79    fn rewrite(
80        &self,
81        plan: LogicalPlan,
82        config: &dyn OptimizerConfig,
83    ) -> Result<Transformed<LogicalPlan>> {
84        match plan {
85            LogicalPlan::Filter(filter) => {
86                // Optimization: skip the rest of the rule and its copies if
87                // there are no scalar subqueries
88                if !contains_scalar_subquery(&filter.predicate) {
89                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
90                }
91
92                let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs(
93                    &filter.predicate,
94                    config.alias_generator(),
95                )?;
96
97                if subqueries.is_empty() {
98                    return internal_err!("Expected subqueries not found in filter");
99                }
100
101                // iterate through all subqueries in predicate, turning each into a left join
102                let mut cur_input = filter.input.as_ref().clone();
103                for (subquery, alias) in subqueries {
104                    if let Some((optimized_subquery, expr_check_map)) =
105                        build_join(&subquery, &cur_input, &alias)?
106                    {
107                        if !expr_check_map.is_empty() {
108                            rewrite_expr = rewrite_expr
109                                .transform_up(|expr| {
110                                    // replace column references with entry in map, if it exists
111                                    if let Some(map_expr) = expr
112                                        .try_as_col()
113                                        .and_then(|col| expr_check_map.get(&col.name))
114                                    {
115                                        Ok(Transformed::yes(map_expr.clone()))
116                                    } else {
117                                        Ok(Transformed::no(expr))
118                                    }
119                                })
120                                .data()?;
121                        }
122                        cur_input = optimized_subquery;
123                    } else {
124                        // if we can't handle all of the subqueries then bail for now
125                        return Ok(Transformed::no(LogicalPlan::Filter(filter)));
126                    }
127                }
128
129                // Preserve original schema as new Join might have more fields than what Filter & parents expect.
130                let projection =
131                    filter.input.schema().columns().into_iter().map(Expr::from);
132                let new_plan = LogicalPlanBuilder::from(cur_input)
133                    .filter(rewrite_expr)?
134                    .project(projection)?
135                    .build()?;
136                Ok(Transformed::yes(new_plan))
137            }
138            LogicalPlan::Projection(projection) => {
139                // Optimization: skip the rest of the rule and its copies if
140                // there are no scalar subqueries
141                if !projection.expr.iter().any(contains_scalar_subquery) {
142                    return Ok(Transformed::no(LogicalPlan::Projection(projection)));
143                }
144
145                let mut all_subqueries = vec![];
146                let mut expr_to_rewrite_expr_map = HashMap::new();
147                let mut subquery_to_expr_map = HashMap::new();
148                for expr in projection.expr.iter() {
149                    let (subqueries, rewrite_exprs) =
150                        self.extract_subquery_exprs(expr, config.alias_generator())?;
151                    for (subquery, _) in &subqueries {
152                        subquery_to_expr_map.insert(subquery.clone(), expr.clone());
153                    }
154                    all_subqueries.extend(subqueries);
155                    expr_to_rewrite_expr_map.insert(expr, rewrite_exprs);
156                }
157                if all_subqueries.is_empty() {
158                    return internal_err!("Expected subqueries not found in projection");
159                }
160                // iterate through all subqueries in predicate, turning each into a left join
161                let mut cur_input = projection.input.as_ref().clone();
162                for (subquery, alias) in all_subqueries {
163                    if let Some((optimized_subquery, expr_check_map)) =
164                        build_join(&subquery, &cur_input, &alias)?
165                    {
166                        cur_input = optimized_subquery;
167                        if !expr_check_map.is_empty() {
168                            if let Some(expr) = subquery_to_expr_map.get(&subquery) {
169                                if let Some(rewrite_expr) =
170                                    expr_to_rewrite_expr_map.get(expr)
171                                {
172                                    let new_expr = rewrite_expr
173                                        .clone()
174                                        .transform_up(|expr| {
175                                            // replace column references with entry in map, if it exists
176                                            if let Some(map_expr) =
177                                                expr.try_as_col().and_then(|col| {
178                                                    expr_check_map.get(&col.name)
179                                                })
180                                            {
181                                                Ok(Transformed::yes(map_expr.clone()))
182                                            } else {
183                                                Ok(Transformed::no(expr))
184                                            }
185                                        })
186                                        .data()?;
187                                    expr_to_rewrite_expr_map.insert(expr, new_expr);
188                                }
189                            }
190                        }
191                    } else {
192                        // if we can't handle all of the subqueries then bail for now
193                        return Ok(Transformed::no(LogicalPlan::Projection(projection)));
194                    }
195                }
196
197                let mut proj_exprs = vec![];
198                for expr in projection.expr.iter() {
199                    let old_expr_name = expr.schema_name().to_string();
200                    let new_expr = expr_to_rewrite_expr_map.get(expr).unwrap();
201                    let new_expr_name = new_expr.schema_name().to_string();
202                    if new_expr_name != old_expr_name {
203                        proj_exprs.push(new_expr.clone().alias(old_expr_name))
204                    } else {
205                        proj_exprs.push(new_expr.clone());
206                    }
207                }
208                let new_plan = LogicalPlanBuilder::from(cur_input)
209                    .project(proj_exprs)?
210                    .build()?;
211                Ok(Transformed::yes(new_plan))
212            }
213
214            plan => Ok(Transformed::no(plan)),
215        }
216    }
217
218    fn name(&self) -> &str {
219        "scalar_subquery_to_join"
220    }
221
222    fn apply_order(&self) -> Option<ApplyOrder> {
223        Some(ApplyOrder::TopDown)
224    }
225}
226
227/// Returns true if the expression has a scalar subquery somewhere in it
228/// false otherwise
229fn contains_scalar_subquery(expr: &Expr) -> bool {
230    expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_))))
231        .expect("Inner is always Ok")
232}
233
234struct ExtractScalarSubQuery<'a> {
235    sub_query_info: Vec<(Subquery, String)>,
236    alias_gen: &'a Arc<AliasGenerator>,
237}
238
239impl TreeNodeRewriter for ExtractScalarSubQuery<'_> {
240    type Node = Expr;
241
242    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
243        match expr {
244            Expr::ScalarSubquery(subquery) => {
245                let subqry_alias = self.alias_gen.next("__scalar_sq");
246                self.sub_query_info
247                    .push((subquery.clone(), subqry_alias.clone()));
248                let scalar_expr = subquery
249                    .subquery
250                    .head_output_expr()?
251                    .map_or(plan_err!("single expression required."), Ok)?;
252                Ok(Transformed::new(
253                    Expr::Column(create_col_from_scalar_expr(
254                        &scalar_expr,
255                        subqry_alias,
256                    )?),
257                    true,
258                    TreeNodeRecursion::Jump,
259                ))
260            }
261            _ => Ok(Transformed::no(expr)),
262        }
263    }
264}
265
266/// Takes a query like:
267///
268/// ```text
269/// select id from customers where balance >
270///     (select avg(total) from orders where orders.c_id = customers.id)
271/// ```
272///
273/// and optimizes it into:
274///
275/// ```text
276/// select c.id from customers c
277/// left join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id
278/// where c.balance > o.val
279/// ```
280///
281/// Or a query like:
282///
283/// ```text
284/// select id from customers where balance >
285///     (select avg(total) from orders)
286/// ```
287///
288/// and optimizes it into:
289///
290/// ```text
291/// select c.id from customers c
292/// left join (select avg(total) as val from orders) a
293/// where c.balance > a.val
294/// ```
295///
296/// # Arguments
297///
298/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders)
299/// * `filter_input` - The non-subquery portion (from customers)
300/// * `outer_others` - Any additional parts to the `where` expression (and c.x = y)
301/// * `subquery_alias` - Subquery aliases
302fn build_join(
303    subquery: &Subquery,
304    filter_input: &LogicalPlan,
305    subquery_alias: &str,
306) -> Result<Option<(LogicalPlan, HashMap<String, Expr>)>> {
307    let subquery_plan = subquery.subquery.as_ref();
308    let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
309    let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?;
310    if !pull_up.can_pull_up {
311        return Ok(None);
312    }
313
314    let collected_count_expr_map =
315        pull_up.collected_count_expr_map.get(&new_plan).cloned();
316    let sub_query_alias = LogicalPlanBuilder::from(new_plan)
317        .alias(subquery_alias.to_string())?
318        .build()?;
319
320    let mut all_correlated_cols = BTreeSet::new();
321    pull_up
322        .correlated_subquery_cols_map
323        .values()
324        .for_each(|cols| all_correlated_cols.extend(cols.clone()));
325
326    // alias the join filter
327    let join_filter_opt =
328        conjunction(pull_up.join_filters).map_or(Ok(None), |filter| {
329            replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some)
330        })?;
331
332    // join our sub query into the main plan
333    let new_plan = if join_filter_opt.is_none() {
334        match filter_input {
335            LogicalPlan::EmptyRelation(EmptyRelation {
336                produce_one_row: true,
337                schema: _,
338            }) => sub_query_alias,
339            _ => {
340                // if not correlated, group down to 1 row and left join on that (preserving row count)
341                LogicalPlanBuilder::from(filter_input.clone())
342                    .join_on(
343                        sub_query_alias,
344                        JoinType::Left,
345                        vec![Expr::Literal(ScalarValue::Boolean(Some(true)), None)],
346                    )?
347                    .build()?
348            }
349        }
350    } else {
351        // left join if correlated, grouping by the join keys so we don't change row count
352        LogicalPlanBuilder::from(filter_input.clone())
353            .join_on(sub_query_alias, JoinType::Left, join_filter_opt)?
354            .build()?
355    };
356    let mut computation_project_expr = HashMap::new();
357    if let Some(expr_map) = collected_count_expr_map {
358        for (name, result) in expr_map {
359            if evaluates_to_null(result.clone(), result.column_refs())? {
360                // If expr always returns null when column is null, skip processing
361                continue;
362            }
363            let computer_expr = if let Some(filter) = &pull_up.pull_up_having_expr {
364                Expr::Case(expr::Case {
365                    expr: None,
366                    when_then_expr: vec![
367                        (
368                            Box::new(Expr::IsNull(Box::new(Expr::Column(
369                                Column::new_unqualified(UN_MATCHED_ROW_INDICATOR),
370                            )))),
371                            Box::new(result),
372                        ),
373                        (
374                            Box::new(Expr::Not(Box::new(filter.clone()))),
375                            Box::new(Expr::Literal(ScalarValue::Null, None)),
376                        ),
377                    ],
378                    else_expr: Some(Box::new(Expr::Column(Column::new_unqualified(
379                        name.clone(),
380                    )))),
381                })
382            } else {
383                Expr::Case(expr::Case {
384                    expr: None,
385                    when_then_expr: vec![(
386                        Box::new(Expr::IsNull(Box::new(Expr::Column(
387                            Column::new_unqualified(UN_MATCHED_ROW_INDICATOR),
388                        )))),
389                        Box::new(result),
390                    )],
391                    else_expr: Some(Box::new(Expr::Column(Column::new_unqualified(
392                        name.clone(),
393                    )))),
394                })
395            };
396            let mut expr_rewrite = TypeCoercionRewriter {
397                schema: new_plan.schema(),
398            };
399            computation_project_expr
400                .insert(name, computer_expr.rewrite(&mut expr_rewrite).data()?);
401        }
402    }
403
404    Ok(Some((new_plan, computation_project_expr)))
405}
406
407#[cfg(test)]
408mod tests {
409    use std::ops::Add;
410
411    use super::*;
412    use crate::test::*;
413
414    use arrow::datatypes::DataType;
415    use datafusion_expr::test::function_stub::sum;
416
417    use crate::assert_optimized_plan_eq_display_indent_snapshot;
418    use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between};
419    use datafusion_functions_aggregate::min_max::{max, min};
420
421    macro_rules! assert_optimized_plan_equal {
422        (
423            $plan:expr,
424            @ $expected:literal $(,)?
425        ) => {{
426            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ScalarSubqueryToJoin::new());
427            assert_optimized_plan_eq_display_indent_snapshot!(
428                rule,
429                $plan,
430                @ $expected,
431            )
432        }};
433    }
434
435    /// Test multiple correlated subqueries
436    #[test]
437    fn multiple_subqueries() -> Result<()> {
438        let orders = Arc::new(
439            LogicalPlanBuilder::from(scan_tpch_table("orders"))
440                .filter(
441                    col("orders.o_custkey")
442                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
443                )?
444                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
445                .project(vec![max(col("orders.o_custkey"))])?
446                .build()?,
447        );
448
449        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
450            .filter(
451                lit(1)
452                    .lt(scalar_subquery(Arc::clone(&orders)))
453                    .and(lit(1).lt(scalar_subquery(orders))),
454            )?
455            .project(vec![col("customer.c_custkey")])?
456            .build()?;
457
458        assert_optimized_plan_equal!(
459            plan,
460            @r"
461        Projection: customer.c_custkey [c_custkey:Int64]
462          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
463            Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
464              Left Join:  Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
465                Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
466                  TableScan: customer [c_custkey:Int64, c_name:Utf8]
467                  SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
468                    Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
469                      Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
470                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
471                SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
472                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
473                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
474                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
475        "
476        )
477    }
478
479    /// Test recursive correlated subqueries
480    #[test]
481    fn recursive_subqueries() -> Result<()> {
482        let lineitem = Arc::new(
483            LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
484                .filter(
485                    col("lineitem.l_orderkey")
486                        .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
487                )?
488                .aggregate(
489                    Vec::<Expr>::new(),
490                    vec![sum(col("lineitem.l_extendedprice"))],
491                )?
492                .project(vec![sum(col("lineitem.l_extendedprice"))])?
493                .build()?,
494        );
495
496        let orders = Arc::new(
497            LogicalPlanBuilder::from(scan_tpch_table("orders"))
498                .filter(
499                    col("orders.o_custkey")
500                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey"))
501                        .and(col("orders.o_totalprice").lt(scalar_subquery(lineitem))),
502                )?
503                .aggregate(Vec::<Expr>::new(), vec![sum(col("orders.o_totalprice"))])?
504                .project(vec![sum(col("orders.o_totalprice"))])?
505                .build()?,
506        );
507
508        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
509            .filter(col("customer.c_acctbal").lt(scalar_subquery(orders)))?
510            .project(vec![col("customer.c_custkey")])?
511            .build()?;
512
513        assert_optimized_plan_equal!(
514            plan,
515            @r"
516        Projection: customer.c_custkey [c_custkey:Int64]
517          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
518            Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]
519              Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]
520                TableScan: customer [c_custkey:Int64, c_name:Utf8]
521                SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]
522                  Projection: sum(orders.o_totalprice), orders.o_custkey, __always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]
523                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, sum(orders.o_totalprice):Float64;N]
524                      Projection: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
525                        Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]
526                          Left Join:  Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]
527                            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
528                            SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]
529                              Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]
530                                Aggregate: groupBy=[[lineitem.l_orderkey, Boolean(true) AS __always_true]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, __always_true:Boolean, sum(lineitem.l_extendedprice):Float64;N]
531                                  TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]
532        "
533        )
534    }
535
536    /// Test for correlated scalar subquery filter with additional subquery filters
537    #[test]
538    fn scalar_subquery_with_subquery_filters() -> Result<()> {
539        let sq = Arc::new(
540            LogicalPlanBuilder::from(scan_tpch_table("orders"))
541                .filter(
542                    out_ref_col(DataType::Int64, "customer.c_custkey")
543                        .eq(col("orders.o_custkey"))
544                        .and(col("o_orderkey").eq(lit(1))),
545                )?
546                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
547                .project(vec![max(col("orders.o_custkey"))])?
548                .build()?,
549        );
550
551        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
552            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
553            .project(vec![col("customer.c_custkey")])?
554            .build()?;
555
556        assert_optimized_plan_equal!(
557            plan,
558            @r"
559        Projection: customer.c_custkey [c_custkey:Int64]
560          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
561            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
562              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
563                TableScan: customer [c_custkey:Int64, c_name:Utf8]
564                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
565                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
566                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
567                      Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
568                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
569        "
570        )
571    }
572
573    /// Test for correlated scalar subquery with no columns in schema
574    #[test]
575    fn scalar_subquery_no_cols() -> Result<()> {
576        let sq = Arc::new(
577            LogicalPlanBuilder::from(scan_tpch_table("orders"))
578                .filter(
579                    out_ref_col(DataType::Int64, "customer.c_custkey")
580                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
581                )?
582                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
583                .project(vec![max(col("orders.o_custkey"))])?
584                .build()?,
585        );
586
587        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
588            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
589            .project(vec![col("customer.c_custkey")])?
590            .build()?;
591
592        // it will optimize, but fail for the same reason the unoptimized query would
593        assert_optimized_plan_equal!(
594            plan,
595            @r"
596        Projection: customer.c_custkey [c_custkey:Int64]
597          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
598            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
599              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
600                TableScan: customer [c_custkey:Int64, c_name:Utf8]
601                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
602                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
603                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
604                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
605        "
606        )
607    }
608
609    /// Test for scalar subquery with both columns in schema
610    #[test]
611    fn scalar_subquery_with_no_correlated_cols() -> Result<()> {
612        let sq = Arc::new(
613            LogicalPlanBuilder::from(scan_tpch_table("orders"))
614                .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
615                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
616                .project(vec![max(col("orders.o_custkey"))])?
617                .build()?,
618        );
619
620        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
621            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
622            .project(vec![col("customer.c_custkey")])?
623            .build()?;
624
625        assert_optimized_plan_equal!(
626            plan,
627            @r"
628        Projection: customer.c_custkey [c_custkey:Int64]
629          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
630            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
631              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
632                TableScan: customer [c_custkey:Int64, c_name:Utf8]
633                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
634                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
635                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
636                      Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
637                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
638        "
639        )
640    }
641
642    /// Test for correlated scalar subquery not equal
643    #[test]
644    fn scalar_subquery_where_not_eq() -> Result<()> {
645        let sq = Arc::new(
646            LogicalPlanBuilder::from(scan_tpch_table("orders"))
647                .filter(
648                    out_ref_col(DataType::Int64, "customer.c_custkey")
649                        .not_eq(col("orders.o_custkey")),
650                )?
651                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
652                .project(vec![max(col("orders.o_custkey"))])?
653                .build()?,
654        );
655
656        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
657            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
658            .project(vec![col("customer.c_custkey")])?
659            .build()?;
660
661        // Unsupported predicate, subquery should not be decorrelated
662        assert_optimized_plan_equal!(
663            plan,
664            @r"
665        Projection: customer.c_custkey [c_custkey:Int64]
666          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
667            Subquery: [max(orders.o_custkey):Int64;N]
668              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
669                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
670                  Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
671                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
672            TableScan: customer [c_custkey:Int64, c_name:Utf8]
673        "
674        )
675    }
676
677    /// Test for correlated scalar subquery less than
678    #[test]
679    fn scalar_subquery_where_less_than() -> Result<()> {
680        let sq = Arc::new(
681            LogicalPlanBuilder::from(scan_tpch_table("orders"))
682                .filter(
683                    out_ref_col(DataType::Int64, "customer.c_custkey")
684                        .lt(col("orders.o_custkey")),
685                )?
686                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
687                .project(vec![max(col("orders.o_custkey"))])?
688                .build()?,
689        );
690
691        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
692            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
693            .project(vec![col("customer.c_custkey")])?
694            .build()?;
695
696        // Unsupported predicate, subquery should not be decorrelated
697        assert_optimized_plan_equal!(
698            plan,
699            @r"
700        Projection: customer.c_custkey [c_custkey:Int64]
701          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
702            Subquery: [max(orders.o_custkey):Int64;N]
703              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
704                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
705                  Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
706                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
707            TableScan: customer [c_custkey:Int64, c_name:Utf8]
708        "
709        )
710    }
711
712    /// Test for correlated scalar subquery filter with subquery disjunction
713    #[test]
714    fn scalar_subquery_with_subquery_disjunction() -> Result<()> {
715        let sq = Arc::new(
716            LogicalPlanBuilder::from(scan_tpch_table("orders"))
717                .filter(
718                    out_ref_col(DataType::Int64, "customer.c_custkey")
719                        .eq(col("orders.o_custkey"))
720                        .or(col("o_orderkey").eq(lit(1))),
721                )?
722                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
723                .project(vec![max(col("orders.o_custkey"))])?
724                .build()?,
725        );
726
727        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
728            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
729            .project(vec![col("customer.c_custkey")])?
730            .build()?;
731
732        // Unsupported predicate, subquery should not be decorrelated
733        assert_optimized_plan_equal!(
734            plan,
735            @r"
736        Projection: customer.c_custkey [c_custkey:Int64]
737          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
738            Subquery: [max(orders.o_custkey):Int64;N]
739              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
740                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
741                  Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
742                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
743            TableScan: customer [c_custkey:Int64, c_name:Utf8]
744        "
745        )
746    }
747
748    /// Test for correlated scalar without projection
749    #[test]
750    fn scalar_subquery_no_projection() -> Result<()> {
751        let sq = Arc::new(
752            LogicalPlanBuilder::from(scan_tpch_table("orders"))
753                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
754                .build()?,
755        );
756
757        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
758            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
759            .project(vec![col("customer.c_custkey")])?
760            .build()?;
761
762        let expected = "Error during planning: Scalar subquery should only return one column, but found 4: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice";
763        assert_analyzer_check_err(vec![], plan, expected);
764        Ok(())
765    }
766
767    /// Test for correlated scalar expressions
768    #[test]
769    fn scalar_subquery_project_expr() -> Result<()> {
770        let sq = Arc::new(
771            LogicalPlanBuilder::from(scan_tpch_table("orders"))
772                .filter(
773                    out_ref_col(DataType::Int64, "customer.c_custkey")
774                        .eq(col("orders.o_custkey")),
775                )?
776                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
777                .project(vec![col("max(orders.o_custkey)").add(lit(1))])?
778                .build()?,
779        );
780
781        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
782            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
783            .project(vec![col("customer.c_custkey")])?
784            .build()?;
785
786        assert_optimized_plan_equal!(
787            plan,
788            @r"
789        Projection: customer.c_custkey [c_custkey:Int64]
790          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
791            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
792              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
793                TableScan: customer [c_custkey:Int64, c_name:Utf8]
794                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]
795                  Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]
796                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
797                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
798        "
799        )
800    }
801
802    /// Test for correlated scalar subquery with non-strong project
803    #[test]
804    fn scalar_subquery_with_non_strong_project() -> Result<()> {
805        let case = Expr::Case(expr::Case {
806            expr: None,
807            when_then_expr: vec![(
808                Box::new(col("max(orders.o_totalprice)")),
809                Box::new(lit("a")),
810            )],
811            else_expr: Some(Box::new(lit("b"))),
812        });
813
814        let sq = Arc::new(
815            LogicalPlanBuilder::from(scan_tpch_table("orders"))
816                .filter(
817                    out_ref_col(DataType::Int64, "customer.c_custkey")
818                        .eq(col("orders.o_custkey")),
819                )?
820                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_totalprice"))])?
821                .project(vec![case])?
822                .build()?,
823        );
824
825        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
826            .project(vec![col("customer.c_custkey"), scalar_subquery(sq)])?
827            .build()?;
828
829        assert_optimized_plan_equal!(
830            plan,
831            @r#"
832        Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8("a") ELSE Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N]
833          Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N]
834            TableScan: customer [c_custkey:Int64, c_name:Utf8]
835            SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean]
836              Projection: CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END, orders.o_custkey, __always_true [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean]
837                Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_totalprice):Float64;N]
838                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
839        "#
840        )
841    }
842
843    /// Test for correlated scalar subquery multiple projected columns
844    #[test]
845    fn scalar_subquery_multi_col() -> Result<()> {
846        let sq = Arc::new(
847            LogicalPlanBuilder::from(scan_tpch_table("orders"))
848                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
849                .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])?
850                .build()?,
851        );
852
853        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
854            .filter(
855                col("customer.c_custkey")
856                    .eq(scalar_subquery(sq))
857                    .and(col("c_custkey").eq(lit(1))),
858            )?
859            .project(vec![col("customer.c_custkey")])?
860            .build()?;
861
862        let expected = "Error during planning: Scalar subquery should only return one column, but found 2: orders.o_custkey, orders.o_orderkey";
863        assert_analyzer_check_err(vec![], plan, expected);
864        Ok(())
865    }
866
867    /// Test for correlated scalar subquery filter with additional filters
868    #[test]
869    fn scalar_subquery_additional_filters_with_non_equal_clause() -> Result<()> {
870        let sq = Arc::new(
871            LogicalPlanBuilder::from(scan_tpch_table("orders"))
872                .filter(
873                    out_ref_col(DataType::Int64, "customer.c_custkey")
874                        .eq(col("orders.o_custkey")),
875                )?
876                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
877                .project(vec![max(col("orders.o_custkey"))])?
878                .build()?,
879        );
880
881        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
882            .filter(
883                col("customer.c_custkey")
884                    .gt_eq(scalar_subquery(sq))
885                    .and(col("c_custkey").eq(lit(1))),
886            )?
887            .project(vec![col("customer.c_custkey")])?
888            .build()?;
889
890        assert_optimized_plan_equal!(
891            plan,
892            @r"
893        Projection: customer.c_custkey [c_custkey:Int64]
894          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
895            Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
896              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
897                TableScan: customer [c_custkey:Int64, c_name:Utf8]
898                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
899                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
900                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
901                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
902        "
903        )
904    }
905
906    #[test]
907    fn scalar_subquery_additional_filters_with_equal_clause() -> Result<()> {
908        let sq = Arc::new(
909            LogicalPlanBuilder::from(scan_tpch_table("orders"))
910                .filter(
911                    out_ref_col(DataType::Int64, "customer.c_custkey")
912                        .eq(col("orders.o_custkey")),
913                )?
914                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
915                .project(vec![max(col("orders.o_custkey"))])?
916                .build()?,
917        );
918
919        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
920            .filter(
921                col("customer.c_custkey")
922                    .eq(scalar_subquery(sq))
923                    .and(col("c_custkey").eq(lit(1))),
924            )?
925            .project(vec![col("customer.c_custkey")])?
926            .build()?;
927
928        assert_optimized_plan_equal!(
929            plan,
930            @r"
931        Projection: customer.c_custkey [c_custkey:Int64]
932          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
933            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
934              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
935                TableScan: customer [c_custkey:Int64, c_name:Utf8]
936                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
937                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
938                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
939                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
940        "
941        )
942    }
943
944    /// Test for correlated scalar subquery filter with disjunctions
945    #[test]
946    fn scalar_subquery_disjunction() -> Result<()> {
947        let sq = Arc::new(
948            LogicalPlanBuilder::from(scan_tpch_table("orders"))
949                .filter(
950                    out_ref_col(DataType::Int64, "customer.c_custkey")
951                        .eq(col("orders.o_custkey")),
952                )?
953                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
954                .project(vec![max(col("orders.o_custkey"))])?
955                .build()?,
956        );
957
958        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
959            .filter(
960                col("customer.c_custkey")
961                    .eq(scalar_subquery(sq))
962                    .or(col("customer.c_custkey").eq(lit(1))),
963            )?
964            .project(vec![col("customer.c_custkey")])?
965            .build()?;
966
967        assert_optimized_plan_equal!(
968            plan,
969            @r"
970        Projection: customer.c_custkey [c_custkey:Int64]
971          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
972            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
973              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
974                TableScan: customer [c_custkey:Int64, c_name:Utf8]
975                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
976                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
977                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
978                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
979        "
980        )
981    }
982
983    /// Test for correlated scalar subquery filter
984    #[test]
985    fn exists_subquery_correlated() -> Result<()> {
986        let sq = Arc::new(
987            LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
988                .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
989                .aggregate(Vec::<Expr>::new(), vec![min(col("c"))])?
990                .project(vec![min(col("c"))])?
991                .build()?,
992        );
993
994        let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
995            .filter(col("test.c").lt(scalar_subquery(sq)))?
996            .project(vec![col("test.c")])?
997            .build()?;
998
999        assert_optimized_plan_equal!(
1000            plan,
1001            @r"
1002        Projection: test.c [c:UInt32]
1003          Projection: test.a, test.b, test.c [a:UInt32, b:UInt32, c:UInt32]
1004            Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]
1005              Left Join:  Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]
1006                TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1007                SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]
1008                  Projection: min(sq.c), sq.a, __always_true [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]
1009                    Aggregate: groupBy=[[sq.a, Boolean(true) AS __always_true]], aggr=[[min(sq.c)]] [a:UInt32, __always_true:Boolean, min(sq.c):UInt32;N]
1010                      TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1011        "
1012        )
1013    }
1014
1015    /// Test for non-correlated scalar subquery with no filters
1016    #[test]
1017    fn scalar_subquery_non_correlated_no_filters_with_non_equal_clause() -> Result<()> {
1018        let sq = Arc::new(
1019            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1020                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1021                .project(vec![max(col("orders.o_custkey"))])?
1022                .build()?,
1023        );
1024
1025        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1026            .filter(col("customer.c_custkey").lt(scalar_subquery(sq)))?
1027            .project(vec![col("customer.c_custkey")])?
1028            .build()?;
1029
1030        assert_optimized_plan_equal!(
1031            plan,
1032            @r"
1033        Projection: customer.c_custkey [c_custkey:Int64]
1034          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1035            Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1036              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1037                TableScan: customer [c_custkey:Int64, c_name:Utf8]
1038                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
1039                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1040                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1041                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1042        "
1043        )
1044    }
1045
1046    #[test]
1047    fn scalar_subquery_non_correlated_no_filters_with_equal_clause() -> Result<()> {
1048        let sq = Arc::new(
1049            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1050                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1051                .project(vec![max(col("orders.o_custkey"))])?
1052                .build()?,
1053        );
1054
1055        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1056            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
1057            .project(vec![col("customer.c_custkey")])?
1058            .build()?;
1059
1060        assert_optimized_plan_equal!(
1061            plan,
1062            @r"
1063        Projection: customer.c_custkey [c_custkey:Int64]
1064          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1065            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1066              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1067                TableScan: customer [c_custkey:Int64, c_name:Utf8]
1068                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
1069                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1070                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1071                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1072        "
1073        )
1074    }
1075
1076    #[test]
1077    fn correlated_scalar_subquery_in_between_clause() -> Result<()> {
1078        let sq1 = Arc::new(
1079            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1080                .filter(
1081                    out_ref_col(DataType::Int64, "customer.c_custkey")
1082                        .eq(col("orders.o_custkey")),
1083                )?
1084                .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1085                .project(vec![min(col("orders.o_custkey"))])?
1086                .build()?,
1087        );
1088        let sq2 = Arc::new(
1089            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1090                .filter(
1091                    out_ref_col(DataType::Int64, "customer.c_custkey")
1092                        .eq(col("orders.o_custkey")),
1093                )?
1094                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1095                .project(vec![max(col("orders.o_custkey"))])?
1096                .build()?,
1097        );
1098
1099        let between_expr = Expr::Between(Between {
1100            expr: Box::new(col("customer.c_custkey")),
1101            negated: false,
1102            low: Box::new(scalar_subquery(sq1)),
1103            high: Box::new(scalar_subquery(sq2)),
1104        });
1105
1106        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1107            .filter(between_expr)?
1108            .project(vec![col("customer.c_custkey")])?
1109            .build()?;
1110
1111        assert_optimized_plan_equal!(
1112            plan,
1113            @r"
1114        Projection: customer.c_custkey [c_custkey:Int64]
1115          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1116            Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1117              Left Join:  Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1118                Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1119                  TableScan: customer [c_custkey:Int64, c_name:Utf8]
1120                  SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1121                    Projection: min(orders.o_custkey), orders.o_custkey, __always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1122                      Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, min(orders.o_custkey):Int64;N]
1123                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1124                SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1125                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1126                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
1127                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1128        "
1129        )
1130    }
1131
1132    #[test]
1133    fn uncorrelated_scalar_subquery_in_between_clause() -> Result<()> {
1134        let sq1 = Arc::new(
1135            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1136                .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1137                .project(vec![min(col("orders.o_custkey"))])?
1138                .build()?,
1139        );
1140        let sq2 = Arc::new(
1141            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1142                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1143                .project(vec![max(col("orders.o_custkey"))])?
1144                .build()?,
1145        );
1146
1147        let between_expr = Expr::Between(Between {
1148            expr: Box::new(col("customer.c_custkey")),
1149            negated: false,
1150            low: Box::new(scalar_subquery(sq1)),
1151            high: Box::new(scalar_subquery(sq2)),
1152        });
1153
1154        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1155            .filter(between_expr)?
1156            .project(vec![col("customer.c_custkey")])?
1157            .build()?;
1158
1159        assert_optimized_plan_equal!(
1160            plan,
1161            @r"
1162        Projection: customer.c_custkey [c_custkey:Int64]
1163          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1164            Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]
1165              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]
1166                Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N]
1167                  TableScan: customer [c_custkey:Int64, c_name:Utf8]
1168                  SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N]
1169                    Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N]
1170                      Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N]
1171                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1172                SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N]
1173                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1174                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1175                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1176        "
1177        )
1178    }
1179}