datafusion_optimizer/
scalar_subquery_to_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s
19
20use std::collections::{BTreeSet, HashMap};
21use std::sync::Arc;
22
23use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR};
24use crate::optimizer::ApplyOrder;
25use crate::utils::{evaluates_to_null, replace_qualified_name};
26use crate::{OptimizerConfig, OptimizerRule};
27
28use crate::analyzer::type_coercion::TypeCoercionRewriter;
29use datafusion_common::alias::AliasGenerator;
30use datafusion_common::tree_node::{
31    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
32};
33use datafusion_common::{internal_err, plan_err, Column, Result, ScalarValue};
34use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
35use datafusion_expr::logical_plan::{JoinType, Subquery};
36use datafusion_expr::utils::conjunction;
37use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder};
38
39/// Optimizer rule for rewriting subquery filters to joins
40#[derive(Default, Debug)]
41pub struct ScalarSubqueryToJoin {}
42
43impl ScalarSubqueryToJoin {
44    #[allow(missing_docs)]
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Finds expressions that have a scalar subquery in them (and recurses when found)
50    ///
51    /// # Arguments
52    /// * `predicate` - A conjunction to split and search
53    ///
54    /// Returns a tuple (subqueries, alias)
55    fn extract_subquery_exprs(
56        &self,
57        predicate: &Expr,
58        alias_gen: &Arc<AliasGenerator>,
59    ) -> Result<(Vec<(Subquery, String)>, Expr)> {
60        let mut extract = ExtractScalarSubQuery {
61            sub_query_info: vec![],
62            alias_gen,
63        };
64        predicate
65            .clone()
66            .rewrite(&mut extract)
67            .data()
68            .map(|new_expr| (extract.sub_query_info, new_expr))
69    }
70}
71
72impl OptimizerRule for ScalarSubqueryToJoin {
73    fn supports_rewrite(&self) -> bool {
74        true
75    }
76
77    fn rewrite(
78        &self,
79        plan: LogicalPlan,
80        config: &dyn OptimizerConfig,
81    ) -> Result<Transformed<LogicalPlan>> {
82        match plan {
83            LogicalPlan::Filter(filter) => {
84                // Optimization: skip the rest of the rule and its copies if
85                // there are no scalar subqueries
86                if !contains_scalar_subquery(&filter.predicate) {
87                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
88                }
89
90                let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs(
91                    &filter.predicate,
92                    config.alias_generator(),
93                )?;
94
95                if subqueries.is_empty() {
96                    return internal_err!("Expected subqueries not found in filter");
97                }
98
99                // iterate through all subqueries in predicate, turning each into a left join
100                let mut cur_input = filter.input.as_ref().clone();
101                for (subquery, alias) in subqueries {
102                    if let Some((optimized_subquery, expr_check_map)) =
103                        build_join(&subquery, &cur_input, &alias)?
104                    {
105                        if !expr_check_map.is_empty() {
106                            rewrite_expr = rewrite_expr
107                                .transform_up(|expr| {
108                                    // replace column references with entry in map, if it exists
109                                    if let Some(map_expr) = expr
110                                        .try_as_col()
111                                        .and_then(|col| expr_check_map.get(&col.name))
112                                    {
113                                        Ok(Transformed::yes(map_expr.clone()))
114                                    } else {
115                                        Ok(Transformed::no(expr))
116                                    }
117                                })
118                                .data()?;
119                        }
120                        cur_input = optimized_subquery;
121                    } else {
122                        // if we can't handle all of the subqueries then bail for now
123                        return Ok(Transformed::no(LogicalPlan::Filter(filter)));
124                    }
125                }
126                let new_plan = LogicalPlanBuilder::from(cur_input)
127                    .filter(rewrite_expr)?
128                    .build()?;
129                Ok(Transformed::yes(new_plan))
130            }
131            LogicalPlan::Projection(projection) => {
132                // Optimization: skip the rest of the rule and its copies if
133                // there are no scalar subqueries
134                if !projection.expr.iter().any(contains_scalar_subquery) {
135                    return Ok(Transformed::no(LogicalPlan::Projection(projection)));
136                }
137
138                let mut all_subqueries = vec![];
139                let mut expr_to_rewrite_expr_map = HashMap::new();
140                let mut subquery_to_expr_map = HashMap::new();
141                for expr in projection.expr.iter() {
142                    let (subqueries, rewrite_exprs) =
143                        self.extract_subquery_exprs(expr, config.alias_generator())?;
144                    for (subquery, _) in &subqueries {
145                        subquery_to_expr_map.insert(subquery.clone(), expr.clone());
146                    }
147                    all_subqueries.extend(subqueries);
148                    expr_to_rewrite_expr_map.insert(expr, rewrite_exprs);
149                }
150                if all_subqueries.is_empty() {
151                    return internal_err!("Expected subqueries not found in projection");
152                }
153                // iterate through all subqueries in predicate, turning each into a left join
154                let mut cur_input = projection.input.as_ref().clone();
155                for (subquery, alias) in all_subqueries {
156                    if let Some((optimized_subquery, expr_check_map)) =
157                        build_join(&subquery, &cur_input, &alias)?
158                    {
159                        cur_input = optimized_subquery;
160                        if !expr_check_map.is_empty() {
161                            if let Some(expr) = subquery_to_expr_map.get(&subquery) {
162                                if let Some(rewrite_expr) =
163                                    expr_to_rewrite_expr_map.get(expr)
164                                {
165                                    let new_expr = rewrite_expr
166                                        .clone()
167                                        .transform_up(|expr| {
168                                            // replace column references with entry in map, if it exists
169                                            if let Some(map_expr) =
170                                                expr.try_as_col().and_then(|col| {
171                                                    expr_check_map.get(&col.name)
172                                                })
173                                            {
174                                                Ok(Transformed::yes(map_expr.clone()))
175                                            } else {
176                                                Ok(Transformed::no(expr))
177                                            }
178                                        })
179                                        .data()?;
180                                    expr_to_rewrite_expr_map.insert(expr, new_expr);
181                                }
182                            }
183                        }
184                    } else {
185                        // if we can't handle all of the subqueries then bail for now
186                        return Ok(Transformed::no(LogicalPlan::Projection(projection)));
187                    }
188                }
189
190                let mut proj_exprs = vec![];
191                for expr in projection.expr.iter() {
192                    let old_expr_name = expr.schema_name().to_string();
193                    let new_expr = expr_to_rewrite_expr_map.get(expr).unwrap();
194                    let new_expr_name = new_expr.schema_name().to_string();
195                    if new_expr_name != old_expr_name {
196                        proj_exprs.push(new_expr.clone().alias(old_expr_name))
197                    } else {
198                        proj_exprs.push(new_expr.clone());
199                    }
200                }
201                let new_plan = LogicalPlanBuilder::from(cur_input)
202                    .project(proj_exprs)?
203                    .build()?;
204                Ok(Transformed::yes(new_plan))
205            }
206
207            plan => Ok(Transformed::no(plan)),
208        }
209    }
210
211    fn name(&self) -> &str {
212        "scalar_subquery_to_join"
213    }
214
215    fn apply_order(&self) -> Option<ApplyOrder> {
216        Some(ApplyOrder::TopDown)
217    }
218}
219
220/// Returns true if the expression has a scalar subquery somewhere in it
221/// false otherwise
222fn contains_scalar_subquery(expr: &Expr) -> bool {
223    expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_))))
224        .expect("Inner is always Ok")
225}
226
227struct ExtractScalarSubQuery<'a> {
228    sub_query_info: Vec<(Subquery, String)>,
229    alias_gen: &'a Arc<AliasGenerator>,
230}
231
232impl TreeNodeRewriter for ExtractScalarSubQuery<'_> {
233    type Node = Expr;
234
235    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
236        match expr {
237            Expr::ScalarSubquery(subquery) => {
238                let subqry_alias = self.alias_gen.next("__scalar_sq");
239                self.sub_query_info
240                    .push((subquery.clone(), subqry_alias.clone()));
241                let scalar_expr = subquery
242                    .subquery
243                    .head_output_expr()?
244                    .map_or(plan_err!("single expression required."), Ok)?;
245                Ok(Transformed::new(
246                    Expr::Column(create_col_from_scalar_expr(
247                        &scalar_expr,
248                        subqry_alias,
249                    )?),
250                    true,
251                    TreeNodeRecursion::Jump,
252                ))
253            }
254            _ => Ok(Transformed::no(expr)),
255        }
256    }
257}
258
259/// Takes a query like:
260///
261/// ```text
262/// select id from customers where balance >
263///     (select avg(total) from orders where orders.c_id = customers.id)
264/// ```
265///
266/// and optimizes it into:
267///
268/// ```text
269/// select c.id from customers c
270/// left join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id
271/// where c.balance > o.val
272/// ```
273///
274/// Or a query like:
275///
276/// ```text
277/// select id from customers where balance >
278///     (select avg(total) from orders)
279/// ```
280///
281/// and optimizes it into:
282///
283/// ```text
284/// select c.id from customers c
285/// left join (select avg(total) as val from orders) a
286/// where c.balance > a.val
287/// ```
288///
289/// # Arguments
290///
291/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders)
292/// * `filter_input` - The non-subquery portion (from customers)
293/// * `outer_others` - Any additional parts to the `where` expression (and c.x = y)
294/// * `subquery_alias` - Subquery aliases
295fn build_join(
296    subquery: &Subquery,
297    filter_input: &LogicalPlan,
298    subquery_alias: &str,
299) -> Result<Option<(LogicalPlan, HashMap<String, Expr>)>> {
300    let subquery_plan = subquery.subquery.as_ref();
301    let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
302    let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?;
303    if !pull_up.can_pull_up {
304        return Ok(None);
305    }
306
307    let collected_count_expr_map =
308        pull_up.collected_count_expr_map.get(&new_plan).cloned();
309    let sub_query_alias = LogicalPlanBuilder::from(new_plan)
310        .alias(subquery_alias.to_string())?
311        .build()?;
312
313    let mut all_correlated_cols = BTreeSet::new();
314    pull_up
315        .correlated_subquery_cols_map
316        .values()
317        .for_each(|cols| all_correlated_cols.extend(cols.clone()));
318
319    // alias the join filter
320    let join_filter_opt =
321        conjunction(pull_up.join_filters).map_or(Ok(None), |filter| {
322            replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some)
323        })?;
324
325    // join our sub query into the main plan
326    let new_plan = if join_filter_opt.is_none() {
327        match filter_input {
328            LogicalPlan::EmptyRelation(EmptyRelation {
329                produce_one_row: true,
330                schema: _,
331            }) => sub_query_alias,
332            _ => {
333                // if not correlated, group down to 1 row and left join on that (preserving row count)
334                LogicalPlanBuilder::from(filter_input.clone())
335                    .join_on(
336                        sub_query_alias,
337                        JoinType::Left,
338                        vec![Expr::Literal(ScalarValue::Boolean(Some(true)))],
339                    )?
340                    .build()?
341            }
342        }
343    } else {
344        // left join if correlated, grouping by the join keys so we don't change row count
345        LogicalPlanBuilder::from(filter_input.clone())
346            .join_on(sub_query_alias, JoinType::Left, join_filter_opt)?
347            .build()?
348    };
349    let mut computation_project_expr = HashMap::new();
350    if let Some(expr_map) = collected_count_expr_map {
351        for (name, result) in expr_map {
352            if evaluates_to_null(result.clone(), result.column_refs())? {
353                // If expr always returns null when column is null, skip processing
354                continue;
355            }
356            let computer_expr = if let Some(filter) = &pull_up.pull_up_having_expr {
357                Expr::Case(expr::Case {
358                    expr: None,
359                    when_then_expr: vec![
360                        (
361                            Box::new(Expr::IsNull(Box::new(Expr::Column(
362                                Column::new_unqualified(UN_MATCHED_ROW_INDICATOR),
363                            )))),
364                            Box::new(result),
365                        ),
366                        (
367                            Box::new(Expr::Not(Box::new(filter.clone()))),
368                            Box::new(Expr::Literal(ScalarValue::Null)),
369                        ),
370                    ],
371                    else_expr: Some(Box::new(Expr::Column(Column::new_unqualified(
372                        name.clone(),
373                    )))),
374                })
375            } else {
376                Expr::Case(expr::Case {
377                    expr: None,
378                    when_then_expr: vec![(
379                        Box::new(Expr::IsNull(Box::new(Expr::Column(
380                            Column::new_unqualified(UN_MATCHED_ROW_INDICATOR),
381                        )))),
382                        Box::new(result),
383                    )],
384                    else_expr: Some(Box::new(Expr::Column(Column::new_unqualified(
385                        name.clone(),
386                    )))),
387                })
388            };
389            let mut expr_rewrite = TypeCoercionRewriter {
390                schema: new_plan.schema(),
391            };
392            computation_project_expr
393                .insert(name, computer_expr.rewrite(&mut expr_rewrite).data()?);
394        }
395    }
396
397    Ok(Some((new_plan, computation_project_expr)))
398}
399
400#[cfg(test)]
401mod tests {
402    use std::ops::Add;
403
404    use super::*;
405    use crate::test::*;
406
407    use arrow::datatypes::DataType;
408    use datafusion_expr::test::function_stub::sum;
409
410    use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between};
411    use datafusion_functions_aggregate::min_max::{max, min};
412
413    /// Test multiple correlated subqueries
414    #[test]
415    fn multiple_subqueries() -> Result<()> {
416        let orders = Arc::new(
417            LogicalPlanBuilder::from(scan_tpch_table("orders"))
418                .filter(
419                    col("orders.o_custkey")
420                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
421                )?
422                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
423                .project(vec![max(col("orders.o_custkey"))])?
424                .build()?,
425        );
426
427        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
428            .filter(
429                lit(1)
430                    .lt(scalar_subquery(Arc::clone(&orders)))
431                    .and(lit(1).lt(scalar_subquery(orders))),
432            )?
433            .project(vec![col("customer.c_custkey")])?
434            .build()?;
435
436        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
437            \n  Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
438            \n    Left Join:  Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
439            \n      Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
440            \n        TableScan: customer [c_custkey:Int64, c_name:Utf8]\
441            \n        SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
442            \n          Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
443            \n            Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\
444            \n              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
445            \n      SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
446            \n        Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
447            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\
448            \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
449        assert_multi_rules_optimized_plan_eq_display_indent(
450            vec![Arc::new(ScalarSubqueryToJoin::new())],
451            plan,
452            expected,
453        );
454        Ok(())
455    }
456
457    /// Test recursive correlated subqueries
458    #[test]
459    fn recursive_subqueries() -> Result<()> {
460        let lineitem = Arc::new(
461            LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
462                .filter(
463                    col("lineitem.l_orderkey")
464                        .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
465                )?
466                .aggregate(
467                    Vec::<Expr>::new(),
468                    vec![sum(col("lineitem.l_extendedprice"))],
469                )?
470                .project(vec![sum(col("lineitem.l_extendedprice"))])?
471                .build()?,
472        );
473
474        let orders = Arc::new(
475            LogicalPlanBuilder::from(scan_tpch_table("orders"))
476                .filter(
477                    col("orders.o_custkey")
478                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey"))
479                        .and(col("orders.o_totalprice").lt(scalar_subquery(lineitem))),
480                )?
481                .aggregate(Vec::<Expr>::new(), vec![sum(col("orders.o_totalprice"))])?
482                .project(vec![sum(col("orders.o_totalprice"))])?
483                .build()?,
484        );
485
486        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
487            .filter(col("customer.c_acctbal").lt(scalar_subquery(orders)))?
488            .project(vec![col("customer.c_custkey")])?
489            .build()?;
490
491        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
492            \n  Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
493            \n    Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
494            \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
495            \n      SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]\
496            \n        Projection: sum(orders.o_totalprice), orders.o_custkey, __always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]\
497            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, sum(orders.o_totalprice):Float64;N]\
498            \n            Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]\
499            \n              Left Join:  Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]\
500            \n                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
501            \n                SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]\
502            \n                  Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]\
503            \n                    Aggregate: groupBy=[[lineitem.l_orderkey, Boolean(true) AS __always_true]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, __always_true:Boolean, sum(lineitem.l_extendedprice):Float64;N]\
504            \n                      TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]";
505        assert_multi_rules_optimized_plan_eq_display_indent(
506            vec![Arc::new(ScalarSubqueryToJoin::new())],
507            plan,
508            expected,
509        );
510        Ok(())
511    }
512
513    /// Test for correlated scalar subquery filter with additional subquery filters
514    #[test]
515    fn scalar_subquery_with_subquery_filters() -> Result<()> {
516        let sq = Arc::new(
517            LogicalPlanBuilder::from(scan_tpch_table("orders"))
518                .filter(
519                    out_ref_col(DataType::Int64, "customer.c_custkey")
520                        .eq(col("orders.o_custkey"))
521                        .and(col("o_orderkey").eq(lit(1))),
522                )?
523                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
524                .project(vec![max(col("orders.o_custkey"))])?
525                .build()?,
526        );
527
528        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
529            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
530            .project(vec![col("customer.c_custkey")])?
531            .build()?;
532
533        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
534            \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
535            \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
536            \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
537            \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
538            \n        Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
539            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\
540            \n            Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
541            \n              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
542
543        assert_multi_rules_optimized_plan_eq_display_indent(
544            vec![Arc::new(ScalarSubqueryToJoin::new())],
545            plan,
546            expected,
547        );
548        Ok(())
549    }
550
551    /// Test for correlated scalar subquery with no columns in schema
552    #[test]
553    fn scalar_subquery_no_cols() -> Result<()> {
554        let sq = Arc::new(
555            LogicalPlanBuilder::from(scan_tpch_table("orders"))
556                .filter(
557                    out_ref_col(DataType::Int64, "customer.c_custkey")
558                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
559                )?
560                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
561                .project(vec![max(col("orders.o_custkey"))])?
562                .build()?,
563        );
564
565        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
566            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
567            .project(vec![col("customer.c_custkey")])?
568            .build()?;
569
570        // it will optimize, but fail for the same reason the unoptimized query would
571        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
572        \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
573        \n    Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
574        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
575        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
576        \n        Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
577        \n          Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
578        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
579        assert_multi_rules_optimized_plan_eq_display_indent(
580            vec![Arc::new(ScalarSubqueryToJoin::new())],
581            plan,
582            expected,
583        );
584        Ok(())
585    }
586
587    /// Test for scalar subquery with both columns in schema
588    #[test]
589    fn scalar_subquery_with_no_correlated_cols() -> Result<()> {
590        let sq = Arc::new(
591            LogicalPlanBuilder::from(scan_tpch_table("orders"))
592                .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
593                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
594                .project(vec![max(col("orders.o_custkey"))])?
595                .build()?,
596        );
597
598        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
599            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
600            .project(vec![col("customer.c_custkey")])?
601            .build()?;
602
603        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
604        \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
605        \n    Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
606        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
607        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
608        \n        Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
609        \n          Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
610        \n            Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
611        \n              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
612
613        assert_multi_rules_optimized_plan_eq_display_indent(
614            vec![Arc::new(ScalarSubqueryToJoin::new())],
615            plan,
616            expected,
617        );
618        Ok(())
619    }
620
621    /// Test for correlated scalar subquery not equal
622    #[test]
623    fn scalar_subquery_where_not_eq() -> Result<()> {
624        let sq = Arc::new(
625            LogicalPlanBuilder::from(scan_tpch_table("orders"))
626                .filter(
627                    out_ref_col(DataType::Int64, "customer.c_custkey")
628                        .not_eq(col("orders.o_custkey")),
629                )?
630                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
631                .project(vec![max(col("orders.o_custkey"))])?
632                .build()?,
633        );
634
635        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
636            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
637            .project(vec![col("customer.c_custkey")])?
638            .build()?;
639
640        // Unsupported predicate, subquery should not be decorrelated
641        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
642            \n  Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]\
643            \n    Subquery: [max(orders.o_custkey):Int64;N]\
644            \n      Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
645            \n        Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
646            \n          Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
647            \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
648            \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]";
649
650        assert_multi_rules_optimized_plan_eq_display_indent(
651            vec![Arc::new(ScalarSubqueryToJoin::new())],
652            plan,
653            expected,
654        );
655        Ok(())
656    }
657
658    /// Test for correlated scalar subquery less than
659    #[test]
660    fn scalar_subquery_where_less_than() -> Result<()> {
661        let sq = Arc::new(
662            LogicalPlanBuilder::from(scan_tpch_table("orders"))
663                .filter(
664                    out_ref_col(DataType::Int64, "customer.c_custkey")
665                        .lt(col("orders.o_custkey")),
666                )?
667                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
668                .project(vec![max(col("orders.o_custkey"))])?
669                .build()?,
670        );
671
672        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
673            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
674            .project(vec![col("customer.c_custkey")])?
675            .build()?;
676
677        // Unsupported predicate, subquery should not be decorrelated
678        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
679            \n  Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]\
680            \n    Subquery: [max(orders.o_custkey):Int64;N]\
681            \n      Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
682            \n        Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
683            \n          Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
684            \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
685            \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]";
686
687        assert_multi_rules_optimized_plan_eq_display_indent(
688            vec![Arc::new(ScalarSubqueryToJoin::new())],
689            plan,
690            expected,
691        );
692        Ok(())
693    }
694
695    /// Test for correlated scalar subquery filter with subquery disjunction
696    #[test]
697    fn scalar_subquery_with_subquery_disjunction() -> Result<()> {
698        let sq = Arc::new(
699            LogicalPlanBuilder::from(scan_tpch_table("orders"))
700                .filter(
701                    out_ref_col(DataType::Int64, "customer.c_custkey")
702                        .eq(col("orders.o_custkey"))
703                        .or(col("o_orderkey").eq(lit(1))),
704                )?
705                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
706                .project(vec![max(col("orders.o_custkey"))])?
707                .build()?,
708        );
709
710        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
711            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
712            .project(vec![col("customer.c_custkey")])?
713            .build()?;
714
715        // Unsupported predicate, subquery should not be decorrelated
716        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
717            \n  Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]\
718            \n    Subquery: [max(orders.o_custkey):Int64;N]\
719            \n      Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
720            \n        Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
721            \n          Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
722            \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
723            \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]";
724
725        assert_multi_rules_optimized_plan_eq_display_indent(
726            vec![Arc::new(ScalarSubqueryToJoin::new())],
727            plan,
728            expected,
729        );
730        Ok(())
731    }
732
733    /// Test for correlated scalar without projection
734    #[test]
735    fn scalar_subquery_no_projection() -> Result<()> {
736        let sq = Arc::new(
737            LogicalPlanBuilder::from(scan_tpch_table("orders"))
738                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
739                .build()?,
740        );
741
742        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
743            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
744            .project(vec![col("customer.c_custkey")])?
745            .build()?;
746
747        let expected = "Error during planning: Scalar subquery should only return one column, but found 4: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice";
748        assert_analyzer_check_err(vec![], plan, expected);
749        Ok(())
750    }
751
752    /// Test for correlated scalar expressions
753    #[test]
754    fn scalar_subquery_project_expr() -> Result<()> {
755        let sq = Arc::new(
756            LogicalPlanBuilder::from(scan_tpch_table("orders"))
757                .filter(
758                    out_ref_col(DataType::Int64, "customer.c_custkey")
759                        .eq(col("orders.o_custkey")),
760                )?
761                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
762                .project(vec![col("max(orders.o_custkey)").add(lit(1))])?
763                .build()?,
764        );
765
766        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
767            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
768            .project(vec![col("customer.c_custkey")])?
769            .build()?;
770
771        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
772            \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
773            \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
774            \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
775            \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]\
776            \n        Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]\
777            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\
778            \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
779
780        assert_multi_rules_optimized_plan_eq_display_indent(
781            vec![Arc::new(ScalarSubqueryToJoin::new())],
782            plan,
783            expected,
784        );
785        Ok(())
786    }
787
788    /// Test for correlated scalar subquery with non-strong project
789    #[test]
790    fn scalar_subquery_with_non_strong_project() -> Result<()> {
791        let case = Expr::Case(expr::Case {
792            expr: None,
793            when_then_expr: vec![(
794                Box::new(col("max(orders.o_totalprice)")),
795                Box::new(lit("a")),
796            )],
797            else_expr: Some(Box::new(lit("b"))),
798        });
799
800        let sq = Arc::new(
801            LogicalPlanBuilder::from(scan_tpch_table("orders"))
802                .filter(
803                    out_ref_col(DataType::Int64, "customer.c_custkey")
804                        .eq(col("orders.o_custkey")),
805                )?
806                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_totalprice"))])?
807                .project(vec![case])?
808                .build()?,
809        );
810
811        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
812            .project(vec![col("customer.c_custkey"), scalar_subquery(sq)])?
813            .build()?;
814
815        let expected = "Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8(\"a\") ELSE Utf8(\"b\") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8;N]\
816            \n  Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N]\
817            \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]\
818            \n    SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8, o_custkey:Int64, __always_true:Boolean]\
819            \n      Projection: CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END, orders.o_custkey, __always_true [CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8, o_custkey:Int64, __always_true:Boolean]\
820            \n        Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_totalprice):Float64;N]\
821            \n          TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
822
823        assert_multi_rules_optimized_plan_eq_display_indent(
824            vec![Arc::new(ScalarSubqueryToJoin::new())],
825            plan,
826            expected,
827        );
828        Ok(())
829    }
830
831    /// Test for correlated scalar subquery multiple projected columns
832    #[test]
833    fn scalar_subquery_multi_col() -> Result<()> {
834        let sq = Arc::new(
835            LogicalPlanBuilder::from(scan_tpch_table("orders"))
836                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
837                .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])?
838                .build()?,
839        );
840
841        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
842            .filter(
843                col("customer.c_custkey")
844                    .eq(scalar_subquery(sq))
845                    .and(col("c_custkey").eq(lit(1))),
846            )?
847            .project(vec![col("customer.c_custkey")])?
848            .build()?;
849
850        let expected = "Error during planning: Scalar subquery should only return one column, but found 2: orders.o_custkey, orders.o_orderkey";
851        assert_analyzer_check_err(vec![], plan, expected);
852        Ok(())
853    }
854
855    /// Test for correlated scalar subquery filter with additional filters
856    #[test]
857    fn scalar_subquery_additional_filters_with_non_equal_clause() -> Result<()> {
858        let sq = Arc::new(
859            LogicalPlanBuilder::from(scan_tpch_table("orders"))
860                .filter(
861                    out_ref_col(DataType::Int64, "customer.c_custkey")
862                        .eq(col("orders.o_custkey")),
863                )?
864                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
865                .project(vec![max(col("orders.o_custkey"))])?
866                .build()?,
867        );
868
869        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
870            .filter(
871                col("customer.c_custkey")
872                    .gt_eq(scalar_subquery(sq))
873                    .and(col("c_custkey").eq(lit(1))),
874            )?
875            .project(vec![col("customer.c_custkey")])?
876            .build()?;
877
878        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
879            \n  Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
880            \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
881            \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
882            \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
883            \n        Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
884            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\
885            \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
886
887        assert_multi_rules_optimized_plan_eq_display_indent(
888            vec![Arc::new(ScalarSubqueryToJoin::new())],
889            plan,
890            expected,
891        );
892        Ok(())
893    }
894
895    #[test]
896    fn scalar_subquery_additional_filters_with_equal_clause() -> Result<()> {
897        let sq = Arc::new(
898            LogicalPlanBuilder::from(scan_tpch_table("orders"))
899                .filter(
900                    out_ref_col(DataType::Int64, "customer.c_custkey")
901                        .eq(col("orders.o_custkey")),
902                )?
903                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
904                .project(vec![max(col("orders.o_custkey"))])?
905                .build()?,
906        );
907
908        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
909            .filter(
910                col("customer.c_custkey")
911                    .eq(scalar_subquery(sq))
912                    .and(col("c_custkey").eq(lit(1))),
913            )?
914            .project(vec![col("customer.c_custkey")])?
915            .build()?;
916
917        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
918            \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
919            \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
920            \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
921            \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
922            \n        Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
923            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\
924            \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
925
926        assert_multi_rules_optimized_plan_eq_display_indent(
927            vec![Arc::new(ScalarSubqueryToJoin::new())],
928            plan,
929            expected,
930        );
931        Ok(())
932    }
933
934    /// Test for correlated scalar subquery filter with disjunctions
935    #[test]
936    fn scalar_subquery_disjunction() -> Result<()> {
937        let sq = Arc::new(
938            LogicalPlanBuilder::from(scan_tpch_table("orders"))
939                .filter(
940                    out_ref_col(DataType::Int64, "customer.c_custkey")
941                        .eq(col("orders.o_custkey")),
942                )?
943                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
944                .project(vec![max(col("orders.o_custkey"))])?
945                .build()?,
946        );
947
948        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
949            .filter(
950                col("customer.c_custkey")
951                    .eq(scalar_subquery(sq))
952                    .or(col("customer.c_custkey").eq(lit(1))),
953            )?
954            .project(vec![col("customer.c_custkey")])?
955            .build()?;
956
957        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
958            \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
959            \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
960            \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
961            \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
962            \n        Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
963            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\
964            \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
965
966        assert_multi_rules_optimized_plan_eq_display_indent(
967            vec![Arc::new(ScalarSubqueryToJoin::new())],
968            plan,
969            expected,
970        );
971        Ok(())
972    }
973
974    /// Test for correlated scalar subquery filter
975    #[test]
976    fn exists_subquery_correlated() -> Result<()> {
977        let sq = Arc::new(
978            LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
979                .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
980                .aggregate(Vec::<Expr>::new(), vec![min(col("c"))])?
981                .project(vec![min(col("c"))])?
982                .build()?,
983        );
984
985        let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
986            .filter(col("test.c").lt(scalar_subquery(sq)))?
987            .project(vec![col("test.c")])?
988            .build()?;
989
990        let expected = "Projection: test.c [c:UInt32]\
991            \n  Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]\
992            \n    Left Join:  Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]\
993            \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
994            \n      SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]\
995            \n        Projection: min(sq.c), sq.a, __always_true [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]\
996            \n          Aggregate: groupBy=[[sq.a, Boolean(true) AS __always_true]], aggr=[[min(sq.c)]] [a:UInt32, __always_true:Boolean, min(sq.c):UInt32;N]\
997            \n            TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
998
999        assert_multi_rules_optimized_plan_eq_display_indent(
1000            vec![Arc::new(ScalarSubqueryToJoin::new())],
1001            plan,
1002            expected,
1003        );
1004        Ok(())
1005    }
1006
1007    /// Test for non-correlated scalar subquery with no filters
1008    #[test]
1009    fn scalar_subquery_non_correlated_no_filters_with_non_equal_clause() -> Result<()> {
1010        let sq = Arc::new(
1011            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1012                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1013                .project(vec![max(col("orders.o_custkey"))])?
1014                .build()?,
1015        );
1016
1017        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1018            .filter(col("customer.c_custkey").lt(scalar_subquery(sq)))?
1019            .project(vec![col("customer.c_custkey")])?
1020            .build()?;
1021
1022        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
1023        \n  Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
1024        \n    Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
1025        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
1026        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
1027        \n        Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
1028        \n          Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
1029        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
1030
1031        assert_multi_rules_optimized_plan_eq_display_indent(
1032            vec![Arc::new(ScalarSubqueryToJoin::new())],
1033            plan,
1034            expected,
1035        );
1036        Ok(())
1037    }
1038
1039    #[test]
1040    fn scalar_subquery_non_correlated_no_filters_with_equal_clause() -> Result<()> {
1041        let sq = Arc::new(
1042            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1043                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1044                .project(vec![max(col("orders.o_custkey"))])?
1045                .build()?,
1046        );
1047
1048        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1049            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
1050            .project(vec![col("customer.c_custkey")])?
1051            .build()?;
1052
1053        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
1054        \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
1055        \n    Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
1056        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
1057        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
1058        \n        Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
1059        \n          Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
1060        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
1061
1062        assert_multi_rules_optimized_plan_eq_display_indent(
1063            vec![Arc::new(ScalarSubqueryToJoin::new())],
1064            plan,
1065            expected,
1066        );
1067        Ok(())
1068    }
1069
1070    #[test]
1071    fn correlated_scalar_subquery_in_between_clause() -> Result<()> {
1072        let sq1 = Arc::new(
1073            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1074                .filter(
1075                    out_ref_col(DataType::Int64, "customer.c_custkey")
1076                        .eq(col("orders.o_custkey")),
1077                )?
1078                .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1079                .project(vec![min(col("orders.o_custkey"))])?
1080                .build()?,
1081        );
1082        let sq2 = Arc::new(
1083            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1084                .filter(
1085                    out_ref_col(DataType::Int64, "customer.c_custkey")
1086                        .eq(col("orders.o_custkey")),
1087                )?
1088                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1089                .project(vec![max(col("orders.o_custkey"))])?
1090                .build()?,
1091        );
1092
1093        let between_expr = Expr::Between(Between {
1094            expr: Box::new(col("customer.c_custkey")),
1095            negated: false,
1096            low: Box::new(scalar_subquery(sq1)),
1097            high: Box::new(scalar_subquery(sq2)),
1098        });
1099
1100        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1101            .filter(between_expr)?
1102            .project(vec![col("customer.c_custkey")])?
1103            .build()?;
1104
1105        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
1106            \n  Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
1107            \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
1108            \n      Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\
1109            \n        TableScan: customer [c_custkey:Int64, c_name:Utf8]\
1110            \n        SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
1111            \n          Projection: min(orders.o_custkey), orders.o_custkey, __always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
1112            \n            Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, min(orders.o_custkey):Int64;N]\
1113            \n              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
1114            \n      SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
1115            \n        Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\
1116            \n          Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\
1117            \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
1118
1119        assert_multi_rules_optimized_plan_eq_display_indent(
1120            vec![Arc::new(ScalarSubqueryToJoin::new())],
1121            plan,
1122            expected,
1123        );
1124        Ok(())
1125    }
1126
1127    #[test]
1128    fn uncorrelated_scalar_subquery_in_between_clause() -> Result<()> {
1129        let sq1 = Arc::new(
1130            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1131                .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1132                .project(vec![min(col("orders.o_custkey"))])?
1133                .build()?,
1134        );
1135        let sq2 = Arc::new(
1136            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1137                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1138                .project(vec![max(col("orders.o_custkey"))])?
1139                .build()?,
1140        );
1141
1142        let between_expr = Expr::Between(Between {
1143            expr: Box::new(col("customer.c_custkey")),
1144            negated: false,
1145            low: Box::new(scalar_subquery(sq1)),
1146            high: Box::new(scalar_subquery(sq2)),
1147        });
1148
1149        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1150            .filter(between_expr)?
1151            .project(vec![col("customer.c_custkey")])?
1152            .build()?;
1153
1154        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
1155        \n  Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\
1156        \n    Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\
1157        \n      Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N]\
1158        \n        TableScan: customer [c_custkey:Int64, c_name:Utf8]\
1159        \n        SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N]\
1160        \n          Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N]\
1161        \n            Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N]\
1162        \n              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
1163        \n      SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N]\
1164        \n        Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
1165        \n          Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
1166        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
1167
1168        assert_multi_rules_optimized_plan_eq_display_indent(
1169            vec![Arc::new(ScalarSubqueryToJoin::new())],
1170            plan,
1171            expected,
1172        );
1173        Ok(())
1174    }
1175}