Skip to main content

datafusion_optimizer/
decorrelate_predicate_subquery.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`DecorrelatePredicateSubquery`] converts `IN`/`EXISTS` subquery predicates to `SEMI`/`ANTI` joins
19use std::collections::BTreeSet;
20use std::ops::Deref;
21use std::sync::Arc;
22
23use crate::decorrelate::PullUpCorrelatedExpr;
24use crate::optimizer::ApplyOrder;
25use crate::utils::replace_qualified_name;
26use crate::{OptimizerConfig, OptimizerRule};
27
28use datafusion_common::alias::AliasGenerator;
29use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30use datafusion_common::{
31    Column, DFSchemaRef, ExprSchema, NullEquality, Result, assert_or_internal_err,
32    plan_err,
33};
34use datafusion_expr::expr::{Exists, InSubquery};
35use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
36use datafusion_expr::logical_plan::{JoinType, Subquery};
37use datafusion_expr::utils::{conjunction, expr_to_columns, split_conjunction_owned};
38use datafusion_expr::{
39    BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, exists,
40    in_subquery, lit, not, not_exists, not_in_subquery,
41};
42
43use log::debug;
44
45/// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins
46#[derive(Default, Debug)]
47pub struct DecorrelatePredicateSubquery {}
48
49impl DecorrelatePredicateSubquery {
50    #[expect(missing_docs)]
51    pub fn new() -> Self {
52        Self::default()
53    }
54}
55
56impl OptimizerRule for DecorrelatePredicateSubquery {
57    fn supports_rewrite(&self) -> bool {
58        true
59    }
60
61    fn rewrite(
62        &self,
63        plan: LogicalPlan,
64        config: &dyn OptimizerConfig,
65    ) -> Result<Transformed<LogicalPlan>> {
66        let plan = plan
67            .map_subqueries(|subquery| {
68                subquery.transform_down(|p| self.rewrite(p, config))
69            })?
70            .data;
71
72        let LogicalPlan::Filter(filter) = plan else {
73            return Ok(Transformed::no(plan));
74        };
75
76        if !has_subquery(&filter.predicate) {
77            return Ok(Transformed::no(LogicalPlan::Filter(filter)));
78        }
79
80        let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
81            split_conjunction_owned(filter.predicate)
82                .into_iter()
83                .partition(has_subquery);
84
85        assert_or_internal_err!(
86            !with_subqueries.is_empty(),
87            "can not find expected subqueries in DecorrelatePredicateSubquery"
88        );
89
90        // iterate through all exists clauses in predicate, turning each into a join
91        let mut cur_input = Arc::unwrap_or_clone(filter.input);
92        for subquery_expr in with_subqueries {
93            match extract_subquery_info(subquery_expr) {
94                // The subquery expression is at the top level of the filter
95                SubqueryPredicate::Top(subquery) => {
96                    match build_join_top(&subquery, &cur_input, config.alias_generator())?
97                    {
98                        Some(plan) => cur_input = plan,
99                        // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
100                        None => other_exprs.push(subquery.expr()),
101                    }
102                }
103                // The subquery expression is embedded within another expression
104                SubqueryPredicate::Embedded(expr) => {
105                    let (plan, expr_without_subqueries) =
106                        rewrite_inner_subqueries(cur_input, expr, config)?;
107                    cur_input = plan;
108                    other_exprs.push(expr_without_subqueries);
109                }
110            }
111        }
112
113        let expr = conjunction(other_exprs);
114        if let Some(expr) = expr {
115            let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
116            cur_input = LogicalPlan::Filter(new_filter);
117        }
118        Ok(Transformed::yes(cur_input))
119    }
120
121    fn name(&self) -> &str {
122        "decorrelate_predicate_subquery"
123    }
124
125    fn apply_order(&self) -> Option<ApplyOrder> {
126        Some(ApplyOrder::TopDown)
127    }
128}
129
130fn rewrite_inner_subqueries(
131    outer: LogicalPlan,
132    expr: Expr,
133    config: &dyn OptimizerConfig,
134) -> Result<(LogicalPlan, Expr)> {
135    let mut cur_input = outer;
136    let alias = config.alias_generator();
137    let expr_without_subqueries = expr.transform(|e| match e {
138        Expr::Exists(Exists {
139            subquery: Subquery { subquery, .. },
140            negated,
141        }) => match mark_join(&cur_input, &subquery, None, negated, alias)? {
142            Some((plan, exists_expr)) => {
143                cur_input = plan;
144                Ok(Transformed::yes(exists_expr))
145            }
146            None if negated => Ok(Transformed::no(not_exists(subquery))),
147            None => Ok(Transformed::no(exists(subquery))),
148        },
149        Expr::InSubquery(InSubquery {
150            expr,
151            subquery: Subquery { subquery, .. },
152            negated,
153        }) => {
154            let in_predicate = subquery
155                .head_output_expr()?
156                .map_or(plan_err!("single expression required."), |output_expr| {
157                    Ok(Expr::eq(*expr.clone(), output_expr))
158                })?;
159            match mark_join(&cur_input, &subquery, Some(&in_predicate), negated, alias)? {
160                Some((plan, exists_expr)) => {
161                    cur_input = plan;
162                    Ok(Transformed::yes(exists_expr))
163                }
164                None if negated => Ok(Transformed::no(not_in_subquery(*expr, subquery))),
165                None => Ok(Transformed::no(in_subquery(*expr, subquery))),
166            }
167        }
168        _ => Ok(Transformed::no(e)),
169    })?;
170    Ok((cur_input, expr_without_subqueries.data))
171}
172
173enum SubqueryPredicate {
174    // The subquery expression is at the top level of the filter and can be fully replaced by a
175    // semi/anti join
176    Top(SubqueryInfo),
177    // The subquery expression is embedded within another expression and is replaced using an
178    // existence join
179    Embedded(Expr),
180}
181
182fn extract_subquery_info(expr: Expr) -> SubqueryPredicate {
183    match expr {
184        Expr::Not(not_expr) => match *not_expr {
185            Expr::InSubquery(InSubquery {
186                expr,
187                subquery,
188                negated,
189            }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr(
190                subquery, *expr, !negated,
191            )),
192            Expr::Exists(Exists { subquery, negated }) => {
193                SubqueryPredicate::Top(SubqueryInfo::new(subquery, !negated))
194            }
195            expr => SubqueryPredicate::Embedded(not(expr)),
196        },
197        Expr::InSubquery(InSubquery {
198            expr,
199            subquery,
200            negated,
201        }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr(
202            subquery, *expr, negated,
203        )),
204        Expr::Exists(Exists { subquery, negated }) => {
205            SubqueryPredicate::Top(SubqueryInfo::new(subquery, negated))
206        }
207        expr => SubqueryPredicate::Embedded(expr),
208    }
209}
210
211fn has_subquery(expr: &Expr) -> bool {
212    expr.exists(|e| match e {
213        Expr::InSubquery(_) | Expr::Exists(_) => Ok(true),
214        _ => Ok(false),
215    })
216    .unwrap()
217}
218
219/// Optimize the subquery to left-anti/left-semi join.
220/// If the subquery is a correlated subquery, we need extract the join predicate from the subquery.
221///
222/// For example, given a query like:
223/// `select t1.a, t1.b from t1 where t1 in (select t2.a from t2 where t1.b = t2.b and t1.c > t2.c)`
224///
225/// The optimized plan will be:
226///
227/// ```text
228/// Projection: t1.a, t1.b
229///   LeftSemi Join:  Filter: t1.a = __correlated_sq_1.a AND t1.b = __correlated_sq_1.b AND t1.c > __correlated_sq_1.c
230///     TableScan: t1
231///     SubqueryAlias: __correlated_sq_1
232///       Projection: t2.a, t2.b, t2.c
233///         TableScan: t2
234/// ```
235///
236/// Given another query like:
237/// `select t1.id from t1 where exists(SELECT t2.id FROM t2 WHERE t1.id = t2.id)`
238///
239/// The optimized plan will be:
240///
241/// ```text
242/// Projection: t1.id
243///   LeftSemi Join:  Filter: t1.id = __correlated_sq_1.id
244///     TableScan: t1
245///     SubqueryAlias: __correlated_sq_1
246///       Projection: t2.id
247///         TableScan: t2
248/// ```
249fn build_join_top(
250    query_info: &SubqueryInfo,
251    left: &LogicalPlan,
252    alias: &Arc<AliasGenerator>,
253) -> Result<Option<LogicalPlan>> {
254    let where_in_expr_opt = &query_info.where_in_expr;
255    let in_predicate_opt = where_in_expr_opt
256        .clone()
257        .map(|where_in_expr| {
258            query_info
259                .query
260                .subquery
261                .head_output_expr()?
262                .map_or(plan_err!("single expression required."), |expr| {
263                    Ok(Expr::eq(where_in_expr, expr))
264                })
265        })
266        .map_or(Ok(None), |v| v.map(Some))?;
267
268    let join_type = match query_info.negated {
269        true => JoinType::LeftAnti,
270        false => JoinType::LeftSemi,
271    };
272    let subquery = query_info.query.subquery.as_ref();
273    let subquery_alias = alias.next("__correlated_sq");
274    build_join(
275        left,
276        subquery,
277        in_predicate_opt.as_ref(),
278        join_type,
279        subquery_alias,
280    )
281}
282
283/// This is used to handle the case when the subquery is embedded in a more complex boolean
284/// expression like and OR. For example
285///
286/// `select t1.id from t1 where t1.id < 0 OR exists(SELECT t2.id FROM t2 WHERE t1.id = t2.id)`
287///
288/// The optimized plan will be:
289///
290/// ```text
291/// Projection: t1.id
292///   Filter: t1.id < 0 OR __correlated_sq_1.mark
293///     LeftMark Join:  Filter: t1.id = __correlated_sq_1.id
294///       TableScan: t1
295///       SubqueryAlias: __correlated_sq_1
296///         Projection: t2.id
297///           TableScan: t2
298fn mark_join(
299    left: &LogicalPlan,
300    subquery: &LogicalPlan,
301    in_predicate_opt: Option<&Expr>,
302    negated: bool,
303    alias_generator: &Arc<AliasGenerator>,
304) -> Result<Option<(LogicalPlan, Expr)>> {
305    let alias = alias_generator.next("__correlated_sq");
306
307    let exists_col = Expr::Column(Column::new(Some(alias.clone()), "mark"));
308    let exists_expr = if negated { !exists_col } else { exists_col };
309
310    Ok(
311        build_join(left, subquery, in_predicate_opt, JoinType::LeftMark, alias)?
312            .map(|plan| (plan, exists_expr)),
313    )
314}
315
316/// Check if join keys in the join filter may contain NULL values
317///
318/// Returns true if any join key column is nullable on either side.
319/// This is used to optimize null-aware anti joins: if all join keys are non-nullable,
320/// we can use a regular anti join instead of the more expensive null-aware variant.
321fn join_keys_may_be_null(
322    join_filter: &Expr,
323    left_schema: &DFSchemaRef,
324    right_schema: &DFSchemaRef,
325) -> Result<bool> {
326    // Extract columns from the join filter
327    let mut columns = std::collections::HashSet::new();
328    expr_to_columns(join_filter, &mut columns)?;
329
330    // Check if any column is nullable
331    for col in columns {
332        // Check in left schema
333        if let Ok(field) = left_schema.field_from_column(&col)
334            && field.as_ref().is_nullable()
335        {
336            return Ok(true);
337        }
338        // Check in right schema
339        if let Ok(field) = right_schema.field_from_column(&col)
340            && field.as_ref().is_nullable()
341        {
342            return Ok(true);
343        }
344    }
345
346    Ok(false)
347}
348
349fn build_join(
350    left: &LogicalPlan,
351    subquery: &LogicalPlan,
352    in_predicate_opt: Option<&Expr>,
353    join_type: JoinType,
354    alias: String,
355) -> Result<Option<LogicalPlan>> {
356    let mut pull_up = PullUpCorrelatedExpr::new()
357        .with_in_predicate_opt(in_predicate_opt.cloned())
358        .with_exists_sub_query(in_predicate_opt.is_none());
359
360    let new_plan = subquery.clone().rewrite(&mut pull_up).data()?;
361    if !pull_up.can_pull_up {
362        return Ok(None);
363    }
364
365    let sub_query_alias = LogicalPlanBuilder::from(new_plan)
366        .alias(alias.to_string())?
367        .build()?;
368    let mut all_correlated_cols = BTreeSet::new();
369    pull_up
370        .correlated_subquery_cols_map
371        .values()
372        .for_each(|cols| all_correlated_cols.extend(cols.clone()));
373
374    // alias the join filter
375    let join_filter_opt = conjunction(pull_up.join_filters)
376        .map_or(Ok(None), |filter| {
377            replace_qualified_name(filter, &all_correlated_cols, &alias).map(Some)
378        })?;
379
380    let join_filter = match (join_filter_opt, in_predicate_opt.cloned()) {
381        (
382            Some(join_filter),
383            Some(Expr::BinaryExpr(BinaryExpr {
384                left,
385                op: Operator::Eq,
386                right,
387            })),
388        ) => {
389            let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
390            let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col));
391            in_predicate.and(join_filter)
392        }
393        (Some(join_filter), _) => join_filter,
394        (
395            _,
396            Some(Expr::BinaryExpr(BinaryExpr {
397                left,
398                op: Operator::Eq,
399                right,
400            })),
401        ) => {
402            let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
403
404            Expr::eq(left.deref().clone(), Expr::Column(right_col))
405        }
406        (None, None) => lit(true),
407        _ => return Ok(None),
408    };
409
410    if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) {
411        let right_schema = sub_query_alias.schema();
412
413        // Gather all columns needed for the join filter + predicates
414        let mut needed = std::collections::HashSet::new();
415        expr_to_columns(&join_filter, &mut needed)?;
416        if let Some(in_pred) = in_predicate_opt {
417            expr_to_columns(in_pred, &mut needed)?;
418        }
419
420        // Keep only columns that actually belong to the RIGHT child, and sort by their
421        // position in the right schema for deterministic order.
422        let mut right_cols_idx_and_col: Vec<(usize, Column)> = needed
423            .into_iter()
424            .filter_map(|c| right_schema.index_of_column(&c).ok().map(|idx| (idx, c)))
425            .collect();
426
427        right_cols_idx_and_col.sort_by_key(|(idx, _)| *idx);
428
429        let right_proj_exprs: Vec<Expr> = right_cols_idx_and_col
430            .into_iter()
431            .map(|(_, c)| Expr::Column(c))
432            .collect();
433
434        let right_projected = if !right_proj_exprs.is_empty() {
435            LogicalPlanBuilder::from(sub_query_alias.clone())
436                .project(right_proj_exprs)?
437                .build()?
438        } else {
439            // Degenerate case: no right columns referenced by the predicate(s)
440            sub_query_alias.clone()
441        };
442
443        // Mark joins don't use null-aware semantics (they use three-valued logic with mark column)
444        let new_plan = LogicalPlanBuilder::from(left.clone())
445            .join_on(right_projected, join_type, Some(join_filter))?
446            .build()?;
447
448        debug!(
449            "predicate subquery optimized:\n{}",
450            new_plan.display_indent()
451        );
452
453        return Ok(Some(new_plan));
454    }
455
456    // Determine if this should be a null-aware anti join
457    // Null-aware semantics are only needed for NOT IN subqueries, not NOT EXISTS:
458    // - NOT IN: Uses three-valued logic, requires null-aware handling
459    // - NOT EXISTS: Uses two-valued logic, regular anti join is correct
460    // We can distinguish them: NOT IN has in_predicate_opt, NOT EXISTS does not
461    //
462    // Additionally, if the join keys are non-nullable on both sides, we don't need
463    // null-aware semantics because NULLs cannot exist in the data.
464    let null_aware = join_type == JoinType::LeftAnti
465        && in_predicate_opt.is_some()
466        && join_keys_may_be_null(&join_filter, left.schema(), sub_query_alias.schema())?;
467
468    // join our sub query into the main plan
469    let new_plan = if null_aware {
470        // Use join_detailed_with_options to set null_aware flag
471        LogicalPlanBuilder::from(left.clone())
472            .join_detailed_with_options(
473                sub_query_alias,
474                join_type,
475                (Vec::<Column>::new(), Vec::<Column>::new()), // No equijoin keys, filter-based join
476                Some(join_filter),
477                NullEquality::NullEqualsNothing,
478                true, // null_aware
479            )?
480            .build()?
481    } else {
482        LogicalPlanBuilder::from(left.clone())
483            .join_on(sub_query_alias, join_type, Some(join_filter))?
484            .build()?
485    };
486    debug!(
487        "predicate subquery optimized:\n{}",
488        new_plan.display_indent()
489    );
490    Ok(Some(new_plan))
491}
492
493#[derive(Debug)]
494struct SubqueryInfo {
495    query: Subquery,
496    where_in_expr: Option<Expr>,
497    negated: bool,
498}
499
500impl SubqueryInfo {
501    pub fn new(query: Subquery, negated: bool) -> Self {
502        Self {
503            query,
504            where_in_expr: None,
505            negated,
506        }
507    }
508
509    pub fn new_with_in_expr(query: Subquery, expr: Expr, negated: bool) -> Self {
510        Self {
511            query,
512            where_in_expr: Some(expr),
513            negated,
514        }
515    }
516
517    pub fn expr(self) -> Expr {
518        match self.where_in_expr {
519            Some(expr) => match self.negated {
520                true => not_in_subquery(expr, self.query.subquery),
521                false => in_subquery(expr, self.query.subquery),
522            },
523            None => match self.negated {
524                true => not_exists(self.query.subquery),
525                false => exists(self.query.subquery),
526            },
527        }
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use std::ops::Add;
534
535    use super::*;
536    use crate::test::*;
537
538    use crate::assert_optimized_plan_eq_display_indent_snapshot;
539    use arrow::datatypes::{DataType, Field, Schema};
540    use datafusion_expr::builder::table_source;
541    use datafusion_expr::{and, binary_expr, col, lit, not, out_ref_col, table_scan};
542
543    macro_rules! assert_optimized_plan_equal {
544        (
545            $plan:expr,
546            @ $expected:literal $(,)?
547        ) => {{
548            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(DecorrelatePredicateSubquery::new());
549            assert_optimized_plan_eq_display_indent_snapshot!(
550                rule,
551                $plan,
552                @ $expected,
553            )
554        }};
555    }
556
557    fn test_subquery_with_name(name: &str) -> Result<Arc<LogicalPlan>> {
558        let table_scan = test_table_scan_with_name(name)?;
559        Ok(Arc::new(
560            LogicalPlanBuilder::from(table_scan)
561                .project(vec![col("c")])?
562                .build()?,
563        ))
564    }
565
566    /// Test for several IN subquery expressions
567    #[test]
568    fn in_subquery_multiple() -> Result<()> {
569        let table_scan = test_table_scan()?;
570        let plan = LogicalPlanBuilder::from(table_scan)
571            .filter(and(
572                in_subquery(col("c"), test_subquery_with_name("sq_1")?),
573                in_subquery(col("b"), test_subquery_with_name("sq_2")?),
574            ))?
575            .project(vec![col("test.b")])?
576            .build()?;
577
578        assert_optimized_plan_equal!(
579            plan,
580            @r"
581        Projection: test.b [b:UInt32]
582          LeftSemi Join:  Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]
583            LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
584              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
585              SubqueryAlias: __correlated_sq_1 [c:UInt32]
586                Projection: sq_1.c [c:UInt32]
587                  TableScan: sq_1 [a:UInt32, b:UInt32, c:UInt32]
588            SubqueryAlias: __correlated_sq_2 [c:UInt32]
589              Projection: sq_2.c [c:UInt32]
590                TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]
591        "
592        )
593    }
594
595    /// Test for IN subquery with additional AND filter
596    #[test]
597    fn in_subquery_with_and_filters() -> Result<()> {
598        let table_scan = test_table_scan()?;
599        let plan = LogicalPlanBuilder::from(table_scan)
600            .filter(and(
601                in_subquery(col("c"), test_subquery_with_name("sq")?),
602                and(
603                    binary_expr(col("a"), Operator::Eq, lit(1_u32)),
604                    binary_expr(col("b"), Operator::Lt, lit(30_u32)),
605                ),
606            ))?
607            .project(vec![col("test.b")])?
608            .build()?;
609
610        assert_optimized_plan_equal!(
611            plan,
612            @r"
613        Projection: test.b [b:UInt32]
614          Filter: test.a = UInt32(1) AND test.b < UInt32(30) [a:UInt32, b:UInt32, c:UInt32]
615            LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
616              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
617              SubqueryAlias: __correlated_sq_1 [c:UInt32]
618                Projection: sq.c [c:UInt32]
619                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
620        "
621        )
622    }
623
624    /// Test for nested IN subqueries
625    #[test]
626    fn in_subquery_nested() -> Result<()> {
627        let table_scan = test_table_scan()?;
628
629        let subquery = LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
630            .filter(in_subquery(col("a"), test_subquery_with_name("sq_nested")?))?
631            .project(vec![col("a")])?
632            .build()?;
633
634        let plan = LogicalPlanBuilder::from(table_scan)
635            .filter(in_subquery(col("b"), Arc::new(subquery)))?
636            .project(vec![col("test.b")])?
637            .build()?;
638
639        assert_optimized_plan_equal!(
640            plan,
641            @r"
642        Projection: test.b [b:UInt32]
643          LeftSemi Join:  Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]
644            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
645            SubqueryAlias: __correlated_sq_2 [a:UInt32]
646              Projection: sq.a [a:UInt32]
647                LeftSemi Join:  Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
648                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
649                  SubqueryAlias: __correlated_sq_1 [c:UInt32]
650                    Projection: sq_nested.c [c:UInt32]
651                      TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]
652        "
653        )
654    }
655
656    /// Test multiple correlated subqueries
657    /// See subqueries.rs where_in_multiple()
658    #[test]
659    fn multiple_subqueries() -> Result<()> {
660        let orders = Arc::new(
661            LogicalPlanBuilder::from(scan_tpch_table("orders"))
662                .filter(
663                    col("orders.o_custkey")
664                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
665                )?
666                .project(vec![col("orders.o_custkey")])?
667                .build()?,
668        );
669        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
670            .filter(
671                in_subquery(col("customer.c_custkey"), Arc::clone(&orders))
672                    .and(in_subquery(col("customer.c_custkey"), orders)),
673            )?
674            .project(vec![col("customer.c_custkey")])?
675            .build()?;
676        debug!("plan to optimize:\n{}", plan.display_indent());
677
678        assert_optimized_plan_equal!(
679                plan,
680                @r"
681        Projection: customer.c_custkey [c_custkey:Int64]
682          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]
683            LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
684              TableScan: customer [c_custkey:Int64, c_name:Utf8]
685              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
686                Projection: orders.o_custkey [o_custkey:Int64]
687                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
688            SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]
689              Projection: orders.o_custkey [o_custkey:Int64]
690                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
691        "    
692        )
693    }
694
695    /// Test recursive correlated subqueries
696    /// See subqueries.rs where_in_recursive()
697    #[test]
698    fn recursive_subqueries() -> Result<()> {
699        let lineitem = Arc::new(
700            LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
701                .filter(
702                    col("lineitem.l_orderkey")
703                        .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
704                )?
705                .project(vec![col("lineitem.l_orderkey")])?
706                .build()?,
707        );
708
709        let orders = Arc::new(
710            LogicalPlanBuilder::from(scan_tpch_table("orders"))
711                .filter(
712                    in_subquery(col("orders.o_orderkey"), lineitem).and(
713                        col("orders.o_custkey")
714                            .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
715                    ),
716                )?
717                .project(vec![col("orders.o_custkey")])?
718                .build()?,
719        );
720
721        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
722            .filter(in_subquery(col("customer.c_custkey"), orders))?
723            .project(vec![col("customer.c_custkey")])?
724            .build()?;
725
726        assert_optimized_plan_equal!(
727            plan,
728            @r"
729        Projection: customer.c_custkey [c_custkey:Int64]
730          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]
731            TableScan: customer [c_custkey:Int64, c_name:Utf8]
732            SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]
733              Projection: orders.o_custkey [o_custkey:Int64]
734                LeftSemi Join:  Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
735                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
736                  SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]
737                    Projection: lineitem.l_orderkey [l_orderkey:Int64]
738                      TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]
739        "
740        )
741    }
742
743    /// Test for correlated IN subquery filter with additional subquery filters
744    #[test]
745    fn in_subquery_with_subquery_filters() -> Result<()> {
746        let sq = Arc::new(
747            LogicalPlanBuilder::from(scan_tpch_table("orders"))
748                .filter(
749                    out_ref_col(DataType::Int64, "customer.c_custkey")
750                        .eq(col("orders.o_custkey"))
751                        .and(col("o_orderkey").eq(lit(1))),
752                )?
753                .project(vec![col("orders.o_custkey")])?
754                .build()?,
755        );
756
757        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
758            .filter(in_subquery(col("customer.c_custkey"), sq))?
759            .project(vec![col("customer.c_custkey")])?
760            .build()?;
761
762        assert_optimized_plan_equal!(
763            plan,
764            @r"
765        Projection: customer.c_custkey [c_custkey:Int64]
766          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
767            TableScan: customer [c_custkey:Int64, c_name:Utf8]
768            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
769              Projection: orders.o_custkey [o_custkey:Int64]
770                Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
771                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
772        "
773        )
774    }
775
776    /// Test for correlated IN subquery with no columns in schema
777    #[test]
778    fn in_subquery_no_cols() -> Result<()> {
779        let sq = Arc::new(
780            LogicalPlanBuilder::from(scan_tpch_table("orders"))
781                .filter(
782                    out_ref_col(DataType::Int64, "customer.c_custkey")
783                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
784                )?
785                .project(vec![col("orders.o_custkey")])?
786                .build()?,
787        );
788
789        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
790            .filter(in_subquery(col("customer.c_custkey"), sq))?
791            .project(vec![col("customer.c_custkey")])?
792            .build()?;
793
794        assert_optimized_plan_equal!(
795            plan,
796            @r"
797        Projection: customer.c_custkey [c_custkey:Int64]
798          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
799            TableScan: customer [c_custkey:Int64, c_name:Utf8]
800            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
801              Projection: orders.o_custkey [o_custkey:Int64]
802                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
803        "
804        )
805    }
806
807    /// Test for IN subquery with both columns in schema
808    #[test]
809    fn in_subquery_with_no_correlated_cols() -> Result<()> {
810        let sq = Arc::new(
811            LogicalPlanBuilder::from(scan_tpch_table("orders"))
812                .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
813                .project(vec![col("orders.o_custkey")])?
814                .build()?,
815        );
816
817        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
818            .filter(in_subquery(col("customer.c_custkey"), sq))?
819            .project(vec![col("customer.c_custkey")])?
820            .build()?;
821
822        assert_optimized_plan_equal!(
823            plan,
824            @r"
825        Projection: customer.c_custkey [c_custkey:Int64]
826          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
827            TableScan: customer [c_custkey:Int64, c_name:Utf8]
828            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
829              Projection: orders.o_custkey [o_custkey:Int64]
830                Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
831                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
832        "
833        )
834    }
835
836    /// Test for correlated IN subquery not equal
837    #[test]
838    fn in_subquery_where_not_eq() -> Result<()> {
839        let sq = Arc::new(
840            LogicalPlanBuilder::from(scan_tpch_table("orders"))
841                .filter(
842                    out_ref_col(DataType::Int64, "customer.c_custkey")
843                        .not_eq(col("orders.o_custkey")),
844                )?
845                .project(vec![col("orders.o_custkey")])?
846                .build()?,
847        );
848
849        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
850            .filter(in_subquery(col("customer.c_custkey"), sq))?
851            .project(vec![col("customer.c_custkey")])?
852            .build()?;
853
854        assert_optimized_plan_equal!(
855            plan,
856            @r"
857        Projection: customer.c_custkey [c_custkey:Int64]
858          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
859            TableScan: customer [c_custkey:Int64, c_name:Utf8]
860            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
861              Projection: orders.o_custkey [o_custkey:Int64]
862                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
863        "
864        )
865    }
866
867    /// Test for correlated IN subquery less than
868    #[test]
869    fn in_subquery_where_less_than() -> Result<()> {
870        let sq = Arc::new(
871            LogicalPlanBuilder::from(scan_tpch_table("orders"))
872                .filter(
873                    out_ref_col(DataType::Int64, "customer.c_custkey")
874                        .lt(col("orders.o_custkey")),
875                )?
876                .project(vec![col("orders.o_custkey")])?
877                .build()?,
878        );
879
880        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
881            .filter(in_subquery(col("customer.c_custkey"), sq))?
882            .project(vec![col("customer.c_custkey")])?
883            .build()?;
884
885        assert_optimized_plan_equal!(
886            plan,
887            @r"
888        Projection: customer.c_custkey [c_custkey:Int64]
889          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
890            TableScan: customer [c_custkey:Int64, c_name:Utf8]
891            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
892              Projection: orders.o_custkey [o_custkey:Int64]
893                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
894        "
895        )
896    }
897
898    /// Test for correlated IN subquery filter with subquery disjunction
899    #[test]
900    fn in_subquery_with_subquery_disjunction() -> Result<()> {
901        let sq = Arc::new(
902            LogicalPlanBuilder::from(scan_tpch_table("orders"))
903                .filter(
904                    out_ref_col(DataType::Int64, "customer.c_custkey")
905                        .eq(col("orders.o_custkey"))
906                        .or(col("o_orderkey").eq(lit(1))),
907                )?
908                .project(vec![col("orders.o_custkey")])?
909                .build()?,
910        );
911
912        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
913            .filter(in_subquery(col("customer.c_custkey"), sq))?
914            .project(vec![col("customer.c_custkey")])?
915            .build()?;
916
917        assert_optimized_plan_equal!(
918            plan,
919            @r"
920        Projection: customer.c_custkey [c_custkey:Int64]
921          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8]
922            TableScan: customer [c_custkey:Int64, c_name:Utf8]
923            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]
924              Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]
925                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
926        "
927        )
928    }
929
930    /// Test for correlated IN without projection
931    #[test]
932    fn in_subquery_no_projection() -> Result<()> {
933        let sq = Arc::new(
934            LogicalPlanBuilder::from(scan_tpch_table("orders"))
935                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
936                .build()?,
937        );
938
939        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
940            .filter(in_subquery(col("customer.c_custkey"), sq))?
941            .project(vec![col("customer.c_custkey")])?
942            .build()?;
943
944        // Maybe okay if the table only has a single column?
945        let expected = "Invalid (non-executable) plan after Analyzer\
946        \ncaused by\
947        \nError during planning: InSubquery should only return one column, but found 4";
948        assert_analyzer_check_err(vec![], plan, expected);
949
950        Ok(())
951    }
952
953    /// Test for correlated IN subquery join on expression
954    #[test]
955    fn in_subquery_join_expr() -> Result<()> {
956        let sq = Arc::new(
957            LogicalPlanBuilder::from(scan_tpch_table("orders"))
958                .filter(
959                    out_ref_col(DataType::Int64, "customer.c_custkey")
960                        .eq(col("orders.o_custkey")),
961                )?
962                .project(vec![col("orders.o_custkey")])?
963                .build()?,
964        );
965
966        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
967            .filter(in_subquery(col("customer.c_custkey").add(lit(1)), sq))?
968            .project(vec![col("customer.c_custkey")])?
969            .build()?;
970
971        assert_optimized_plan_equal!(
972            plan,
973            @r"
974        Projection: customer.c_custkey [c_custkey:Int64]
975          LeftSemi Join:  Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
976            TableScan: customer [c_custkey:Int64, c_name:Utf8]
977            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
978              Projection: orders.o_custkey [o_custkey:Int64]
979                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
980        "
981        )
982    }
983
984    /// Test for correlated IN expressions
985    #[test]
986    fn in_subquery_project_expr() -> Result<()> {
987        let sq = Arc::new(
988            LogicalPlanBuilder::from(scan_tpch_table("orders"))
989                .filter(
990                    out_ref_col(DataType::Int64, "customer.c_custkey")
991                        .eq(col("orders.o_custkey")),
992                )?
993                .project(vec![col("orders.o_custkey").add(lit(1))])?
994                .build()?,
995        );
996
997        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
998            .filter(in_subquery(col("customer.c_custkey"), sq))?
999            .project(vec![col("customer.c_custkey")])?
1000            .build()?;
1001
1002        assert_optimized_plan_equal!(
1003            plan,
1004            @r"
1005        Projection: customer.c_custkey [c_custkey:Int64]
1006          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.orders.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1007            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1008            SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]
1009              Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]
1010                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1011        "
1012        )
1013    }
1014
1015    /// Test for correlated IN subquery multiple projected columns
1016    #[test]
1017    fn in_subquery_multi_col() -> Result<()> {
1018        let sq = Arc::new(
1019            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1020                .filter(
1021                    out_ref_col(DataType::Int64, "customer.c_custkey")
1022                        .eq(col("orders.o_custkey")),
1023                )?
1024                .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])?
1025                .build()?,
1026        );
1027
1028        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1029            .filter(
1030                in_subquery(col("customer.c_custkey"), sq)
1031                    .and(col("c_custkey").eq(lit(1))),
1032            )?
1033            .project(vec![col("customer.c_custkey")])?
1034            .build()?;
1035
1036        let expected = "Invalid (non-executable) plan after Analyzer\
1037        \ncaused by\
1038        \nError during planning: InSubquery should only return one column";
1039        assert_analyzer_check_err(vec![], plan, expected);
1040
1041        Ok(())
1042    }
1043
1044    /// Test for correlated IN subquery filter with additional filters
1045    #[test]
1046    fn should_support_additional_filters() -> Result<()> {
1047        let sq = Arc::new(
1048            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1049                .filter(
1050                    out_ref_col(DataType::Int64, "customer.c_custkey")
1051                        .eq(col("orders.o_custkey")),
1052                )?
1053                .project(vec![col("orders.o_custkey")])?
1054                .build()?,
1055        );
1056
1057        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1058            .filter(
1059                in_subquery(col("customer.c_custkey"), sq)
1060                    .and(col("c_custkey").eq(lit(1))),
1061            )?
1062            .project(vec![col("customer.c_custkey")])?
1063            .build()?;
1064
1065        assert_optimized_plan_equal!(
1066            plan,
1067            @r"
1068        Projection: customer.c_custkey [c_custkey:Int64]
1069          Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
1070            LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1071              TableScan: customer [c_custkey:Int64, c_name:Utf8]
1072              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1073                Projection: orders.o_custkey [o_custkey:Int64]
1074                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1075        "
1076        )
1077    }
1078
1079    /// Test for correlated IN subquery filter
1080    #[test]
1081    fn in_subquery_correlated() -> Result<()> {
1082        let sq = Arc::new(
1083            LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
1084                .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
1085                .project(vec![col("c")])?
1086                .build()?,
1087        );
1088
1089        let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
1090            .filter(in_subquery(col("c"), sq))?
1091            .project(vec![col("test.b")])?
1092            .build()?;
1093
1094        assert_optimized_plan_equal!(
1095            plan,
1096            @r"
1097        Projection: test.b [b:UInt32]
1098          LeftSemi Join:  Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
1099            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1100            SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]
1101              Projection: sq.c, sq.a [c:UInt32, a:UInt32]
1102                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1103        "
1104        )
1105    }
1106
1107    /// Test for single IN subquery filter
1108    #[test]
1109    fn in_subquery_simple() -> Result<()> {
1110        let table_scan = test_table_scan()?;
1111        let plan = LogicalPlanBuilder::from(table_scan)
1112            .filter(in_subquery(col("c"), test_subquery_with_name("sq")?))?
1113            .project(vec![col("test.b")])?
1114            .build()?;
1115
1116        assert_optimized_plan_equal!(
1117            plan,
1118            @r"
1119        Projection: test.b [b:UInt32]
1120          LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1121            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1122            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1123              Projection: sq.c [c:UInt32]
1124                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1125        "
1126        )
1127    }
1128
1129    /// Test for single NOT IN subquery filter
1130    #[test]
1131    fn not_in_subquery_simple() -> Result<()> {
1132        let table_scan = test_table_scan()?;
1133        let plan = LogicalPlanBuilder::from(table_scan)
1134            .filter(not_in_subquery(col("c"), test_subquery_with_name("sq")?))?
1135            .project(vec![col("test.b")])?
1136            .build()?;
1137
1138        assert_optimized_plan_equal!(
1139            plan,
1140            @r"
1141        Projection: test.b [b:UInt32]
1142          LeftAnti Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1143            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1144            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1145              Projection: sq.c [c:UInt32]
1146                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1147        "
1148        )
1149    }
1150
1151    #[test]
1152    fn wrapped_not_in_subquery() -> Result<()> {
1153        let table_scan = test_table_scan()?;
1154        let plan = LogicalPlanBuilder::from(table_scan)
1155            .filter(not(in_subquery(col("c"), test_subquery_with_name("sq")?)))?
1156            .project(vec![col("test.b")])?
1157            .build()?;
1158
1159        assert_optimized_plan_equal!(
1160            plan,
1161            @r"
1162        Projection: test.b [b:UInt32]
1163          LeftAnti Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1164            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1165            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1166              Projection: sq.c [c:UInt32]
1167                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1168        "
1169        )
1170    }
1171
1172    #[test]
1173    fn wrapped_not_not_in_subquery() -> Result<()> {
1174        let table_scan = test_table_scan()?;
1175        let plan = LogicalPlanBuilder::from(table_scan)
1176            .filter(not(not_in_subquery(
1177                col("c"),
1178                test_subquery_with_name("sq")?,
1179            )))?
1180            .project(vec![col("test.b")])?
1181            .build()?;
1182
1183        assert_optimized_plan_equal!(
1184            plan,
1185            @r"
1186        Projection: test.b [b:UInt32]
1187          LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1188            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1189            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1190              Projection: sq.c [c:UInt32]
1191                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1192        "
1193        )
1194    }
1195
1196    #[test]
1197    fn in_subquery_both_side_expr() -> Result<()> {
1198        let table_scan = test_table_scan()?;
1199        let subquery_scan = test_table_scan_with_name("sq")?;
1200
1201        let subquery = LogicalPlanBuilder::from(subquery_scan)
1202            .project(vec![col("c") * lit(2u32)])?
1203            .build()?;
1204
1205        let plan = LogicalPlanBuilder::from(table_scan)
1206            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
1207            .project(vec![col("test.b")])?
1208            .build()?;
1209
1210        assert_optimized_plan_equal!(
1211            plan,
1212            @r"
1213        Projection: test.b [b:UInt32]
1214          LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
1215            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1216            SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32]
1217              Projection: sq.c * UInt32(2) [sq.c * UInt32(2):UInt32]
1218                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1219        "
1220        )
1221    }
1222
1223    #[test]
1224    fn in_subquery_join_filter_and_inner_filter() -> Result<()> {
1225        let table_scan = test_table_scan()?;
1226        let subquery_scan = test_table_scan_with_name("sq")?;
1227
1228        let subquery = LogicalPlanBuilder::from(subquery_scan)
1229            .filter(
1230                out_ref_col(DataType::UInt32, "test.a")
1231                    .eq(col("sq.a"))
1232                    .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))),
1233            )?
1234            .project(vec![col("c") * lit(2u32)])?
1235            .build()?;
1236
1237        let plan = LogicalPlanBuilder::from(table_scan)
1238            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
1239            .project(vec![col("test.b")])?
1240            .build()?;
1241
1242        assert_optimized_plan_equal!(
1243            plan,
1244            @r"
1245        Projection: test.b [b:UInt32]
1246          LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
1247            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1248            SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32]
1249              Projection: sq.c * UInt32(2), sq.a [sq.c * UInt32(2):UInt32, a:UInt32]
1250                Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]
1251                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1252        "
1253        )
1254    }
1255
1256    #[test]
1257    fn in_subquery_multi_project_subquery_cols() -> Result<()> {
1258        let table_scan = test_table_scan()?;
1259        let subquery_scan = test_table_scan_with_name("sq")?;
1260
1261        let subquery = LogicalPlanBuilder::from(subquery_scan)
1262            .filter(
1263                out_ref_col(DataType::UInt32, "test.a")
1264                    .add(out_ref_col(DataType::UInt32, "test.b"))
1265                    .eq(col("sq.a").add(col("sq.b")))
1266                    .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))),
1267            )?
1268            .project(vec![col("c") * lit(2u32)])?
1269            .build()?;
1270
1271        let plan = LogicalPlanBuilder::from(table_scan)
1272            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
1273            .project(vec![col("test.b")])?
1274            .build()?;
1275
1276        assert_optimized_plan_equal!(
1277            plan,
1278            @r"
1279        Projection: test.b [b:UInt32]
1280          LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32]
1281            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1282            SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32]
1283              Projection: sq.c * UInt32(2), sq.a, sq.b [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32]
1284                Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]
1285                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1286        "
1287        )
1288    }
1289
1290    #[test]
1291    fn two_in_subquery_with_outer_filter() -> Result<()> {
1292        let table_scan = test_table_scan()?;
1293        let subquery_scan1 = test_table_scan_with_name("sq1")?;
1294        let subquery_scan2 = test_table_scan_with_name("sq2")?;
1295
1296        let subquery1 = LogicalPlanBuilder::from(subquery_scan1)
1297            .filter(out_ref_col(DataType::UInt32, "test.a").gt(col("sq1.a")))?
1298            .project(vec![col("c") * lit(2u32)])?
1299            .build()?;
1300
1301        let subquery2 = LogicalPlanBuilder::from(subquery_scan2)
1302            .filter(out_ref_col(DataType::UInt32, "test.a").gt(col("sq2.a")))?
1303            .project(vec![col("c") * lit(2u32)])?
1304            .build()?;
1305
1306        let plan = LogicalPlanBuilder::from(table_scan)
1307            .filter(
1308                in_subquery(col("c") + lit(1u32), Arc::new(subquery1)).and(
1309                    in_subquery(col("c") * lit(2u32), Arc::new(subquery2))
1310                        .and(col("test.c").gt(lit(1u32))),
1311                ),
1312            )?
1313            .project(vec![col("test.b")])?
1314            .build()?;
1315
1316        assert_optimized_plan_equal!(
1317            plan,
1318            @r"
1319        Projection: test.b [b:UInt32]
1320          Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]
1321            LeftSemi Join:  Filter: test.c * UInt32(2) = __correlated_sq_2.sq2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]
1322              LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.sq1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
1323                TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1324                SubqueryAlias: __correlated_sq_1 [sq1.c * UInt32(2):UInt32, a:UInt32]
1325                  Projection: sq1.c * UInt32(2), sq1.a [sq1.c * UInt32(2):UInt32, a:UInt32]
1326                    TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]
1327              SubqueryAlias: __correlated_sq_2 [sq2.c * UInt32(2):UInt32, a:UInt32]
1328                Projection: sq2.c * UInt32(2), sq2.a [sq2.c * UInt32(2):UInt32, a:UInt32]
1329                  TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]
1330        "
1331        )
1332    }
1333
1334    #[test]
1335    fn in_subquery_with_same_table() -> Result<()> {
1336        let outer_scan = test_table_scan()?;
1337        let subquery_scan = test_table_scan()?;
1338        let subquery = LogicalPlanBuilder::from(subquery_scan)
1339            .filter(col("test.a").gt(col("test.b")))?
1340            .project(vec![col("c")])?
1341            .build()?;
1342
1343        let plan = LogicalPlanBuilder::from(outer_scan)
1344            .filter(in_subquery(col("test.a"), Arc::new(subquery)))?
1345            .project(vec![col("test.b")])?
1346            .build()?;
1347
1348        // Subquery and outer query refer to the same table.
1349        assert_optimized_plan_equal!(
1350            plan,
1351            @r"
1352        Projection: test.b [b:UInt32]
1353          LeftSemi Join:  Filter: test.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]
1354            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1355            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1356              Projection: test.c [c:UInt32]
1357                Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]
1358                  TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1359        "
1360        )
1361    }
1362
1363    /// Test for multiple exists subqueries in the same filter expression
1364    #[test]
1365    fn multiple_exists_subqueries() -> Result<()> {
1366        let orders = Arc::new(
1367            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1368                .filter(
1369                    col("orders.o_custkey")
1370                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
1371                )?
1372                .project(vec![col("orders.o_custkey")])?
1373                .build()?,
1374        );
1375
1376        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1377            .filter(exists(Arc::clone(&orders)).and(exists(orders)))?
1378            .project(vec![col("customer.c_custkey")])?
1379            .build()?;
1380
1381        assert_optimized_plan_equal!(
1382            plan,
1383            @r"
1384        Projection: customer.c_custkey [c_custkey:Int64]
1385          LeftSemi Join:  Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]
1386            LeftSemi Join:  Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]
1387              TableScan: customer [c_custkey:Int64, c_name:Utf8]
1388              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1389                Projection: orders.o_custkey [o_custkey:Int64]
1390                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1391            SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]
1392              Projection: orders.o_custkey [o_custkey:Int64]
1393                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1394        "
1395        )
1396    }
1397
1398    /// Test recursive correlated subqueries
1399    #[test]
1400    fn recursive_exists_subqueries() -> Result<()> {
1401        let lineitem = Arc::new(
1402            LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
1403                .filter(
1404                    col("lineitem.l_orderkey")
1405                        .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
1406                )?
1407                .project(vec![col("lineitem.l_orderkey")])?
1408                .build()?,
1409        );
1410
1411        let orders = Arc::new(
1412            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1413                .filter(
1414                    exists(lineitem).and(
1415                        col("orders.o_custkey")
1416                            .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
1417                    ),
1418                )?
1419                .project(vec![col("orders.o_custkey")])?
1420                .build()?,
1421        );
1422
1423        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1424            .filter(exists(orders))?
1425            .project(vec![col("customer.c_custkey")])?
1426            .build()?;
1427
1428        assert_optimized_plan_equal!(
1429            plan,
1430            @r"
1431        Projection: customer.c_custkey [c_custkey:Int64]
1432          LeftSemi Join:  Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]
1433            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1434            SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]
1435              Projection: orders.o_custkey [o_custkey:Int64]
1436                LeftSemi Join:  Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1437                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1438                  SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]
1439                    Projection: lineitem.l_orderkey [l_orderkey:Int64]
1440                      TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]
1441        "
1442        )
1443    }
1444
1445    /// Test for correlated exists subquery filter with additional subquery filters
1446    #[test]
1447    fn exists_subquery_with_subquery_filters() -> Result<()> {
1448        let sq = Arc::new(
1449            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1450                .filter(
1451                    out_ref_col(DataType::Int64, "customer.c_custkey")
1452                        .eq(col("orders.o_custkey"))
1453                        .and(col("o_orderkey").eq(lit(1))),
1454                )?
1455                .project(vec![col("orders.o_custkey")])?
1456                .build()?,
1457        );
1458
1459        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1460            .filter(exists(sq))?
1461            .project(vec![col("customer.c_custkey")])?
1462            .build()?;
1463
1464        assert_optimized_plan_equal!(
1465            plan,
1466            @r"
1467        Projection: customer.c_custkey [c_custkey:Int64]
1468          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1469            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1470            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1471              Projection: orders.o_custkey [o_custkey:Int64]
1472                Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1473                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1474        "
1475        )
1476    }
1477
1478    #[test]
1479    fn exists_subquery_no_cols() -> Result<()> {
1480        let sq = Arc::new(
1481            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1482                .filter(out_ref_col(DataType::Int64, "customer.c_custkey").eq(lit(1u32)))?
1483                .project(vec![col("orders.o_custkey")])?
1484                .build()?,
1485        );
1486
1487        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1488            .filter(exists(sq))?
1489            .project(vec![col("customer.c_custkey")])?
1490            .build()?;
1491
1492        // Other rule will pushdown `customer.c_custkey = 1`,
1493        assert_optimized_plan_equal!(
1494            plan,
1495            @r"
1496        Projection: customer.c_custkey [c_custkey:Int64]
1497          LeftSemi Join:  Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8]
1498            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1499            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1500              Projection: orders.o_custkey [o_custkey:Int64]
1501                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1502        "
1503        )
1504    }
1505
1506    /// Test for exists subquery with both columns in schema
1507    #[test]
1508    fn exists_subquery_with_no_correlated_cols() -> Result<()> {
1509        let sq = Arc::new(
1510            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1511                .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
1512                .project(vec![col("orders.o_custkey")])?
1513                .build()?,
1514        );
1515
1516        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1517            .filter(exists(sq))?
1518            .project(vec![col("customer.c_custkey")])?
1519            .build()?;
1520
1521        assert_optimized_plan_equal!(
1522            plan,
1523            @r"
1524        Projection: customer.c_custkey [c_custkey:Int64]
1525          LeftSemi Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8]
1526            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1527            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1528              Projection: orders.o_custkey [o_custkey:Int64]
1529                Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1530                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1531        "
1532        )
1533    }
1534
1535    /// Test for correlated exists subquery not equal
1536    #[test]
1537    fn exists_subquery_where_not_eq() -> Result<()> {
1538        let sq = Arc::new(
1539            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1540                .filter(
1541                    out_ref_col(DataType::Int64, "customer.c_custkey")
1542                        .not_eq(col("orders.o_custkey")),
1543                )?
1544                .project(vec![col("orders.o_custkey")])?
1545                .build()?,
1546        );
1547
1548        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1549            .filter(exists(sq))?
1550            .project(vec![col("customer.c_custkey")])?
1551            .build()?;
1552
1553        assert_optimized_plan_equal!(
1554            plan,
1555            @r"
1556        Projection: customer.c_custkey [c_custkey:Int64]
1557          LeftSemi Join:  Filter: customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1558            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1559            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1560              Projection: orders.o_custkey [o_custkey:Int64]
1561                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1562        "
1563        )
1564    }
1565
1566    /// Test for correlated exists subquery less than
1567    #[test]
1568    fn exists_subquery_where_less_than() -> Result<()> {
1569        let sq = Arc::new(
1570            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1571                .filter(
1572                    out_ref_col(DataType::Int64, "customer.c_custkey")
1573                        .lt(col("orders.o_custkey")),
1574                )?
1575                .project(vec![col("orders.o_custkey")])?
1576                .build()?,
1577        );
1578
1579        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1580            .filter(exists(sq))?
1581            .project(vec![col("customer.c_custkey")])?
1582            .build()?;
1583
1584        assert_optimized_plan_equal!(
1585            plan,
1586            @r"
1587        Projection: customer.c_custkey [c_custkey:Int64]
1588          LeftSemi Join:  Filter: customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1589            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1590            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1591              Projection: orders.o_custkey [o_custkey:Int64]
1592                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1593        "
1594        )
1595    }
1596
1597    /// Test for correlated exists subquery filter with subquery disjunction
1598    #[test]
1599    fn exists_subquery_with_subquery_disjunction() -> Result<()> {
1600        let sq = Arc::new(
1601            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1602                .filter(
1603                    out_ref_col(DataType::Int64, "customer.c_custkey")
1604                        .eq(col("orders.o_custkey"))
1605                        .or(col("o_orderkey").eq(lit(1))),
1606                )?
1607                .project(vec![col("orders.o_custkey")])?
1608                .build()?,
1609        );
1610
1611        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1612            .filter(exists(sq))?
1613            .project(vec![col("customer.c_custkey")])?
1614            .build()?;
1615
1616        assert_optimized_plan_equal!(
1617            plan,
1618            @r"
1619        Projection: customer.c_custkey [c_custkey:Int64]
1620          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
1621            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1622            SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]
1623              Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]
1624                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1625        "
1626        )
1627    }
1628
1629    /// Test for correlated exists without projection
1630    #[test]
1631    fn exists_subquery_no_projection() -> Result<()> {
1632        let sq = Arc::new(
1633            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1634                .filter(
1635                    out_ref_col(DataType::Int64, "customer.c_custkey")
1636                        .eq(col("orders.o_custkey")),
1637                )?
1638                .build()?,
1639        );
1640
1641        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1642            .filter(exists(sq))?
1643            .project(vec![col("customer.c_custkey")])?
1644            .build()?;
1645
1646        assert_optimized_plan_equal!(
1647            plan,
1648            @r"
1649        Projection: customer.c_custkey [c_custkey:Int64]
1650          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1651            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1652            SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1653              TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1654        "
1655        )
1656    }
1657
1658    /// Test for correlated exists expressions
1659    #[test]
1660    fn exists_subquery_project_expr() -> Result<()> {
1661        let sq = Arc::new(
1662            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1663                .filter(
1664                    out_ref_col(DataType::Int64, "customer.c_custkey")
1665                        .eq(col("orders.o_custkey")),
1666                )?
1667                .project(vec![col("orders.o_custkey").add(lit(1))])?
1668                .build()?,
1669        );
1670
1671        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1672            .filter(exists(sq))?
1673            .project(vec![col("customer.c_custkey")])?
1674            .build()?;
1675
1676        assert_optimized_plan_equal!(
1677            plan,
1678            @r"
1679        Projection: customer.c_custkey [c_custkey:Int64]
1680          LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1681            TableScan: customer [c_custkey:Int64, c_name:Utf8]
1682            SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]
1683              Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]
1684                TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1685        "
1686        )
1687    }
1688
1689    /// Test for correlated exists subquery filter with additional filters
1690    #[test]
1691    fn exists_subquery_should_support_additional_filters() -> Result<()> {
1692        let sq = Arc::new(
1693            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1694                .filter(
1695                    out_ref_col(DataType::Int64, "customer.c_custkey")
1696                        .eq(col("orders.o_custkey")),
1697                )?
1698                .project(vec![col("orders.o_custkey")])?
1699                .build()?,
1700        );
1701        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1702            .filter(exists(sq).and(col("c_custkey").eq(lit(1))))?
1703            .project(vec![col("customer.c_custkey")])?
1704            .build()?;
1705
1706        assert_optimized_plan_equal!(
1707            plan,
1708            @r"
1709        Projection: customer.c_custkey [c_custkey:Int64]
1710          Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
1711            LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
1712              TableScan: customer [c_custkey:Int64, c_name:Utf8]
1713              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1714                Projection: orders.o_custkey [o_custkey:Int64]
1715                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1716        "
1717        )
1718    }
1719
1720    /// Test for correlated exists subquery filter with disjunctions
1721    #[test]
1722    fn exists_subquery_disjunction() -> Result<()> {
1723        let sq = Arc::new(
1724            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1725                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
1726                .project(vec![col("orders.o_custkey")])?
1727                .build()?,
1728        );
1729
1730        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1731            .filter(exists(sq).or(col("customer.c_custkey").eq(lit(1))))?
1732            .project(vec![col("customer.c_custkey")])?
1733            .build()?;
1734
1735        assert_optimized_plan_equal!(
1736            plan,
1737            @r"
1738        Projection: customer.c_custkey [c_custkey:Int64]
1739          Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1740            LeftMark Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1741              TableScan: customer [c_custkey:Int64, c_name:Utf8]
1742              SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1743                Projection: orders.o_custkey [o_custkey:Int64]
1744                  Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1745                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1746        "
1747        )
1748    }
1749
1750    /// Test for correlated EXISTS subquery filter
1751    #[test]
1752    fn exists_subquery_correlated() -> Result<()> {
1753        let sq = Arc::new(
1754            LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
1755                .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
1756                .project(vec![col("c")])?
1757                .build()?,
1758        );
1759
1760        let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
1761            .filter(exists(sq))?
1762            .project(vec![col("test.c")])?
1763            .build()?;
1764
1765        assert_optimized_plan_equal!(
1766            plan,
1767            @r"
1768        Projection: test.c [c:UInt32]
1769          LeftSemi Join:  Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
1770            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1771            SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]
1772              Projection: sq.c, sq.a [c:UInt32, a:UInt32]
1773                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1774        "
1775        )
1776    }
1777
1778    /// Test for single exists subquery filter
1779    #[test]
1780    fn exists_subquery_simple() -> Result<()> {
1781        let table_scan = test_table_scan()?;
1782        let plan = LogicalPlanBuilder::from(table_scan)
1783            .filter(exists(test_subquery_with_name("sq")?))?
1784            .project(vec![col("test.b")])?
1785            .build()?;
1786
1787        assert_optimized_plan_equal!(
1788            plan,
1789            @r"
1790        Projection: test.b [b:UInt32]
1791          LeftSemi Join:  Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]
1792            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1793            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1794              Projection: sq.c [c:UInt32]
1795                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1796        "
1797        )
1798    }
1799
1800    /// Test for single NOT exists subquery filter
1801    #[test]
1802    fn not_exists_subquery_simple() -> Result<()> {
1803        let table_scan = test_table_scan()?;
1804        let plan = LogicalPlanBuilder::from(table_scan)
1805            .filter(not_exists(test_subquery_with_name("sq")?))?
1806            .project(vec![col("test.b")])?
1807            .build()?;
1808
1809        assert_optimized_plan_equal!(
1810            plan,
1811            @r"
1812        Projection: test.b [b:UInt32]
1813          LeftAnti Join:  Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]
1814            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1815            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1816              Projection: sq.c [c:UInt32]
1817                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1818        "
1819        )
1820    }
1821
1822    #[test]
1823    fn two_exists_subquery_with_outer_filter() -> Result<()> {
1824        let table_scan = test_table_scan()?;
1825        let subquery_scan1 = test_table_scan_with_name("sq1")?;
1826        let subquery_scan2 = test_table_scan_with_name("sq2")?;
1827
1828        let subquery1 = LogicalPlanBuilder::from(subquery_scan1)
1829            .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq1.a")))?
1830            .project(vec![col("c")])?
1831            .build()?;
1832
1833        let subquery2 = LogicalPlanBuilder::from(subquery_scan2)
1834            .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq2.a")))?
1835            .project(vec![col("c")])?
1836            .build()?;
1837
1838        let plan = LogicalPlanBuilder::from(table_scan)
1839            .filter(
1840                exists(Arc::new(subquery1))
1841                    .and(exists(Arc::new(subquery2)).and(col("test.c").gt(lit(1u32)))),
1842            )?
1843            .project(vec![col("test.b")])?
1844            .build()?;
1845
1846        assert_optimized_plan_equal!(
1847            plan,
1848            @r"
1849        Projection: test.b [b:UInt32]
1850          Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]
1851            LeftSemi Join:  Filter: test.a = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]
1852              LeftSemi Join:  Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]
1853                TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1854                SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]
1855                  Projection: sq1.c, sq1.a [c:UInt32, a:UInt32]
1856                    TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]
1857              SubqueryAlias: __correlated_sq_2 [c:UInt32, a:UInt32]
1858                Projection: sq2.c, sq2.a [c:UInt32, a:UInt32]
1859                  TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]
1860        "
1861        )
1862    }
1863
1864    #[test]
1865    fn exists_subquery_expr_filter() -> Result<()> {
1866        let table_scan = test_table_scan()?;
1867        let subquery_scan = test_table_scan_with_name("sq")?;
1868        let subquery = LogicalPlanBuilder::from(subquery_scan)
1869            .filter(
1870                (lit(1u32) + col("sq.a"))
1871                    .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
1872            )?
1873            .project(vec![lit(1u32)])?
1874            .build()?;
1875        let plan = LogicalPlanBuilder::from(table_scan)
1876            .filter(exists(Arc::new(subquery)))?
1877            .project(vec![col("test.b")])?
1878            .build()?;
1879
1880        assert_optimized_plan_equal!(
1881            plan,
1882            @r"
1883        Projection: test.b [b:UInt32]
1884          LeftSemi Join:  Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
1885            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1886            SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, a:UInt32]
1887              Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32]
1888                TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1889        "
1890        )
1891    }
1892
1893    #[test]
1894    fn exists_subquery_with_same_table() -> Result<()> {
1895        let outer_scan = test_table_scan()?;
1896        let subquery_scan = test_table_scan()?;
1897        let subquery = LogicalPlanBuilder::from(subquery_scan)
1898            .filter(col("test.a").gt(col("test.b")))?
1899            .project(vec![col("c")])?
1900            .build()?;
1901
1902        let plan = LogicalPlanBuilder::from(outer_scan)
1903            .filter(exists(Arc::new(subquery)))?
1904            .project(vec![col("test.b")])?
1905            .build()?;
1906
1907        // Subquery and outer query refer to the same table.
1908        assert_optimized_plan_equal!(
1909            plan,
1910            @r"
1911        Projection: test.b [b:UInt32]
1912          LeftSemi Join:  Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]
1913            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1914            SubqueryAlias: __correlated_sq_1 [c:UInt32]
1915              Projection: test.c [c:UInt32]
1916                Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]
1917                  TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1918        "
1919        )
1920    }
1921
1922    #[test]
1923    fn exists_distinct_subquery() -> Result<()> {
1924        let table_scan = test_table_scan()?;
1925        let subquery_scan = test_table_scan_with_name("sq")?;
1926        let subquery = LogicalPlanBuilder::from(subquery_scan)
1927            .filter(
1928                (lit(1u32) + col("sq.a"))
1929                    .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
1930            )?
1931            .project(vec![col("sq.c")])?
1932            .distinct()?
1933            .build()?;
1934        let plan = LogicalPlanBuilder::from(table_scan)
1935            .filter(exists(Arc::new(subquery)))?
1936            .project(vec![col("test.b")])?
1937            .build()?;
1938
1939        assert_optimized_plan_equal!(
1940            plan,
1941            @r"
1942        Projection: test.b [b:UInt32]
1943          LeftSemi Join:  Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
1944            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1945            SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]
1946              Distinct: [c:UInt32, a:UInt32]
1947                Projection: sq.c, sq.a [c:UInt32, a:UInt32]
1948                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1949        "
1950        )
1951    }
1952
1953    #[test]
1954    fn exists_distinct_expr_subquery() -> Result<()> {
1955        let table_scan = test_table_scan()?;
1956        let subquery_scan = test_table_scan_with_name("sq")?;
1957        let subquery = LogicalPlanBuilder::from(subquery_scan)
1958            .filter(
1959                (lit(1u32) + col("sq.a"))
1960                    .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
1961            )?
1962            .project(vec![col("sq.b") + col("sq.c")])?
1963            .distinct()?
1964            .build()?;
1965        let plan = LogicalPlanBuilder::from(table_scan)
1966            .filter(exists(Arc::new(subquery)))?
1967            .project(vec![col("test.b")])?
1968            .build()?;
1969
1970        assert_optimized_plan_equal!(
1971            plan,
1972            @r"
1973        Projection: test.b [b:UInt32]
1974          LeftSemi Join:  Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
1975            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1976            SubqueryAlias: __correlated_sq_1 [sq.b + sq.c:UInt32, a:UInt32]
1977              Distinct: [sq.b + sq.c:UInt32, a:UInt32]
1978                Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32]
1979                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1980        "
1981        )
1982    }
1983
1984    #[test]
1985    fn exists_distinct_subquery_with_literal() -> Result<()> {
1986        let table_scan = test_table_scan()?;
1987        let subquery_scan = test_table_scan_with_name("sq")?;
1988        let subquery = LogicalPlanBuilder::from(subquery_scan)
1989            .filter(
1990                (lit(1u32) + col("sq.a"))
1991                    .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
1992            )?
1993            .project(vec![lit(1u32), col("sq.c")])?
1994            .distinct()?
1995            .build()?;
1996        let plan = LogicalPlanBuilder::from(table_scan)
1997            .filter(exists(Arc::new(subquery)))?
1998            .project(vec![col("test.b")])?
1999            .build()?;
2000
2001        assert_optimized_plan_equal!(
2002            plan,
2003            @r"
2004        Projection: test.b [b:UInt32]
2005          LeftSemi Join:  Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]
2006            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
2007            SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, c:UInt32, a:UInt32]
2008              Distinct: [UInt32(1):UInt32, c:UInt32, a:UInt32]
2009                Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32]
2010                  TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
2011        "
2012        )
2013    }
2014
2015    #[test]
2016    fn exists_uncorrelated_unnest() -> Result<()> {
2017        let subquery_table_source = table_source(&Schema::new(vec![Field::new(
2018            "arr",
2019            DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
2020            true,
2021        )]));
2022        let subquery = LogicalPlanBuilder::scan_with_filters(
2023            "sq",
2024            subquery_table_source,
2025            None,
2026            vec![],
2027        )?
2028        .unnest_column("arr")?
2029        .build()?;
2030        let table_scan = test_table_scan()?;
2031        let plan = LogicalPlanBuilder::from(table_scan)
2032            .filter(exists(Arc::new(subquery)))?
2033            .project(vec![col("test.b")])?
2034            .build()?;
2035
2036        assert_optimized_plan_equal!(
2037            plan,
2038            @r"
2039        Projection: test.b [b:UInt32]
2040          LeftSemi Join:  Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]
2041            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
2042            SubqueryAlias: __correlated_sq_1 [arr:Int32;N]
2043              Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N]
2044                TableScan: sq [arr:List(Int32);N]
2045        "
2046        )
2047    }
2048
2049    #[test]
2050    fn exists_correlated_unnest() -> Result<()> {
2051        let table_scan = test_table_scan()?;
2052        let subquery_table_source = table_source(&Schema::new(vec![Field::new(
2053            "a",
2054            DataType::List(Arc::new(Field::new_list_field(DataType::UInt32, true))),
2055            true,
2056        )]));
2057        let subquery = LogicalPlanBuilder::scan_with_filters(
2058            "sq",
2059            subquery_table_source,
2060            None,
2061            vec![],
2062        )?
2063        .unnest_column("a")?
2064        .filter(col("a").eq(out_ref_col(DataType::UInt32, "test.b")))?
2065        .build()?;
2066        let plan = LogicalPlanBuilder::from(table_scan)
2067            .filter(exists(Arc::new(subquery)))?
2068            .project(vec![col("test.b")])?
2069            .build()?;
2070
2071        assert_optimized_plan_equal!(
2072            plan,
2073            @r"
2074        Projection: test.b [b:UInt32]
2075          LeftSemi Join:  Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32]
2076            TableScan: test [a:UInt32, b:UInt32, c:UInt32]
2077            SubqueryAlias: __correlated_sq_1 [a:UInt32;N]
2078              Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N]
2079                TableScan: sq [a:List(UInt32);N]
2080        "
2081        )
2082    }
2083
2084    #[test]
2085    fn upper_case_ident() -> Result<()> {
2086        let fields = vec![
2087            Field::new("A", DataType::UInt32, false),
2088            Field::new("B", DataType::UInt32, false),
2089        ];
2090
2091        let schema = Schema::new(fields);
2092        let table_scan_a = table_scan(Some("\"TEST_A\""), &schema, None)?.build()?;
2093        let table_scan_b = table_scan(Some("\"TEST_B\""), &schema, None)?.build()?;
2094
2095        let subquery = LogicalPlanBuilder::from(table_scan_b)
2096            .filter(col("\"A\"").eq(out_ref_col(DataType::UInt32, "\"TEST_A\".\"A\"")))?
2097            .project(vec![lit(1)])?
2098            .build()?;
2099
2100        let plan = LogicalPlanBuilder::from(table_scan_a)
2101            .filter(exists(Arc::new(subquery)))?
2102            .project(vec![col("\"TEST_A\".\"B\"")])?
2103            .build()?;
2104
2105        assert_optimized_plan_equal!(
2106            plan,
2107            @r"
2108        Projection: TEST_A.B [B:UInt32]
2109          LeftSemi Join:  Filter: __correlated_sq_1.A = TEST_A.A [A:UInt32, B:UInt32]
2110            TableScan: TEST_A [A:UInt32, B:UInt32]
2111            SubqueryAlias: __correlated_sq_1 [Int32(1):Int32, A:UInt32]
2112              Projection: Int32(1), TEST_B.A [Int32(1):Int32, A:UInt32]
2113                TableScan: TEST_B [A:UInt32, B:UInt32]
2114        "
2115        )
2116    }
2117}