datafusion_optimizer/
decorrelate_predicate_subquery.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`DecorrelatePredicateSubquery`] converts `IN`/`EXISTS` subquery predicates to `SEMI`/`ANTI` joins
19use std::collections::BTreeSet;
20use std::ops::Deref;
21use std::sync::Arc;
22
23use crate::decorrelate::PullUpCorrelatedExpr;
24use crate::optimizer::ApplyOrder;
25use crate::utils::replace_qualified_name;
26use crate::{OptimizerConfig, OptimizerRule};
27
28use datafusion_common::alias::AliasGenerator;
29use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30use datafusion_common::{internal_err, plan_err, Column, Result};
31use datafusion_expr::expr::{Exists, InSubquery};
32use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
33use datafusion_expr::logical_plan::{JoinType, Subquery};
34use datafusion_expr::utils::{conjunction, expr_to_columns, split_conjunction_owned};
35use datafusion_expr::{
36    exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
37    LogicalPlan, LogicalPlanBuilder, Operator,
38};
39
40use log::debug;
41
42/// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins
43#[derive(Default, Debug)]
44pub struct DecorrelatePredicateSubquery {}
45
46impl DecorrelatePredicateSubquery {
47    #[allow(missing_docs)]
48    pub fn new() -> Self {
49        Self::default()
50    }
51}
52
53impl OptimizerRule for DecorrelatePredicateSubquery {
54    fn supports_rewrite(&self) -> bool {
55        true
56    }
57
58    fn rewrite(
59        &self,
60        plan: LogicalPlan,
61        config: &dyn OptimizerConfig,
62    ) -> Result<Transformed<LogicalPlan>> {
63        let plan = plan
64            .map_subqueries(|subquery| {
65                subquery.transform_down(|p| self.rewrite(p, config))
66            })?
67            .data;
68
69        let LogicalPlan::Filter(filter) = plan else {
70            return Ok(Transformed::no(plan));
71        };
72
73        if !has_subquery(&filter.predicate) {
74            return Ok(Transformed::no(LogicalPlan::Filter(filter)));
75        }
76
77        let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
78            split_conjunction_owned(filter.predicate)
79                .into_iter()
80                .partition(has_subquery);
81
82        if with_subqueries.is_empty() {
83            return internal_err!(
84                "can not find expected subqueries in DecorrelatePredicateSubquery"
85            );
86        }
87
88        // iterate through all exists clauses in predicate, turning each into a join
89        let mut cur_input = Arc::unwrap_or_clone(filter.input);
90        for subquery_expr in with_subqueries {
91            match extract_subquery_info(subquery_expr) {
92                // The subquery expression is at the top level of the filter
93                SubqueryPredicate::Top(subquery) => {
94                    match build_join_top(&subquery, &cur_input, config.alias_generator())?
95                    {
96                        Some(plan) => cur_input = plan,
97                        // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
98                        None => other_exprs.push(subquery.expr()),
99                    }
100                }
101                // The subquery expression is embedded within another expression
102                SubqueryPredicate::Embedded(expr) => {
103                    let (plan, expr_without_subqueries) =
104                        rewrite_inner_subqueries(cur_input, expr, config)?;
105                    cur_input = plan;
106                    other_exprs.push(expr_without_subqueries);
107                }
108            }
109        }
110
111        let expr = conjunction(other_exprs);
112        if let Some(expr) = expr {
113            let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
114            cur_input = LogicalPlan::Filter(new_filter);
115        }
116        Ok(Transformed::yes(cur_input))
117    }
118
119    fn name(&self) -> &str {
120        "decorrelate_predicate_subquery"
121    }
122
123    fn apply_order(&self) -> Option<ApplyOrder> {
124        Some(ApplyOrder::TopDown)
125    }
126}
127
128fn rewrite_inner_subqueries(
129    outer: LogicalPlan,
130    expr: Expr,
131    config: &dyn OptimizerConfig,
132) -> Result<(LogicalPlan, Expr)> {
133    let mut cur_input = outer;
134    let alias = config.alias_generator();
135    let expr_without_subqueries = expr.transform(|e| match e {
136        Expr::Exists(Exists {
137            subquery: Subquery { subquery, .. },
138            negated,
139        }) => match mark_join(&cur_input, &subquery, None, negated, alias)? {
140            Some((plan, exists_expr)) => {
141                cur_input = plan;
142                Ok(Transformed::yes(exists_expr))
143            }
144            None if negated => Ok(Transformed::no(not_exists(subquery))),
145            None => Ok(Transformed::no(exists(subquery))),
146        },
147        Expr::InSubquery(InSubquery {
148            expr,
149            subquery: Subquery { subquery, .. },
150            negated,
151        }) => {
152            let in_predicate = subquery
153                .head_output_expr()?
154                .map_or(plan_err!("single expression required."), |output_expr| {
155                    Ok(Expr::eq(*expr.clone(), output_expr))
156                })?;
157            match mark_join(&cur_input, &subquery, Some(&in_predicate), negated, alias)? {
158                Some((plan, exists_expr)) => {
159                    cur_input = plan;
160                    Ok(Transformed::yes(exists_expr))
161                }
162                None if negated => Ok(Transformed::no(not_in_subquery(*expr, subquery))),
163                None => Ok(Transformed::no(in_subquery(*expr, subquery))),
164            }
165        }
166        _ => Ok(Transformed::no(e)),
167    })?;
168    Ok((cur_input, expr_without_subqueries.data))
169}
170
171enum SubqueryPredicate {
172    // The subquery expression is at the top level of the filter and can be fully replaced by a
173    // semi/anti join
174    Top(SubqueryInfo),
175    // The subquery expression is embedded within another expression and is replaced using an
176    // existence join
177    Embedded(Expr),
178}
179
180fn extract_subquery_info(expr: Expr) -> SubqueryPredicate {
181    match expr {
182        Expr::Not(not_expr) => match *not_expr {
183            Expr::InSubquery(InSubquery {
184                expr,
185                subquery,
186                negated,
187            }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr(
188                subquery, *expr, !negated,
189            )),
190            Expr::Exists(Exists { subquery, negated }) => {
191                SubqueryPredicate::Top(SubqueryInfo::new(subquery, !negated))
192            }
193            expr => SubqueryPredicate::Embedded(not(expr)),
194        },
195        Expr::InSubquery(InSubquery {
196            expr,
197            subquery,
198            negated,
199        }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr(
200            subquery, *expr, negated,
201        )),
202        Expr::Exists(Exists { subquery, negated }) => {
203            SubqueryPredicate::Top(SubqueryInfo::new(subquery, negated))
204        }
205        expr => SubqueryPredicate::Embedded(expr),
206    }
207}
208
209fn has_subquery(expr: &Expr) -> bool {
210    expr.exists(|e| match e {
211        Expr::InSubquery(_) | Expr::Exists(_) => Ok(true),
212        _ => Ok(false),
213    })
214    .unwrap()
215}
216
217/// Optimize the subquery to left-anti/left-semi join.
218/// If the subquery is a correlated subquery, we need extract the join predicate from the subquery.
219///
220/// For example, given a query like:
221/// `select t1.a, t1.b from t1 where t1 in (select t2.a from t2 where t1.b = t2.b and t1.c > t2.c)`
222///
223/// The optimized plan will be:
224///
225/// ```text
226/// Projection: t1.a, t1.b
227///   LeftSemi Join:  Filter: t1.a = __correlated_sq_1.a AND t1.b = __correlated_sq_1.b AND t1.c > __correlated_sq_1.c
228///     TableScan: t1
229///     SubqueryAlias: __correlated_sq_1
230///       Projection: t2.a, t2.b, t2.c
231///         TableScan: t2
232/// ```
233///
234/// Given another query like:
235/// `select t1.id from t1 where exists(SELECT t2.id FROM t2 WHERE t1.id = t2.id)`
236///
237/// The optimized plan will be:
238///
239/// ```text
240/// Projection: t1.id
241///   LeftSemi Join:  Filter: t1.id = __correlated_sq_1.id
242///     TableScan: t1
243///     SubqueryAlias: __correlated_sq_1
244///       Projection: t2.id
245///         TableScan: t2
246/// ```
247fn build_join_top(
248    query_info: &SubqueryInfo,
249    left: &LogicalPlan,
250    alias: &Arc<AliasGenerator>,
251) -> Result<Option<LogicalPlan>> {
252    let where_in_expr_opt = &query_info.where_in_expr;
253    let in_predicate_opt = where_in_expr_opt
254        .clone()
255        .map(|where_in_expr| {
256            query_info
257                .query
258                .subquery
259                .head_output_expr()?
260                .map_or(plan_err!("single expression required."), |expr| {
261                    Ok(Expr::eq(where_in_expr, expr))
262                })
263        })
264        .map_or(Ok(None), |v| v.map(Some))?;
265
266    let join_type = match query_info.negated {
267        true => JoinType::LeftAnti,
268        false => JoinType::LeftSemi,
269    };
270    let subquery = query_info.query.subquery.as_ref();
271    let subquery_alias = alias.next("__correlated_sq");
272    build_join(
273        left,
274        subquery,
275        in_predicate_opt.as_ref(),
276        join_type,
277        subquery_alias,
278    )
279}
280
281/// This is used to handle the case when the subquery is embedded in a more complex boolean
282/// expression like and OR. For example
283///
284/// `select t1.id from t1 where t1.id < 0 OR exists(SELECT t2.id FROM t2 WHERE t1.id = t2.id)`
285///
286/// The optimized plan will be:
287///
288/// ```text
289/// Projection: t1.id
290///   Filter: t1.id < 0 OR __correlated_sq_1.mark
291///     LeftMark Join:  Filter: t1.id = __correlated_sq_1.id
292///       TableScan: t1
293///       SubqueryAlias: __correlated_sq_1
294///         Projection: t2.id
295///           TableScan: t2
296fn mark_join(
297    left: &LogicalPlan,
298    subquery: &LogicalPlan,
299    in_predicate_opt: Option<&Expr>,
300    negated: bool,
301    alias_generator: &Arc<AliasGenerator>,
302) -> Result<Option<(LogicalPlan, Expr)>> {
303    let alias = alias_generator.next("__correlated_sq");
304
305    let exists_col = Expr::Column(Column::new(Some(alias.clone()), "mark"));
306    let exists_expr = if negated { !exists_col } else { exists_col };
307
308    Ok(
309        build_join(left, subquery, in_predicate_opt, JoinType::LeftMark, alias)?
310            .map(|plan| (plan, exists_expr)),
311    )
312}
313
314fn build_join(
315    left: &LogicalPlan,
316    subquery: &LogicalPlan,
317    in_predicate_opt: Option<&Expr>,
318    join_type: JoinType,
319    alias: String,
320) -> Result<Option<LogicalPlan>> {
321    let mut pull_up = PullUpCorrelatedExpr::new()
322        .with_in_predicate_opt(in_predicate_opt.cloned())
323        .with_exists_sub_query(in_predicate_opt.is_none());
324
325    let new_plan = subquery.clone().rewrite(&mut pull_up).data()?;
326    if !pull_up.can_pull_up {
327        return Ok(None);
328    }
329
330    let sub_query_alias = LogicalPlanBuilder::from(new_plan)
331        .alias(alias.to_string())?
332        .build()?;
333    let mut all_correlated_cols = BTreeSet::new();
334    pull_up
335        .correlated_subquery_cols_map
336        .values()
337        .for_each(|cols| all_correlated_cols.extend(cols.clone()));
338
339    // alias the join filter
340    let join_filter_opt = conjunction(pull_up.join_filters)
341        .map_or(Ok(None), |filter| {
342            replace_qualified_name(filter, &all_correlated_cols, &alias).map(Some)
343        })?;
344
345    let join_filter = match (join_filter_opt, in_predicate_opt.cloned()) {
346        (
347            Some(join_filter),
348            Some(Expr::BinaryExpr(BinaryExpr {
349                left,
350                op: Operator::Eq,
351                right,
352            })),
353        ) => {
354            let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
355            let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col));
356            in_predicate.and(join_filter)
357        }
358        (Some(join_filter), _) => join_filter,
359        (
360            _,
361            Some(Expr::BinaryExpr(BinaryExpr {
362                left,
363                op: Operator::Eq,
364                right,
365            })),
366        ) => {
367            let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
368            let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col));
369            in_predicate
370        }
371        (None, None) => lit(true),
372        _ => return Ok(None),
373    };
374
375    if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) {
376        let right_schema = sub_query_alias.schema();
377
378        // Gather all columns needed for the join filter + predicates
379        let mut needed = std::collections::HashSet::new();
380        expr_to_columns(&join_filter, &mut needed)?;
381        if let Some(in_pred) = in_predicate_opt {
382            expr_to_columns(in_pred, &mut needed)?;
383        }
384
385        // Keep only columns that actually belong to the RIGHT child, and sort by their
386        // position in the right schema for deterministic order.
387        let mut right_cols_idx_and_col: Vec<(usize, Column)> = needed
388            .into_iter()
389            .filter_map(|c| right_schema.index_of_column(&c).ok().map(|idx| (idx, c)))
390            .collect();
391
392        right_cols_idx_and_col.sort_by_key(|(idx, _)| *idx);
393
394        let right_proj_exprs: Vec<Expr> = right_cols_idx_and_col
395            .into_iter()
396            .map(|(_, c)| Expr::Column(c))
397            .collect();
398
399        let right_projected = if !right_proj_exprs.is_empty() {
400            LogicalPlanBuilder::from(sub_query_alias.clone())
401                .project(right_proj_exprs)?
402                .build()?
403        } else {
404            // Degenerate case: no right columns referenced by the predicate(s)
405            sub_query_alias.clone()
406        };
407        let new_plan = LogicalPlanBuilder::from(left.clone())
408            .join_on(right_projected, join_type, Some(join_filter))?
409            .build()?;
410
411        debug!(
412            "predicate subquery optimized:\n{}",
413            new_plan.display_indent()
414        );
415
416        return Ok(Some(new_plan));
417    }
418
419    // join our sub query into the main plan
420    let new_plan = LogicalPlanBuilder::from(left.clone())
421        .join_on(sub_query_alias, join_type, Some(join_filter))?
422        .build()?;
423    debug!(
424        "predicate subquery optimized:\n{}",
425        new_plan.display_indent()
426    );
427    Ok(Some(new_plan))
428}
429
430#[derive(Debug)]
431struct SubqueryInfo {
432    query: Subquery,
433    where_in_expr: Option<Expr>,
434    negated: bool,
435}
436
437impl SubqueryInfo {
438    pub fn new(query: Subquery, negated: bool) -> Self {
439        Self {
440            query,
441            where_in_expr: None,
442            negated,
443        }
444    }
445
446    pub fn new_with_in_expr(query: Subquery, expr: Expr, negated: bool) -> Self {
447        Self {
448            query,
449            where_in_expr: Some(expr),
450            negated,
451        }
452    }
453
454    pub fn expr(self) -> Expr {
455        match self.where_in_expr {
456            Some(expr) => match self.negated {
457                true => not_in_subquery(expr, self.query.subquery),
458                false => in_subquery(expr, self.query.subquery),
459            },
460            None => match self.negated {
461                true => not_exists(self.query.subquery),
462                false => exists(self.query.subquery),
463            },
464        }
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use std::ops::Add;
471
472    use super::*;
473    use crate::test::*;
474
475    use crate::assert_optimized_plan_eq_display_indent_snapshot;
476    use arrow::datatypes::{DataType, Field, Schema};
477    use datafusion_expr::builder::table_source;
478    use datafusion_expr::{and, binary_expr, col, lit, not, out_ref_col, table_scan};
479
480    macro_rules! assert_optimized_plan_equal {
481        (
482            $plan:expr,
483            @ $expected:literal $(,)?
484        ) => {{
485            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(DecorrelatePredicateSubquery::new());
486            assert_optimized_plan_eq_display_indent_snapshot!(
487                rule,
488                $plan,
489                @ $expected,
490            )
491        }};
492    }
493
494    fn test_subquery_with_name(name: &str) -> Result<Arc<LogicalPlan>> {
495        let table_scan = test_table_scan_with_name(name)?;
496        Ok(Arc::new(
497            LogicalPlanBuilder::from(table_scan)
498                .project(vec![col("c")])?
499                .build()?,
500        ))
501    }
502
503    /// Test for several IN subquery expressions
504    #[test]
505    fn in_subquery_multiple() -> Result<()> {
506        let table_scan = test_table_scan()?;
507        let plan = LogicalPlanBuilder::from(table_scan)
508            .filter(and(
509                in_subquery(col("c"), test_subquery_with_name("sq_1")?),
510                in_subquery(col("b"), test_subquery_with_name("sq_2")?),
511            ))?
512            .project(vec![col("test.b")])?
513            .build()?;
514
515        assert_optimized_plan_equal!(
516            plan,
517            @r"
518        Projection: test.b [b:UInt32]
519          LeftSemi Join:  Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]
520            LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
521              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
522              SubqueryAlias: __correlated_sq_1 [c:UInt32]
523                Projection: sq_1.c [c:UInt32]
524                  TableScan: sq_1 [a:UInt32, b:UInt32, c:UInt32]
525            SubqueryAlias: __correlated_sq_2 [c:UInt32]
526              Projection: sq_2.c [c:UInt32]
527                TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]
528        "
529        )
530    }
531
532    /// Test for IN subquery with additional AND filter
533    #[test]
534    fn in_subquery_with_and_filters() -> Result<()> {
535        let table_scan = test_table_scan()?;
536        let plan = LogicalPlanBuilder::from(table_scan)
537            .filter(and(
538                in_subquery(col("c"), test_subquery_with_name("sq")?),
539                and(
540                    binary_expr(col("a"), Operator::Eq, lit(1_u32)),
541                    binary_expr(col("b"), Operator::Lt, lit(30_u32)),
542                ),
543            ))?
544            .project(vec![col("test.b")])?
545            .build()?;
546
547        assert_optimized_plan_equal!(
548            plan,
549            @r"
550        Projection: test.b [b:UInt32]
551          Filter: test.a = UInt32(1) AND test.b < UInt32(30) [a:UInt32, b:UInt32, c:UInt32]
552            LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
553              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
554              SubqueryAlias: __correlated_sq_1 [c:UInt32]
555                Projection: sq.c [c:UInt32]
556                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
557        "
558        )
559    }
560
561    /// Test for nested IN subqueries
562    #[test]
563    fn in_subquery_nested() -> Result<()> {
564        let table_scan = test_table_scan()?;
565
566        let subquery = LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
567            .filter(in_subquery(col("a"), test_subquery_with_name("sq_nested")?))?
568            .project(vec![col("a")])?
569            .build()?;
570
571        let plan = LogicalPlanBuilder::from(table_scan)
572            .filter(in_subquery(col("b"), Arc::new(subquery)))?
573            .project(vec![col("test.b")])?
574            .build()?;
575
576        assert_optimized_plan_equal!(
577            plan,
578            @r"
579        Projection: test.b [b:UInt32]
580          LeftSemi Join:  Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]
581            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
582            SubqueryAlias: __correlated_sq_2 [a:UInt32]
583              Projection: sq.a [a:UInt32]
584                LeftSemi Join:  Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
585                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
586                  SubqueryAlias: __correlated_sq_1 [c:UInt32]
587                    Projection: sq_nested.c [c:UInt32]
588                      TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]
589        "
590        )
591    }
592
593    /// Test multiple correlated subqueries
594    /// See subqueries.rs where_in_multiple()
595    #[test]
596    fn multiple_subqueries() -> Result<()> {
597        let orders = Arc::new(
598            LogicalPlanBuilder::from(scan_tpch_table("orders"))
599                .filter(
600                    col("orders.o_custkey")
601                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
602                )?
603                .project(vec![col("orders.o_custkey")])?
604                .build()?,
605        );
606        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
607            .filter(
608                in_subquery(col("customer.c_custkey"), Arc::clone(&orders))
609                    .and(in_subquery(col("customer.c_custkey"), orders)),
610            )?
611            .project(vec![col("customer.c_custkey")])?
612            .build()?;
613        debug!("plan to optimize:\n{}", plan.display_indent());
614
615        assert_optimized_plan_equal!(
616                plan,
617                @r###"
618        Projection: customer.c_custkey [c_custkey:Int64]
619          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]
620            LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
621              TableScan: customer [c_custkey:Int64, c_name:Utf8]
622              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
623                Projection: orders.o_custkey [o_custkey:Int64]
624                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
625            SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]
626              Projection: orders.o_custkey [o_custkey:Int64]
627                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
628        "###    
629        )
630    }
631
632    /// Test recursive correlated subqueries
633    /// See subqueries.rs where_in_recursive()
634    #[test]
635    fn recursive_subqueries() -> Result<()> {
636        let lineitem = Arc::new(
637            LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
638                .filter(
639                    col("lineitem.l_orderkey")
640                        .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
641                )?
642                .project(vec![col("lineitem.l_orderkey")])?
643                .build()?,
644        );
645
646        let orders = Arc::new(
647            LogicalPlanBuilder::from(scan_tpch_table("orders"))
648                .filter(
649                    in_subquery(col("orders.o_orderkey"), lineitem).and(
650                        col("orders.o_custkey")
651                            .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
652                    ),
653                )?
654                .project(vec![col("orders.o_custkey")])?
655                .build()?,
656        );
657
658        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
659            .filter(in_subquery(col("customer.c_custkey"), orders))?
660            .project(vec![col("customer.c_custkey")])?
661            .build()?;
662
663        assert_optimized_plan_equal!(
664            plan,
665            @r"
666        Projection: customer.c_custkey [c_custkey:Int64]
667          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]
668            TableScan: customer [c_custkey:Int64, c_name:Utf8]
669            SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]
670              Projection: orders.o_custkey [o_custkey:Int64]
671                LeftSemi Join:  Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
672                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
673                  SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]
674                    Projection: lineitem.l_orderkey [l_orderkey:Int64]
675                      TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]
676        "
677        )
678    }
679
680    /// Test for correlated IN subquery filter with additional subquery filters
681    #[test]
682    fn in_subquery_with_subquery_filters() -> Result<()> {
683        let sq = Arc::new(
684            LogicalPlanBuilder::from(scan_tpch_table("orders"))
685                .filter(
686                    out_ref_col(DataType::Int64, "customer.c_custkey")
687                        .eq(col("orders.o_custkey"))
688                        .and(col("o_orderkey").eq(lit(1))),
689                )?
690                .project(vec![col("orders.o_custkey")])?
691                .build()?,
692        );
693
694        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
695            .filter(in_subquery(col("customer.c_custkey"), sq))?
696            .project(vec![col("customer.c_custkey")])?
697            .build()?;
698
699        assert_optimized_plan_equal!(
700            plan,
701            @r"
702        Projection: customer.c_custkey [c_custkey:Int64]
703          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
704            TableScan: customer [c_custkey:Int64, c_name:Utf8]
705            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
706              Projection: orders.o_custkey [o_custkey:Int64]
707                Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
708                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
709        "
710        )
711    }
712
713    /// Test for correlated IN subquery with no columns in schema
714    #[test]
715    fn in_subquery_no_cols() -> Result<()> {
716        let sq = Arc::new(
717            LogicalPlanBuilder::from(scan_tpch_table("orders"))
718                .filter(
719                    out_ref_col(DataType::Int64, "customer.c_custkey")
720                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
721                )?
722                .project(vec![col("orders.o_custkey")])?
723                .build()?,
724        );
725
726        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
727            .filter(in_subquery(col("customer.c_custkey"), sq))?
728            .project(vec![col("customer.c_custkey")])?
729            .build()?;
730
731        assert_optimized_plan_equal!(
732            plan,
733            @r"
734        Projection: customer.c_custkey [c_custkey:Int64]
735          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
736            TableScan: customer [c_custkey:Int64, c_name:Utf8]
737            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
738              Projection: orders.o_custkey [o_custkey:Int64]
739                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
740        "
741        )
742    }
743
744    /// Test for IN subquery with both columns in schema
745    #[test]
746    fn in_subquery_with_no_correlated_cols() -> Result<()> {
747        let sq = Arc::new(
748            LogicalPlanBuilder::from(scan_tpch_table("orders"))
749                .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
750                .project(vec![col("orders.o_custkey")])?
751                .build()?,
752        );
753
754        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
755            .filter(in_subquery(col("customer.c_custkey"), sq))?
756            .project(vec![col("customer.c_custkey")])?
757            .build()?;
758
759        assert_optimized_plan_equal!(
760            plan,
761            @r"
762        Projection: customer.c_custkey [c_custkey:Int64]
763          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
764            TableScan: customer [c_custkey:Int64, c_name:Utf8]
765            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
766              Projection: orders.o_custkey [o_custkey:Int64]
767                Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
768                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
769        "
770        )
771    }
772
773    /// Test for correlated IN subquery not equal
774    #[test]
775    fn in_subquery_where_not_eq() -> Result<()> {
776        let sq = Arc::new(
777            LogicalPlanBuilder::from(scan_tpch_table("orders"))
778                .filter(
779                    out_ref_col(DataType::Int64, "customer.c_custkey")
780                        .not_eq(col("orders.o_custkey")),
781                )?
782                .project(vec![col("orders.o_custkey")])?
783                .build()?,
784        );
785
786        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
787            .filter(in_subquery(col("customer.c_custkey"), sq))?
788            .project(vec![col("customer.c_custkey")])?
789            .build()?;
790
791        assert_optimized_plan_equal!(
792            plan,
793            @r"
794        Projection: customer.c_custkey [c_custkey:Int64]
795          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
796            TableScan: customer [c_custkey:Int64, c_name:Utf8]
797            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
798              Projection: orders.o_custkey [o_custkey:Int64]
799                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
800        "
801        )
802    }
803
804    /// Test for correlated IN subquery less than
805    #[test]
806    fn in_subquery_where_less_than() -> Result<()> {
807        let sq = Arc::new(
808            LogicalPlanBuilder::from(scan_tpch_table("orders"))
809                .filter(
810                    out_ref_col(DataType::Int64, "customer.c_custkey")
811                        .lt(col("orders.o_custkey")),
812                )?
813                .project(vec![col("orders.o_custkey")])?
814                .build()?,
815        );
816
817        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
818            .filter(in_subquery(col("customer.c_custkey"), sq))?
819            .project(vec![col("customer.c_custkey")])?
820            .build()?;
821
822        assert_optimized_plan_equal!(
823            plan,
824            @r"
825        Projection: customer.c_custkey [c_custkey:Int64]
826          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
827            TableScan: customer [c_custkey:Int64, c_name:Utf8]
828            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
829              Projection: orders.o_custkey [o_custkey:Int64]
830                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
831        "
832        )
833    }
834
835    /// Test for correlated IN subquery filter with subquery disjunction
836    #[test]
837    fn in_subquery_with_subquery_disjunction() -> Result<()> {
838        let sq = Arc::new(
839            LogicalPlanBuilder::from(scan_tpch_table("orders"))
840                .filter(
841                    out_ref_col(DataType::Int64, "customer.c_custkey")
842                        .eq(col("orders.o_custkey"))
843                        .or(col("o_orderkey").eq(lit(1))),
844                )?
845                .project(vec![col("orders.o_custkey")])?
846                .build()?,
847        );
848
849        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
850            .filter(in_subquery(col("customer.c_custkey"), sq))?
851            .project(vec![col("customer.c_custkey")])?
852            .build()?;
853
854        assert_optimized_plan_equal!(
855            plan,
856            @r"
857        Projection: customer.c_custkey [c_custkey:Int64]
858          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8]
859            TableScan: customer [c_custkey:Int64, c_name:Utf8]
860            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]
861              Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]
862                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
863        "
864        )
865    }
866
867    /// Test for correlated IN without projection
868    #[test]
869    fn in_subquery_no_projection() -> Result<()> {
870        let sq = Arc::new(
871            LogicalPlanBuilder::from(scan_tpch_table("orders"))
872                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
873                .build()?,
874        );
875
876        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
877            .filter(in_subquery(col("customer.c_custkey"), sq))?
878            .project(vec![col("customer.c_custkey")])?
879            .build()?;
880
881        // Maybe okay if the table only has a single column?
882        let expected = "Invalid (non-executable) plan after Analyzer\
883        \ncaused by\
884        \nError during planning: InSubquery should only return one column, but found 4";
885        assert_analyzer_check_err(vec![], plan, expected);
886
887        Ok(())
888    }
889
890    /// Test for correlated IN subquery join on expression
891    #[test]
892    fn in_subquery_join_expr() -> Result<()> {
893        let sq = Arc::new(
894            LogicalPlanBuilder::from(scan_tpch_table("orders"))
895                .filter(
896                    out_ref_col(DataType::Int64, "customer.c_custkey")
897                        .eq(col("orders.o_custkey")),
898                )?
899                .project(vec![col("orders.o_custkey")])?
900                .build()?,
901        );
902
903        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
904            .filter(in_subquery(col("customer.c_custkey").add(lit(1)), sq))?
905            .project(vec![col("customer.c_custkey")])?
906            .build()?;
907
908        assert_optimized_plan_equal!(
909            plan,
910            @r"
911        Projection: customer.c_custkey [c_custkey:Int64]
912          LeftSemi Join:  Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
913            TableScan: customer [c_custkey:Int64, c_name:Utf8]
914            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
915              Projection: orders.o_custkey [o_custkey:Int64]
916                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
917        "
918        )
919    }
920
921    /// Test for correlated IN expressions
922    #[test]
923    fn in_subquery_project_expr() -> Result<()> {
924        let sq = Arc::new(
925            LogicalPlanBuilder::from(scan_tpch_table("orders"))
926                .filter(
927                    out_ref_col(DataType::Int64, "customer.c_custkey")
928                        .eq(col("orders.o_custkey")),
929                )?
930                .project(vec![col("orders.o_custkey").add(lit(1))])?
931                .build()?,
932        );
933
934        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
935            .filter(in_subquery(col("customer.c_custkey"), sq))?
936            .project(vec![col("customer.c_custkey")])?
937            .build()?;
938
939        assert_optimized_plan_equal!(
940            plan,
941            @r"
942        Projection: customer.c_custkey [c_custkey:Int64]
943          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.orders.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
944            TableScan: customer [c_custkey:Int64, c_name:Utf8]
945            SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]
946              Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]
947                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
948        "
949        )
950    }
951
952    /// Test for correlated IN subquery multiple projected columns
953    #[test]
954    fn in_subquery_multi_col() -> Result<()> {
955        let sq = Arc::new(
956            LogicalPlanBuilder::from(scan_tpch_table("orders"))
957                .filter(
958                    out_ref_col(DataType::Int64, "customer.c_custkey")
959                        .eq(col("orders.o_custkey")),
960                )?
961                .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])?
962                .build()?,
963        );
964
965        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
966            .filter(
967                in_subquery(col("customer.c_custkey"), sq)
968                    .and(col("c_custkey").eq(lit(1))),
969            )?
970            .project(vec![col("customer.c_custkey")])?
971            .build()?;
972
973        let expected = "Invalid (non-executable) plan after Analyzer\
974        \ncaused by\
975        \nError during planning: InSubquery should only return one column";
976        assert_analyzer_check_err(vec![], plan, expected);
977
978        Ok(())
979    }
980
981    /// Test for correlated IN subquery filter with additional filters
982    #[test]
983    fn should_support_additional_filters() -> Result<()> {
984        let sq = Arc::new(
985            LogicalPlanBuilder::from(scan_tpch_table("orders"))
986                .filter(
987                    out_ref_col(DataType::Int64, "customer.c_custkey")
988                        .eq(col("orders.o_custkey")),
989                )?
990                .project(vec![col("orders.o_custkey")])?
991                .build()?,
992        );
993
994        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
995            .filter(
996                in_subquery(col("customer.c_custkey"), sq)
997                    .and(col("c_custkey").eq(lit(1))),
998            )?
999            .project(vec![col("customer.c_custkey")])?
1000            .build()?;
1001
1002        assert_optimized_plan_equal!(
1003            plan,
1004            @r"
1005        Projection: customer.c_custkey [c_custkey:Int64]
1006          Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
1007            LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1008              TableScan: customer [c_custkey:Int64, c_name:Utf8]
1009              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1010                Projection: orders.o_custkey [o_custkey:Int64]
1011                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1012        "
1013        )
1014    }
1015
1016    /// Test for correlated IN subquery filter
1017    #[test]
1018    fn in_subquery_correlated() -> Result<()> {
1019        let sq = Arc::new(
1020            LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
1021                .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
1022                .project(vec![col("c")])?
1023                .build()?,
1024        );
1025
1026        let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
1027            .filter(in_subquery(col("c"), sq))?
1028            .project(vec![col("test.b")])?
1029            .build()?;
1030
1031        assert_optimized_plan_equal!(
1032            plan,
1033            @r"
1034        Projection: test.b [b:UInt32]
1035          LeftSemi Join:  Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
1036            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1037            SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]
1038              Projection: sq.c, sq.a [c:UInt32, a:UInt32]
1039                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1040        "
1041        )
1042    }
1043
1044    /// Test for single IN subquery filter
1045    #[test]
1046    fn in_subquery_simple() -> Result<()> {
1047        let table_scan = test_table_scan()?;
1048        let plan = LogicalPlanBuilder::from(table_scan)
1049            .filter(in_subquery(col("c"), test_subquery_with_name("sq")?))?
1050            .project(vec![col("test.b")])?
1051            .build()?;
1052
1053        assert_optimized_plan_equal!(
1054            plan,
1055            @r"
1056        Projection: test.b [b:UInt32]
1057          LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1058            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1059            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1060              Projection: sq.c [c:UInt32]
1061                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1062        "
1063        )
1064    }
1065
1066    /// Test for single NOT IN subquery filter
1067    #[test]
1068    fn not_in_subquery_simple() -> Result<()> {
1069        let table_scan = test_table_scan()?;
1070        let plan = LogicalPlanBuilder::from(table_scan)
1071            .filter(not_in_subquery(col("c"), test_subquery_with_name("sq")?))?
1072            .project(vec![col("test.b")])?
1073            .build()?;
1074
1075        assert_optimized_plan_equal!(
1076            plan,
1077            @r"
1078        Projection: test.b [b:UInt32]
1079          LeftAnti Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1080            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1081            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1082              Projection: sq.c [c:UInt32]
1083                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1084        "
1085        )
1086    }
1087
1088    #[test]
1089    fn wrapped_not_in_subquery() -> Result<()> {
1090        let table_scan = test_table_scan()?;
1091        let plan = LogicalPlanBuilder::from(table_scan)
1092            .filter(not(in_subquery(col("c"), test_subquery_with_name("sq")?)))?
1093            .project(vec![col("test.b")])?
1094            .build()?;
1095
1096        assert_optimized_plan_equal!(
1097            plan,
1098            @r"
1099        Projection: test.b [b:UInt32]
1100          LeftAnti Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1101            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1102            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1103              Projection: sq.c [c:UInt32]
1104                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1105        "
1106        )
1107    }
1108
1109    #[test]
1110    fn wrapped_not_not_in_subquery() -> Result<()> {
1111        let table_scan = test_table_scan()?;
1112        let plan = LogicalPlanBuilder::from(table_scan)
1113            .filter(not(not_in_subquery(
1114                col("c"),
1115                test_subquery_with_name("sq")?,
1116            )))?
1117            .project(vec![col("test.b")])?
1118            .build()?;
1119
1120        assert_optimized_plan_equal!(
1121            plan,
1122            @r"
1123        Projection: test.b [b:UInt32]
1124          LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1125            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1126            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1127              Projection: sq.c [c:UInt32]
1128                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1129        "
1130        )
1131    }
1132
1133    #[test]
1134    fn in_subquery_both_side_expr() -> Result<()> {
1135        let table_scan = test_table_scan()?;
1136        let subquery_scan = test_table_scan_with_name("sq")?;
1137
1138        let subquery = LogicalPlanBuilder::from(subquery_scan)
1139            .project(vec![col("c") * lit(2u32)])?
1140            .build()?;
1141
1142        let plan = LogicalPlanBuilder::from(table_scan)
1143            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
1144            .project(vec![col("test.b")])?
1145            .build()?;
1146
1147        assert_optimized_plan_equal!(
1148            plan,
1149            @r"
1150        Projection: test.b [b:UInt32]
1151          LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
1152            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1153            SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32]
1154              Projection: sq.c * UInt32(2) [sq.c * UInt32(2):UInt32]
1155                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1156        "
1157        )
1158    }
1159
1160    #[test]
1161    fn in_subquery_join_filter_and_inner_filter() -> Result<()> {
1162        let table_scan = test_table_scan()?;
1163        let subquery_scan = test_table_scan_with_name("sq")?;
1164
1165        let subquery = LogicalPlanBuilder::from(subquery_scan)
1166            .filter(
1167                out_ref_col(DataType::UInt32, "test.a")
1168                    .eq(col("sq.a"))
1169                    .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))),
1170            )?
1171            .project(vec![col("c") * lit(2u32)])?
1172            .build()?;
1173
1174        let plan = LogicalPlanBuilder::from(table_scan)
1175            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
1176            .project(vec![col("test.b")])?
1177            .build()?;
1178
1179        assert_optimized_plan_equal!(
1180            plan,
1181            @r"
1182        Projection: test.b [b:UInt32]
1183          LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
1184            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1185            SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32]
1186              Projection: sq.c * UInt32(2), sq.a [sq.c * UInt32(2):UInt32, a:UInt32]
1187                Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]
1188                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1189        "
1190        )
1191    }
1192
1193    #[test]
1194    fn in_subquery_multi_project_subquery_cols() -> Result<()> {
1195        let table_scan = test_table_scan()?;
1196        let subquery_scan = test_table_scan_with_name("sq")?;
1197
1198        let subquery = LogicalPlanBuilder::from(subquery_scan)
1199            .filter(
1200                out_ref_col(DataType::UInt32, "test.a")
1201                    .add(out_ref_col(DataType::UInt32, "test.b"))
1202                    .eq(col("sq.a").add(col("sq.b")))
1203                    .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))),
1204            )?
1205            .project(vec![col("c") * lit(2u32)])?
1206            .build()?;
1207
1208        let plan = LogicalPlanBuilder::from(table_scan)
1209            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
1210            .project(vec![col("test.b")])?
1211            .build()?;
1212
1213        assert_optimized_plan_equal!(
1214            plan,
1215            @r"
1216        Projection: test.b [b:UInt32]
1217          LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32]
1218            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1219            SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32]
1220              Projection: sq.c * UInt32(2), sq.a, sq.b [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32]
1221                Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]
1222                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1223        "
1224        )
1225    }
1226
1227    #[test]
1228    fn two_in_subquery_with_outer_filter() -> Result<()> {
1229        let table_scan = test_table_scan()?;
1230        let subquery_scan1 = test_table_scan_with_name("sq1")?;
1231        let subquery_scan2 = test_table_scan_with_name("sq2")?;
1232
1233        let subquery1 = LogicalPlanBuilder::from(subquery_scan1)
1234            .filter(out_ref_col(DataType::UInt32, "test.a").gt(col("sq1.a")))?
1235            .project(vec![col("c") * lit(2u32)])?
1236            .build()?;
1237
1238        let subquery2 = LogicalPlanBuilder::from(subquery_scan2)
1239            .filter(out_ref_col(DataType::UInt32, "test.a").gt(col("sq2.a")))?
1240            .project(vec![col("c") * lit(2u32)])?
1241            .build()?;
1242
1243        let plan = LogicalPlanBuilder::from(table_scan)
1244            .filter(
1245                in_subquery(col("c") + lit(1u32), Arc::new(subquery1)).and(
1246                    in_subquery(col("c") * lit(2u32), Arc::new(subquery2))
1247                        .and(col("test.c").gt(lit(1u32))),
1248                ),
1249            )?
1250            .project(vec![col("test.b")])?
1251            .build()?;
1252
1253        assert_optimized_plan_equal!(
1254            plan,
1255            @r"
1256        Projection: test.b [b:UInt32]
1257          Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]
1258            LeftSemi Join:  Filter: test.c * UInt32(2) = __correlated_sq_2.sq2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]
1259              LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.sq1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
1260                TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1261                SubqueryAlias: __correlated_sq_1 [sq1.c * UInt32(2):UInt32, a:UInt32]
1262                  Projection: sq1.c * UInt32(2), sq1.a [sq1.c * UInt32(2):UInt32, a:UInt32]
1263                    TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]
1264              SubqueryAlias: __correlated_sq_2 [sq2.c * UInt32(2):UInt32, a:UInt32]
1265                Projection: sq2.c * UInt32(2), sq2.a [sq2.c * UInt32(2):UInt32, a:UInt32]
1266                  TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]
1267        "
1268        )
1269    }
1270
1271    #[test]
1272    fn in_subquery_with_same_table() -> Result<()> {
1273        let outer_scan = test_table_scan()?;
1274        let subquery_scan = test_table_scan()?;
1275        let subquery = LogicalPlanBuilder::from(subquery_scan)
1276            .filter(col("test.a").gt(col("test.b")))?
1277            .project(vec![col("c")])?
1278            .build()?;
1279
1280        let plan = LogicalPlanBuilder::from(outer_scan)
1281            .filter(in_subquery(col("test.a"), Arc::new(subquery)))?
1282            .project(vec![col("test.b")])?
1283            .build()?;
1284
1285        // Subquery and outer query refer to the same table.
1286        assert_optimized_plan_equal!(
1287            plan,
1288            @r"
1289        Projection: test.b [b:UInt32]
1290          LeftSemi Join:  Filter: test.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1291            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1292            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1293              Projection: test.c [c:UInt32]
1294                Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]
1295                  TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1296        "
1297        )
1298    }
1299
1300    /// Test for multiple exists subqueries in the same filter expression
1301    #[test]
1302    fn multiple_exists_subqueries() -> Result<()> {
1303        let orders = Arc::new(
1304            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1305                .filter(
1306                    col("orders.o_custkey")
1307                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
1308                )?
1309                .project(vec![col("orders.o_custkey")])?
1310                .build()?,
1311        );
1312
1313        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1314            .filter(exists(Arc::clone(&orders)).and(exists(orders)))?
1315            .project(vec![col("customer.c_custkey")])?
1316            .build()?;
1317
1318        assert_optimized_plan_equal!(
1319            plan,
1320            @r"
1321        Projection: customer.c_custkey [c_custkey:Int64]
1322          LeftSemi Join:  Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]
1323            LeftSemi Join:  Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]
1324              TableScan: customer [c_custkey:Int64, c_name:Utf8]
1325              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1326                Projection: orders.o_custkey [o_custkey:Int64]
1327                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1328            SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]
1329              Projection: orders.o_custkey [o_custkey:Int64]
1330                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1331        "
1332        )
1333    }
1334
1335    /// Test recursive correlated subqueries
1336    #[test]
1337    fn recursive_exists_subqueries() -> Result<()> {
1338        let lineitem = Arc::new(
1339            LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
1340                .filter(
1341                    col("lineitem.l_orderkey")
1342                        .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
1343                )?
1344                .project(vec![col("lineitem.l_orderkey")])?
1345                .build()?,
1346        );
1347
1348        let orders = Arc::new(
1349            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1350                .filter(
1351                    exists(lineitem).and(
1352                        col("orders.o_custkey")
1353                            .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
1354                    ),
1355                )?
1356                .project(vec![col("orders.o_custkey")])?
1357                .build()?,
1358        );
1359
1360        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1361            .filter(exists(orders))?
1362            .project(vec![col("customer.c_custkey")])?
1363            .build()?;
1364
1365        assert_optimized_plan_equal!(
1366            plan,
1367            @r"
1368        Projection: customer.c_custkey [c_custkey:Int64]
1369          LeftSemi Join:  Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]
1370            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1371            SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]
1372              Projection: orders.o_custkey [o_custkey:Int64]
1373                LeftSemi Join:  Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1374                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1375                  SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]
1376                    Projection: lineitem.l_orderkey [l_orderkey:Int64]
1377                      TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]
1378        "
1379        )
1380    }
1381
1382    /// Test for correlated exists subquery filter with additional subquery filters
1383    #[test]
1384    fn exists_subquery_with_subquery_filters() -> Result<()> {
1385        let sq = Arc::new(
1386            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1387                .filter(
1388                    out_ref_col(DataType::Int64, "customer.c_custkey")
1389                        .eq(col("orders.o_custkey"))
1390                        .and(col("o_orderkey").eq(lit(1))),
1391                )?
1392                .project(vec![col("orders.o_custkey")])?
1393                .build()?,
1394        );
1395
1396        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1397            .filter(exists(sq))?
1398            .project(vec![col("customer.c_custkey")])?
1399            .build()?;
1400
1401        assert_optimized_plan_equal!(
1402            plan,
1403            @r"
1404        Projection: customer.c_custkey [c_custkey:Int64]
1405          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1406            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1407            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1408              Projection: orders.o_custkey [o_custkey:Int64]
1409                Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1410                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1411        "
1412        )
1413    }
1414
1415    #[test]
1416    fn exists_subquery_no_cols() -> Result<()> {
1417        let sq = Arc::new(
1418            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1419                .filter(out_ref_col(DataType::Int64, "customer.c_custkey").eq(lit(1u32)))?
1420                .project(vec![col("orders.o_custkey")])?
1421                .build()?,
1422        );
1423
1424        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1425            .filter(exists(sq))?
1426            .project(vec![col("customer.c_custkey")])?
1427            .build()?;
1428
1429        // Other rule will pushdown `customer.c_custkey = 1`,
1430        assert_optimized_plan_equal!(
1431            plan,
1432            @r"
1433        Projection: customer.c_custkey [c_custkey:Int64]
1434          LeftSemi Join:  Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8]
1435            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1436            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1437              Projection: orders.o_custkey [o_custkey:Int64]
1438                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1439        "
1440        )
1441    }
1442
1443    /// Test for exists subquery with both columns in schema
1444    #[test]
1445    fn exists_subquery_with_no_correlated_cols() -> Result<()> {
1446        let sq = Arc::new(
1447            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1448                .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
1449                .project(vec![col("orders.o_custkey")])?
1450                .build()?,
1451        );
1452
1453        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1454            .filter(exists(sq))?
1455            .project(vec![col("customer.c_custkey")])?
1456            .build()?;
1457
1458        assert_optimized_plan_equal!(
1459            plan,
1460            @r"
1461        Projection: customer.c_custkey [c_custkey:Int64]
1462          LeftSemi Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8]
1463            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1464            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1465              Projection: orders.o_custkey [o_custkey:Int64]
1466                Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1467                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1468        "
1469        )
1470    }
1471
1472    /// Test for correlated exists subquery not equal
1473    #[test]
1474    fn exists_subquery_where_not_eq() -> Result<()> {
1475        let sq = Arc::new(
1476            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1477                .filter(
1478                    out_ref_col(DataType::Int64, "customer.c_custkey")
1479                        .not_eq(col("orders.o_custkey")),
1480                )?
1481                .project(vec![col("orders.o_custkey")])?
1482                .build()?,
1483        );
1484
1485        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1486            .filter(exists(sq))?
1487            .project(vec![col("customer.c_custkey")])?
1488            .build()?;
1489
1490        assert_optimized_plan_equal!(
1491            plan,
1492            @r"
1493        Projection: customer.c_custkey [c_custkey:Int64]
1494          LeftSemi Join:  Filter: customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1495            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1496            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1497              Projection: orders.o_custkey [o_custkey:Int64]
1498                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1499        "
1500        )
1501    }
1502
1503    /// Test for correlated exists subquery less than
1504    #[test]
1505    fn exists_subquery_where_less_than() -> Result<()> {
1506        let sq = Arc::new(
1507            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1508                .filter(
1509                    out_ref_col(DataType::Int64, "customer.c_custkey")
1510                        .lt(col("orders.o_custkey")),
1511                )?
1512                .project(vec![col("orders.o_custkey")])?
1513                .build()?,
1514        );
1515
1516        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1517            .filter(exists(sq))?
1518            .project(vec![col("customer.c_custkey")])?
1519            .build()?;
1520
1521        assert_optimized_plan_equal!(
1522            plan,
1523            @r"
1524        Projection: customer.c_custkey [c_custkey:Int64]
1525          LeftSemi Join:  Filter: customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1526            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1527            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1528              Projection: orders.o_custkey [o_custkey:Int64]
1529                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1530        "
1531        )
1532    }
1533
1534    /// Test for correlated exists subquery filter with subquery disjunction
1535    #[test]
1536    fn exists_subquery_with_subquery_disjunction() -> Result<()> {
1537        let sq = Arc::new(
1538            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1539                .filter(
1540                    out_ref_col(DataType::Int64, "customer.c_custkey")
1541                        .eq(col("orders.o_custkey"))
1542                        .or(col("o_orderkey").eq(lit(1))),
1543                )?
1544                .project(vec![col("orders.o_custkey")])?
1545                .build()?,
1546        );
1547
1548        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1549            .filter(exists(sq))?
1550            .project(vec![col("customer.c_custkey")])?
1551            .build()?;
1552
1553        assert_optimized_plan_equal!(
1554            plan,
1555            @r"
1556        Projection: customer.c_custkey [c_custkey:Int64]
1557          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
1558            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1559            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]
1560              Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]
1561                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1562        "
1563        )
1564    }
1565
1566    /// Test for correlated exists without projection
1567    #[test]
1568    fn exists_subquery_no_projection() -> Result<()> {
1569        let sq = Arc::new(
1570            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1571                .filter(
1572                    out_ref_col(DataType::Int64, "customer.c_custkey")
1573                        .eq(col("orders.o_custkey")),
1574                )?
1575                .build()?,
1576        );
1577
1578        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1579            .filter(exists(sq))?
1580            .project(vec![col("customer.c_custkey")])?
1581            .build()?;
1582
1583        assert_optimized_plan_equal!(
1584            plan,
1585            @r"
1586        Projection: customer.c_custkey [c_custkey:Int64]
1587          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1588            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1589            SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1590              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1591        "
1592        )
1593    }
1594
1595    /// Test for correlated exists expressions
1596    #[test]
1597    fn exists_subquery_project_expr() -> Result<()> {
1598        let sq = Arc::new(
1599            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1600                .filter(
1601                    out_ref_col(DataType::Int64, "customer.c_custkey")
1602                        .eq(col("orders.o_custkey")),
1603                )?
1604                .project(vec![col("orders.o_custkey").add(lit(1))])?
1605                .build()?,
1606        );
1607
1608        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1609            .filter(exists(sq))?
1610            .project(vec![col("customer.c_custkey")])?
1611            .build()?;
1612
1613        assert_optimized_plan_equal!(
1614            plan,
1615            @r"
1616        Projection: customer.c_custkey [c_custkey:Int64]
1617          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1618            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1619            SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]
1620              Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]
1621                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1622        "
1623        )
1624    }
1625
1626    /// Test for correlated exists subquery filter with additional filters
1627    #[test]
1628    fn exists_subquery_should_support_additional_filters() -> Result<()> {
1629        let sq = Arc::new(
1630            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1631                .filter(
1632                    out_ref_col(DataType::Int64, "customer.c_custkey")
1633                        .eq(col("orders.o_custkey")),
1634                )?
1635                .project(vec![col("orders.o_custkey")])?
1636                .build()?,
1637        );
1638        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1639            .filter(exists(sq).and(col("c_custkey").eq(lit(1))))?
1640            .project(vec![col("customer.c_custkey")])?
1641            .build()?;
1642
1643        assert_optimized_plan_equal!(
1644            plan,
1645            @r"
1646        Projection: customer.c_custkey [c_custkey:Int64]
1647          Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
1648            LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1649              TableScan: customer [c_custkey:Int64, c_name:Utf8]
1650              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1651                Projection: orders.o_custkey [o_custkey:Int64]
1652                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1653        "
1654        )
1655    }
1656
1657    /// Test for correlated exists subquery filter with disjunctions
1658    #[test]
1659    fn exists_subquery_disjunction() -> Result<()> {
1660        let sq = Arc::new(
1661            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1662                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
1663                .project(vec![col("orders.o_custkey")])?
1664                .build()?,
1665        );
1666
1667        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1668            .filter(exists(sq).or(col("customer.c_custkey").eq(lit(1))))?
1669            .project(vec![col("customer.c_custkey")])?
1670            .build()?;
1671
1672        assert_optimized_plan_equal!(
1673            plan,
1674            @r"
1675        Projection: customer.c_custkey [c_custkey:Int64]
1676          Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1677            LeftMark Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1678              TableScan: customer [c_custkey:Int64, c_name:Utf8]
1679              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1680                Projection: orders.o_custkey [o_custkey:Int64]
1681                  Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1682                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1683        "
1684        )
1685    }
1686
1687    /// Test for correlated EXISTS subquery filter
1688    #[test]
1689    fn exists_subquery_correlated() -> Result<()> {
1690        let sq = Arc::new(
1691            LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
1692                .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
1693                .project(vec![col("c")])?
1694                .build()?,
1695        );
1696
1697        let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
1698            .filter(exists(sq))?
1699            .project(vec![col("test.c")])?
1700            .build()?;
1701
1702        assert_optimized_plan_equal!(
1703            plan,
1704            @r"
1705        Projection: test.c [c:UInt32]
1706          LeftSemi Join:  Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
1707            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1708            SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]
1709              Projection: sq.c, sq.a [c:UInt32, a:UInt32]
1710                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1711        "
1712        )
1713    }
1714
1715    /// Test for single exists subquery filter
1716    #[test]
1717    fn exists_subquery_simple() -> Result<()> {
1718        let table_scan = test_table_scan()?;
1719        let plan = LogicalPlanBuilder::from(table_scan)
1720            .filter(exists(test_subquery_with_name("sq")?))?
1721            .project(vec![col("test.b")])?
1722            .build()?;
1723
1724        assert_optimized_plan_equal!(
1725            plan,
1726            @r"
1727        Projection: test.b [b:UInt32]
1728          LeftSemi Join:  Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]
1729            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1730            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1731              Projection: sq.c [c:UInt32]
1732                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1733        "
1734        )
1735    }
1736
1737    /// Test for single NOT exists subquery filter
1738    #[test]
1739    fn not_exists_subquery_simple() -> Result<()> {
1740        let table_scan = test_table_scan()?;
1741        let plan = LogicalPlanBuilder::from(table_scan)
1742            .filter(not_exists(test_subquery_with_name("sq")?))?
1743            .project(vec![col("test.b")])?
1744            .build()?;
1745
1746        assert_optimized_plan_equal!(
1747            plan,
1748            @r"
1749        Projection: test.b [b:UInt32]
1750          LeftAnti Join:  Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]
1751            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1752            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1753              Projection: sq.c [c:UInt32]
1754                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1755        "
1756        )
1757    }
1758
1759    #[test]
1760    fn two_exists_subquery_with_outer_filter() -> Result<()> {
1761        let table_scan = test_table_scan()?;
1762        let subquery_scan1 = test_table_scan_with_name("sq1")?;
1763        let subquery_scan2 = test_table_scan_with_name("sq2")?;
1764
1765        let subquery1 = LogicalPlanBuilder::from(subquery_scan1)
1766            .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq1.a")))?
1767            .project(vec![col("c")])?
1768            .build()?;
1769
1770        let subquery2 = LogicalPlanBuilder::from(subquery_scan2)
1771            .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq2.a")))?
1772            .project(vec![col("c")])?
1773            .build()?;
1774
1775        let plan = LogicalPlanBuilder::from(table_scan)
1776            .filter(
1777                exists(Arc::new(subquery1))
1778                    .and(exists(Arc::new(subquery2)).and(col("test.c").gt(lit(1u32)))),
1779            )?
1780            .project(vec![col("test.b")])?
1781            .build()?;
1782
1783        assert_optimized_plan_equal!(
1784            plan,
1785            @r"
1786        Projection: test.b [b:UInt32]
1787          Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]
1788            LeftSemi Join:  Filter: test.a = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]
1789              LeftSemi Join:  Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
1790                TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1791                SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]
1792                  Projection: sq1.c, sq1.a [c:UInt32, a:UInt32]
1793                    TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]
1794              SubqueryAlias: __correlated_sq_2 [c:UInt32, a:UInt32]
1795                Projection: sq2.c, sq2.a [c:UInt32, a:UInt32]
1796                  TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]
1797        "
1798        )
1799    }
1800
1801    #[test]
1802    fn exists_subquery_expr_filter() -> Result<()> {
1803        let table_scan = test_table_scan()?;
1804        let subquery_scan = test_table_scan_with_name("sq")?;
1805        let subquery = LogicalPlanBuilder::from(subquery_scan)
1806            .filter(
1807                (lit(1u32) + col("sq.a"))
1808                    .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
1809            )?
1810            .project(vec![lit(1u32)])?
1811            .build()?;
1812        let plan = LogicalPlanBuilder::from(table_scan)
1813            .filter(exists(Arc::new(subquery)))?
1814            .project(vec![col("test.b")])?
1815            .build()?;
1816
1817        assert_optimized_plan_equal!(
1818            plan,
1819            @r"
1820        Projection: test.b [b:UInt32]
1821          LeftSemi Join:  Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
1822            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1823            SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, a:UInt32]
1824              Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32]
1825                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1826        "
1827        )
1828    }
1829
1830    #[test]
1831    fn exists_subquery_with_same_table() -> Result<()> {
1832        let outer_scan = test_table_scan()?;
1833        let subquery_scan = test_table_scan()?;
1834        let subquery = LogicalPlanBuilder::from(subquery_scan)
1835            .filter(col("test.a").gt(col("test.b")))?
1836            .project(vec![col("c")])?
1837            .build()?;
1838
1839        let plan = LogicalPlanBuilder::from(outer_scan)
1840            .filter(exists(Arc::new(subquery)))?
1841            .project(vec![col("test.b")])?
1842            .build()?;
1843
1844        // Subquery and outer query refer to the same table.
1845        assert_optimized_plan_equal!(
1846            plan,
1847            @r"
1848        Projection: test.b [b:UInt32]
1849          LeftSemi Join:  Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]
1850            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1851            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1852              Projection: test.c [c:UInt32]
1853                Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]
1854                  TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1855        "
1856        )
1857    }
1858
1859    #[test]
1860    fn exists_distinct_subquery() -> Result<()> {
1861        let table_scan = test_table_scan()?;
1862        let subquery_scan = test_table_scan_with_name("sq")?;
1863        let subquery = LogicalPlanBuilder::from(subquery_scan)
1864            .filter(
1865                (lit(1u32) + col("sq.a"))
1866                    .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
1867            )?
1868            .project(vec![col("sq.c")])?
1869            .distinct()?
1870            .build()?;
1871        let plan = LogicalPlanBuilder::from(table_scan)
1872            .filter(exists(Arc::new(subquery)))?
1873            .project(vec![col("test.b")])?
1874            .build()?;
1875
1876        assert_optimized_plan_equal!(
1877            plan,
1878            @r"
1879        Projection: test.b [b:UInt32]
1880          LeftSemi Join:  Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
1881            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1882            SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]
1883              Distinct: [c:UInt32, a:UInt32]
1884                Projection: sq.c, sq.a [c:UInt32, a:UInt32]
1885                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1886        "
1887        )
1888    }
1889
1890    #[test]
1891    fn exists_distinct_expr_subquery() -> Result<()> {
1892        let table_scan = test_table_scan()?;
1893        let subquery_scan = test_table_scan_with_name("sq")?;
1894        let subquery = LogicalPlanBuilder::from(subquery_scan)
1895            .filter(
1896                (lit(1u32) + col("sq.a"))
1897                    .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
1898            )?
1899            .project(vec![col("sq.b") + col("sq.c")])?
1900            .distinct()?
1901            .build()?;
1902        let plan = LogicalPlanBuilder::from(table_scan)
1903            .filter(exists(Arc::new(subquery)))?
1904            .project(vec![col("test.b")])?
1905            .build()?;
1906
1907        assert_optimized_plan_equal!(
1908            plan,
1909            @r"
1910        Projection: test.b [b:UInt32]
1911          LeftSemi Join:  Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
1912            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1913            SubqueryAlias: __correlated_sq_1 [sq.b + sq.c:UInt32, a:UInt32]
1914              Distinct: [sq.b + sq.c:UInt32, a:UInt32]
1915                Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32]
1916                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1917        "
1918        )
1919    }
1920
1921    #[test]
1922    fn exists_distinct_subquery_with_literal() -> Result<()> {
1923        let table_scan = test_table_scan()?;
1924        let subquery_scan = test_table_scan_with_name("sq")?;
1925        let subquery = LogicalPlanBuilder::from(subquery_scan)
1926            .filter(
1927                (lit(1u32) + col("sq.a"))
1928                    .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
1929            )?
1930            .project(vec![lit(1u32), col("sq.c")])?
1931            .distinct()?
1932            .build()?;
1933        let plan = LogicalPlanBuilder::from(table_scan)
1934            .filter(exists(Arc::new(subquery)))?
1935            .project(vec![col("test.b")])?
1936            .build()?;
1937
1938        assert_optimized_plan_equal!(
1939            plan,
1940            @r"
1941        Projection: test.b [b:UInt32]
1942          LeftSemi Join:  Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
1943            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1944            SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, c:UInt32, a:UInt32]
1945              Distinct: [UInt32(1):UInt32, c:UInt32, a:UInt32]
1946                Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32]
1947                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1948        "
1949        )
1950    }
1951
1952    #[test]
1953    fn exists_uncorrelated_unnest() -> Result<()> {
1954        let subquery_table_source = table_source(&Schema::new(vec![Field::new(
1955            "arr",
1956            DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
1957            true,
1958        )]));
1959        let subquery = LogicalPlanBuilder::scan_with_filters(
1960            "sq",
1961            subquery_table_source,
1962            None,
1963            vec![],
1964        )?
1965        .unnest_column("arr")?
1966        .build()?;
1967        let table_scan = test_table_scan()?;
1968        let plan = LogicalPlanBuilder::from(table_scan)
1969            .filter(exists(Arc::new(subquery)))?
1970            .project(vec![col("test.b")])?
1971            .build()?;
1972
1973        assert_optimized_plan_equal!(
1974            plan,
1975            @r"
1976        Projection: test.b [b:UInt32]
1977          LeftSemi Join:  Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]
1978            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1979            SubqueryAlias: __correlated_sq_1 [arr:Int32;N]
1980              Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N]
1981                TableScan: sq [arr:List(Field { data_type: Int32, nullable: true });N]
1982        "
1983        )
1984    }
1985
1986    #[test]
1987    fn exists_correlated_unnest() -> Result<()> {
1988        let table_scan = test_table_scan()?;
1989        let subquery_table_source = table_source(&Schema::new(vec![Field::new(
1990            "a",
1991            DataType::List(Arc::new(Field::new_list_field(DataType::UInt32, true))),
1992            true,
1993        )]));
1994        let subquery = LogicalPlanBuilder::scan_with_filters(
1995            "sq",
1996            subquery_table_source,
1997            None,
1998            vec![],
1999        )?
2000        .unnest_column("a")?
2001        .filter(col("a").eq(out_ref_col(DataType::UInt32, "test.b")))?
2002        .build()?;
2003        let plan = LogicalPlanBuilder::from(table_scan)
2004            .filter(exists(Arc::new(subquery)))?
2005            .project(vec![col("test.b")])?
2006            .build()?;
2007
2008        assert_optimized_plan_equal!(
2009            plan,
2010            @r"
2011        Projection: test.b [b:UInt32]
2012          LeftSemi Join:  Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32]
2013            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
2014            SubqueryAlias: __correlated_sq_1 [a:UInt32;N]
2015              Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N]
2016                TableScan: sq [a:List(Field { data_type: UInt32, nullable: true });N]
2017        "
2018        )
2019    }
2020
2021    #[test]
2022    fn upper_case_ident() -> Result<()> {
2023        let fields = vec![
2024            Field::new("A", DataType::UInt32, false),
2025            Field::new("B", DataType::UInt32, false),
2026        ];
2027
2028        let schema = Schema::new(fields);
2029        let table_scan_a = table_scan(Some("\"TEST_A\""), &schema, None)?.build()?;
2030        let table_scan_b = table_scan(Some("\"TEST_B\""), &schema, None)?.build()?;
2031
2032        let subquery = LogicalPlanBuilder::from(table_scan_b)
2033            .filter(col("\"A\"").eq(out_ref_col(DataType::UInt32, "\"TEST_A\".\"A\"")))?
2034            .project(vec![lit(1)])?
2035            .build()?;
2036
2037        let plan = LogicalPlanBuilder::from(table_scan_a)
2038            .filter(exists(Arc::new(subquery)))?
2039            .project(vec![col("\"TEST_A\".\"B\"")])?
2040            .build()?;
2041
2042        assert_optimized_plan_equal!(
2043            plan,
2044            @r"
2045        Projection: TEST_A.B [B:UInt32]
2046          LeftSemi Join:  Filter: __correlated_sq_1.A = TEST_A.A [A:UInt32, B:UInt32]
2047            TableScan: TEST_A [A:UInt32, B:UInt32]
2048            SubqueryAlias: __correlated_sq_1 [Int32(1):Int32, A:UInt32]
2049              Projection: Int32(1), TEST_B.A [Int32(1):Int32, A:UInt32]
2050                TableScan: TEST_B [A:UInt32, B:UInt32]
2051        "
2052        )
2053    }
2054}