datafusion_optimizer/
scalar_subquery_to_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s
19
20use std::collections::{BTreeSet, HashMap};
21use std::sync::Arc;
22
23use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR};
24use crate::optimizer::ApplyOrder;
25use crate::utils::replace_qualified_name;
26use crate::{OptimizerConfig, OptimizerRule};
27
28use datafusion_common::alias::AliasGenerator;
29use datafusion_common::tree_node::{
30    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
31};
32use datafusion_common::{internal_err, plan_err, Column, Result, ScalarValue};
33use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
34use datafusion_expr::logical_plan::{JoinType, Subquery};
35use datafusion_expr::utils::conjunction;
36use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder};
37
38/// Optimizer rule for rewriting subquery filters to joins
39#[derive(Default, Debug)]
40pub struct ScalarSubqueryToJoin {}
41
42impl ScalarSubqueryToJoin {
43    #[allow(missing_docs)]
44    pub fn new() -> Self {
45        Self::default()
46    }
47
48    /// Finds expressions that have a scalar subquery in them (and recurses when found)
49    ///
50    /// # Arguments
51    /// * `predicate` - A conjunction to split and search
52    ///
53    /// Returns a tuple (subqueries, alias)
54    fn extract_subquery_exprs(
55        &self,
56        predicate: &Expr,
57        alias_gen: &Arc<AliasGenerator>,
58    ) -> Result<(Vec<(Subquery, String)>, Expr)> {
59        let mut extract = ExtractScalarSubQuery {
60            sub_query_info: vec![],
61            alias_gen,
62        };
63        predicate
64            .clone()
65            .rewrite(&mut extract)
66            .data()
67            .map(|new_expr| (extract.sub_query_info, new_expr))
68    }
69}
70
71impl OptimizerRule for ScalarSubqueryToJoin {
72    fn supports_rewrite(&self) -> bool {
73        true
74    }
75
76    fn rewrite(
77        &self,
78        plan: LogicalPlan,
79        config: &dyn OptimizerConfig,
80    ) -> Result<Transformed<LogicalPlan>> {
81        match plan {
82            LogicalPlan::Filter(filter) => {
83                // Optimization: skip the rest of the rule and its copies if
84                // there are no scalar subqueries
85                if !contains_scalar_subquery(&filter.predicate) {
86                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
87                }
88
89                let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs(
90                    &filter.predicate,
91                    config.alias_generator(),
92                )?;
93
94                if subqueries.is_empty() {
95                    return internal_err!("Expected subqueries not found in filter");
96                }
97
98                // iterate through all subqueries in predicate, turning each into a left join
99                let mut cur_input = filter.input.as_ref().clone();
100                for (subquery, alias) in subqueries {
101                    if let Some((optimized_subquery, expr_check_map)) =
102                        build_join(&subquery, &cur_input, &alias)?
103                    {
104                        if !expr_check_map.is_empty() {
105                            rewrite_expr = rewrite_expr
106                                .transform_up(|expr| {
107                                    // replace column references with entry in map, if it exists
108                                    if let Some(map_expr) = expr
109                                        .try_as_col()
110                                        .and_then(|col| expr_check_map.get(&col.name))
111                                    {
112                                        Ok(Transformed::yes(map_expr.clone()))
113                                    } else {
114                                        Ok(Transformed::no(expr))
115                                    }
116                                })
117                                .data()?;
118                        }
119                        cur_input = optimized_subquery;
120                    } else {
121                        // if we can't handle all of the subqueries then bail for now
122                        return Ok(Transformed::no(LogicalPlan::Filter(filter)));
123                    }
124                }
125                let new_plan = LogicalPlanBuilder::from(cur_input)
126                    .filter(rewrite_expr)?
127                    .build()?;
128                Ok(Transformed::yes(new_plan))
129            }
130            LogicalPlan::Projection(projection) => {
131                // Optimization: skip the rest of the rule and its copies if
132                // there are no scalar subqueries
133                if !projection.expr.iter().any(contains_scalar_subquery) {
134                    return Ok(Transformed::no(LogicalPlan::Projection(projection)));
135                }
136
137                let mut all_subqueries = vec![];
138                let mut expr_to_rewrite_expr_map = HashMap::new();
139                let mut subquery_to_expr_map = HashMap::new();
140                for expr in projection.expr.iter() {
141                    let (subqueries, rewrite_exprs) =
142                        self.extract_subquery_exprs(expr, config.alias_generator())?;
143                    for (subquery, _) in &subqueries {
144                        subquery_to_expr_map.insert(subquery.clone(), expr.clone());
145                    }
146                    all_subqueries.extend(subqueries);
147                    expr_to_rewrite_expr_map.insert(expr, rewrite_exprs);
148                }
149                if all_subqueries.is_empty() {
150                    return internal_err!("Expected subqueries not found in projection");
151                }
152                // iterate through all subqueries in predicate, turning each into a left join
153                let mut cur_input = projection.input.as_ref().clone();
154                for (subquery, alias) in all_subqueries {
155                    if let Some((optimized_subquery, expr_check_map)) =
156                        build_join(&subquery, &cur_input, &alias)?
157                    {
158                        cur_input = optimized_subquery;
159                        if !expr_check_map.is_empty() {
160                            if let Some(expr) = subquery_to_expr_map.get(&subquery) {
161                                if let Some(rewrite_expr) =
162                                    expr_to_rewrite_expr_map.get(expr)
163                                {
164                                    let new_expr = rewrite_expr
165                                        .clone()
166                                        .transform_up(|expr| {
167                                            // replace column references with entry in map, if it exists
168                                            if let Some(map_expr) =
169                                                expr.try_as_col().and_then(|col| {
170                                                    expr_check_map.get(&col.name)
171                                                })
172                                            {
173                                                Ok(Transformed::yes(map_expr.clone()))
174                                            } else {
175                                                Ok(Transformed::no(expr))
176                                            }
177                                        })
178                                        .data()?;
179                                    expr_to_rewrite_expr_map.insert(expr, new_expr);
180                                }
181                            }
182                        }
183                    } else {
184                        // if we can't handle all of the subqueries then bail for now
185                        return Ok(Transformed::no(LogicalPlan::Projection(projection)));
186                    }
187                }
188
189                let mut proj_exprs = vec![];
190                for expr in projection.expr.iter() {
191                    let old_expr_name = expr.schema_name().to_string();
192                    let new_expr = expr_to_rewrite_expr_map.get(expr).unwrap();
193                    let new_expr_name = new_expr.schema_name().to_string();
194                    if new_expr_name != old_expr_name {
195                        proj_exprs.push(new_expr.clone().alias(old_expr_name))
196                    } else {
197                        proj_exprs.push(new_expr.clone());
198                    }
199                }
200                let new_plan = LogicalPlanBuilder::from(cur_input)
201                    .project(proj_exprs)?
202                    .build()?;
203                Ok(Transformed::yes(new_plan))
204            }
205
206            plan => Ok(Transformed::no(plan)),
207        }
208    }
209
210    fn name(&self) -> &str {
211        "scalar_subquery_to_join"
212    }
213
214    fn apply_order(&self) -> Option<ApplyOrder> {
215        Some(ApplyOrder::TopDown)
216    }
217}
218
219/// Returns true if the expression has a scalar subquery somewhere in it
220/// false otherwise
221fn contains_scalar_subquery(expr: &Expr) -> bool {
222    expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_))))
223        .expect("Inner is always Ok")
224}
225
226struct ExtractScalarSubQuery<'a> {
227    sub_query_info: Vec<(Subquery, String)>,
228    alias_gen: &'a Arc<AliasGenerator>,
229}
230
231impl TreeNodeRewriter for ExtractScalarSubQuery<'_> {
232    type Node = Expr;
233
234    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
235        match expr {
236            Expr::ScalarSubquery(subquery) => {
237                let subqry_alias = self.alias_gen.next("__scalar_sq");
238                self.sub_query_info
239                    .push((subquery.clone(), subqry_alias.clone()));
240                let scalar_expr = subquery
241                    .subquery
242                    .head_output_expr()?
243                    .map_or(plan_err!("single expression required."), Ok)?;
244                Ok(Transformed::new(
245                    Expr::Column(create_col_from_scalar_expr(
246                        &scalar_expr,
247                        subqry_alias,
248                    )?),
249                    true,
250                    TreeNodeRecursion::Jump,
251                ))
252            }
253            _ => Ok(Transformed::no(expr)),
254        }
255    }
256}
257
258/// Takes a query like:
259///
260/// ```text
261/// select id from customers where balance >
262///     (select avg(total) from orders where orders.c_id = customers.id)
263/// ```
264///
265/// and optimizes it into:
266///
267/// ```text
268/// select c.id from customers c
269/// left join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id
270/// where c.balance > o.val
271/// ```
272///
273/// Or a query like:
274///
275/// ```text
276/// select id from customers where balance >
277///     (select avg(total) from orders)
278/// ```
279///
280/// and optimizes it into:
281///
282/// ```text
283/// select c.id from customers c
284/// left join (select avg(total) as val from orders) a
285/// where c.balance > a.val
286/// ```
287///
288/// # Arguments
289///
290/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders)
291/// * `filter_input` - The non-subquery portion (from customers)
292/// * `outer_others` - Any additional parts to the `where` expression (and c.x = y)
293/// * `subquery_alias` - Subquery aliases
294fn build_join(
295    subquery: &Subquery,
296    filter_input: &LogicalPlan,
297    subquery_alias: &str,
298) -> Result<Option<(LogicalPlan, HashMap<String, Expr>)>> {
299    let subquery_plan = subquery.subquery.as_ref();
300    let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
301    let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?;
302    if !pull_up.can_pull_up {
303        return Ok(None);
304    }
305
306    let collected_count_expr_map =
307        pull_up.collected_count_expr_map.get(&new_plan).cloned();
308    let sub_query_alias = LogicalPlanBuilder::from(new_plan)
309        .alias(subquery_alias.to_string())?
310        .build()?;
311
312    let mut all_correlated_cols = BTreeSet::new();
313    pull_up
314        .correlated_subquery_cols_map
315        .values()
316        .for_each(|cols| all_correlated_cols.extend(cols.clone()));
317
318    // alias the join filter
319    let join_filter_opt =
320        conjunction(pull_up.join_filters).map_or(Ok(None), |filter| {
321            replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some)
322        })?;
323
324    // join our sub query into the main plan
325    let new_plan = if join_filter_opt.is_none() {
326        match filter_input {
327            LogicalPlan::EmptyRelation(EmptyRelation {
328                produce_one_row: true,
329                schema: _,
330            }) => sub_query_alias,
331            _ => {
332                // if not correlated, group down to 1 row and left join on that (preserving row count)
333                LogicalPlanBuilder::from(filter_input.clone())
334                    .join_on(sub_query_alias, JoinType::Left, None)?
335                    .build()?
336            }
337        }
338    } else {
339        // left join if correlated, grouping by the join keys so we don't change row count
340        LogicalPlanBuilder::from(filter_input.clone())
341            .join_on(sub_query_alias, JoinType::Left, join_filter_opt)?
342            .build()?
343    };
344    let mut computation_project_expr = HashMap::new();
345    if let Some(expr_map) = collected_count_expr_map {
346        for (name, result) in expr_map {
347            let computer_expr = if let Some(filter) = &pull_up.pull_up_having_expr {
348                Expr::Case(expr::Case {
349                    expr: None,
350                    when_then_expr: vec![
351                        (
352                            Box::new(Expr::IsNull(Box::new(Expr::Column(
353                                Column::new_unqualified(UN_MATCHED_ROW_INDICATOR),
354                            )))),
355                            Box::new(result),
356                        ),
357                        (
358                            Box::new(Expr::Not(Box::new(filter.clone()))),
359                            Box::new(Expr::Literal(ScalarValue::Null)),
360                        ),
361                    ],
362                    else_expr: Some(Box::new(Expr::Column(Column::new_unqualified(
363                        name.clone(),
364                    )))),
365                })
366            } else {
367                Expr::Case(expr::Case {
368                    expr: None,
369                    when_then_expr: vec![(
370                        Box::new(Expr::IsNull(Box::new(Expr::Column(
371                            Column::new_unqualified(UN_MATCHED_ROW_INDICATOR),
372                        )))),
373                        Box::new(result),
374                    )],
375                    else_expr: Some(Box::new(Expr::Column(Column::new_unqualified(
376                        name.clone(),
377                    )))),
378                })
379            };
380            computation_project_expr.insert(name, computer_expr);
381        }
382    }
383
384    Ok(Some((new_plan, computation_project_expr)))
385}
386
387#[cfg(test)]
388mod tests {
389    use std::ops::Add;
390
391    use super::*;
392    use crate::test::*;
393
394    use arrow::datatypes::DataType;
395    use datafusion_expr::test::function_stub::sum;
396
397    use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between};
398    use datafusion_functions_aggregate::min_max::{max, min};
399
400    /// Test multiple correlated subqueries
401    #[test]
402    fn multiple_subqueries() -> Result<()> {
403        let orders = Arc::new(
404            LogicalPlanBuilder::from(scan_tpch_table("orders"))
405                .filter(
406                    col("orders.o_custkey")
407                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
408                )?
409                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
410                .project(vec![max(col("orders.o_custkey"))])?
411                .build()?,
412        );
413
414        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
415            .filter(
416                lit(1)
417                    .lt(scalar_subquery(Arc::clone(&orders)))
418                    .and(lit(1).lt(scalar_subquery(orders))),
419            )?
420            .project(vec![col("customer.c_custkey")])?
421            .build()?;
422
423        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
424        \n  Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
425        \n    Left Join:  Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
426        \n      Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
427        \n        TableScan: customer [c_custkey:Int64, c_name:Utf8]\
428        \n        SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
429        \n          Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
430        \n            Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
431        \n              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
432        \n      SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
433        \n        Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
434        \n          Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
435        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
436        assert_multi_rules_optimized_plan_eq_display_indent(
437            vec![Arc::new(ScalarSubqueryToJoin::new())],
438            plan,
439            expected,
440        );
441        Ok(())
442    }
443
444    /// Test recursive correlated subqueries
445    #[test]
446    fn recursive_subqueries() -> Result<()> {
447        let lineitem = Arc::new(
448            LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
449                .filter(
450                    col("lineitem.l_orderkey")
451                        .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
452                )?
453                .aggregate(
454                    Vec::<Expr>::new(),
455                    vec![sum(col("lineitem.l_extendedprice"))],
456                )?
457                .project(vec![sum(col("lineitem.l_extendedprice"))])?
458                .build()?,
459        );
460
461        let orders = Arc::new(
462            LogicalPlanBuilder::from(scan_tpch_table("orders"))
463                .filter(
464                    col("orders.o_custkey")
465                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey"))
466                        .and(col("orders.o_totalprice").lt(scalar_subquery(lineitem))),
467                )?
468                .aggregate(Vec::<Expr>::new(), vec![sum(col("orders.o_totalprice"))])?
469                .project(vec![sum(col("orders.o_totalprice"))])?
470                .build()?,
471        );
472
473        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
474            .filter(col("customer.c_acctbal").lt(scalar_subquery(orders)))?
475            .project(vec![col("customer.c_custkey")])?
476            .build()?;
477
478        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
479        \n  Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N]\
480        \n    Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N]\
481        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
482        \n      SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64]\
483        \n        Projection: sum(orders.o_totalprice), orders.o_custkey [sum(orders.o_totalprice):Float64;N, o_custkey:Int64]\
484        \n          Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, sum(orders.o_totalprice):Float64;N]\
485        \n            Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N]\
486        \n              Left Join:  Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N]\
487        \n                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
488        \n                SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\
489        \n                  Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\
490        \n                    Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, sum(lineitem.l_extendedprice):Float64;N]\
491        \n                      TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]";
492        assert_multi_rules_optimized_plan_eq_display_indent(
493            vec![Arc::new(ScalarSubqueryToJoin::new())],
494            plan,
495            expected,
496        );
497        Ok(())
498    }
499
500    /// Test for correlated scalar subquery filter with additional subquery filters
501    #[test]
502    fn scalar_subquery_with_subquery_filters() -> Result<()> {
503        let sq = Arc::new(
504            LogicalPlanBuilder::from(scan_tpch_table("orders"))
505                .filter(
506                    out_ref_col(DataType::Int64, "customer.c_custkey")
507                        .eq(col("orders.o_custkey"))
508                        .and(col("o_orderkey").eq(lit(1))),
509                )?
510                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
511                .project(vec![max(col("orders.o_custkey"))])?
512                .build()?,
513        );
514
515        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
516            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
517            .project(vec![col("customer.c_custkey")])?
518            .build()?;
519
520        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
521        \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
522        \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
523        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
524        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
525        \n        Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
526        \n          Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
527        \n            Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
528        \n              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
529
530        assert_multi_rules_optimized_plan_eq_display_indent(
531            vec![Arc::new(ScalarSubqueryToJoin::new())],
532            plan,
533            expected,
534        );
535        Ok(())
536    }
537
538    /// Test for correlated scalar subquery with no columns in schema
539    #[test]
540    fn scalar_subquery_no_cols() -> Result<()> {
541        let sq = Arc::new(
542            LogicalPlanBuilder::from(scan_tpch_table("orders"))
543                .filter(
544                    out_ref_col(DataType::Int64, "customer.c_custkey")
545                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
546                )?
547                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
548                .project(vec![max(col("orders.o_custkey"))])?
549                .build()?,
550        );
551
552        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
553            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
554            .project(vec![col("customer.c_custkey")])?
555            .build()?;
556
557        // it will optimize, but fail for the same reason the unoptimized query would
558        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
559        \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
560        \n    Left Join:  [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
561        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
562        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
563        \n        Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
564        \n          Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
565        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
566        assert_multi_rules_optimized_plan_eq_display_indent(
567            vec![Arc::new(ScalarSubqueryToJoin::new())],
568            plan,
569            expected,
570        );
571        Ok(())
572    }
573
574    /// Test for scalar subquery with both columns in schema
575    #[test]
576    fn scalar_subquery_with_no_correlated_cols() -> Result<()> {
577        let sq = Arc::new(
578            LogicalPlanBuilder::from(scan_tpch_table("orders"))
579                .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
580                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
581                .project(vec![max(col("orders.o_custkey"))])?
582                .build()?,
583        );
584
585        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
586            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
587            .project(vec![col("customer.c_custkey")])?
588            .build()?;
589
590        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
591        \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
592        \n    Left Join:  [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
593        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
594        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
595        \n        Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
596        \n          Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
597        \n            Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
598        \n              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
599
600        assert_multi_rules_optimized_plan_eq_display_indent(
601            vec![Arc::new(ScalarSubqueryToJoin::new())],
602            plan,
603            expected,
604        );
605        Ok(())
606    }
607
608    /// Test for correlated scalar subquery not equal
609    #[test]
610    fn scalar_subquery_where_not_eq() -> Result<()> {
611        let sq = Arc::new(
612            LogicalPlanBuilder::from(scan_tpch_table("orders"))
613                .filter(
614                    out_ref_col(DataType::Int64, "customer.c_custkey")
615                        .not_eq(col("orders.o_custkey")),
616                )?
617                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
618                .project(vec![max(col("orders.o_custkey"))])?
619                .build()?,
620        );
621
622        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
623            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
624            .project(vec![col("customer.c_custkey")])?
625            .build()?;
626
627        // Unsupported predicate, subquery should not be decorrelated
628        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
629            \n  Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]\
630            \n    Subquery: [max(orders.o_custkey):Int64;N]\
631            \n      Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
632            \n        Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
633            \n          Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
634            \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
635            \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]";
636
637        assert_multi_rules_optimized_plan_eq_display_indent(
638            vec![Arc::new(ScalarSubqueryToJoin::new())],
639            plan,
640            expected,
641        );
642        Ok(())
643    }
644
645    /// Test for correlated scalar subquery less than
646    #[test]
647    fn scalar_subquery_where_less_than() -> Result<()> {
648        let sq = Arc::new(
649            LogicalPlanBuilder::from(scan_tpch_table("orders"))
650                .filter(
651                    out_ref_col(DataType::Int64, "customer.c_custkey")
652                        .lt(col("orders.o_custkey")),
653                )?
654                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
655                .project(vec![max(col("orders.o_custkey"))])?
656                .build()?,
657        );
658
659        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
660            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
661            .project(vec![col("customer.c_custkey")])?
662            .build()?;
663
664        // Unsupported predicate, subquery should not be decorrelated
665        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
666            \n  Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]\
667            \n    Subquery: [max(orders.o_custkey):Int64;N]\
668            \n      Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
669            \n        Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
670            \n          Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
671            \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
672            \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]";
673
674        assert_multi_rules_optimized_plan_eq_display_indent(
675            vec![Arc::new(ScalarSubqueryToJoin::new())],
676            plan,
677            expected,
678        );
679        Ok(())
680    }
681
682    /// Test for correlated scalar subquery filter with subquery disjunction
683    #[test]
684    fn scalar_subquery_with_subquery_disjunction() -> Result<()> {
685        let sq = Arc::new(
686            LogicalPlanBuilder::from(scan_tpch_table("orders"))
687                .filter(
688                    out_ref_col(DataType::Int64, "customer.c_custkey")
689                        .eq(col("orders.o_custkey"))
690                        .or(col("o_orderkey").eq(lit(1))),
691                )?
692                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
693                .project(vec![max(col("orders.o_custkey"))])?
694                .build()?,
695        );
696
697        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
698            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
699            .project(vec![col("customer.c_custkey")])?
700            .build()?;
701
702        // Unsupported predicate, subquery should not be decorrelated
703        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
704            \n  Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]\
705            \n    Subquery: [max(orders.o_custkey):Int64;N]\
706            \n      Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
707            \n        Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
708            \n          Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
709            \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
710            \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]";
711
712        assert_multi_rules_optimized_plan_eq_display_indent(
713            vec![Arc::new(ScalarSubqueryToJoin::new())],
714            plan,
715            expected,
716        );
717        Ok(())
718    }
719
720    /// Test for correlated scalar without projection
721    #[test]
722    fn scalar_subquery_no_projection() -> Result<()> {
723        let sq = Arc::new(
724            LogicalPlanBuilder::from(scan_tpch_table("orders"))
725                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
726                .build()?,
727        );
728
729        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
730            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
731            .project(vec![col("customer.c_custkey")])?
732            .build()?;
733
734        let expected = "Invalid (non-executable) plan after Analyzer\
735        \ncaused by\
736        \nError during planning: Scalar subquery should only return one column";
737        assert_analyzer_check_err(vec![], plan, expected);
738        Ok(())
739    }
740
741    /// Test for correlated scalar expressions
742    #[test]
743    fn scalar_subquery_project_expr() -> Result<()> {
744        let sq = Arc::new(
745            LogicalPlanBuilder::from(scan_tpch_table("orders"))
746                .filter(
747                    out_ref_col(DataType::Int64, "customer.c_custkey")
748                        .eq(col("orders.o_custkey")),
749                )?
750                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
751                .project(vec![col("max(orders.o_custkey)").add(lit(1))])?
752                .build()?,
753        );
754
755        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
756            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
757            .project(vec![col("customer.c_custkey")])?
758            .build()?;
759
760        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
761        \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N]\
762        \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N]\
763        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
764        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64]\
765        \n        Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64]\
766        \n          Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
767        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
768
769        assert_multi_rules_optimized_plan_eq_display_indent(
770            vec![Arc::new(ScalarSubqueryToJoin::new())],
771            plan,
772            expected,
773        );
774        Ok(())
775    }
776
777    /// Test for correlated scalar subquery multiple projected columns
778    #[test]
779    fn scalar_subquery_multi_col() -> Result<()> {
780        let sq = Arc::new(
781            LogicalPlanBuilder::from(scan_tpch_table("orders"))
782                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
783                .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])?
784                .build()?,
785        );
786
787        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
788            .filter(
789                col("customer.c_custkey")
790                    .eq(scalar_subquery(sq))
791                    .and(col("c_custkey").eq(lit(1))),
792            )?
793            .project(vec![col("customer.c_custkey")])?
794            .build()?;
795
796        let expected = "Invalid (non-executable) plan after Analyzer\
797        \ncaused by\
798        \nError during planning: Scalar subquery should only return one column";
799        assert_analyzer_check_err(vec![], plan, expected);
800        Ok(())
801    }
802
803    /// Test for correlated scalar subquery filter with additional filters
804    #[test]
805    fn scalar_subquery_additional_filters_with_non_equal_clause() -> Result<()> {
806        let sq = Arc::new(
807            LogicalPlanBuilder::from(scan_tpch_table("orders"))
808                .filter(
809                    out_ref_col(DataType::Int64, "customer.c_custkey")
810                        .eq(col("orders.o_custkey")),
811                )?
812                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
813                .project(vec![max(col("orders.o_custkey"))])?
814                .build()?,
815        );
816
817        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
818            .filter(
819                col("customer.c_custkey")
820                    .gt_eq(scalar_subquery(sq))
821                    .and(col("c_custkey").eq(lit(1))),
822            )?
823            .project(vec![col("customer.c_custkey")])?
824            .build()?;
825
826        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
827        \n  Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
828        \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
829        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
830        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
831        \n        Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
832        \n          Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
833        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
834
835        assert_multi_rules_optimized_plan_eq_display_indent(
836            vec![Arc::new(ScalarSubqueryToJoin::new())],
837            plan,
838            expected,
839        );
840        Ok(())
841    }
842
843    #[test]
844    fn scalar_subquery_additional_filters_with_equal_clause() -> Result<()> {
845        let sq = Arc::new(
846            LogicalPlanBuilder::from(scan_tpch_table("orders"))
847                .filter(
848                    out_ref_col(DataType::Int64, "customer.c_custkey")
849                        .eq(col("orders.o_custkey")),
850                )?
851                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
852                .project(vec![max(col("orders.o_custkey"))])?
853                .build()?,
854        );
855
856        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
857            .filter(
858                col("customer.c_custkey")
859                    .eq(scalar_subquery(sq))
860                    .and(col("c_custkey").eq(lit(1))),
861            )?
862            .project(vec![col("customer.c_custkey")])?
863            .build()?;
864
865        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
866        \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
867        \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
868        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
869        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
870        \n        Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
871        \n          Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
872        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
873
874        assert_multi_rules_optimized_plan_eq_display_indent(
875            vec![Arc::new(ScalarSubqueryToJoin::new())],
876            plan,
877            expected,
878        );
879        Ok(())
880    }
881
882    /// Test for correlated scalar subquery filter with disjunctions
883    #[test]
884    fn scalar_subquery_disjunction() -> Result<()> {
885        let sq = Arc::new(
886            LogicalPlanBuilder::from(scan_tpch_table("orders"))
887                .filter(
888                    out_ref_col(DataType::Int64, "customer.c_custkey")
889                        .eq(col("orders.o_custkey")),
890                )?
891                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
892                .project(vec![max(col("orders.o_custkey"))])?
893                .build()?,
894        );
895
896        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
897            .filter(
898                col("customer.c_custkey")
899                    .eq(scalar_subquery(sq))
900                    .or(col("customer.c_custkey").eq(lit(1))),
901            )?
902            .project(vec![col("customer.c_custkey")])?
903            .build()?;
904
905        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
906        \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
907        \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
908        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
909        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
910        \n        Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
911        \n          Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
912        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
913
914        assert_multi_rules_optimized_plan_eq_display_indent(
915            vec![Arc::new(ScalarSubqueryToJoin::new())],
916            plan,
917            expected,
918        );
919        Ok(())
920    }
921
922    /// Test for correlated scalar subquery filter
923    #[test]
924    fn exists_subquery_correlated() -> Result<()> {
925        let sq = Arc::new(
926            LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
927                .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
928                .aggregate(Vec::<Expr>::new(), vec![min(col("c"))])?
929                .project(vec![min(col("c"))])?
930                .build()?,
931        );
932
933        let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
934            .filter(col("test.c").lt(scalar_subquery(sq)))?
935            .project(vec![col("test.c")])?
936            .build()?;
937
938        let expected = "Projection: test.c [c:UInt32]\
939        \n  Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N]\
940        \n    Left Join:  Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N]\
941        \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
942        \n      SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32]\
943        \n        Projection: min(sq.c), sq.a [min(sq.c):UInt32;N, a:UInt32]\
944        \n          Aggregate: groupBy=[[sq.a]], aggr=[[min(sq.c)]] [a:UInt32, min(sq.c):UInt32;N]\
945        \n            TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
946
947        assert_multi_rules_optimized_plan_eq_display_indent(
948            vec![Arc::new(ScalarSubqueryToJoin::new())],
949            plan,
950            expected,
951        );
952        Ok(())
953    }
954
955    /// Test for non-correlated scalar subquery with no filters
956    #[test]
957    fn scalar_subquery_non_correlated_no_filters_with_non_equal_clause() -> Result<()> {
958        let sq = Arc::new(
959            LogicalPlanBuilder::from(scan_tpch_table("orders"))
960                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
961                .project(vec![max(col("orders.o_custkey"))])?
962                .build()?,
963        );
964
965        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
966            .filter(col("customer.c_custkey").lt(scalar_subquery(sq)))?
967            .project(vec![col("customer.c_custkey")])?
968            .build()?;
969
970        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
971        \n  Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
972        \n    Left Join:  [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
973        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
974        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
975        \n        Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
976        \n          Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
977        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
978
979        assert_multi_rules_optimized_plan_eq_display_indent(
980            vec![Arc::new(ScalarSubqueryToJoin::new())],
981            plan,
982            expected,
983        );
984        Ok(())
985    }
986
987    #[test]
988    fn scalar_subquery_non_correlated_no_filters_with_equal_clause() -> Result<()> {
989        let sq = Arc::new(
990            LogicalPlanBuilder::from(scan_tpch_table("orders"))
991                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
992                .project(vec![max(col("orders.o_custkey"))])?
993                .build()?,
994        );
995
996        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
997            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
998            .project(vec![col("customer.c_custkey")])?
999            .build()?;
1000
1001        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
1002        \n  Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
1003        \n    Left Join:  [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
1004        \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
1005        \n      SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
1006        \n        Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
1007        \n          Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
1008        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
1009
1010        assert_multi_rules_optimized_plan_eq_display_indent(
1011            vec![Arc::new(ScalarSubqueryToJoin::new())],
1012            plan,
1013            expected,
1014        );
1015        Ok(())
1016    }
1017
1018    #[test]
1019    fn correlated_scalar_subquery_in_between_clause() -> Result<()> {
1020        let sq1 = Arc::new(
1021            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1022                .filter(
1023                    out_ref_col(DataType::Int64, "customer.c_custkey")
1024                        .eq(col("orders.o_custkey")),
1025                )?
1026                .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1027                .project(vec![min(col("orders.o_custkey"))])?
1028                .build()?,
1029        );
1030        let sq2 = Arc::new(
1031            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1032                .filter(
1033                    out_ref_col(DataType::Int64, "customer.c_custkey")
1034                        .eq(col("orders.o_custkey")),
1035                )?
1036                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1037                .project(vec![max(col("orders.o_custkey"))])?
1038                .build()?,
1039        );
1040
1041        let between_expr = Expr::Between(Between {
1042            expr: Box::new(col("customer.c_custkey")),
1043            negated: false,
1044            low: Box::new(scalar_subquery(sq1)),
1045            high: Box::new(scalar_subquery(sq2)),
1046        });
1047
1048        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1049            .filter(between_expr)?
1050            .project(vec![col("customer.c_custkey")])?
1051            .build()?;
1052
1053        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
1054        \n  Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
1055        \n    Left Join:  Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
1056        \n      Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
1057        \n        TableScan: customer [c_custkey:Int64, c_name:Utf8]\
1058        \n        SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64]\
1059        \n          Projection: min(orders.o_custkey), orders.o_custkey [min(orders.o_custkey):Int64;N, o_custkey:Int64]\
1060        \n            Aggregate: groupBy=[[orders.o_custkey]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, min(orders.o_custkey):Int64;N]\
1061        \n              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
1062        \n      SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
1063        \n        Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
1064        \n          Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
1065        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
1066
1067        assert_multi_rules_optimized_plan_eq_display_indent(
1068            vec![Arc::new(ScalarSubqueryToJoin::new())],
1069            plan,
1070            expected,
1071        );
1072        Ok(())
1073    }
1074
1075    #[test]
1076    fn uncorrelated_scalar_subquery_in_between_clause() -> Result<()> {
1077        let sq1 = Arc::new(
1078            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1079                .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1080                .project(vec![min(col("orders.o_custkey"))])?
1081                .build()?,
1082        );
1083        let sq2 = Arc::new(
1084            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1085                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1086                .project(vec![max(col("orders.o_custkey"))])?
1087                .build()?,
1088        );
1089
1090        let between_expr = Expr::Between(Between {
1091            expr: Box::new(col("customer.c_custkey")),
1092            negated: false,
1093            low: Box::new(scalar_subquery(sq1)),
1094            high: Box::new(scalar_subquery(sq2)),
1095        });
1096
1097        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1098            .filter(between_expr)?
1099            .project(vec![col("customer.c_custkey")])?
1100            .build()?;
1101
1102        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
1103        \n  Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\
1104        \n    Left Join:  [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\
1105        \n      Left Join:  [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N]\
1106        \n        TableScan: customer [c_custkey:Int64, c_name:Utf8]\
1107        \n        SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N]\
1108        \n          Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N]\
1109        \n            Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N]\
1110        \n              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
1111        \n      SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N]\
1112        \n        Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
1113        \n          Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
1114        \n            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
1115
1116        assert_multi_rules_optimized_plan_eq_display_indent(
1117            vec![Arc::new(ScalarSubqueryToJoin::new())],
1118            plan,
1119            expected,
1120        );
1121        Ok(())
1122    }
1123}