datafusion_optimizer/
decorrelate_predicate_subquery.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`DecorrelatePredicateSubquery`] converts `IN`/`EXISTS` subquery predicates to `SEMI`/`ANTI` joins
19use std::collections::BTreeSet;
20use std::ops::Deref;
21use std::sync::Arc;
22
23use crate::decorrelate::PullUpCorrelatedExpr;
24use crate::optimizer::ApplyOrder;
25use crate::utils::replace_qualified_name;
26use crate::{OptimizerConfig, OptimizerRule};
27
28use datafusion_common::alias::AliasGenerator;
29use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30use datafusion_common::{internal_err, plan_err, Column, Result};
31use datafusion_expr::expr::{Exists, InSubquery};
32use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
33use datafusion_expr::logical_plan::{JoinType, Subquery};
34use datafusion_expr::utils::{conjunction, split_conjunction_owned};
35use datafusion_expr::{
36    exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
37    LogicalPlan, LogicalPlanBuilder, Operator,
38};
39
40use log::debug;
41
42/// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins
43#[derive(Default, Debug)]
44pub struct DecorrelatePredicateSubquery {}
45
46impl DecorrelatePredicateSubquery {
47    #[allow(missing_docs)]
48    pub fn new() -> Self {
49        Self::default()
50    }
51}
52
53impl OptimizerRule for DecorrelatePredicateSubquery {
54    fn supports_rewrite(&self) -> bool {
55        true
56    }
57
58    fn rewrite(
59        &self,
60        plan: LogicalPlan,
61        config: &dyn OptimizerConfig,
62    ) -> Result<Transformed<LogicalPlan>> {
63        let plan = plan
64            .map_subqueries(|subquery| {
65                subquery.transform_down(|p| self.rewrite(p, config))
66            })?
67            .data;
68
69        let LogicalPlan::Filter(filter) = plan else {
70            return Ok(Transformed::no(plan));
71        };
72
73        if !has_subquery(&filter.predicate) {
74            return Ok(Transformed::no(LogicalPlan::Filter(filter)));
75        }
76
77        let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
78            split_conjunction_owned(filter.predicate)
79                .into_iter()
80                .partition(has_subquery);
81
82        if with_subqueries.is_empty() {
83            return internal_err!(
84                "can not find expected subqueries in DecorrelatePredicateSubquery"
85            );
86        }
87
88        // iterate through all exists clauses in predicate, turning each into a join
89        let mut cur_input = Arc::unwrap_or_clone(filter.input);
90        for subquery_expr in with_subqueries {
91            match extract_subquery_info(subquery_expr) {
92                // The subquery expression is at the top level of the filter
93                SubqueryPredicate::Top(subquery) => {
94                    match build_join_top(&subquery, &cur_input, config.alias_generator())?
95                    {
96                        Some(plan) => cur_input = plan,
97                        // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
98                        None => other_exprs.push(subquery.expr()),
99                    }
100                }
101                // The subquery expression is embedded within another expression
102                SubqueryPredicate::Embedded(expr) => {
103                    let (plan, expr_without_subqueries) =
104                        rewrite_inner_subqueries(cur_input, expr, config)?;
105                    cur_input = plan;
106                    other_exprs.push(expr_without_subqueries);
107                }
108            }
109        }
110
111        let expr = conjunction(other_exprs);
112        if let Some(expr) = expr {
113            let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
114            cur_input = LogicalPlan::Filter(new_filter);
115        }
116        Ok(Transformed::yes(cur_input))
117    }
118
119    fn name(&self) -> &str {
120        "decorrelate_predicate_subquery"
121    }
122
123    fn apply_order(&self) -> Option<ApplyOrder> {
124        Some(ApplyOrder::TopDown)
125    }
126}
127
128fn rewrite_inner_subqueries(
129    outer: LogicalPlan,
130    expr: Expr,
131    config: &dyn OptimizerConfig,
132) -> Result<(LogicalPlan, Expr)> {
133    let mut cur_input = outer;
134    let alias = config.alias_generator();
135    let expr_without_subqueries = expr.transform(|e| match e {
136        Expr::Exists(Exists {
137            subquery: Subquery { subquery, .. },
138            negated,
139        }) => match mark_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? {
140            Some((plan, exists_expr)) => {
141                cur_input = plan;
142                Ok(Transformed::yes(exists_expr))
143            }
144            None if negated => Ok(Transformed::no(not_exists(subquery))),
145            None => Ok(Transformed::no(exists(subquery))),
146        },
147        Expr::InSubquery(InSubquery {
148            expr,
149            subquery: Subquery { subquery, .. },
150            negated,
151        }) => {
152            let in_predicate = subquery
153                .head_output_expr()?
154                .map_or(plan_err!("single expression required."), |output_expr| {
155                    Ok(Expr::eq(*expr.clone(), output_expr))
156                })?;
157            match mark_join(
158                &cur_input,
159                Arc::clone(&subquery),
160                Some(in_predicate),
161                negated,
162                alias,
163            )? {
164                Some((plan, exists_expr)) => {
165                    cur_input = plan;
166                    Ok(Transformed::yes(exists_expr))
167                }
168                None if negated => Ok(Transformed::no(not_in_subquery(*expr, subquery))),
169                None => Ok(Transformed::no(in_subquery(*expr, subquery))),
170            }
171        }
172        _ => Ok(Transformed::no(e)),
173    })?;
174    Ok((cur_input, expr_without_subqueries.data))
175}
176
177enum SubqueryPredicate {
178    // The subquery expression is at the top level of the filter and can be fully replaced by a
179    // semi/anti join
180    Top(SubqueryInfo),
181    // The subquery expression is embedded within another expression and is replaced using an
182    // existence join
183    Embedded(Expr),
184}
185
186fn extract_subquery_info(expr: Expr) -> SubqueryPredicate {
187    match expr {
188        Expr::Not(not_expr) => match *not_expr {
189            Expr::InSubquery(InSubquery {
190                expr,
191                subquery,
192                negated,
193            }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr(
194                subquery, *expr, !negated,
195            )),
196            Expr::Exists(Exists { subquery, negated }) => {
197                SubqueryPredicate::Top(SubqueryInfo::new(subquery, !negated))
198            }
199            expr => SubqueryPredicate::Embedded(not(expr)),
200        },
201        Expr::InSubquery(InSubquery {
202            expr,
203            subquery,
204            negated,
205        }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr(
206            subquery, *expr, negated,
207        )),
208        Expr::Exists(Exists { subquery, negated }) => {
209            SubqueryPredicate::Top(SubqueryInfo::new(subquery, negated))
210        }
211        expr => SubqueryPredicate::Embedded(expr),
212    }
213}
214
215fn has_subquery(expr: &Expr) -> bool {
216    expr.exists(|e| match e {
217        Expr::InSubquery(_) | Expr::Exists(_) => Ok(true),
218        _ => Ok(false),
219    })
220    .unwrap()
221}
222
223/// Optimize the subquery to left-anti/left-semi join.
224/// If the subquery is a correlated subquery, we need extract the join predicate from the subquery.
225///
226/// For example, given a query like:
227/// `select t1.a, t1.b from t1 where t1 in (select t2.a from t2 where t1.b = t2.b and t1.c > t2.c)`
228///
229/// The optimized plan will be:
230///
231/// ```text
232/// Projection: t1.a, t1.b
233///   LeftSemi Join:  Filter: t1.a = __correlated_sq_1.a AND t1.b = __correlated_sq_1.b AND t1.c > __correlated_sq_1.c
234///     TableScan: t1
235///     SubqueryAlias: __correlated_sq_1
236///       Projection: t2.a, t2.b, t2.c
237///         TableScan: t2
238/// ```
239///
240/// Given another query like:
241/// `select t1.id from t1 where exists(SELECT t2.id FROM t2 WHERE t1.id = t2.id)`
242///
243/// The optimized plan will be:
244///
245/// ```text
246/// Projection: t1.id
247///   LeftSemi Join:  Filter: t1.id = __correlated_sq_1.id
248///     TableScan: t1
249///     SubqueryAlias: __correlated_sq_1
250///       Projection: t2.id
251///         TableScan: t2
252/// ```
253fn build_join_top(
254    query_info: &SubqueryInfo,
255    left: &LogicalPlan,
256    alias: &Arc<AliasGenerator>,
257) -> Result<Option<LogicalPlan>> {
258    let where_in_expr_opt = &query_info.where_in_expr;
259    let in_predicate_opt = where_in_expr_opt
260        .clone()
261        .map(|where_in_expr| {
262            query_info
263                .query
264                .subquery
265                .head_output_expr()?
266                .map_or(plan_err!("single expression required."), |expr| {
267                    Ok(Expr::eq(where_in_expr, expr))
268                })
269        })
270        .map_or(Ok(None), |v| v.map(Some))?;
271
272    let join_type = match query_info.negated {
273        true => JoinType::LeftAnti,
274        false => JoinType::LeftSemi,
275    };
276    let subquery = query_info.query.subquery.as_ref();
277    let subquery_alias = alias.next("__correlated_sq");
278    build_join(left, subquery, in_predicate_opt, join_type, subquery_alias)
279}
280
281/// This is used to handle the case when the subquery is embedded in a more complex boolean
282/// expression like and OR. For example
283///
284/// `select t1.id from t1 where t1.id < 0 OR exists(SELECT t2.id FROM t2 WHERE t1.id = t2.id)`
285///
286/// The optimized plan will be:
287///
288/// ```text
289/// Projection: t1.id
290///   Filter: t1.id < 0 OR __correlated_sq_1.mark
291///     LeftMark Join:  Filter: t1.id = __correlated_sq_1.id
292///       TableScan: t1
293///       SubqueryAlias: __correlated_sq_1
294///         Projection: t2.id
295///           TableScan: t2
296fn mark_join(
297    left: &LogicalPlan,
298    subquery: Arc<LogicalPlan>,
299    in_predicate_opt: Option<Expr>,
300    negated: bool,
301    alias_generator: &Arc<AliasGenerator>,
302) -> Result<Option<(LogicalPlan, Expr)>> {
303    let alias = alias_generator.next("__correlated_sq");
304
305    let exists_col = Expr::Column(Column::new(Some(alias.clone()), "mark"));
306    let exists_expr = if negated { !exists_col } else { exists_col };
307
308    Ok(
309        build_join(left, &subquery, in_predicate_opt, JoinType::LeftMark, alias)?
310            .map(|plan| (plan, exists_expr)),
311    )
312}
313
314fn build_join(
315    left: &LogicalPlan,
316    subquery: &LogicalPlan,
317    in_predicate_opt: Option<Expr>,
318    join_type: JoinType,
319    alias: String,
320) -> Result<Option<LogicalPlan>> {
321    let mut pull_up = PullUpCorrelatedExpr::new()
322        .with_in_predicate_opt(in_predicate_opt.clone())
323        .with_exists_sub_query(in_predicate_opt.is_none());
324
325    let new_plan = subquery.clone().rewrite(&mut pull_up).data()?;
326    if !pull_up.can_pull_up {
327        return Ok(None);
328    }
329
330    let sub_query_alias = LogicalPlanBuilder::from(new_plan)
331        .alias(alias.to_string())?
332        .build()?;
333    let mut all_correlated_cols = BTreeSet::new();
334    pull_up
335        .correlated_subquery_cols_map
336        .values()
337        .for_each(|cols| all_correlated_cols.extend(cols.clone()));
338
339    // alias the join filter
340    let join_filter_opt = conjunction(pull_up.join_filters)
341        .map_or(Ok(None), |filter| {
342            replace_qualified_name(filter, &all_correlated_cols, &alias).map(Some)
343        })?;
344
345    let join_filter = match (join_filter_opt, in_predicate_opt) {
346        (
347            Some(join_filter),
348            Some(Expr::BinaryExpr(BinaryExpr {
349                left,
350                op: Operator::Eq,
351                right,
352            })),
353        ) => {
354            let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
355            let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col));
356            in_predicate.and(join_filter)
357        }
358        (Some(join_filter), _) => join_filter,
359        (
360            _,
361            Some(Expr::BinaryExpr(BinaryExpr {
362                left,
363                op: Operator::Eq,
364                right,
365            })),
366        ) => {
367            let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
368            let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col));
369            in_predicate
370        }
371        (None, None) => lit(true),
372        _ => return Ok(None),
373    };
374    // join our sub query into the main plan
375    let new_plan = LogicalPlanBuilder::from(left.clone())
376        .join_on(sub_query_alias, join_type, Some(join_filter))?
377        .build()?;
378    debug!(
379        "predicate subquery optimized:\n{}",
380        new_plan.display_indent()
381    );
382    Ok(Some(new_plan))
383}
384
385#[derive(Debug)]
386struct SubqueryInfo {
387    query: Subquery,
388    where_in_expr: Option<Expr>,
389    negated: bool,
390}
391
392impl SubqueryInfo {
393    pub fn new(query: Subquery, negated: bool) -> Self {
394        Self {
395            query,
396            where_in_expr: None,
397            negated,
398        }
399    }
400
401    pub fn new_with_in_expr(query: Subquery, expr: Expr, negated: bool) -> Self {
402        Self {
403            query,
404            where_in_expr: Some(expr),
405            negated,
406        }
407    }
408
409    pub fn expr(self) -> Expr {
410        match self.where_in_expr {
411            Some(expr) => match self.negated {
412                true => not_in_subquery(expr, self.query.subquery),
413                false => in_subquery(expr, self.query.subquery),
414            },
415            None => match self.negated {
416                true => not_exists(self.query.subquery),
417                false => exists(self.query.subquery),
418            },
419        }
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use std::ops::Add;
426
427    use super::*;
428    use crate::test::*;
429
430    use crate::assert_optimized_plan_eq_display_indent_snapshot;
431    use arrow::datatypes::{DataType, Field, Schema};
432    use datafusion_expr::builder::table_source;
433    use datafusion_expr::{and, binary_expr, col, lit, not, out_ref_col, table_scan};
434
435    macro_rules! assert_optimized_plan_equal {
436        (
437            $plan:expr,
438            @ $expected:literal $(,)?
439        ) => {{
440            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(DecorrelatePredicateSubquery::new());
441            assert_optimized_plan_eq_display_indent_snapshot!(
442                rule,
443                $plan,
444                @ $expected,
445            )
446        }};
447    }
448
449    fn test_subquery_with_name(name: &str) -> Result<Arc<LogicalPlan>> {
450        let table_scan = test_table_scan_with_name(name)?;
451        Ok(Arc::new(
452            LogicalPlanBuilder::from(table_scan)
453                .project(vec![col("c")])?
454                .build()?,
455        ))
456    }
457
458    /// Test for several IN subquery expressions
459    #[test]
460    fn in_subquery_multiple() -> Result<()> {
461        let table_scan = test_table_scan()?;
462        let plan = LogicalPlanBuilder::from(table_scan)
463            .filter(and(
464                in_subquery(col("c"), test_subquery_with_name("sq_1")?),
465                in_subquery(col("b"), test_subquery_with_name("sq_2")?),
466            ))?
467            .project(vec![col("test.b")])?
468            .build()?;
469
470        assert_optimized_plan_equal!(
471            plan,
472            @r"
473        Projection: test.b [b:UInt32]
474          LeftSemi Join:  Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]
475            LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
476              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
477              SubqueryAlias: __correlated_sq_1 [c:UInt32]
478                Projection: sq_1.c [c:UInt32]
479                  TableScan: sq_1 [a:UInt32, b:UInt32, c:UInt32]
480            SubqueryAlias: __correlated_sq_2 [c:UInt32]
481              Projection: sq_2.c [c:UInt32]
482                TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]
483        "
484        )
485    }
486
487    /// Test for IN subquery with additional AND filter
488    #[test]
489    fn in_subquery_with_and_filters() -> Result<()> {
490        let table_scan = test_table_scan()?;
491        let plan = LogicalPlanBuilder::from(table_scan)
492            .filter(and(
493                in_subquery(col("c"), test_subquery_with_name("sq")?),
494                and(
495                    binary_expr(col("a"), Operator::Eq, lit(1_u32)),
496                    binary_expr(col("b"), Operator::Lt, lit(30_u32)),
497                ),
498            ))?
499            .project(vec![col("test.b")])?
500            .build()?;
501
502        assert_optimized_plan_equal!(
503            plan,
504            @r"
505        Projection: test.b [b:UInt32]
506          Filter: test.a = UInt32(1) AND test.b < UInt32(30) [a:UInt32, b:UInt32, c:UInt32]
507            LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
508              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
509              SubqueryAlias: __correlated_sq_1 [c:UInt32]
510                Projection: sq.c [c:UInt32]
511                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
512        "
513        )
514    }
515
516    /// Test for nested IN subqueries
517    #[test]
518    fn in_subquery_nested() -> Result<()> {
519        let table_scan = test_table_scan()?;
520
521        let subquery = LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
522            .filter(in_subquery(col("a"), test_subquery_with_name("sq_nested")?))?
523            .project(vec![col("a")])?
524            .build()?;
525
526        let plan = LogicalPlanBuilder::from(table_scan)
527            .filter(in_subquery(col("b"), Arc::new(subquery)))?
528            .project(vec![col("test.b")])?
529            .build()?;
530
531        assert_optimized_plan_equal!(
532            plan,
533            @r"
534        Projection: test.b [b:UInt32]
535          LeftSemi Join:  Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]
536            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
537            SubqueryAlias: __correlated_sq_2 [a:UInt32]
538              Projection: sq.a [a:UInt32]
539                LeftSemi Join:  Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
540                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
541                  SubqueryAlias: __correlated_sq_1 [c:UInt32]
542                    Projection: sq_nested.c [c:UInt32]
543                      TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]
544        "
545        )
546    }
547
548    /// Test multiple correlated subqueries
549    /// See subqueries.rs where_in_multiple()
550    #[test]
551    fn multiple_subqueries() -> Result<()> {
552        let orders = Arc::new(
553            LogicalPlanBuilder::from(scan_tpch_table("orders"))
554                .filter(
555                    col("orders.o_custkey")
556                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
557                )?
558                .project(vec![col("orders.o_custkey")])?
559                .build()?,
560        );
561        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
562            .filter(
563                in_subquery(col("customer.c_custkey"), Arc::clone(&orders))
564                    .and(in_subquery(col("customer.c_custkey"), orders)),
565            )?
566            .project(vec![col("customer.c_custkey")])?
567            .build()?;
568        debug!("plan to optimize:\n{}", plan.display_indent());
569
570        assert_optimized_plan_equal!(
571                plan,
572                @r###"
573        Projection: customer.c_custkey [c_custkey:Int64]
574          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]
575            LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
576              TableScan: customer [c_custkey:Int64, c_name:Utf8]
577              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
578                Projection: orders.o_custkey [o_custkey:Int64]
579                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
580            SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]
581              Projection: orders.o_custkey [o_custkey:Int64]
582                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
583        "###    
584        )
585    }
586
587    /// Test recursive correlated subqueries
588    /// See subqueries.rs where_in_recursive()
589    #[test]
590    fn recursive_subqueries() -> Result<()> {
591        let lineitem = Arc::new(
592            LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
593                .filter(
594                    col("lineitem.l_orderkey")
595                        .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
596                )?
597                .project(vec![col("lineitem.l_orderkey")])?
598                .build()?,
599        );
600
601        let orders = Arc::new(
602            LogicalPlanBuilder::from(scan_tpch_table("orders"))
603                .filter(
604                    in_subquery(col("orders.o_orderkey"), lineitem).and(
605                        col("orders.o_custkey")
606                            .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
607                    ),
608                )?
609                .project(vec![col("orders.o_custkey")])?
610                .build()?,
611        );
612
613        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
614            .filter(in_subquery(col("customer.c_custkey"), orders))?
615            .project(vec![col("customer.c_custkey")])?
616            .build()?;
617
618        assert_optimized_plan_equal!(
619            plan,
620            @r"
621        Projection: customer.c_custkey [c_custkey:Int64]
622          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]
623            TableScan: customer [c_custkey:Int64, c_name:Utf8]
624            SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]
625              Projection: orders.o_custkey [o_custkey:Int64]
626                LeftSemi Join:  Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
627                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
628                  SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]
629                    Projection: lineitem.l_orderkey [l_orderkey:Int64]
630                      TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]
631        "
632        )
633    }
634
635    /// Test for correlated IN subquery filter with additional subquery filters
636    #[test]
637    fn in_subquery_with_subquery_filters() -> Result<()> {
638        let sq = Arc::new(
639            LogicalPlanBuilder::from(scan_tpch_table("orders"))
640                .filter(
641                    out_ref_col(DataType::Int64, "customer.c_custkey")
642                        .eq(col("orders.o_custkey"))
643                        .and(col("o_orderkey").eq(lit(1))),
644                )?
645                .project(vec![col("orders.o_custkey")])?
646                .build()?,
647        );
648
649        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
650            .filter(in_subquery(col("customer.c_custkey"), sq))?
651            .project(vec![col("customer.c_custkey")])?
652            .build()?;
653
654        assert_optimized_plan_equal!(
655            plan,
656            @r"
657        Projection: customer.c_custkey [c_custkey:Int64]
658          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
659            TableScan: customer [c_custkey:Int64, c_name:Utf8]
660            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
661              Projection: orders.o_custkey [o_custkey:Int64]
662                Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
663                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
664        "
665        )
666    }
667
668    /// Test for correlated IN subquery with no columns in schema
669    #[test]
670    fn in_subquery_no_cols() -> Result<()> {
671        let sq = Arc::new(
672            LogicalPlanBuilder::from(scan_tpch_table("orders"))
673                .filter(
674                    out_ref_col(DataType::Int64, "customer.c_custkey")
675                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
676                )?
677                .project(vec![col("orders.o_custkey")])?
678                .build()?,
679        );
680
681        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
682            .filter(in_subquery(col("customer.c_custkey"), sq))?
683            .project(vec![col("customer.c_custkey")])?
684            .build()?;
685
686        assert_optimized_plan_equal!(
687            plan,
688            @r"
689        Projection: customer.c_custkey [c_custkey:Int64]
690          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
691            TableScan: customer [c_custkey:Int64, c_name:Utf8]
692            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
693              Projection: orders.o_custkey [o_custkey:Int64]
694                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
695        "
696        )
697    }
698
699    /// Test for IN subquery with both columns in schema
700    #[test]
701    fn in_subquery_with_no_correlated_cols() -> Result<()> {
702        let sq = Arc::new(
703            LogicalPlanBuilder::from(scan_tpch_table("orders"))
704                .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
705                .project(vec![col("orders.o_custkey")])?
706                .build()?,
707        );
708
709        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
710            .filter(in_subquery(col("customer.c_custkey"), sq))?
711            .project(vec![col("customer.c_custkey")])?
712            .build()?;
713
714        assert_optimized_plan_equal!(
715            plan,
716            @r"
717        Projection: customer.c_custkey [c_custkey:Int64]
718          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
719            TableScan: customer [c_custkey:Int64, c_name:Utf8]
720            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
721              Projection: orders.o_custkey [o_custkey:Int64]
722                Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
723                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
724        "
725        )
726    }
727
728    /// Test for correlated IN subquery not equal
729    #[test]
730    fn in_subquery_where_not_eq() -> Result<()> {
731        let sq = Arc::new(
732            LogicalPlanBuilder::from(scan_tpch_table("orders"))
733                .filter(
734                    out_ref_col(DataType::Int64, "customer.c_custkey")
735                        .not_eq(col("orders.o_custkey")),
736                )?
737                .project(vec![col("orders.o_custkey")])?
738                .build()?,
739        );
740
741        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
742            .filter(in_subquery(col("customer.c_custkey"), sq))?
743            .project(vec![col("customer.c_custkey")])?
744            .build()?;
745
746        assert_optimized_plan_equal!(
747            plan,
748            @r"
749        Projection: customer.c_custkey [c_custkey:Int64]
750          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
751            TableScan: customer [c_custkey:Int64, c_name:Utf8]
752            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
753              Projection: orders.o_custkey [o_custkey:Int64]
754                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
755        "
756        )
757    }
758
759    /// Test for correlated IN subquery less than
760    #[test]
761    fn in_subquery_where_less_than() -> Result<()> {
762        let sq = Arc::new(
763            LogicalPlanBuilder::from(scan_tpch_table("orders"))
764                .filter(
765                    out_ref_col(DataType::Int64, "customer.c_custkey")
766                        .lt(col("orders.o_custkey")),
767                )?
768                .project(vec![col("orders.o_custkey")])?
769                .build()?,
770        );
771
772        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
773            .filter(in_subquery(col("customer.c_custkey"), sq))?
774            .project(vec![col("customer.c_custkey")])?
775            .build()?;
776
777        assert_optimized_plan_equal!(
778            plan,
779            @r"
780        Projection: customer.c_custkey [c_custkey:Int64]
781          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
782            TableScan: customer [c_custkey:Int64, c_name:Utf8]
783            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
784              Projection: orders.o_custkey [o_custkey:Int64]
785                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
786        "
787        )
788    }
789
790    /// Test for correlated IN subquery filter with subquery disjunction
791    #[test]
792    fn in_subquery_with_subquery_disjunction() -> Result<()> {
793        let sq = Arc::new(
794            LogicalPlanBuilder::from(scan_tpch_table("orders"))
795                .filter(
796                    out_ref_col(DataType::Int64, "customer.c_custkey")
797                        .eq(col("orders.o_custkey"))
798                        .or(col("o_orderkey").eq(lit(1))),
799                )?
800                .project(vec![col("orders.o_custkey")])?
801                .build()?,
802        );
803
804        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
805            .filter(in_subquery(col("customer.c_custkey"), sq))?
806            .project(vec![col("customer.c_custkey")])?
807            .build()?;
808
809        assert_optimized_plan_equal!(
810            plan,
811            @r"
812        Projection: customer.c_custkey [c_custkey:Int64]
813          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8]
814            TableScan: customer [c_custkey:Int64, c_name:Utf8]
815            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]
816              Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]
817                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
818        "
819        )
820    }
821
822    /// Test for correlated IN without projection
823    #[test]
824    fn in_subquery_no_projection() -> Result<()> {
825        let sq = Arc::new(
826            LogicalPlanBuilder::from(scan_tpch_table("orders"))
827                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
828                .build()?,
829        );
830
831        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
832            .filter(in_subquery(col("customer.c_custkey"), sq))?
833            .project(vec![col("customer.c_custkey")])?
834            .build()?;
835
836        // Maybe okay if the table only has a single column?
837        let expected = "Invalid (non-executable) plan after Analyzer\
838        \ncaused by\
839        \nError during planning: InSubquery should only return one column, but found 4";
840        assert_analyzer_check_err(vec![], plan, expected);
841
842        Ok(())
843    }
844
845    /// Test for correlated IN subquery join on expression
846    #[test]
847    fn in_subquery_join_expr() -> Result<()> {
848        let sq = Arc::new(
849            LogicalPlanBuilder::from(scan_tpch_table("orders"))
850                .filter(
851                    out_ref_col(DataType::Int64, "customer.c_custkey")
852                        .eq(col("orders.o_custkey")),
853                )?
854                .project(vec![col("orders.o_custkey")])?
855                .build()?,
856        );
857
858        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
859            .filter(in_subquery(col("customer.c_custkey").add(lit(1)), sq))?
860            .project(vec![col("customer.c_custkey")])?
861            .build()?;
862
863        assert_optimized_plan_equal!(
864            plan,
865            @r"
866        Projection: customer.c_custkey [c_custkey:Int64]
867          LeftSemi Join:  Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
868            TableScan: customer [c_custkey:Int64, c_name:Utf8]
869            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
870              Projection: orders.o_custkey [o_custkey:Int64]
871                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
872        "
873        )
874    }
875
876    /// Test for correlated IN expressions
877    #[test]
878    fn in_subquery_project_expr() -> Result<()> {
879        let sq = Arc::new(
880            LogicalPlanBuilder::from(scan_tpch_table("orders"))
881                .filter(
882                    out_ref_col(DataType::Int64, "customer.c_custkey")
883                        .eq(col("orders.o_custkey")),
884                )?
885                .project(vec![col("orders.o_custkey").add(lit(1))])?
886                .build()?,
887        );
888
889        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
890            .filter(in_subquery(col("customer.c_custkey"), sq))?
891            .project(vec![col("customer.c_custkey")])?
892            .build()?;
893
894        assert_optimized_plan_equal!(
895            plan,
896            @r"
897        Projection: customer.c_custkey [c_custkey:Int64]
898          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.orders.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
899            TableScan: customer [c_custkey:Int64, c_name:Utf8]
900            SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]
901              Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]
902                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
903        "
904        )
905    }
906
907    /// Test for correlated IN subquery multiple projected columns
908    #[test]
909    fn in_subquery_multi_col() -> Result<()> {
910        let sq = Arc::new(
911            LogicalPlanBuilder::from(scan_tpch_table("orders"))
912                .filter(
913                    out_ref_col(DataType::Int64, "customer.c_custkey")
914                        .eq(col("orders.o_custkey")),
915                )?
916                .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])?
917                .build()?,
918        );
919
920        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
921            .filter(
922                in_subquery(col("customer.c_custkey"), sq)
923                    .and(col("c_custkey").eq(lit(1))),
924            )?
925            .project(vec![col("customer.c_custkey")])?
926            .build()?;
927
928        let expected = "Invalid (non-executable) plan after Analyzer\
929        \ncaused by\
930        \nError during planning: InSubquery should only return one column";
931        assert_analyzer_check_err(vec![], plan, expected);
932
933        Ok(())
934    }
935
936    /// Test for correlated IN subquery filter with additional filters
937    #[test]
938    fn should_support_additional_filters() -> Result<()> {
939        let sq = Arc::new(
940            LogicalPlanBuilder::from(scan_tpch_table("orders"))
941                .filter(
942                    out_ref_col(DataType::Int64, "customer.c_custkey")
943                        .eq(col("orders.o_custkey")),
944                )?
945                .project(vec![col("orders.o_custkey")])?
946                .build()?,
947        );
948
949        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
950            .filter(
951                in_subquery(col("customer.c_custkey"), sq)
952                    .and(col("c_custkey").eq(lit(1))),
953            )?
954            .project(vec![col("customer.c_custkey")])?
955            .build()?;
956
957        assert_optimized_plan_equal!(
958            plan,
959            @r"
960        Projection: customer.c_custkey [c_custkey:Int64]
961          Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
962            LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
963              TableScan: customer [c_custkey:Int64, c_name:Utf8]
964              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
965                Projection: orders.o_custkey [o_custkey:Int64]
966                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
967        "
968        )
969    }
970
971    /// Test for correlated IN subquery filter
972    #[test]
973    fn in_subquery_correlated() -> Result<()> {
974        let sq = Arc::new(
975            LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
976                .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
977                .project(vec![col("c")])?
978                .build()?,
979        );
980
981        let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
982            .filter(in_subquery(col("c"), sq))?
983            .project(vec![col("test.b")])?
984            .build()?;
985
986        assert_optimized_plan_equal!(
987            plan,
988            @r"
989        Projection: test.b [b:UInt32]
990          LeftSemi Join:  Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
991            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
992            SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]
993              Projection: sq.c, sq.a [c:UInt32, a:UInt32]
994                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
995        "
996        )
997    }
998
999    /// Test for single IN subquery filter
1000    #[test]
1001    fn in_subquery_simple() -> Result<()> {
1002        let table_scan = test_table_scan()?;
1003        let plan = LogicalPlanBuilder::from(table_scan)
1004            .filter(in_subquery(col("c"), test_subquery_with_name("sq")?))?
1005            .project(vec![col("test.b")])?
1006            .build()?;
1007
1008        assert_optimized_plan_equal!(
1009            plan,
1010            @r"
1011        Projection: test.b [b:UInt32]
1012          LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1013            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1014            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1015              Projection: sq.c [c:UInt32]
1016                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1017        "
1018        )
1019    }
1020
1021    /// Test for single NOT IN subquery filter
1022    #[test]
1023    fn not_in_subquery_simple() -> Result<()> {
1024        let table_scan = test_table_scan()?;
1025        let plan = LogicalPlanBuilder::from(table_scan)
1026            .filter(not_in_subquery(col("c"), test_subquery_with_name("sq")?))?
1027            .project(vec![col("test.b")])?
1028            .build()?;
1029
1030        assert_optimized_plan_equal!(
1031            plan,
1032            @r"
1033        Projection: test.b [b:UInt32]
1034          LeftAnti Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1035            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1036            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1037              Projection: sq.c [c:UInt32]
1038                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1039        "
1040        )
1041    }
1042
1043    #[test]
1044    fn wrapped_not_in_subquery() -> Result<()> {
1045        let table_scan = test_table_scan()?;
1046        let plan = LogicalPlanBuilder::from(table_scan)
1047            .filter(not(in_subquery(col("c"), test_subquery_with_name("sq")?)))?
1048            .project(vec![col("test.b")])?
1049            .build()?;
1050
1051        assert_optimized_plan_equal!(
1052            plan,
1053            @r"
1054        Projection: test.b [b:UInt32]
1055          LeftAnti Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1056            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1057            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1058              Projection: sq.c [c:UInt32]
1059                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1060        "
1061        )
1062    }
1063
1064    #[test]
1065    fn wrapped_not_not_in_subquery() -> Result<()> {
1066        let table_scan = test_table_scan()?;
1067        let plan = LogicalPlanBuilder::from(table_scan)
1068            .filter(not(not_in_subquery(
1069                col("c"),
1070                test_subquery_with_name("sq")?,
1071            )))?
1072            .project(vec![col("test.b")])?
1073            .build()?;
1074
1075        assert_optimized_plan_equal!(
1076            plan,
1077            @r"
1078        Projection: test.b [b:UInt32]
1079          LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1080            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1081            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1082              Projection: sq.c [c:UInt32]
1083                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1084        "
1085        )
1086    }
1087
1088    #[test]
1089    fn in_subquery_both_side_expr() -> Result<()> {
1090        let table_scan = test_table_scan()?;
1091        let subquery_scan = test_table_scan_with_name("sq")?;
1092
1093        let subquery = LogicalPlanBuilder::from(subquery_scan)
1094            .project(vec![col("c") * lit(2u32)])?
1095            .build()?;
1096
1097        let plan = LogicalPlanBuilder::from(table_scan)
1098            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
1099            .project(vec![col("test.b")])?
1100            .build()?;
1101
1102        assert_optimized_plan_equal!(
1103            plan,
1104            @r"
1105        Projection: test.b [b:UInt32]
1106          LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
1107            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1108            SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32]
1109              Projection: sq.c * UInt32(2) [sq.c * UInt32(2):UInt32]
1110                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1111        "
1112        )
1113    }
1114
1115    #[test]
1116    fn in_subquery_join_filter_and_inner_filter() -> Result<()> {
1117        let table_scan = test_table_scan()?;
1118        let subquery_scan = test_table_scan_with_name("sq")?;
1119
1120        let subquery = LogicalPlanBuilder::from(subquery_scan)
1121            .filter(
1122                out_ref_col(DataType::UInt32, "test.a")
1123                    .eq(col("sq.a"))
1124                    .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))),
1125            )?
1126            .project(vec![col("c") * lit(2u32)])?
1127            .build()?;
1128
1129        let plan = LogicalPlanBuilder::from(table_scan)
1130            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
1131            .project(vec![col("test.b")])?
1132            .build()?;
1133
1134        assert_optimized_plan_equal!(
1135            plan,
1136            @r"
1137        Projection: test.b [b:UInt32]
1138          LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
1139            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1140            SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32]
1141              Projection: sq.c * UInt32(2), sq.a [sq.c * UInt32(2):UInt32, a:UInt32]
1142                Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]
1143                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1144        "
1145        )
1146    }
1147
1148    #[test]
1149    fn in_subquery_multi_project_subquery_cols() -> Result<()> {
1150        let table_scan = test_table_scan()?;
1151        let subquery_scan = test_table_scan_with_name("sq")?;
1152
1153        let subquery = LogicalPlanBuilder::from(subquery_scan)
1154            .filter(
1155                out_ref_col(DataType::UInt32, "test.a")
1156                    .add(out_ref_col(DataType::UInt32, "test.b"))
1157                    .eq(col("sq.a").add(col("sq.b")))
1158                    .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))),
1159            )?
1160            .project(vec![col("c") * lit(2u32)])?
1161            .build()?;
1162
1163        let plan = LogicalPlanBuilder::from(table_scan)
1164            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
1165            .project(vec![col("test.b")])?
1166            .build()?;
1167
1168        assert_optimized_plan_equal!(
1169            plan,
1170            @r"
1171        Projection: test.b [b:UInt32]
1172          LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32]
1173            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1174            SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32]
1175              Projection: sq.c * UInt32(2), sq.a, sq.b [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32]
1176                Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]
1177                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1178        "
1179        )
1180    }
1181
1182    #[test]
1183    fn two_in_subquery_with_outer_filter() -> Result<()> {
1184        let table_scan = test_table_scan()?;
1185        let subquery_scan1 = test_table_scan_with_name("sq1")?;
1186        let subquery_scan2 = test_table_scan_with_name("sq2")?;
1187
1188        let subquery1 = LogicalPlanBuilder::from(subquery_scan1)
1189            .filter(out_ref_col(DataType::UInt32, "test.a").gt(col("sq1.a")))?
1190            .project(vec![col("c") * lit(2u32)])?
1191            .build()?;
1192
1193        let subquery2 = LogicalPlanBuilder::from(subquery_scan2)
1194            .filter(out_ref_col(DataType::UInt32, "test.a").gt(col("sq2.a")))?
1195            .project(vec![col("c") * lit(2u32)])?
1196            .build()?;
1197
1198        let plan = LogicalPlanBuilder::from(table_scan)
1199            .filter(
1200                in_subquery(col("c") + lit(1u32), Arc::new(subquery1)).and(
1201                    in_subquery(col("c") * lit(2u32), Arc::new(subquery2))
1202                        .and(col("test.c").gt(lit(1u32))),
1203                ),
1204            )?
1205            .project(vec![col("test.b")])?
1206            .build()?;
1207
1208        assert_optimized_plan_equal!(
1209            plan,
1210            @r"
1211        Projection: test.b [b:UInt32]
1212          Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]
1213            LeftSemi Join:  Filter: test.c * UInt32(2) = __correlated_sq_2.sq2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]
1214              LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.sq1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
1215                TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1216                SubqueryAlias: __correlated_sq_1 [sq1.c * UInt32(2):UInt32, a:UInt32]
1217                  Projection: sq1.c * UInt32(2), sq1.a [sq1.c * UInt32(2):UInt32, a:UInt32]
1218                    TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]
1219              SubqueryAlias: __correlated_sq_2 [sq2.c * UInt32(2):UInt32, a:UInt32]
1220                Projection: sq2.c * UInt32(2), sq2.a [sq2.c * UInt32(2):UInt32, a:UInt32]
1221                  TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]
1222        "
1223        )
1224    }
1225
1226    #[test]
1227    fn in_subquery_with_same_table() -> Result<()> {
1228        let outer_scan = test_table_scan()?;
1229        let subquery_scan = test_table_scan()?;
1230        let subquery = LogicalPlanBuilder::from(subquery_scan)
1231            .filter(col("test.a").gt(col("test.b")))?
1232            .project(vec![col("c")])?
1233            .build()?;
1234
1235        let plan = LogicalPlanBuilder::from(outer_scan)
1236            .filter(in_subquery(col("test.a"), Arc::new(subquery)))?
1237            .project(vec![col("test.b")])?
1238            .build()?;
1239
1240        // Subquery and outer query refer to the same table.
1241        assert_optimized_plan_equal!(
1242            plan,
1243            @r"
1244        Projection: test.b [b:UInt32]
1245          LeftSemi Join:  Filter: test.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1246            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1247            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1248              Projection: test.c [c:UInt32]
1249                Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]
1250                  TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1251        "
1252        )
1253    }
1254
1255    /// Test for multiple exists subqueries in the same filter expression
1256    #[test]
1257    fn multiple_exists_subqueries() -> Result<()> {
1258        let orders = Arc::new(
1259            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1260                .filter(
1261                    col("orders.o_custkey")
1262                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
1263                )?
1264                .project(vec![col("orders.o_custkey")])?
1265                .build()?,
1266        );
1267
1268        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1269            .filter(exists(Arc::clone(&orders)).and(exists(orders)))?
1270            .project(vec![col("customer.c_custkey")])?
1271            .build()?;
1272
1273        assert_optimized_plan_equal!(
1274            plan,
1275            @r"
1276        Projection: customer.c_custkey [c_custkey:Int64]
1277          LeftSemi Join:  Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]
1278            LeftSemi Join:  Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]
1279              TableScan: customer [c_custkey:Int64, c_name:Utf8]
1280              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1281                Projection: orders.o_custkey [o_custkey:Int64]
1282                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1283            SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]
1284              Projection: orders.o_custkey [o_custkey:Int64]
1285                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1286        "
1287        )
1288    }
1289
1290    /// Test recursive correlated subqueries
1291    #[test]
1292    fn recursive_exists_subqueries() -> Result<()> {
1293        let lineitem = Arc::new(
1294            LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
1295                .filter(
1296                    col("lineitem.l_orderkey")
1297                        .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
1298                )?
1299                .project(vec![col("lineitem.l_orderkey")])?
1300                .build()?,
1301        );
1302
1303        let orders = Arc::new(
1304            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1305                .filter(
1306                    exists(lineitem).and(
1307                        col("orders.o_custkey")
1308                            .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
1309                    ),
1310                )?
1311                .project(vec![col("orders.o_custkey")])?
1312                .build()?,
1313        );
1314
1315        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1316            .filter(exists(orders))?
1317            .project(vec![col("customer.c_custkey")])?
1318            .build()?;
1319
1320        assert_optimized_plan_equal!(
1321            plan,
1322            @r"
1323        Projection: customer.c_custkey [c_custkey:Int64]
1324          LeftSemi Join:  Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]
1325            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1326            SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]
1327              Projection: orders.o_custkey [o_custkey:Int64]
1328                LeftSemi Join:  Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1329                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1330                  SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]
1331                    Projection: lineitem.l_orderkey [l_orderkey:Int64]
1332                      TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]
1333        "
1334        )
1335    }
1336
1337    /// Test for correlated exists subquery filter with additional subquery filters
1338    #[test]
1339    fn exists_subquery_with_subquery_filters() -> Result<()> {
1340        let sq = Arc::new(
1341            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1342                .filter(
1343                    out_ref_col(DataType::Int64, "customer.c_custkey")
1344                        .eq(col("orders.o_custkey"))
1345                        .and(col("o_orderkey").eq(lit(1))),
1346                )?
1347                .project(vec![col("orders.o_custkey")])?
1348                .build()?,
1349        );
1350
1351        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1352            .filter(exists(sq))?
1353            .project(vec![col("customer.c_custkey")])?
1354            .build()?;
1355
1356        assert_optimized_plan_equal!(
1357            plan,
1358            @r"
1359        Projection: customer.c_custkey [c_custkey:Int64]
1360          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1361            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1362            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1363              Projection: orders.o_custkey [o_custkey:Int64]
1364                Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1365                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1366        "
1367        )
1368    }
1369
1370    #[test]
1371    fn exists_subquery_no_cols() -> Result<()> {
1372        let sq = Arc::new(
1373            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1374                .filter(out_ref_col(DataType::Int64, "customer.c_custkey").eq(lit(1u32)))?
1375                .project(vec![col("orders.o_custkey")])?
1376                .build()?,
1377        );
1378
1379        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1380            .filter(exists(sq))?
1381            .project(vec![col("customer.c_custkey")])?
1382            .build()?;
1383
1384        // Other rule will pushdown `customer.c_custkey = 1`,
1385        assert_optimized_plan_equal!(
1386            plan,
1387            @r"
1388        Projection: customer.c_custkey [c_custkey:Int64]
1389          LeftSemi Join:  Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8]
1390            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1391            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1392              Projection: orders.o_custkey [o_custkey:Int64]
1393                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1394        "
1395        )
1396    }
1397
1398    /// Test for exists subquery with both columns in schema
1399    #[test]
1400    fn exists_subquery_with_no_correlated_cols() -> Result<()> {
1401        let sq = Arc::new(
1402            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1403                .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
1404                .project(vec![col("orders.o_custkey")])?
1405                .build()?,
1406        );
1407
1408        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1409            .filter(exists(sq))?
1410            .project(vec![col("customer.c_custkey")])?
1411            .build()?;
1412
1413        assert_optimized_plan_equal!(
1414            plan,
1415            @r"
1416        Projection: customer.c_custkey [c_custkey:Int64]
1417          LeftSemi Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8]
1418            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1419            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1420              Projection: orders.o_custkey [o_custkey:Int64]
1421                Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1422                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1423        "
1424        )
1425    }
1426
1427    /// Test for correlated exists subquery not equal
1428    #[test]
1429    fn exists_subquery_where_not_eq() -> Result<()> {
1430        let sq = Arc::new(
1431            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1432                .filter(
1433                    out_ref_col(DataType::Int64, "customer.c_custkey")
1434                        .not_eq(col("orders.o_custkey")),
1435                )?
1436                .project(vec![col("orders.o_custkey")])?
1437                .build()?,
1438        );
1439
1440        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1441            .filter(exists(sq))?
1442            .project(vec![col("customer.c_custkey")])?
1443            .build()?;
1444
1445        assert_optimized_plan_equal!(
1446            plan,
1447            @r"
1448        Projection: customer.c_custkey [c_custkey:Int64]
1449          LeftSemi Join:  Filter: customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1450            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1451            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1452              Projection: orders.o_custkey [o_custkey:Int64]
1453                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1454        "
1455        )
1456    }
1457
1458    /// Test for correlated exists subquery less than
1459    #[test]
1460    fn exists_subquery_where_less_than() -> Result<()> {
1461        let sq = Arc::new(
1462            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1463                .filter(
1464                    out_ref_col(DataType::Int64, "customer.c_custkey")
1465                        .lt(col("orders.o_custkey")),
1466                )?
1467                .project(vec![col("orders.o_custkey")])?
1468                .build()?,
1469        );
1470
1471        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1472            .filter(exists(sq))?
1473            .project(vec![col("customer.c_custkey")])?
1474            .build()?;
1475
1476        assert_optimized_plan_equal!(
1477            plan,
1478            @r"
1479        Projection: customer.c_custkey [c_custkey:Int64]
1480          LeftSemi Join:  Filter: customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1481            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1482            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1483              Projection: orders.o_custkey [o_custkey:Int64]
1484                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1485        "
1486        )
1487    }
1488
1489    /// Test for correlated exists subquery filter with subquery disjunction
1490    #[test]
1491    fn exists_subquery_with_subquery_disjunction() -> Result<()> {
1492        let sq = Arc::new(
1493            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1494                .filter(
1495                    out_ref_col(DataType::Int64, "customer.c_custkey")
1496                        .eq(col("orders.o_custkey"))
1497                        .or(col("o_orderkey").eq(lit(1))),
1498                )?
1499                .project(vec![col("orders.o_custkey")])?
1500                .build()?,
1501        );
1502
1503        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1504            .filter(exists(sq))?
1505            .project(vec![col("customer.c_custkey")])?
1506            .build()?;
1507
1508        assert_optimized_plan_equal!(
1509            plan,
1510            @r"
1511        Projection: customer.c_custkey [c_custkey:Int64]
1512          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
1513            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1514            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]
1515              Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]
1516                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1517        "
1518        )
1519    }
1520
1521    /// Test for correlated exists without projection
1522    #[test]
1523    fn exists_subquery_no_projection() -> Result<()> {
1524        let sq = Arc::new(
1525            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1526                .filter(
1527                    out_ref_col(DataType::Int64, "customer.c_custkey")
1528                        .eq(col("orders.o_custkey")),
1529                )?
1530                .build()?,
1531        );
1532
1533        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1534            .filter(exists(sq))?
1535            .project(vec![col("customer.c_custkey")])?
1536            .build()?;
1537
1538        assert_optimized_plan_equal!(
1539            plan,
1540            @r"
1541        Projection: customer.c_custkey [c_custkey:Int64]
1542          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1543            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1544            SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1545              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1546        "
1547        )
1548    }
1549
1550    /// Test for correlated exists expressions
1551    #[test]
1552    fn exists_subquery_project_expr() -> Result<()> {
1553        let sq = Arc::new(
1554            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1555                .filter(
1556                    out_ref_col(DataType::Int64, "customer.c_custkey")
1557                        .eq(col("orders.o_custkey")),
1558                )?
1559                .project(vec![col("orders.o_custkey").add(lit(1))])?
1560                .build()?,
1561        );
1562
1563        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1564            .filter(exists(sq))?
1565            .project(vec![col("customer.c_custkey")])?
1566            .build()?;
1567
1568        assert_optimized_plan_equal!(
1569            plan,
1570            @r"
1571        Projection: customer.c_custkey [c_custkey:Int64]
1572          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1573            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1574            SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]
1575              Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]
1576                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1577        "
1578        )
1579    }
1580
1581    /// Test for correlated exists subquery filter with additional filters
1582    #[test]
1583    fn exists_subquery_should_support_additional_filters() -> Result<()> {
1584        let sq = Arc::new(
1585            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1586                .filter(
1587                    out_ref_col(DataType::Int64, "customer.c_custkey")
1588                        .eq(col("orders.o_custkey")),
1589                )?
1590                .project(vec![col("orders.o_custkey")])?
1591                .build()?,
1592        );
1593        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1594            .filter(exists(sq).and(col("c_custkey").eq(lit(1))))?
1595            .project(vec![col("customer.c_custkey")])?
1596            .build()?;
1597
1598        assert_optimized_plan_equal!(
1599            plan,
1600            @r"
1601        Projection: customer.c_custkey [c_custkey:Int64]
1602          Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
1603            LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1604              TableScan: customer [c_custkey:Int64, c_name:Utf8]
1605              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1606                Projection: orders.o_custkey [o_custkey:Int64]
1607                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1608        "
1609        )
1610    }
1611
1612    /// Test for correlated exists subquery filter with disjunctions
1613    #[test]
1614    fn exists_subquery_disjunction() -> Result<()> {
1615        let sq = Arc::new(
1616            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1617                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
1618                .project(vec![col("orders.o_custkey")])?
1619                .build()?,
1620        );
1621
1622        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1623            .filter(exists(sq).or(col("customer.c_custkey").eq(lit(1))))?
1624            .project(vec![col("customer.c_custkey")])?
1625            .build()?;
1626
1627        assert_optimized_plan_equal!(
1628            plan,
1629            @r"
1630        Projection: customer.c_custkey [c_custkey:Int64]
1631          Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1632            LeftMark Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1633              TableScan: customer [c_custkey:Int64, c_name:Utf8]
1634              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1635                Projection: orders.o_custkey [o_custkey:Int64]
1636                  Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1637                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1638        "
1639        )
1640    }
1641
1642    /// Test for correlated EXISTS subquery filter
1643    #[test]
1644    fn exists_subquery_correlated() -> Result<()> {
1645        let sq = Arc::new(
1646            LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
1647                .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
1648                .project(vec![col("c")])?
1649                .build()?,
1650        );
1651
1652        let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
1653            .filter(exists(sq))?
1654            .project(vec![col("test.c")])?
1655            .build()?;
1656
1657        assert_optimized_plan_equal!(
1658            plan,
1659            @r"
1660        Projection: test.c [c:UInt32]
1661          LeftSemi Join:  Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
1662            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1663            SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]
1664              Projection: sq.c, sq.a [c:UInt32, a:UInt32]
1665                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1666        "
1667        )
1668    }
1669
1670    /// Test for single exists subquery filter
1671    #[test]
1672    fn exists_subquery_simple() -> Result<()> {
1673        let table_scan = test_table_scan()?;
1674        let plan = LogicalPlanBuilder::from(table_scan)
1675            .filter(exists(test_subquery_with_name("sq")?))?
1676            .project(vec![col("test.b")])?
1677            .build()?;
1678
1679        assert_optimized_plan_equal!(
1680            plan,
1681            @r"
1682        Projection: test.b [b:UInt32]
1683          LeftSemi Join:  Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]
1684            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1685            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1686              Projection: sq.c [c:UInt32]
1687                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1688        "
1689        )
1690    }
1691
1692    /// Test for single NOT exists subquery filter
1693    #[test]
1694    fn not_exists_subquery_simple() -> Result<()> {
1695        let table_scan = test_table_scan()?;
1696        let plan = LogicalPlanBuilder::from(table_scan)
1697            .filter(not_exists(test_subquery_with_name("sq")?))?
1698            .project(vec![col("test.b")])?
1699            .build()?;
1700
1701        assert_optimized_plan_equal!(
1702            plan,
1703            @r"
1704        Projection: test.b [b:UInt32]
1705          LeftAnti Join:  Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]
1706            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1707            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1708              Projection: sq.c [c:UInt32]
1709                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1710        "
1711        )
1712    }
1713
1714    #[test]
1715    fn two_exists_subquery_with_outer_filter() -> Result<()> {
1716        let table_scan = test_table_scan()?;
1717        let subquery_scan1 = test_table_scan_with_name("sq1")?;
1718        let subquery_scan2 = test_table_scan_with_name("sq2")?;
1719
1720        let subquery1 = LogicalPlanBuilder::from(subquery_scan1)
1721            .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq1.a")))?
1722            .project(vec![col("c")])?
1723            .build()?;
1724
1725        let subquery2 = LogicalPlanBuilder::from(subquery_scan2)
1726            .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq2.a")))?
1727            .project(vec![col("c")])?
1728            .build()?;
1729
1730        let plan = LogicalPlanBuilder::from(table_scan)
1731            .filter(
1732                exists(Arc::new(subquery1))
1733                    .and(exists(Arc::new(subquery2)).and(col("test.c").gt(lit(1u32)))),
1734            )?
1735            .project(vec![col("test.b")])?
1736            .build()?;
1737
1738        assert_optimized_plan_equal!(
1739            plan,
1740            @r"
1741        Projection: test.b [b:UInt32]
1742          Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]
1743            LeftSemi Join:  Filter: test.a = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]
1744              LeftSemi Join:  Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
1745                TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1746                SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]
1747                  Projection: sq1.c, sq1.a [c:UInt32, a:UInt32]
1748                    TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]
1749              SubqueryAlias: __correlated_sq_2 [c:UInt32, a:UInt32]
1750                Projection: sq2.c, sq2.a [c:UInt32, a:UInt32]
1751                  TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]
1752        "
1753        )
1754    }
1755
1756    #[test]
1757    fn exists_subquery_expr_filter() -> Result<()> {
1758        let table_scan = test_table_scan()?;
1759        let subquery_scan = test_table_scan_with_name("sq")?;
1760        let subquery = LogicalPlanBuilder::from(subquery_scan)
1761            .filter(
1762                (lit(1u32) + col("sq.a"))
1763                    .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
1764            )?
1765            .project(vec![lit(1u32)])?
1766            .build()?;
1767        let plan = LogicalPlanBuilder::from(table_scan)
1768            .filter(exists(Arc::new(subquery)))?
1769            .project(vec![col("test.b")])?
1770            .build()?;
1771
1772        assert_optimized_plan_equal!(
1773            plan,
1774            @r"
1775        Projection: test.b [b:UInt32]
1776          LeftSemi Join:  Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
1777            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1778            SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, a:UInt32]
1779              Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32]
1780                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1781        "
1782        )
1783    }
1784
1785    #[test]
1786    fn exists_subquery_with_same_table() -> Result<()> {
1787        let outer_scan = test_table_scan()?;
1788        let subquery_scan = test_table_scan()?;
1789        let subquery = LogicalPlanBuilder::from(subquery_scan)
1790            .filter(col("test.a").gt(col("test.b")))?
1791            .project(vec![col("c")])?
1792            .build()?;
1793
1794        let plan = LogicalPlanBuilder::from(outer_scan)
1795            .filter(exists(Arc::new(subquery)))?
1796            .project(vec![col("test.b")])?
1797            .build()?;
1798
1799        // Subquery and outer query refer to the same table.
1800        assert_optimized_plan_equal!(
1801            plan,
1802            @r"
1803        Projection: test.b [b:UInt32]
1804          LeftSemi Join:  Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]
1805            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1806            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1807              Projection: test.c [c:UInt32]
1808                Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]
1809                  TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1810        "
1811        )
1812    }
1813
1814    #[test]
1815    fn exists_distinct_subquery() -> Result<()> {
1816        let table_scan = test_table_scan()?;
1817        let subquery_scan = test_table_scan_with_name("sq")?;
1818        let subquery = LogicalPlanBuilder::from(subquery_scan)
1819            .filter(
1820                (lit(1u32) + col("sq.a"))
1821                    .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
1822            )?
1823            .project(vec![col("sq.c")])?
1824            .distinct()?
1825            .build()?;
1826        let plan = LogicalPlanBuilder::from(table_scan)
1827            .filter(exists(Arc::new(subquery)))?
1828            .project(vec![col("test.b")])?
1829            .build()?;
1830
1831        assert_optimized_plan_equal!(
1832            plan,
1833            @r"
1834        Projection: test.b [b:UInt32]
1835          LeftSemi Join:  Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
1836            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1837            SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]
1838              Distinct: [c:UInt32, a:UInt32]
1839                Projection: sq.c, sq.a [c:UInt32, a:UInt32]
1840                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1841        "
1842        )
1843    }
1844
1845    #[test]
1846    fn exists_distinct_expr_subquery() -> Result<()> {
1847        let table_scan = test_table_scan()?;
1848        let subquery_scan = test_table_scan_with_name("sq")?;
1849        let subquery = LogicalPlanBuilder::from(subquery_scan)
1850            .filter(
1851                (lit(1u32) + col("sq.a"))
1852                    .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
1853            )?
1854            .project(vec![col("sq.b") + col("sq.c")])?
1855            .distinct()?
1856            .build()?;
1857        let plan = LogicalPlanBuilder::from(table_scan)
1858            .filter(exists(Arc::new(subquery)))?
1859            .project(vec![col("test.b")])?
1860            .build()?;
1861
1862        assert_optimized_plan_equal!(
1863            plan,
1864            @r"
1865        Projection: test.b [b:UInt32]
1866          LeftSemi Join:  Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
1867            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1868            SubqueryAlias: __correlated_sq_1 [sq.b + sq.c:UInt32, a:UInt32]
1869              Distinct: [sq.b + sq.c:UInt32, a:UInt32]
1870                Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32]
1871                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1872        "
1873        )
1874    }
1875
1876    #[test]
1877    fn exists_distinct_subquery_with_literal() -> Result<()> {
1878        let table_scan = test_table_scan()?;
1879        let subquery_scan = test_table_scan_with_name("sq")?;
1880        let subquery = LogicalPlanBuilder::from(subquery_scan)
1881            .filter(
1882                (lit(1u32) + col("sq.a"))
1883                    .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
1884            )?
1885            .project(vec![lit(1u32), col("sq.c")])?
1886            .distinct()?
1887            .build()?;
1888        let plan = LogicalPlanBuilder::from(table_scan)
1889            .filter(exists(Arc::new(subquery)))?
1890            .project(vec![col("test.b")])?
1891            .build()?;
1892
1893        assert_optimized_plan_equal!(
1894            plan,
1895            @r"
1896        Projection: test.b [b:UInt32]
1897          LeftSemi Join:  Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
1898            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1899            SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, c:UInt32, a:UInt32]
1900              Distinct: [UInt32(1):UInt32, c:UInt32, a:UInt32]
1901                Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32]
1902                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1903        "
1904        )
1905    }
1906
1907    #[test]
1908    fn exists_uncorrelated_unnest() -> Result<()> {
1909        let subquery_table_source = table_source(&Schema::new(vec![Field::new(
1910            "arr",
1911            DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
1912            true,
1913        )]));
1914        let subquery = LogicalPlanBuilder::scan_with_filters(
1915            "sq",
1916            subquery_table_source,
1917            None,
1918            vec![],
1919        )?
1920        .unnest_column("arr")?
1921        .build()?;
1922        let table_scan = test_table_scan()?;
1923        let plan = LogicalPlanBuilder::from(table_scan)
1924            .filter(exists(Arc::new(subquery)))?
1925            .project(vec![col("test.b")])?
1926            .build()?;
1927
1928        assert_optimized_plan_equal!(
1929            plan,
1930            @r#"
1931        Projection: test.b [b:UInt32]
1932          LeftSemi Join:  Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]
1933            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1934            SubqueryAlias: __correlated_sq_1 [arr:Int32;N]
1935              Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N]
1936                TableScan: sq [arr:List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]
1937        "#
1938        )
1939    }
1940
1941    #[test]
1942    fn exists_correlated_unnest() -> Result<()> {
1943        let table_scan = test_table_scan()?;
1944        let subquery_table_source = table_source(&Schema::new(vec![Field::new(
1945            "a",
1946            DataType::List(Arc::new(Field::new_list_field(DataType::UInt32, true))),
1947            true,
1948        )]));
1949        let subquery = LogicalPlanBuilder::scan_with_filters(
1950            "sq",
1951            subquery_table_source,
1952            None,
1953            vec![],
1954        )?
1955        .unnest_column("a")?
1956        .filter(col("a").eq(out_ref_col(DataType::UInt32, "test.b")))?
1957        .build()?;
1958        let plan = LogicalPlanBuilder::from(table_scan)
1959            .filter(exists(Arc::new(subquery)))?
1960            .project(vec![col("test.b")])?
1961            .build()?;
1962
1963        assert_optimized_plan_equal!(
1964            plan,
1965            @r#"
1966        Projection: test.b [b:UInt32]
1967          LeftSemi Join:  Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32]
1968            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1969            SubqueryAlias: __correlated_sq_1 [a:UInt32;N]
1970              Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N]
1971                TableScan: sq [a:List(Field { name: "item", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]
1972        "#
1973        )
1974    }
1975
1976    #[test]
1977    fn upper_case_ident() -> Result<()> {
1978        let fields = vec![
1979            Field::new("A", DataType::UInt32, false),
1980            Field::new("B", DataType::UInt32, false),
1981        ];
1982
1983        let schema = Schema::new(fields);
1984        let table_scan_a = table_scan(Some("\"TEST_A\""), &schema, None)?.build()?;
1985        let table_scan_b = table_scan(Some("\"TEST_B\""), &schema, None)?.build()?;
1986
1987        let subquery = LogicalPlanBuilder::from(table_scan_b)
1988            .filter(col("\"A\"").eq(out_ref_col(DataType::UInt32, "\"TEST_A\".\"A\"")))?
1989            .project(vec![lit(1)])?
1990            .build()?;
1991
1992        let plan = LogicalPlanBuilder::from(table_scan_a)
1993            .filter(exists(Arc::new(subquery)))?
1994            .project(vec![col("\"TEST_A\".\"B\"")])?
1995            .build()?;
1996
1997        assert_optimized_plan_equal!(
1998            plan,
1999            @r"
2000        Projection: TEST_A.B [B:UInt32]
2001          LeftSemi Join:  Filter: __correlated_sq_1.A = TEST_A.A [A:UInt32, B:UInt32]
2002            TableScan: TEST_A [A:UInt32, B:UInt32]
2003            SubqueryAlias: __correlated_sq_1 [Int32(1):Int32, A:UInt32]
2004              Projection: Int32(1), TEST_B.A [Int32(1):Int32, A:UInt32]
2005                TableScan: TEST_B [A:UInt32, B:UInt32]
2006        "
2007        )
2008    }
2009}