datafusion_optimizer/
push_down_filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`PushDownFilter`] applies filters as early as possible
19
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use indexmap::IndexSet;
24use itertools::Itertools;
25
26use datafusion_common::tree_node::{
27    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
28};
29use datafusion_common::{
30    internal_err, plan_err, qualified_name, Column, DFSchema, Result,
31};
32use datafusion_expr::expr::WindowFunction;
33use datafusion_expr::expr_rewriter::replace_col;
34use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
35use datafusion_expr::utils::{
36    conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
37};
38use datafusion_expr::{
39    and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown,
40};
41
42use crate::optimizer::ApplyOrder;
43use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
44use crate::{OptimizerConfig, OptimizerRule};
45
46/// Optimizer rule for pushing (moving) filter expressions down in a plan so
47/// they are applied as early as possible.
48///
49/// # Introduction
50///
51/// The goal of this rule is to improve query performance by eliminating
52/// redundant work.
53///
54/// For example, given a plan that sorts all values where `a > 10`:
55///
56/// ```text
57///  Filter (a > 10)
58///    Sort (a, b)
59/// ```
60///
61/// A better plan is to  filter the data *before* the Sort, which sorts fewer
62/// rows and therefore does less work overall:
63///
64/// ```text
65///  Sort (a, b)
66///    Filter (a > 10)  <-- Filter is moved before the sort
67/// ```
68///
69/// However it is not always possible to push filters down. For example, given a
70/// plan that finds the top 3 values and then keeps only those that are greater
71/// than 10, if the filter is pushed below the limit it would produce a
72/// different result.
73///
74/// ```text
75///  Filter (a > 10)   <-- can not move this Filter before the limit
76///    Limit (fetch=3)
77///      Sort (a, b)
78/// ```
79///
80///
81/// More formally, a filter-commutative operation is an operation `op` that
82/// satisfies `filter(op(data)) = op(filter(data))`.
83///
84/// The filter-commutative property is plan and column-specific. A filter on `a`
85/// can be pushed through a `Aggregate(group_by = [a], agg=[sum(b))`. However, a
86/// filter on  `sum(b)` can not be pushed through the same aggregate.
87///
88/// # Handling Conjunctions
89///
90/// It is possible to only push down **part** of a filter expression if is
91/// connected with `AND`s (more formally if it is a "conjunction").
92///
93/// For example, given the following plan:
94///
95/// ```text
96/// Filter(a > 10 AND sum(b) < 5)
97///   Aggregate(group_by = [a], agg = [sum(b))
98/// ```
99///
100/// The `a > 10` is commutative with the `Aggregate` but  `sum(b) < 5` is not.
101/// Therefore it is possible to only push part of the expression, resulting in:
102///
103/// ```text
104/// Filter(sum(b) < 5)
105///   Aggregate(group_by = [a], agg = [sum(b))
106///     Filter(a > 10)
107/// ```
108///
109/// # Handling Column Aliases
110///
111/// This optimizer must sometimes handle re-writing filter expressions when they
112/// pushed, for example if there is a projection that aliases `a+1` to `"b"`:
113///
114/// ```text
115/// Filter (b > 10)
116///     Projection: [a+1 AS "b"]  <-- changes the name of `a+1` to `b`
117/// ```
118///
119/// To apply the filter prior to the `Projection`, all references to `b` must be
120/// rewritten to `a+1`:
121///
122/// ```text
123/// Projection: a AS "b"
124///     Filter: (a + 1 > 10)  <--- changed from b to a + 1
125/// ```
126/// # Implementation Notes
127///
128/// This implementation performs a single pass through the plan, "pushing" down
129/// filters. When it passes through a filter, it stores that filter, and when it
130/// reaches a plan node that does not commute with that filter, it adds the
131/// filter to that place. When it passes through a projection, it re-writes the
132/// filter's expression taking into account that projection.
133#[derive(Default, Debug)]
134pub struct PushDownFilter {}
135
136/// For a given JOIN type, determine whether each input of the join is preserved
137/// for post-join (`WHERE` clause) filters.
138///
139/// It is only correct to push filters below a join for preserved inputs.
140///
141/// # Return Value
142/// A tuple of booleans - (left_preserved, right_preserved).
143///
144/// # "Preserved" input definition
145///
146/// We say a join side is preserved if the join returns all or a subset of the rows from
147/// the relevant side, such that each row of the output table directly maps to a row of
148/// the preserved input table. If a table is not preserved, it can provide extra null rows.
149/// That is, there may be rows in the output table that don't directly map to a row in the
150/// input table.
151///
152/// For example:
153///   - In an inner join, both sides are preserved, because each row of the output
154///     maps directly to a row from each side.
155///
156///   - In a left join, the left side is preserved (we can push predicates) but
157///     the right is not, because there may be rows in the output that don't
158///     directly map to a row in the right input (due to nulls filling where there
159///     is no match on the right).
160pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
161    match join_type {
162        JoinType::Inner => (true, true),
163        JoinType::Left => (true, false),
164        JoinType::Right => (false, true),
165        JoinType::Full => (false, false),
166        // No columns from the right side of the join can be referenced in output
167        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
168        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
169        // No columns from the left side of the join can be referenced in output
170        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
171        JoinType::RightSemi | JoinType::RightAnti => (false, true),
172    }
173}
174
175/// For a given JOIN type, determine whether each input of the join is preserved
176/// for the join condition (`ON` clause filters).
177///
178/// It is only correct to push filters below a join for preserved inputs.
179///
180/// # Return Value
181/// A tuple of booleans - (left_preserved, right_preserved).
182///
183/// See [`lr_is_preserved`] for a definition of "preserved".
184pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
185    match join_type {
186        JoinType::Inner => (true, true),
187        JoinType::Left => (false, true),
188        JoinType::Right => (true, false),
189        JoinType::Full => (false, false),
190        JoinType::LeftSemi | JoinType::RightSemi => (true, true),
191        JoinType::LeftAnti => (false, true),
192        JoinType::RightAnti => (true, false),
193        JoinType::LeftMark => (false, true),
194    }
195}
196
197/// Evaluates the columns referenced in the given expression to see if they refer
198/// only to the left or right columns
199#[derive(Debug)]
200struct ColumnChecker<'a> {
201    /// schema of left join input
202    left_schema: &'a DFSchema,
203    /// columns in left_schema, computed on demand
204    left_columns: Option<HashSet<Column>>,
205    /// schema of right join input
206    right_schema: &'a DFSchema,
207    /// columns in left_schema, computed on demand
208    right_columns: Option<HashSet<Column>>,
209}
210
211impl<'a> ColumnChecker<'a> {
212    fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
213        Self {
214            left_schema,
215            left_columns: None,
216            right_schema,
217            right_columns: None,
218        }
219    }
220
221    /// Return true if the expression references only columns from the left side of the join
222    fn is_left_only(&mut self, predicate: &Expr) -> bool {
223        if self.left_columns.is_none() {
224            self.left_columns = Some(schema_columns(self.left_schema));
225        }
226        has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
227    }
228
229    /// Return true if the expression references only columns from the right side of the join
230    fn is_right_only(&mut self, predicate: &Expr) -> bool {
231        if self.right_columns.is_none() {
232            self.right_columns = Some(schema_columns(self.right_schema));
233        }
234        has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
235    }
236}
237
238/// Returns all columns in the schema
239fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
240    schema
241        .iter()
242        .flat_map(|(qualifier, field)| {
243            [
244                Column::new(qualifier.cloned(), field.name()),
245                // we need to push down filter using unqualified column as well
246                Column::new_unqualified(field.name()),
247            ]
248        })
249        .collect::<HashSet<_>>()
250}
251
252/// Determine whether the predicate can evaluate as the join conditions
253fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
254    let mut is_evaluate = true;
255    predicate.apply(|expr| match expr {
256        Expr::Column(_)
257        | Expr::Literal(_, _)
258        | Expr::Placeholder(_)
259        | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
260        Expr::Exists { .. }
261        | Expr::InSubquery(_)
262        | Expr::ScalarSubquery(_)
263        | Expr::OuterReferenceColumn(_, _)
264        | Expr::Unnest(_) => {
265            is_evaluate = false;
266            Ok(TreeNodeRecursion::Stop)
267        }
268        Expr::Alias(_)
269        | Expr::BinaryExpr(_)
270        | Expr::Like(_)
271        | Expr::SimilarTo(_)
272        | Expr::Not(_)
273        | Expr::IsNotNull(_)
274        | Expr::IsNull(_)
275        | Expr::IsTrue(_)
276        | Expr::IsFalse(_)
277        | Expr::IsUnknown(_)
278        | Expr::IsNotTrue(_)
279        | Expr::IsNotFalse(_)
280        | Expr::IsNotUnknown(_)
281        | Expr::Negative(_)
282        | Expr::Between(_)
283        | Expr::Case(_)
284        | Expr::Cast(_)
285        | Expr::TryCast(_)
286        | Expr::InList { .. }
287        | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
288        // TODO: remove the next line after `Expr::Wildcard` is removed
289        #[expect(deprecated)]
290        Expr::AggregateFunction(_)
291        | Expr::WindowFunction(_)
292        | Expr::Wildcard { .. }
293        | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
294    })?;
295    Ok(is_evaluate)
296}
297
298/// examine OR clause to see if any useful clauses can be extracted and push down.
299/// extract at least one qual from each sub clauses of OR clause, then form the quals
300/// to new OR clause as predicate.
301///
302/// # Example
303/// ```text
304/// Filter: (a = c and a < 20) or (b = d and b > 10)
305///     join/crossjoin:
306///          TableScan: projection=[a, b]
307///          TableScan: projection=[c, d]
308/// ```
309///
310/// is optimized to
311///
312/// ```text
313/// Filter: (a = c and a < 20) or (b = d and b > 10)
314///     join/crossjoin:
315///          Filter: (a < 20) or (b > 10)
316///              TableScan: projection=[a, b]
317///          TableScan: projection=[c, d]
318/// ```
319///
320/// In general, predicates of this form:
321///
322/// ```sql
323/// (A AND B) OR (C AND D)
324/// ```
325///
326/// will be transformed to one of:
327///
328/// * `((A AND B) OR (C AND D)) AND (A OR C)`
329/// * `((A AND B) OR (C AND D)) AND ((A AND B) OR C)`
330/// * do nothing.
331fn extract_or_clauses_for_join<'a>(
332    filters: &'a [Expr],
333    schema: &'a DFSchema,
334) -> impl Iterator<Item = Expr> + 'a {
335    let schema_columns = schema_columns(schema);
336
337    // new formed OR clauses and their column references
338    filters.iter().filter_map(move |expr| {
339        if let Expr::BinaryExpr(BinaryExpr {
340            left,
341            op: Operator::Or,
342            right,
343        }) = expr
344        {
345            let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
346            let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
347
348            // If nothing can be extracted from any sub clauses, do nothing for this OR clause.
349            if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
350                return Some(or(left_expr, right_expr));
351            }
352        }
353        None
354    })
355}
356
357/// extract qual from OR sub-clause.
358///
359/// A qual is extracted if it only contains set of column references in schema_columns.
360///
361/// For AND clause, we extract from both sub-clauses, then make new AND clause by extracted
362/// clauses if both extracted; Otherwise, use the extracted clause from any sub-clauses or None.
363///
364/// For OR clause, we extract from both sub-clauses, then make new OR clause by extracted clauses if both extracted;
365/// Otherwise, return None.
366///
367/// For other clause, apply the rule above to extract clause.
368fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
369    let mut predicate = None;
370
371    match expr {
372        Expr::BinaryExpr(BinaryExpr {
373            left: l_expr,
374            op: Operator::Or,
375            right: r_expr,
376        }) => {
377            let l_expr = extract_or_clause(l_expr, schema_columns);
378            let r_expr = extract_or_clause(r_expr, schema_columns);
379
380            if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
381                predicate = Some(or(l_expr, r_expr));
382            }
383        }
384        Expr::BinaryExpr(BinaryExpr {
385            left: l_expr,
386            op: Operator::And,
387            right: r_expr,
388        }) => {
389            let l_expr = extract_or_clause(l_expr, schema_columns);
390            let r_expr = extract_or_clause(r_expr, schema_columns);
391
392            match (l_expr, r_expr) {
393                (Some(l_expr), Some(r_expr)) => {
394                    predicate = Some(and(l_expr, r_expr));
395                }
396                (Some(l_expr), None) => {
397                    predicate = Some(l_expr);
398                }
399                (None, Some(r_expr)) => {
400                    predicate = Some(r_expr);
401                }
402                (None, None) => {
403                    predicate = None;
404                }
405            }
406        }
407        _ => {
408            if has_all_column_refs(expr, schema_columns) {
409                predicate = Some(expr.clone());
410            }
411        }
412    }
413
414    predicate
415}
416
417/// push down join/cross-join
418fn push_down_all_join(
419    predicates: Vec<Expr>,
420    inferred_join_predicates: Vec<Expr>,
421    mut join: Join,
422    on_filter: Vec<Expr>,
423) -> Result<Transformed<LogicalPlan>> {
424    let is_inner_join = join.join_type == JoinType::Inner;
425    // Get pushable predicates from current optimizer state
426    let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
427
428    // The predicates can be divided to three categories:
429    // 1) can push through join to its children(left or right)
430    // 2) can be converted to join conditions if the join type is Inner
431    // 3) should be kept as filter conditions
432    let left_schema = join.left.schema();
433    let right_schema = join.right.schema();
434    let mut left_push = vec![];
435    let mut right_push = vec![];
436    let mut keep_predicates = vec![];
437    let mut join_conditions = vec![];
438    let mut checker = ColumnChecker::new(left_schema, right_schema);
439    for predicate in predicates {
440        if left_preserved && checker.is_left_only(&predicate) {
441            left_push.push(predicate);
442        } else if right_preserved && checker.is_right_only(&predicate) {
443            right_push.push(predicate);
444        } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
445            // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
446            // and convert to the join on condition
447            join_conditions.push(predicate);
448        } else {
449            keep_predicates.push(predicate);
450        }
451    }
452
453    // For infer predicates, if they can not push through join, just drop them
454    for predicate in inferred_join_predicates {
455        if left_preserved && checker.is_left_only(&predicate) {
456            left_push.push(predicate);
457        } else if right_preserved && checker.is_right_only(&predicate) {
458            right_push.push(predicate);
459        }
460    }
461
462    let mut on_filter_join_conditions = vec![];
463    let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
464
465    if !on_filter.is_empty() {
466        for on in on_filter {
467            if on_left_preserved && checker.is_left_only(&on) {
468                left_push.push(on)
469            } else if on_right_preserved && checker.is_right_only(&on) {
470                right_push.push(on)
471            } else {
472                on_filter_join_conditions.push(on)
473            }
474        }
475    }
476
477    // Extract from OR clause, generate new predicates for both side of join if possible.
478    // We only track the unpushable predicates above.
479    if left_preserved {
480        left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
481        left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
482    }
483    if right_preserved {
484        right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
485        right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
486    }
487
488    // For predicates from join filter, we should check with if a join side is preserved
489    // in term of join filtering.
490    if on_left_preserved {
491        left_push.extend(extract_or_clauses_for_join(
492            &on_filter_join_conditions,
493            left_schema,
494        ));
495    }
496    if on_right_preserved {
497        right_push.extend(extract_or_clauses_for_join(
498            &on_filter_join_conditions,
499            right_schema,
500        ));
501    }
502
503    if let Some(predicate) = conjunction(left_push) {
504        join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
505    }
506    if let Some(predicate) = conjunction(right_push) {
507        join.right =
508            Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
509    }
510
511    // Add any new join conditions as the non join predicates
512    join_conditions.extend(on_filter_join_conditions);
513    join.filter = conjunction(join_conditions);
514
515    // wrap the join on the filter whose predicates must be kept, if any
516    let plan = LogicalPlan::Join(join);
517    let plan = if let Some(predicate) = conjunction(keep_predicates) {
518        LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
519    } else {
520        plan
521    };
522    Ok(Transformed::yes(plan))
523}
524
525fn push_down_join(
526    join: Join,
527    parent_predicate: Option<&Expr>,
528) -> Result<Transformed<LogicalPlan>> {
529    // Split the parent predicate into individual conjunctive parts.
530    let predicates = parent_predicate
531        .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
532
533    // Extract conjunctions from the JOIN's ON filter, if present.
534    let on_filters = join
535        .filter
536        .as_ref()
537        .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
538
539    // Are there any new join predicates that can be inferred from the filter expressions?
540    let inferred_join_predicates =
541        infer_join_predicates(&join, &predicates, &on_filters)?;
542
543    if on_filters.is_empty()
544        && predicates.is_empty()
545        && inferred_join_predicates.is_empty()
546    {
547        return Ok(Transformed::no(LogicalPlan::Join(join)));
548    }
549
550    push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
551}
552
553/// Extracts any equi-join join predicates from the given filter expressions.
554///
555/// Parameters
556/// * `join` the join in question
557///
558/// * `predicates` the pushed down filter expression
559///
560/// * `on_filters` filters from the join ON clause that have not already been
561///   identified as join predicates
562///
563fn infer_join_predicates(
564    join: &Join,
565    predicates: &[Expr],
566    on_filters: &[Expr],
567) -> Result<Vec<Expr>> {
568    // Only allow both side key is column.
569    let join_col_keys = join
570        .on
571        .iter()
572        .filter_map(|(l, r)| {
573            let left_col = l.try_as_col()?;
574            let right_col = r.try_as_col()?;
575            Some((left_col, right_col))
576        })
577        .collect::<Vec<_>>();
578
579    let join_type = join.join_type;
580
581    let mut inferred_predicates = InferredPredicates::new(join_type);
582
583    infer_join_predicates_from_predicates(
584        &join_col_keys,
585        predicates,
586        &mut inferred_predicates,
587    )?;
588
589    infer_join_predicates_from_on_filters(
590        &join_col_keys,
591        join_type,
592        on_filters,
593        &mut inferred_predicates,
594    )?;
595
596    Ok(inferred_predicates.predicates)
597}
598
599/// Inferred predicates collector.
600/// When the JoinType is not Inner, we need to detect whether the inferred predicate can strictly
601/// filter out NULL, otherwise ignore it. e.g.
602/// ```text
603/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL;
604/// ```
605/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate will be pushed down to
606/// the left side, resulting in the wrong result.
607struct InferredPredicates {
608    predicates: Vec<Expr>,
609    is_inner_join: bool,
610}
611
612impl InferredPredicates {
613    fn new(join_type: JoinType) -> Self {
614        Self {
615            predicates: vec![],
616            is_inner_join: matches!(join_type, JoinType::Inner),
617        }
618    }
619
620    fn try_build_predicate(
621        &mut self,
622        predicate: Expr,
623        replace_map: &HashMap<&Column, &Column>,
624    ) -> Result<()> {
625        if self.is_inner_join
626            || matches!(
627                is_restrict_null_predicate(
628                    predicate.clone(),
629                    replace_map.keys().cloned()
630                ),
631                Ok(true)
632            )
633        {
634            self.predicates.push(replace_col(predicate, replace_map)?);
635        }
636
637        Ok(())
638    }
639}
640
641/// Infer predicates from the pushed down predicates.
642///
643/// Parameters
644/// * `join_col_keys` column pairs from the join ON clause
645///
646/// * `predicates` the pushed down predicates
647///
648/// * `inferred_predicates` the inferred results
649///
650fn infer_join_predicates_from_predicates(
651    join_col_keys: &[(&Column, &Column)],
652    predicates: &[Expr],
653    inferred_predicates: &mut InferredPredicates,
654) -> Result<()> {
655    infer_join_predicates_impl::<true, true>(
656        join_col_keys,
657        predicates,
658        inferred_predicates,
659    )
660}
661
662/// Infer predicates from the join filter.
663///
664/// Parameters
665/// * `join_col_keys` column pairs from the join ON clause
666///
667/// * `join_type` the JoinType of Join
668///
669/// * `on_filters` filters from the join ON clause that have not already been
670///   identified as join predicates
671///
672/// * `inferred_predicates` the inferred results
673///
674fn infer_join_predicates_from_on_filters(
675    join_col_keys: &[(&Column, &Column)],
676    join_type: JoinType,
677    on_filters: &[Expr],
678    inferred_predicates: &mut InferredPredicates,
679) -> Result<()> {
680    match join_type {
681        JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
682        JoinType::Inner => infer_join_predicates_impl::<true, true>(
683            join_col_keys,
684            on_filters,
685            inferred_predicates,
686        ),
687        JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
688            infer_join_predicates_impl::<true, false>(
689                join_col_keys,
690                on_filters,
691                inferred_predicates,
692            )
693        }
694        JoinType::Right | JoinType::RightSemi => {
695            infer_join_predicates_impl::<false, true>(
696                join_col_keys,
697                on_filters,
698                inferred_predicates,
699            )
700        }
701    }
702}
703
704/// Infer predicates from the given predicates.
705///
706/// Parameters
707/// * `join_col_keys` column pairs from the join ON clause
708///
709/// * `input_predicates` the given predicates. It can be the pushed down predicates,
710///   or it can be the filters of the Join
711///
712/// * `inferred_predicates` the inferred results
713///
714/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate can
715///   be inferred from the left table related predicate
716///
717/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate can
718///   be inferred from the right table related predicate
719///
720fn infer_join_predicates_impl<
721    const ENABLE_LEFT_TO_RIGHT: bool,
722    const ENABLE_RIGHT_TO_LEFT: bool,
723>(
724    join_col_keys: &[(&Column, &Column)],
725    input_predicates: &[Expr],
726    inferred_predicates: &mut InferredPredicates,
727) -> Result<()> {
728    for predicate in input_predicates {
729        let mut join_cols_to_replace = HashMap::new();
730
731        for &col in &predicate.column_refs() {
732            for (l, r) in join_col_keys.iter() {
733                if ENABLE_LEFT_TO_RIGHT && col == *l {
734                    join_cols_to_replace.insert(col, *r);
735                    break;
736                }
737                if ENABLE_RIGHT_TO_LEFT && col == *r {
738                    join_cols_to_replace.insert(col, *l);
739                    break;
740                }
741            }
742        }
743        if join_cols_to_replace.is_empty() {
744            continue;
745        }
746
747        inferred_predicates
748            .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
749    }
750    Ok(())
751}
752
753impl OptimizerRule for PushDownFilter {
754    fn name(&self) -> &str {
755        "push_down_filter"
756    }
757
758    fn apply_order(&self) -> Option<ApplyOrder> {
759        Some(ApplyOrder::TopDown)
760    }
761
762    fn supports_rewrite(&self) -> bool {
763        true
764    }
765
766    fn rewrite(
767        &self,
768        plan: LogicalPlan,
769        _config: &dyn OptimizerConfig,
770    ) -> Result<Transformed<LogicalPlan>> {
771        if let LogicalPlan::Join(join) = plan {
772            return push_down_join(join, None);
773        };
774
775        let plan_schema = Arc::clone(plan.schema());
776
777        let LogicalPlan::Filter(mut filter) = plan else {
778            return Ok(Transformed::no(plan));
779        };
780
781        match Arc::unwrap_or_clone(filter.input) {
782            LogicalPlan::Filter(child_filter) => {
783                let parents_predicates = split_conjunction_owned(filter.predicate);
784
785                // remove duplicated filters
786                let child_predicates = split_conjunction_owned(child_filter.predicate);
787                let new_predicates = parents_predicates
788                    .into_iter()
789                    .chain(child_predicates)
790                    // use IndexSet to remove dupes while preserving predicate order
791                    .collect::<IndexSet<_>>()
792                    .into_iter()
793                    .collect::<Vec<_>>();
794
795                let Some(new_predicate) = conjunction(new_predicates) else {
796                    return plan_err!("at least one expression exists");
797                };
798                let new_filter = LogicalPlan::Filter(Filter::try_new(
799                    new_predicate,
800                    child_filter.input,
801                )?);
802                #[allow(clippy::used_underscore_binding)]
803                self.rewrite(new_filter, _config)
804            }
805            LogicalPlan::Repartition(repartition) => {
806                let new_filter =
807                    Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
808                        .map(LogicalPlan::Filter)?;
809                insert_below(LogicalPlan::Repartition(repartition), new_filter)
810            }
811            LogicalPlan::Distinct(distinct) => {
812                let new_filter =
813                    Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
814                        .map(LogicalPlan::Filter)?;
815                insert_below(LogicalPlan::Distinct(distinct), new_filter)
816            }
817            LogicalPlan::Sort(sort) => {
818                let new_filter =
819                    Filter::try_new(filter.predicate, Arc::clone(&sort.input))
820                        .map(LogicalPlan::Filter)?;
821                insert_below(LogicalPlan::Sort(sort), new_filter)
822            }
823            LogicalPlan::SubqueryAlias(subquery_alias) => {
824                let mut replace_map = HashMap::new();
825                for (i, (qualifier, field)) in
826                    subquery_alias.input.schema().iter().enumerate()
827                {
828                    let (sub_qualifier, sub_field) =
829                        subquery_alias.schema.qualified_field(i);
830                    replace_map.insert(
831                        qualified_name(sub_qualifier, sub_field.name()),
832                        Expr::Column(Column::new(qualifier.cloned(), field.name())),
833                    );
834                }
835                let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
836
837                let new_filter = LogicalPlan::Filter(Filter::try_new(
838                    new_predicate,
839                    Arc::clone(&subquery_alias.input),
840                )?);
841                insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
842            }
843            LogicalPlan::Projection(projection) => {
844                let predicates = split_conjunction_owned(filter.predicate.clone());
845                let (new_projection, keep_predicate) =
846                    rewrite_projection(predicates, projection)?;
847                if new_projection.transformed {
848                    match keep_predicate {
849                        None => Ok(new_projection),
850                        Some(keep_predicate) => new_projection.map_data(|child_plan| {
851                            Filter::try_new(keep_predicate, Arc::new(child_plan))
852                                .map(LogicalPlan::Filter)
853                        }),
854                    }
855                } else {
856                    filter.input = Arc::new(new_projection.data);
857                    Ok(Transformed::no(LogicalPlan::Filter(filter)))
858                }
859            }
860            LogicalPlan::Unnest(mut unnest) => {
861                let predicates = split_conjunction_owned(filter.predicate.clone());
862                let mut non_unnest_predicates = vec![];
863                let mut unnest_predicates = vec![];
864                for predicate in predicates {
865                    // collect all the Expr::Column in predicate recursively
866                    let mut accum: HashSet<Column> = HashSet::new();
867                    expr_to_columns(&predicate, &mut accum)?;
868
869                    if unnest.list_type_columns.iter().any(|(_, unnest_list)| {
870                        accum.contains(&unnest_list.output_column)
871                    }) {
872                        unnest_predicates.push(predicate);
873                    } else {
874                        non_unnest_predicates.push(predicate);
875                    }
876                }
877
878                // Unnest predicates should not be pushed down.
879                // If no non-unnest predicates exist, early return
880                if non_unnest_predicates.is_empty() {
881                    filter.input = Arc::new(LogicalPlan::Unnest(unnest));
882                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
883                }
884
885                // Push down non-unnest filter predicate
886                // Unnest
887                //   Unnest Input (Projection)
888                // -> rewritten to
889                // Unnest
890                //   Filter
891                //     Unnest Input (Projection)
892
893                let unnest_input = std::mem::take(&mut unnest.input);
894
895                let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
896                    conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty.
897                    unnest_input,
898                )?);
899
900                // Directly assign new filter plan as the new unnest's input.
901                // The new filter plan will go through another rewrite pass since the rule itself
902                // is applied recursively to all the child from top to down
903                let unnest_plan =
904                    insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
905
906                match conjunction(unnest_predicates) {
907                    None => Ok(unnest_plan),
908                    Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
909                        Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
910                    ))),
911                }
912            }
913            LogicalPlan::Union(ref union) => {
914                let mut inputs = Vec::with_capacity(union.inputs.len());
915                for input in &union.inputs {
916                    let mut replace_map = HashMap::new();
917                    for (i, (qualifier, field)) in input.schema().iter().enumerate() {
918                        let (union_qualifier, union_field) =
919                            union.schema.qualified_field(i);
920                        replace_map.insert(
921                            qualified_name(union_qualifier, union_field.name()),
922                            Expr::Column(Column::new(qualifier.cloned(), field.name())),
923                        );
924                    }
925
926                    let push_predicate =
927                        replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
928                    inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
929                        push_predicate,
930                        Arc::clone(input),
931                    )?)))
932                }
933                Ok(Transformed::yes(LogicalPlan::Union(Union {
934                    inputs,
935                    schema: Arc::clone(&plan_schema),
936                })))
937            }
938            LogicalPlan::Aggregate(agg) => {
939                // We can push down Predicate which in groupby_expr.
940                let group_expr_columns = agg
941                    .group_expr
942                    .iter()
943                    .map(|e| Ok(Column::from_qualified_name(e.schema_name().to_string())))
944                    .collect::<Result<HashSet<_>>>()?;
945
946                let predicates = split_conjunction_owned(filter.predicate);
947
948                let mut keep_predicates = vec![];
949                let mut push_predicates = vec![];
950                for expr in predicates {
951                    let cols = expr.column_refs();
952                    if cols.iter().all(|c| group_expr_columns.contains(c)) {
953                        push_predicates.push(expr);
954                    } else {
955                        keep_predicates.push(expr);
956                    }
957                }
958
959                // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
960                // After push, we need to replace `a+b` with Column(a)+Column(b)
961                // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
962                let mut replace_map = HashMap::new();
963                for expr in &agg.group_expr {
964                    replace_map.insert(expr.schema_name().to_string(), expr.clone());
965                }
966                let replaced_push_predicates = push_predicates
967                    .into_iter()
968                    .map(|expr| replace_cols_by_name(expr, &replace_map))
969                    .collect::<Result<Vec<_>>>()?;
970
971                let agg_input = Arc::clone(&agg.input);
972                Transformed::yes(LogicalPlan::Aggregate(agg))
973                    .transform_data(|new_plan| {
974                        // If we have a filter to push, we push it down to the input of the aggregate
975                        if let Some(predicate) = conjunction(replaced_push_predicates) {
976                            let new_filter = make_filter(predicate, agg_input)?;
977                            insert_below(new_plan, new_filter)
978                        } else {
979                            Ok(Transformed::no(new_plan))
980                        }
981                    })?
982                    .map_data(|child_plan| {
983                        // if there are any remaining predicates we can't push, add them
984                        // back as a filter
985                        if let Some(predicate) = conjunction(keep_predicates) {
986                            make_filter(predicate, Arc::new(child_plan))
987                        } else {
988                            Ok(child_plan)
989                        }
990                    })
991            }
992            // Tries to push filters based on the partition key(s) of the window function(s) used.
993            // Example:
994            //   Before:
995            //     Filter: (a > 1) and (b > 1) and (c > 1)
996            //      Window: func() PARTITION BY [a] ...
997            //   ---
998            //   After:
999            //     Filter: (b > 1) and (c > 1)
1000            //      Window: func() PARTITION BY [a] ...
1001            //        Filter: (a > 1)
1002            LogicalPlan::Window(window) => {
1003                // Retrieve the set of potential partition keys where we can push filters by.
1004                // Unlike aggregations, where there is only one statement per SELECT, there can be
1005                // multiple window functions, each with potentially different partition keys.
1006                // Therefore, we need to ensure that any potential partition key returned is used in
1007                // ALL window functions. Otherwise, filters cannot be pushed by through that column.
1008                let extract_partition_keys = |func: &WindowFunction| {
1009                    func.params
1010                        .partition_by
1011                        .iter()
1012                        .map(|c| Column::from_qualified_name(c.schema_name().to_string()))
1013                        .collect::<HashSet<_>>()
1014                };
1015                let potential_partition_keys = window
1016                    .window_expr
1017                    .iter()
1018                    .map(|e| {
1019                        match e {
1020                            Expr::WindowFunction(window_func) => {
1021                                extract_partition_keys(window_func)
1022                            }
1023                            Expr::Alias(alias) => {
1024                                if let Expr::WindowFunction(window_func) =
1025                                    alias.expr.as_ref()
1026                                {
1027                                    extract_partition_keys(window_func)
1028                                } else {
1029                                    // window functions expressions are only Expr::WindowFunction
1030                                    unreachable!()
1031                                }
1032                            }
1033                            _ => {
1034                                // window functions expressions are only Expr::WindowFunction
1035                                unreachable!()
1036                            }
1037                        }
1038                    })
1039                    // performs the set intersection of the partition keys of all window functions,
1040                    // returning only the common ones
1041                    .reduce(|a, b| &a & &b)
1042                    .unwrap_or_default();
1043
1044                let predicates = split_conjunction_owned(filter.predicate);
1045                let mut keep_predicates = vec![];
1046                let mut push_predicates = vec![];
1047                for expr in predicates {
1048                    let cols = expr.column_refs();
1049                    if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1050                        push_predicates.push(expr);
1051                    } else {
1052                        keep_predicates.push(expr);
1053                    }
1054                }
1055
1056                // Unlike with aggregations, there are no cases where we have to replace, e.g.,
1057                // `a+b` with Column(a)+Column(b). This is because partition expressions are not
1058                // available as standalone columns to the user. For example, while an aggregation on
1059                // `a+b` becomes Column(a + b), in a window partition it becomes
1060                // `func() PARTITION BY [a + b] ...`. Thus, filters on expressions always remain in
1061                // place, so we can use `push_predicates` directly. This is consistent with other
1062                // optimizers, such as the one used by Postgres.
1063
1064                let window_input = Arc::clone(&window.input);
1065                Transformed::yes(LogicalPlan::Window(window))
1066                    .transform_data(|new_plan| {
1067                        // If we have a filter to push, we push it down to the input of the window
1068                        if let Some(predicate) = conjunction(push_predicates) {
1069                            let new_filter = make_filter(predicate, window_input)?;
1070                            insert_below(new_plan, new_filter)
1071                        } else {
1072                            Ok(Transformed::no(new_plan))
1073                        }
1074                    })?
1075                    .map_data(|child_plan| {
1076                        // if there are any remaining predicates we can't push, add them
1077                        // back as a filter
1078                        if let Some(predicate) = conjunction(keep_predicates) {
1079                            make_filter(predicate, Arc::new(child_plan))
1080                        } else {
1081                            Ok(child_plan)
1082                        }
1083                    })
1084            }
1085            LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1086            LogicalPlan::TableScan(scan) => {
1087                let filter_predicates = split_conjunction(&filter.predicate);
1088
1089                let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1090                    filter_predicates
1091                        .into_iter()
1092                        .partition(|pred| pred.is_volatile());
1093
1094                // Check which non-volatile filters are supported by source
1095                let supported_filters = scan
1096                    .source
1097                    .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1098                if non_volatile_filters.len() != supported_filters.len() {
1099                    return internal_err!(
1100                        "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1101                        supported_filters.len(),
1102                        non_volatile_filters.len());
1103                }
1104
1105                // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type
1106                let zip = non_volatile_filters.into_iter().zip(supported_filters);
1107
1108                let new_scan_filters = zip
1109                    .clone()
1110                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1111                    .map(|(pred, _)| pred);
1112
1113                // Add new scan filters
1114                let new_scan_filters: Vec<Expr> = scan
1115                    .filters
1116                    .iter()
1117                    .chain(new_scan_filters)
1118                    .unique()
1119                    .cloned()
1120                    .collect();
1121
1122                // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
1123                let new_predicate: Vec<Expr> = zip
1124                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1125                    .map(|(pred, _)| pred)
1126                    .chain(volatile_filters)
1127                    .cloned()
1128                    .collect();
1129
1130                let new_scan = LogicalPlan::TableScan(TableScan {
1131                    filters: new_scan_filters,
1132                    ..scan
1133                });
1134
1135                Transformed::yes(new_scan).transform_data(|new_scan| {
1136                    if let Some(predicate) = conjunction(new_predicate) {
1137                        make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1138                    } else {
1139                        Ok(Transformed::no(new_scan))
1140                    }
1141                })
1142            }
1143            LogicalPlan::Extension(extension_plan) => {
1144                // This check prevents the Filter from being removed when the extension node has no children,
1145                // so we return the original Filter unchanged.
1146                if extension_plan.node.inputs().is_empty() {
1147                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1148                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1149                }
1150                let prevent_cols =
1151                    extension_plan.node.prevent_predicate_push_down_columns();
1152
1153                // determine if we can push any predicates down past the extension node
1154
1155                // each element is true for push, false to keep
1156                let predicate_push_or_keep = split_conjunction(&filter.predicate)
1157                    .iter()
1158                    .map(|expr| {
1159                        let cols = expr.column_refs();
1160                        if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1161                            Ok(false) // No push (keep)
1162                        } else {
1163                            Ok(true) // push
1164                        }
1165                    })
1166                    .collect::<Result<Vec<_>>>()?;
1167
1168                // all predicates are kept, no changes needed
1169                if predicate_push_or_keep.iter().all(|&x| !x) {
1170                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1171                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1172                }
1173
1174                // going to push some predicates down, so split the predicates
1175                let mut keep_predicates = vec![];
1176                let mut push_predicates = vec![];
1177                for (push, expr) in predicate_push_or_keep
1178                    .into_iter()
1179                    .zip(split_conjunction_owned(filter.predicate).into_iter())
1180                {
1181                    if !push {
1182                        keep_predicates.push(expr);
1183                    } else {
1184                        push_predicates.push(expr);
1185                    }
1186                }
1187
1188                let new_children = match conjunction(push_predicates) {
1189                    Some(predicate) => extension_plan
1190                        .node
1191                        .inputs()
1192                        .into_iter()
1193                        .map(|child| {
1194                            Ok(LogicalPlan::Filter(Filter::try_new(
1195                                predicate.clone(),
1196                                Arc::new(child.clone()),
1197                            )?))
1198                        })
1199                        .collect::<Result<Vec<_>>>()?,
1200                    None => extension_plan.node.inputs().into_iter().cloned().collect(),
1201                };
1202                // extension with new inputs.
1203                let child_plan = LogicalPlan::Extension(extension_plan);
1204                let new_extension =
1205                    child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1206
1207                let new_plan = match conjunction(keep_predicates) {
1208                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1209                        predicate,
1210                        Arc::new(new_extension),
1211                    )?),
1212                    None => new_extension,
1213                };
1214                Ok(Transformed::yes(new_plan))
1215            }
1216            child => {
1217                filter.input = Arc::new(child);
1218                Ok(Transformed::no(LogicalPlan::Filter(filter)))
1219            }
1220        }
1221    }
1222}
1223
1224/// Attempts to push `predicate` into a `FilterExec` below `projection
1225///
1226/// # Returns
1227/// (plan, remaining_predicate)
1228///
1229/// `plan` is a LogicalPlan for `projection` with possibly a new FilterExec below it.
1230/// `remaining_predicate` is any part of the predicate that could not be pushed down
1231///
1232/// # Args
1233/// - predicates: Split predicates like `[foo=5, bar=6]`
1234/// - projection: The target projection plan to push down the predicates
1235///
1236/// # Example
1237///
1238/// Pushing a predicate like `foo=5 AND bar=6` with an input plan like this:
1239///
1240/// ```text
1241/// Projection(foo, c+d as bar)
1242/// ```
1243///
1244/// Might result in returning `remaining_predicate` of `bar=6` and a plan like
1245///
1246/// ```text
1247/// Projection(foo, c+d as bar)
1248///  Filter(foo=5)
1249///   ...
1250/// ```
1251fn rewrite_projection(
1252    predicates: Vec<Expr>,
1253    mut projection: Projection,
1254) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1255    // A projection is filter-commutable if it do not contain volatile predicates or contain volatile
1256    // predicates that are not used in the filter. However, we should re-writes all predicate expressions.
1257    // collect projection.
1258    let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
1259        .schema
1260        .iter()
1261        .zip(projection.expr.iter())
1262        .map(|((qualifier, field), expr)| {
1263            // strip alias, as they should not be part of filters
1264            let expr = expr.clone().unalias();
1265
1266            (qualified_name(qualifier, field.name()), expr)
1267        })
1268        .partition(|(_, value)| value.is_volatile());
1269
1270    let mut push_predicates = vec![];
1271    let mut keep_predicates = vec![];
1272    for expr in predicates {
1273        if contain(&expr, &volatile_map) {
1274            keep_predicates.push(expr);
1275        } else {
1276            push_predicates.push(expr);
1277        }
1278    }
1279
1280    match conjunction(push_predicates) {
1281        Some(expr) => {
1282            // re-write all filters based on this projection
1283            // E.g. in `Filter: b\n  Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
1284            let new_filter = LogicalPlan::Filter(Filter::try_new(
1285                replace_cols_by_name(expr, &non_volatile_map)?,
1286                std::mem::take(&mut projection.input),
1287            )?);
1288
1289            projection.input = Arc::new(new_filter);
1290
1291            Ok((
1292                Transformed::yes(LogicalPlan::Projection(projection)),
1293                conjunction(keep_predicates),
1294            ))
1295        }
1296        None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
1297    }
1298}
1299
1300/// Creates a new LogicalPlan::Filter node.
1301pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1302    Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1303}
1304
1305/// Replace the existing child of the single input node with `new_child`.
1306///
1307/// Starting:
1308/// ```text
1309/// plan
1310///   child
1311/// ```
1312///
1313/// Ending:
1314/// ```text
1315/// plan
1316///   new_child
1317/// ```
1318fn insert_below(
1319    plan: LogicalPlan,
1320    new_child: LogicalPlan,
1321) -> Result<Transformed<LogicalPlan>> {
1322    let mut new_child = Some(new_child);
1323    let transformed_plan = plan.map_children(|_child| {
1324        if let Some(new_child) = new_child.take() {
1325            Ok(Transformed::yes(new_child))
1326        } else {
1327            // already took the new child
1328            internal_err!("node had more than one input")
1329        }
1330    })?;
1331
1332    // make sure we did the actual replacement
1333    if new_child.is_some() {
1334        return internal_err!("node had no  inputs");
1335    }
1336
1337    Ok(transformed_plan)
1338}
1339
1340impl PushDownFilter {
1341    #[allow(missing_docs)]
1342    pub fn new() -> Self {
1343        Self {}
1344    }
1345}
1346
1347/// replaces columns by its name on the projection.
1348pub fn replace_cols_by_name(
1349    e: Expr,
1350    replace_map: &HashMap<String, Expr>,
1351) -> Result<Expr> {
1352    e.transform_up(|expr| {
1353        Ok(if let Expr::Column(c) = &expr {
1354            match replace_map.get(&c.flat_name()) {
1355                Some(new_c) => Transformed::yes(new_c.clone()),
1356                None => Transformed::no(expr),
1357            }
1358        } else {
1359            Transformed::no(expr)
1360        })
1361    })
1362    .data()
1363}
1364
1365/// check whether the expression uses the columns in `check_map`.
1366fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1367    let mut is_contain = false;
1368    e.apply(|expr| {
1369        Ok(if let Expr::Column(c) = &expr {
1370            match check_map.get(&c.flat_name()) {
1371                Some(_) => {
1372                    is_contain = true;
1373                    TreeNodeRecursion::Stop
1374                }
1375                None => TreeNodeRecursion::Continue,
1376            }
1377        } else {
1378            TreeNodeRecursion::Continue
1379        })
1380    })
1381    .unwrap();
1382    is_contain
1383}
1384
1385#[cfg(test)]
1386mod tests {
1387    use std::any::Any;
1388    use std::cmp::Ordering;
1389    use std::fmt::{Debug, Formatter};
1390
1391    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1392    use async_trait::async_trait;
1393
1394    use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1395    use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1396    use datafusion_expr::logical_plan::table_scan;
1397    use datafusion_expr::{
1398        col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension,
1399        LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1400        TableSource, TableType, UserDefinedLogicalNodeCore, Volatility,
1401        WindowFunctionDefinition,
1402    };
1403
1404    use crate::assert_optimized_plan_eq_snapshot;
1405    use crate::optimizer::Optimizer;
1406    use crate::simplify_expressions::SimplifyExpressions;
1407    use crate::test::*;
1408    use crate::OptimizerContext;
1409    use datafusion_expr::test::function_stub::sum;
1410    use insta::assert_snapshot;
1411
1412    use super::*;
1413
1414    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1415
1416    macro_rules! assert_optimized_plan_equal {
1417        (
1418            $plan:expr,
1419            @ $expected:literal $(,)?
1420        ) => {{
1421            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1422            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1423            assert_optimized_plan_eq_snapshot!(
1424                optimizer_ctx,
1425                rules,
1426                $plan,
1427                @ $expected,
1428            )
1429        }};
1430    }
1431
1432    macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1433        (
1434            $plan:expr,
1435            @ $expected:literal $(,)?
1436        ) => {{
1437            let optimizer = Optimizer::with_rules(vec![
1438                Arc::new(SimplifyExpressions::new()),
1439                Arc::new(PushDownFilter::new()),
1440            ]);
1441            let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1442            assert_snapshot!(optimized_plan, @ $expected);
1443            Ok::<(), DataFusionError>(())
1444        }};
1445    }
1446
1447    #[test]
1448    fn filter_before_projection() -> Result<()> {
1449        let table_scan = test_table_scan()?;
1450        let plan = LogicalPlanBuilder::from(table_scan)
1451            .project(vec![col("a"), col("b")])?
1452            .filter(col("a").eq(lit(1i64)))?
1453            .build()?;
1454        // filter is before projection
1455        assert_optimized_plan_equal!(
1456            plan,
1457            @r"
1458        Projection: test.a, test.b
1459          TableScan: test, full_filters=[test.a = Int64(1)]
1460        "
1461        )
1462    }
1463
1464    #[test]
1465    fn filter_after_limit() -> Result<()> {
1466        let table_scan = test_table_scan()?;
1467        let plan = LogicalPlanBuilder::from(table_scan)
1468            .project(vec![col("a"), col("b")])?
1469            .limit(0, Some(10))?
1470            .filter(col("a").eq(lit(1i64)))?
1471            .build()?;
1472        // filter is before single projection
1473        assert_optimized_plan_equal!(
1474            plan,
1475            @r"
1476        Filter: test.a = Int64(1)
1477          Limit: skip=0, fetch=10
1478            Projection: test.a, test.b
1479              TableScan: test
1480        "
1481        )
1482    }
1483
1484    #[test]
1485    fn filter_no_columns() -> Result<()> {
1486        let table_scan = test_table_scan()?;
1487        let plan = LogicalPlanBuilder::from(table_scan)
1488            .filter(lit(0i64).eq(lit(1i64)))?
1489            .build()?;
1490        assert_optimized_plan_equal!(
1491            plan,
1492            @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1493        )
1494    }
1495
1496    #[test]
1497    fn filter_jump_2_plans() -> Result<()> {
1498        let table_scan = test_table_scan()?;
1499        let plan = LogicalPlanBuilder::from(table_scan)
1500            .project(vec![col("a"), col("b"), col("c")])?
1501            .project(vec![col("c"), col("b")])?
1502            .filter(col("a").eq(lit(1i64)))?
1503            .build()?;
1504        // filter is before double projection
1505        assert_optimized_plan_equal!(
1506            plan,
1507            @r"
1508        Projection: test.c, test.b
1509          Projection: test.a, test.b, test.c
1510            TableScan: test, full_filters=[test.a = Int64(1)]
1511        "
1512        )
1513    }
1514
1515    #[test]
1516    fn filter_move_agg() -> Result<()> {
1517        let table_scan = test_table_scan()?;
1518        let plan = LogicalPlanBuilder::from(table_scan)
1519            .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1520            .filter(col("a").gt(lit(10i64)))?
1521            .build()?;
1522        // filter of key aggregation is commutative
1523        assert_optimized_plan_equal!(
1524            plan,
1525            @r"
1526        Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1527          TableScan: test, full_filters=[test.a > Int64(10)]
1528        "
1529        )
1530    }
1531
1532    #[test]
1533    fn filter_complex_group_by() -> Result<()> {
1534        let table_scan = test_table_scan()?;
1535        let plan = LogicalPlanBuilder::from(table_scan)
1536            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1537            .filter(col("b").gt(lit(10i64)))?
1538            .build()?;
1539        assert_optimized_plan_equal!(
1540            plan,
1541            @r"
1542        Filter: test.b > Int64(10)
1543          Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1544            TableScan: test
1545        "
1546        )
1547    }
1548
1549    #[test]
1550    fn push_agg_need_replace_expr() -> Result<()> {
1551        let plan = LogicalPlanBuilder::from(test_table_scan()?)
1552            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1553            .filter(col("test.b + test.a").gt(lit(10i64)))?
1554            .build()?;
1555        assert_optimized_plan_equal!(
1556            plan,
1557            @r"
1558        Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1559          TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1560        "
1561        )
1562    }
1563
1564    #[test]
1565    fn filter_keep_agg() -> Result<()> {
1566        let table_scan = test_table_scan()?;
1567        let plan = LogicalPlanBuilder::from(table_scan)
1568            .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1569            .filter(col("b").gt(lit(10i64)))?
1570            .build()?;
1571        // filter of aggregate is after aggregation since they are non-commutative
1572        assert_optimized_plan_equal!(
1573            plan,
1574            @r"
1575        Filter: b > Int64(10)
1576          Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1577            TableScan: test
1578        "
1579        )
1580    }
1581
1582    /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed
1583    #[test]
1584    fn filter_move_window() -> Result<()> {
1585        let table_scan = test_table_scan()?;
1586
1587        let window = Expr::from(WindowFunction::new(
1588            WindowFunctionDefinition::WindowUDF(
1589                datafusion_functions_window::rank::rank_udwf(),
1590            ),
1591            vec![],
1592        ))
1593        .partition_by(vec![col("a"), col("b")])
1594        .order_by(vec![col("c").sort(true, true)])
1595        .build()
1596        .unwrap();
1597
1598        let plan = LogicalPlanBuilder::from(table_scan)
1599            .window(vec![window])?
1600            .filter(col("b").gt(lit(10i64)))?
1601            .build()?;
1602
1603        assert_optimized_plan_equal!(
1604            plan,
1605            @r"
1606        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1607          TableScan: test, full_filters=[test.b > Int64(10)]
1608        "
1609        )
1610    }
1611
1612    /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and
1613    /// 'b' are pushed
1614    #[test]
1615    fn filter_move_complex_window() -> Result<()> {
1616        let table_scan = test_table_scan()?;
1617
1618        let window = Expr::from(WindowFunction::new(
1619            WindowFunctionDefinition::WindowUDF(
1620                datafusion_functions_window::rank::rank_udwf(),
1621            ),
1622            vec![],
1623        ))
1624        .partition_by(vec![col("a"), col("b")])
1625        .order_by(vec![col("c").sort(true, true)])
1626        .build()
1627        .unwrap();
1628
1629        let plan = LogicalPlanBuilder::from(table_scan)
1630            .window(vec![window])?
1631            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1632            .build()?;
1633
1634        assert_optimized_plan_equal!(
1635            plan,
1636            @r"
1637        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1638          TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1639        "
1640        )
1641    }
1642
1643    /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed
1644    #[test]
1645    fn filter_move_partial_window() -> Result<()> {
1646        let table_scan = test_table_scan()?;
1647
1648        let window = Expr::from(WindowFunction::new(
1649            WindowFunctionDefinition::WindowUDF(
1650                datafusion_functions_window::rank::rank_udwf(),
1651            ),
1652            vec![],
1653        ))
1654        .partition_by(vec![col("a")])
1655        .order_by(vec![col("c").sort(true, true)])
1656        .build()
1657        .unwrap();
1658
1659        let plan = LogicalPlanBuilder::from(table_scan)
1660            .window(vec![window])?
1661            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1662            .build()?;
1663
1664        assert_optimized_plan_equal!(
1665            plan,
1666            @r"
1667        Filter: test.b = Int64(1)
1668          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1669            TableScan: test, full_filters=[test.a > Int64(10)]
1670        "
1671        )
1672    }
1673
1674    /// verifies that filters on partition expressions are not pushed, as the single expression
1675    /// column is not available to the user, unlike with aggregations
1676    #[test]
1677    fn filter_expression_keep_window() -> Result<()> {
1678        let table_scan = test_table_scan()?;
1679
1680        let window = Expr::from(WindowFunction::new(
1681            WindowFunctionDefinition::WindowUDF(
1682                datafusion_functions_window::rank::rank_udwf(),
1683            ),
1684            vec![],
1685        ))
1686        .partition_by(vec![add(col("a"), col("b"))]) // PARTITION BY a + b
1687        .order_by(vec![col("c").sort(true, true)])
1688        .build()
1689        .unwrap();
1690
1691        let plan = LogicalPlanBuilder::from(table_scan)
1692            .window(vec![window])?
1693            // unlike with aggregations, single partition column "test.a + test.b" is not available
1694            // to the plan, so we use multiple columns when filtering
1695            .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1696            .build()?;
1697
1698        assert_optimized_plan_equal!(
1699            plan,
1700            @r"
1701        Filter: test.a + test.b > Int64(10)
1702          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1703            TableScan: test
1704        "
1705        )
1706    }
1707
1708    /// verifies that filters are not pushed on order by columns (that are not used in partitioning)
1709    #[test]
1710    fn filter_order_keep_window() -> Result<()> {
1711        let table_scan = test_table_scan()?;
1712
1713        let window = Expr::from(WindowFunction::new(
1714            WindowFunctionDefinition::WindowUDF(
1715                datafusion_functions_window::rank::rank_udwf(),
1716            ),
1717            vec![],
1718        ))
1719        .partition_by(vec![col("a")])
1720        .order_by(vec![col("c").sort(true, true)])
1721        .build()
1722        .unwrap();
1723
1724        let plan = LogicalPlanBuilder::from(table_scan)
1725            .window(vec![window])?
1726            .filter(col("c").gt(lit(10i64)))?
1727            .build()?;
1728
1729        assert_optimized_plan_equal!(
1730            plan,
1731            @r"
1732        Filter: test.c > Int64(10)
1733          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1734            TableScan: test
1735        "
1736        )
1737    }
1738
1739    /// verifies that when we use multiple window functions with a common partition key, the filter
1740    /// on that key is pushed
1741    #[test]
1742    fn filter_multiple_windows_common_partitions() -> Result<()> {
1743        let table_scan = test_table_scan()?;
1744
1745        let window1 = Expr::from(WindowFunction::new(
1746            WindowFunctionDefinition::WindowUDF(
1747                datafusion_functions_window::rank::rank_udwf(),
1748            ),
1749            vec![],
1750        ))
1751        .partition_by(vec![col("a")])
1752        .order_by(vec![col("c").sort(true, true)])
1753        .build()
1754        .unwrap();
1755
1756        let window2 = Expr::from(WindowFunction::new(
1757            WindowFunctionDefinition::WindowUDF(
1758                datafusion_functions_window::rank::rank_udwf(),
1759            ),
1760            vec![],
1761        ))
1762        .partition_by(vec![col("b"), col("a")])
1763        .order_by(vec![col("c").sort(true, true)])
1764        .build()
1765        .unwrap();
1766
1767        let plan = LogicalPlanBuilder::from(table_scan)
1768            .window(vec![window1, window2])?
1769            .filter(col("a").gt(lit(10i64)))? // a appears in both window functions
1770            .build()?;
1771
1772        assert_optimized_plan_equal!(
1773            plan,
1774            @r"
1775        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1776          TableScan: test, full_filters=[test.a > Int64(10)]
1777        "
1778        )
1779    }
1780
1781    /// verifies that when we use multiple window functions with different partitions keys, the
1782    /// filter cannot be pushed
1783    #[test]
1784    fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1785        let table_scan = test_table_scan()?;
1786
1787        let window1 = Expr::from(WindowFunction::new(
1788            WindowFunctionDefinition::WindowUDF(
1789                datafusion_functions_window::rank::rank_udwf(),
1790            ),
1791            vec![],
1792        ))
1793        .partition_by(vec![col("a")])
1794        .order_by(vec![col("c").sort(true, true)])
1795        .build()
1796        .unwrap();
1797
1798        let window2 = Expr::from(WindowFunction::new(
1799            WindowFunctionDefinition::WindowUDF(
1800                datafusion_functions_window::rank::rank_udwf(),
1801            ),
1802            vec![],
1803        ))
1804        .partition_by(vec![col("b"), col("a")])
1805        .order_by(vec![col("c").sort(true, true)])
1806        .build()
1807        .unwrap();
1808
1809        let plan = LogicalPlanBuilder::from(table_scan)
1810            .window(vec![window1, window2])?
1811            .filter(col("b").gt(lit(10i64)))? // b only appears in one window function
1812            .build()?;
1813
1814        assert_optimized_plan_equal!(
1815            plan,
1816            @r"
1817        Filter: test.b > Int64(10)
1818          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1819            TableScan: test
1820        "
1821        )
1822    }
1823
1824    /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
1825    #[test]
1826    fn alias() -> Result<()> {
1827        let table_scan = test_table_scan()?;
1828        let plan = LogicalPlanBuilder::from(table_scan)
1829            .project(vec![col("a").alias("b"), col("c")])?
1830            .filter(col("b").eq(lit(1i64)))?
1831            .build()?;
1832        // filter is before projection
1833        assert_optimized_plan_equal!(
1834            plan,
1835            @r"
1836        Projection: test.a AS b, test.c
1837          TableScan: test, full_filters=[test.a = Int64(1)]
1838        "
1839        )
1840    }
1841
1842    fn add(left: Expr, right: Expr) -> Expr {
1843        Expr::BinaryExpr(BinaryExpr::new(
1844            Box::new(left),
1845            Operator::Plus,
1846            Box::new(right),
1847        ))
1848    }
1849
1850    fn multiply(left: Expr, right: Expr) -> Expr {
1851        Expr::BinaryExpr(BinaryExpr::new(
1852            Box::new(left),
1853            Operator::Multiply,
1854            Box::new(right),
1855        ))
1856    }
1857
1858    /// verifies that a filter is pushed to before a projection with a complex expression, the filter expression is correctly re-written
1859    #[test]
1860    fn complex_expression() -> Result<()> {
1861        let table_scan = test_table_scan()?;
1862        let plan = LogicalPlanBuilder::from(table_scan)
1863            .project(vec![
1864                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1865                col("c"),
1866            ])?
1867            .filter(col("b").eq(lit(1i64)))?
1868            .build()?;
1869
1870        // not part of the test, just good to know:
1871        assert_snapshot!(plan,
1872        @r"
1873        Filter: b = Int64(1)
1874          Projection: test.a * Int32(2) + test.c AS b, test.c
1875            TableScan: test
1876        ",
1877        );
1878        // filter is before projection
1879        assert_optimized_plan_equal!(
1880            plan,
1881            @r"
1882        Projection: test.a * Int32(2) + test.c AS b, test.c
1883          TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1884        "
1885        )
1886    }
1887
1888    /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written
1889    #[test]
1890    fn complex_plan() -> Result<()> {
1891        let table_scan = test_table_scan()?;
1892        let plan = LogicalPlanBuilder::from(table_scan)
1893            .project(vec![
1894                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1895                col("c"),
1896            ])?
1897            // second projection where we rename columns, just to make it difficult
1898            .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
1899            .filter(col("a").eq(lit(1i64)))?
1900            .build()?;
1901
1902        // not part of the test, just good to know:
1903        assert_snapshot!(plan,
1904        @r"
1905        Filter: a = Int64(1)
1906          Projection: b * Int32(3) AS a, test.c
1907            Projection: test.a * Int32(2) + test.c AS b, test.c
1908              TableScan: test
1909        ",
1910        );
1911        // filter is before the projections
1912        assert_optimized_plan_equal!(
1913            plan,
1914            @r"
1915        Projection: b * Int32(3) AS a, test.c
1916          Projection: test.a * Int32(2) + test.c AS b, test.c
1917            TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
1918        "
1919        )
1920    }
1921
1922    #[derive(Debug, PartialEq, Eq, Hash)]
1923    struct NoopPlan {
1924        input: Vec<LogicalPlan>,
1925        schema: DFSchemaRef,
1926    }
1927
1928    // Manual implementation needed because of `schema` field. Comparison excludes this field.
1929    impl PartialOrd for NoopPlan {
1930        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1931            self.input.partial_cmp(&other.input)
1932        }
1933    }
1934
1935    impl UserDefinedLogicalNodeCore for NoopPlan {
1936        fn name(&self) -> &str {
1937            "NoopPlan"
1938        }
1939
1940        fn inputs(&self) -> Vec<&LogicalPlan> {
1941            self.input.iter().collect()
1942        }
1943
1944        fn schema(&self) -> &DFSchemaRef {
1945            &self.schema
1946        }
1947
1948        fn expressions(&self) -> Vec<Expr> {
1949            self.input
1950                .iter()
1951                .flat_map(|child| child.expressions())
1952                .collect()
1953        }
1954
1955        fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
1956            HashSet::from_iter(vec!["c".to_string()])
1957        }
1958
1959        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1960            write!(f, "NoopPlan")
1961        }
1962
1963        fn with_exprs_and_inputs(
1964            &self,
1965            _exprs: Vec<Expr>,
1966            inputs: Vec<LogicalPlan>,
1967        ) -> Result<Self> {
1968            Ok(Self {
1969                input: inputs,
1970                schema: Arc::clone(&self.schema),
1971            })
1972        }
1973
1974        fn supports_limit_pushdown(&self) -> bool {
1975            false // Disallow limit push-down by default
1976        }
1977    }
1978
1979    #[test]
1980    fn user_defined_plan() -> Result<()> {
1981        let table_scan = test_table_scan()?;
1982
1983        let custom_plan = LogicalPlan::Extension(Extension {
1984            node: Arc::new(NoopPlan {
1985                input: vec![table_scan.clone()],
1986                schema: Arc::clone(table_scan.schema()),
1987            }),
1988        });
1989        let plan = LogicalPlanBuilder::from(custom_plan)
1990            .filter(col("a").eq(lit(1i64)))?
1991            .build()?;
1992
1993        // Push filter below NoopPlan
1994        assert_optimized_plan_equal!(
1995            plan,
1996            @r"
1997        NoopPlan
1998          TableScan: test, full_filters=[test.a = Int64(1)]
1999        "
2000        )?;
2001
2002        let custom_plan = LogicalPlan::Extension(Extension {
2003            node: Arc::new(NoopPlan {
2004                input: vec![table_scan.clone()],
2005                schema: Arc::clone(table_scan.schema()),
2006            }),
2007        });
2008        let plan = LogicalPlanBuilder::from(custom_plan)
2009            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2010            .build()?;
2011
2012        // Push only predicate on `a` below NoopPlan
2013        assert_optimized_plan_equal!(
2014            plan,
2015            @r"
2016        Filter: test.c = Int64(2)
2017          NoopPlan
2018            TableScan: test, full_filters=[test.a = Int64(1)]
2019        "
2020        )?;
2021
2022        let custom_plan = LogicalPlan::Extension(Extension {
2023            node: Arc::new(NoopPlan {
2024                input: vec![table_scan.clone(), table_scan.clone()],
2025                schema: Arc::clone(table_scan.schema()),
2026            }),
2027        });
2028        let plan = LogicalPlanBuilder::from(custom_plan)
2029            .filter(col("a").eq(lit(1i64)))?
2030            .build()?;
2031
2032        // Push filter below NoopPlan for each child branch
2033        assert_optimized_plan_equal!(
2034            plan,
2035            @r"
2036        NoopPlan
2037          TableScan: test, full_filters=[test.a = Int64(1)]
2038          TableScan: test, full_filters=[test.a = Int64(1)]
2039        "
2040        )?;
2041
2042        let custom_plan = LogicalPlan::Extension(Extension {
2043            node: Arc::new(NoopPlan {
2044                input: vec![table_scan.clone(), table_scan.clone()],
2045                schema: Arc::clone(table_scan.schema()),
2046            }),
2047        });
2048        let plan = LogicalPlanBuilder::from(custom_plan)
2049            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2050            .build()?;
2051
2052        // Push only predicate on `a` below NoopPlan
2053        assert_optimized_plan_equal!(
2054            plan,
2055            @r"
2056        Filter: test.c = Int64(2)
2057          NoopPlan
2058            TableScan: test, full_filters=[test.a = Int64(1)]
2059            TableScan: test, full_filters=[test.a = Int64(1)]
2060        "
2061        )
2062    }
2063
2064    /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed
2065    /// and the other not.
2066    #[test]
2067    fn multi_filter() -> Result<()> {
2068        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2069        let table_scan = test_table_scan()?;
2070        let plan = LogicalPlanBuilder::from(table_scan)
2071            .project(vec![col("a").alias("b"), col("c")])?
2072            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2073            .filter(col("b").gt(lit(10i64)))?
2074            .filter(col("sum(test.c)").gt(lit(10i64)))?
2075            .build()?;
2076
2077        // not part of the test, just good to know:
2078        assert_snapshot!(plan,
2079        @r"
2080        Filter: sum(test.c) > Int64(10)
2081          Filter: b > Int64(10)
2082            Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2083              Projection: test.a AS b, test.c
2084                TableScan: test
2085        ",
2086        );
2087        // filter is before the projections
2088        assert_optimized_plan_equal!(
2089            plan,
2090            @r"
2091        Filter: sum(test.c) > Int64(10)
2092          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2093            Projection: test.a AS b, test.c
2094              TableScan: test, full_filters=[test.a > Int64(10)]
2095        "
2096        )
2097    }
2098
2099    /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed
2100    /// and the other not.
2101    #[test]
2102    fn split_filter() -> Result<()> {
2103        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2104        let table_scan = test_table_scan()?;
2105        let plan = LogicalPlanBuilder::from(table_scan)
2106            .project(vec![col("a").alias("b"), col("c")])?
2107            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2108            .filter(and(
2109                col("sum(test.c)").gt(lit(10i64)),
2110                and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2111            ))?
2112            .build()?;
2113
2114        // not part of the test, just good to know:
2115        assert_snapshot!(plan,
2116        @r"
2117        Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2118          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2119            Projection: test.a AS b, test.c
2120              TableScan: test
2121        ",
2122        );
2123        // filter is before the projections
2124        assert_optimized_plan_equal!(
2125            plan,
2126            @r"
2127        Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2128          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2129            Projection: test.a AS b, test.c
2130              TableScan: test, full_filters=[test.a > Int64(10)]
2131        "
2132        )
2133    }
2134
2135    /// verifies that when two limits are in place, we jump neither
2136    #[test]
2137    fn double_limit() -> Result<()> {
2138        let table_scan = test_table_scan()?;
2139        let plan = LogicalPlanBuilder::from(table_scan)
2140            .project(vec![col("a"), col("b")])?
2141            .limit(0, Some(20))?
2142            .limit(0, Some(10))?
2143            .project(vec![col("a"), col("b")])?
2144            .filter(col("a").eq(lit(1i64)))?
2145            .build()?;
2146        // filter does not just any of the limits
2147        assert_optimized_plan_equal!(
2148            plan,
2149            @r"
2150        Projection: test.a, test.b
2151          Filter: test.a = Int64(1)
2152            Limit: skip=0, fetch=10
2153              Limit: skip=0, fetch=20
2154                Projection: test.a, test.b
2155                  TableScan: test
2156        "
2157        )
2158    }
2159
2160    #[test]
2161    fn union_all() -> Result<()> {
2162        let table_scan = test_table_scan()?;
2163        let table_scan2 = test_table_scan_with_name("test2")?;
2164        let plan = LogicalPlanBuilder::from(table_scan)
2165            .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2166            .filter(col("a").eq(lit(1i64)))?
2167            .build()?;
2168        // filter appears below Union
2169        assert_optimized_plan_equal!(
2170            plan,
2171            @r"
2172        Union
2173          TableScan: test, full_filters=[test.a = Int64(1)]
2174          TableScan: test2, full_filters=[test2.a = Int64(1)]
2175        "
2176        )
2177    }
2178
2179    #[test]
2180    fn union_all_on_projection() -> Result<()> {
2181        let table_scan = test_table_scan()?;
2182        let table = LogicalPlanBuilder::from(table_scan)
2183            .project(vec![col("a").alias("b")])?
2184            .alias("test2")?;
2185
2186        let plan = table
2187            .clone()
2188            .union(table.build()?)?
2189            .filter(col("b").eq(lit(1i64)))?
2190            .build()?;
2191
2192        // filter appears below Union
2193        assert_optimized_plan_equal!(
2194            plan,
2195            @r"
2196        Union
2197          SubqueryAlias: test2
2198            Projection: test.a AS b
2199              TableScan: test, full_filters=[test.a = Int64(1)]
2200          SubqueryAlias: test2
2201            Projection: test.a AS b
2202              TableScan: test, full_filters=[test.a = Int64(1)]
2203        "
2204        )
2205    }
2206
2207    #[test]
2208    fn test_union_different_schema() -> Result<()> {
2209        let left = LogicalPlanBuilder::from(test_table_scan()?)
2210            .project(vec![col("a"), col("b"), col("c")])?
2211            .build()?;
2212
2213        let schema = Schema::new(vec![
2214            Field::new("d", DataType::UInt32, false),
2215            Field::new("e", DataType::UInt32, false),
2216            Field::new("f", DataType::UInt32, false),
2217        ]);
2218        let right = table_scan(Some("test1"), &schema, None)?
2219            .project(vec![col("d"), col("e"), col("f")])?
2220            .build()?;
2221        let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2222        let plan = LogicalPlanBuilder::from(left)
2223            .cross_join(right)?
2224            .project(vec![col("test.a"), col("test1.d")])?
2225            .filter(filter)?
2226            .build()?;
2227
2228        assert_optimized_plan_equal!(
2229            plan,
2230            @r"
2231        Projection: test.a, test1.d
2232          Cross Join: 
2233            Projection: test.a, test.b, test.c
2234              TableScan: test, full_filters=[test.a = Int32(1)]
2235            Projection: test1.d, test1.e, test1.f
2236              TableScan: test1, full_filters=[test1.d > Int32(2)]
2237        "
2238        )
2239    }
2240
2241    #[test]
2242    fn test_project_same_name_different_qualifier() -> Result<()> {
2243        let table_scan = test_table_scan()?;
2244        let left = LogicalPlanBuilder::from(table_scan)
2245            .project(vec![col("a"), col("b"), col("c")])?
2246            .build()?;
2247        let right_table_scan = test_table_scan_with_name("test1")?;
2248        let right = LogicalPlanBuilder::from(right_table_scan)
2249            .project(vec![col("a"), col("b"), col("c")])?
2250            .build()?;
2251        let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2252        let plan = LogicalPlanBuilder::from(left)
2253            .cross_join(right)?
2254            .project(vec![col("test.a"), col("test1.a")])?
2255            .filter(filter)?
2256            .build()?;
2257
2258        assert_optimized_plan_equal!(
2259            plan,
2260            @r"
2261        Projection: test.a, test1.a
2262          Cross Join: 
2263            Projection: test.a, test.b, test.c
2264              TableScan: test, full_filters=[test.a = Int32(1)]
2265            Projection: test1.a, test1.b, test1.c
2266              TableScan: test1, full_filters=[test1.a > Int32(2)]
2267        "
2268        )
2269    }
2270
2271    /// verifies that filters with the same columns are correctly placed
2272    #[test]
2273    fn filter_2_breaks_limits() -> Result<()> {
2274        let table_scan = test_table_scan()?;
2275        let plan = LogicalPlanBuilder::from(table_scan)
2276            .project(vec![col("a")])?
2277            .filter(col("a").lt_eq(lit(1i64)))?
2278            .limit(0, Some(1))?
2279            .project(vec![col("a")])?
2280            .filter(col("a").gt_eq(lit(1i64)))?
2281            .build()?;
2282        // Should be able to move both filters below the projections
2283
2284        // not part of the test
2285        assert_snapshot!(plan,
2286        @r"
2287        Filter: test.a >= Int64(1)
2288          Projection: test.a
2289            Limit: skip=0, fetch=1
2290              Filter: test.a <= Int64(1)
2291                Projection: test.a
2292                  TableScan: test
2293        ",
2294        );
2295        assert_optimized_plan_equal!(
2296            plan,
2297            @r"
2298        Projection: test.a
2299          Filter: test.a >= Int64(1)
2300            Limit: skip=0, fetch=1
2301              Projection: test.a
2302                TableScan: test, full_filters=[test.a <= Int64(1)]
2303        "
2304        )
2305    }
2306
2307    /// verifies that filters to be placed on the same depth are ANDed
2308    #[test]
2309    fn two_filters_on_same_depth() -> Result<()> {
2310        let table_scan = test_table_scan()?;
2311        let plan = LogicalPlanBuilder::from(table_scan)
2312            .limit(0, Some(1))?
2313            .filter(col("a").lt_eq(lit(1i64)))?
2314            .filter(col("a").gt_eq(lit(1i64)))?
2315            .project(vec![col("a")])?
2316            .build()?;
2317
2318        // not part of the test
2319        assert_snapshot!(plan,
2320        @r"
2321        Projection: test.a
2322          Filter: test.a >= Int64(1)
2323            Filter: test.a <= Int64(1)
2324              Limit: skip=0, fetch=1
2325                TableScan: test
2326        ",
2327        );
2328        assert_optimized_plan_equal!(
2329            plan,
2330            @r"
2331        Projection: test.a
2332          Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2333            Limit: skip=0, fetch=1
2334              TableScan: test
2335        "
2336        )
2337    }
2338
2339    /// verifies that filters on a plan with user nodes are not lost
2340    /// (ARROW-10547)
2341    #[test]
2342    fn filters_user_defined_node() -> Result<()> {
2343        let table_scan = test_table_scan()?;
2344        let plan = LogicalPlanBuilder::from(table_scan)
2345            .filter(col("a").lt_eq(lit(1i64)))?
2346            .build()?;
2347
2348        let plan = user_defined::new(plan);
2349
2350        // not part of the test
2351        assert_snapshot!(plan,
2352        @r"
2353        TestUserDefined
2354          Filter: test.a <= Int64(1)
2355            TableScan: test
2356        ",
2357        );
2358        assert_optimized_plan_equal!(
2359            plan,
2360            @r"
2361        TestUserDefined
2362          TableScan: test, full_filters=[test.a <= Int64(1)]
2363        "
2364        )
2365    }
2366
2367    /// post-on-join predicates on a column common to both sides is pushed to both sides
2368    #[test]
2369    fn filter_on_join_on_common_independent() -> Result<()> {
2370        let table_scan = test_table_scan()?;
2371        let left = LogicalPlanBuilder::from(table_scan).build()?;
2372        let right_table_scan = test_table_scan_with_name("test2")?;
2373        let right = LogicalPlanBuilder::from(right_table_scan)
2374            .project(vec![col("a")])?
2375            .build()?;
2376        let plan = LogicalPlanBuilder::from(left)
2377            .join(
2378                right,
2379                JoinType::Inner,
2380                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2381                None,
2382            )?
2383            .filter(col("test.a").lt_eq(lit(1i64)))?
2384            .build()?;
2385
2386        // not part of the test, just good to know:
2387        assert_snapshot!(plan,
2388        @r"
2389        Filter: test.a <= Int64(1)
2390          Inner Join: test.a = test2.a
2391            TableScan: test
2392            Projection: test2.a
2393              TableScan: test2
2394        ",
2395        );
2396        // filter sent to side before the join
2397        assert_optimized_plan_equal!(
2398            plan,
2399            @r"
2400        Inner Join: test.a = test2.a
2401          TableScan: test, full_filters=[test.a <= Int64(1)]
2402          Projection: test2.a
2403            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2404        "
2405        )
2406    }
2407
2408    /// post-using-join predicates on a column common to both sides is pushed to both sides
2409    #[test]
2410    fn filter_using_join_on_common_independent() -> Result<()> {
2411        let table_scan = test_table_scan()?;
2412        let left = LogicalPlanBuilder::from(table_scan).build()?;
2413        let right_table_scan = test_table_scan_with_name("test2")?;
2414        let right = LogicalPlanBuilder::from(right_table_scan)
2415            .project(vec![col("a")])?
2416            .build()?;
2417        let plan = LogicalPlanBuilder::from(left)
2418            .join_using(
2419                right,
2420                JoinType::Inner,
2421                vec![Column::from_name("a".to_string())],
2422            )?
2423            .filter(col("a").lt_eq(lit(1i64)))?
2424            .build()?;
2425
2426        // not part of the test, just good to know:
2427        assert_snapshot!(plan,
2428        @r"
2429        Filter: test.a <= Int64(1)
2430          Inner Join: Using test.a = test2.a
2431            TableScan: test
2432            Projection: test2.a
2433              TableScan: test2
2434        ",
2435        );
2436        // filter sent to side before the join
2437        assert_optimized_plan_equal!(
2438            plan,
2439            @r"
2440        Inner Join: Using test.a = test2.a
2441          TableScan: test, full_filters=[test.a <= Int64(1)]
2442          Projection: test2.a
2443            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2444        "
2445        )
2446    }
2447
2448    /// post-join predicates with columns from both sides are converted to join filters
2449    #[test]
2450    fn filter_join_on_common_dependent() -> Result<()> {
2451        let table_scan = test_table_scan()?;
2452        let left = LogicalPlanBuilder::from(table_scan)
2453            .project(vec![col("a"), col("c")])?
2454            .build()?;
2455        let right_table_scan = test_table_scan_with_name("test2")?;
2456        let right = LogicalPlanBuilder::from(right_table_scan)
2457            .project(vec![col("a"), col("b")])?
2458            .build()?;
2459        let plan = LogicalPlanBuilder::from(left)
2460            .join(
2461                right,
2462                JoinType::Inner,
2463                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2464                None,
2465            )?
2466            .filter(col("c").lt_eq(col("b")))?
2467            .build()?;
2468
2469        // not part of the test, just good to know:
2470        assert_snapshot!(plan,
2471        @r"
2472        Filter: test.c <= test2.b
2473          Inner Join: test.a = test2.a
2474            Projection: test.a, test.c
2475              TableScan: test
2476            Projection: test2.a, test2.b
2477              TableScan: test2
2478        ",
2479        );
2480        // Filter is converted to Join Filter
2481        assert_optimized_plan_equal!(
2482            plan,
2483            @r"
2484        Inner Join: test.a = test2.a Filter: test.c <= test2.b
2485          Projection: test.a, test.c
2486            TableScan: test
2487          Projection: test2.a, test2.b
2488            TableScan: test2
2489        "
2490        )
2491    }
2492
2493    /// post-join predicates with columns from one side of a join are pushed only to that side
2494    #[test]
2495    fn filter_join_on_one_side() -> Result<()> {
2496        let table_scan = test_table_scan()?;
2497        let left = LogicalPlanBuilder::from(table_scan)
2498            .project(vec![col("a"), col("b")])?
2499            .build()?;
2500        let table_scan_right = test_table_scan_with_name("test2")?;
2501        let right = LogicalPlanBuilder::from(table_scan_right)
2502            .project(vec![col("a"), col("c")])?
2503            .build()?;
2504
2505        let plan = LogicalPlanBuilder::from(left)
2506            .join(
2507                right,
2508                JoinType::Inner,
2509                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2510                None,
2511            )?
2512            .filter(col("b").lt_eq(lit(1i64)))?
2513            .build()?;
2514
2515        // not part of the test, just good to know:
2516        assert_snapshot!(plan,
2517        @r"
2518        Filter: test.b <= Int64(1)
2519          Inner Join: test.a = test2.a
2520            Projection: test.a, test.b
2521              TableScan: test
2522            Projection: test2.a, test2.c
2523              TableScan: test2
2524        ",
2525        );
2526        assert_optimized_plan_equal!(
2527            plan,
2528            @r"
2529        Inner Join: test.a = test2.a
2530          Projection: test.a, test.b
2531            TableScan: test, full_filters=[test.b <= Int64(1)]
2532          Projection: test2.a, test2.c
2533            TableScan: test2
2534        "
2535        )
2536    }
2537
2538    /// post-join predicates on the right side of a left join are not duplicated
2539    /// TODO: In this case we can sometimes convert the join to an INNER join
2540    #[test]
2541    fn filter_using_left_join() -> Result<()> {
2542        let table_scan = test_table_scan()?;
2543        let left = LogicalPlanBuilder::from(table_scan).build()?;
2544        let right_table_scan = test_table_scan_with_name("test2")?;
2545        let right = LogicalPlanBuilder::from(right_table_scan)
2546            .project(vec![col("a")])?
2547            .build()?;
2548        let plan = LogicalPlanBuilder::from(left)
2549            .join_using(
2550                right,
2551                JoinType::Left,
2552                vec![Column::from_name("a".to_string())],
2553            )?
2554            .filter(col("test2.a").lt_eq(lit(1i64)))?
2555            .build()?;
2556
2557        // not part of the test, just good to know:
2558        assert_snapshot!(plan,
2559        @r"
2560        Filter: test2.a <= Int64(1)
2561          Left Join: Using test.a = test2.a
2562            TableScan: test
2563            Projection: test2.a
2564              TableScan: test2
2565        ",
2566        );
2567        // filter not duplicated nor pushed down - i.e. noop
2568        assert_optimized_plan_equal!(
2569            plan,
2570            @r"
2571        Filter: test2.a <= Int64(1)
2572          Left Join: Using test.a = test2.a
2573            TableScan: test, full_filters=[test.a <= Int64(1)]
2574            Projection: test2.a
2575              TableScan: test2
2576        "
2577        )
2578    }
2579
2580    /// post-join predicates on the left side of a right join are not duplicated
2581    #[test]
2582    fn filter_using_right_join() -> Result<()> {
2583        let table_scan = test_table_scan()?;
2584        let left = LogicalPlanBuilder::from(table_scan).build()?;
2585        let right_table_scan = test_table_scan_with_name("test2")?;
2586        let right = LogicalPlanBuilder::from(right_table_scan)
2587            .project(vec![col("a")])?
2588            .build()?;
2589        let plan = LogicalPlanBuilder::from(left)
2590            .join_using(
2591                right,
2592                JoinType::Right,
2593                vec![Column::from_name("a".to_string())],
2594            )?
2595            .filter(col("test.a").lt_eq(lit(1i64)))?
2596            .build()?;
2597
2598        // not part of the test, just good to know:
2599        assert_snapshot!(plan,
2600        @r"
2601        Filter: test.a <= Int64(1)
2602          Right Join: Using test.a = test2.a
2603            TableScan: test
2604            Projection: test2.a
2605              TableScan: test2
2606        ",
2607        );
2608        // filter not duplicated nor pushed down - i.e. noop
2609        assert_optimized_plan_equal!(
2610            plan,
2611            @r"
2612        Filter: test.a <= Int64(1)
2613          Right Join: Using test.a = test2.a
2614            TableScan: test
2615            Projection: test2.a
2616              TableScan: test2, full_filters=[test2.a <= Int64(1)]
2617        "
2618        )
2619    }
2620
2621    /// post-left-join predicate on a column common to both sides is only pushed to the left side
2622    /// i.e. - not duplicated to the right side
2623    #[test]
2624    fn filter_using_left_join_on_common() -> Result<()> {
2625        let table_scan = test_table_scan()?;
2626        let left = LogicalPlanBuilder::from(table_scan).build()?;
2627        let right_table_scan = test_table_scan_with_name("test2")?;
2628        let right = LogicalPlanBuilder::from(right_table_scan)
2629            .project(vec![col("a")])?
2630            .build()?;
2631        let plan = LogicalPlanBuilder::from(left)
2632            .join_using(
2633                right,
2634                JoinType::Left,
2635                vec![Column::from_name("a".to_string())],
2636            )?
2637            .filter(col("a").lt_eq(lit(1i64)))?
2638            .build()?;
2639
2640        // not part of the test, just good to know:
2641        assert_snapshot!(plan,
2642        @r"
2643        Filter: test.a <= Int64(1)
2644          Left Join: Using test.a = test2.a
2645            TableScan: test
2646            Projection: test2.a
2647              TableScan: test2
2648        ",
2649        );
2650        // filter sent to left side of the join, not the right
2651        assert_optimized_plan_equal!(
2652            plan,
2653            @r"
2654        Left Join: Using test.a = test2.a
2655          TableScan: test, full_filters=[test.a <= Int64(1)]
2656          Projection: test2.a
2657            TableScan: test2
2658        "
2659        )
2660    }
2661
2662    /// post-right-join predicate on a column common to both sides is only pushed to the right side
2663    /// i.e. - not duplicated to the left side.
2664    #[test]
2665    fn filter_using_right_join_on_common() -> Result<()> {
2666        let table_scan = test_table_scan()?;
2667        let left = LogicalPlanBuilder::from(table_scan).build()?;
2668        let right_table_scan = test_table_scan_with_name("test2")?;
2669        let right = LogicalPlanBuilder::from(right_table_scan)
2670            .project(vec![col("a")])?
2671            .build()?;
2672        let plan = LogicalPlanBuilder::from(left)
2673            .join_using(
2674                right,
2675                JoinType::Right,
2676                vec![Column::from_name("a".to_string())],
2677            )?
2678            .filter(col("test2.a").lt_eq(lit(1i64)))?
2679            .build()?;
2680
2681        // not part of the test, just good to know:
2682        assert_snapshot!(plan,
2683        @r"
2684        Filter: test2.a <= Int64(1)
2685          Right Join: Using test.a = test2.a
2686            TableScan: test
2687            Projection: test2.a
2688              TableScan: test2
2689        ",
2690        );
2691        // filter sent to right side of join, not duplicated to the left
2692        assert_optimized_plan_equal!(
2693            plan,
2694            @r"
2695        Right Join: Using test.a = test2.a
2696          TableScan: test
2697          Projection: test2.a
2698            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2699        "
2700        )
2701    }
2702
2703    /// single table predicate parts of ON condition should be pushed to both inputs
2704    #[test]
2705    fn join_on_with_filter() -> Result<()> {
2706        let table_scan = test_table_scan()?;
2707        let left = LogicalPlanBuilder::from(table_scan)
2708            .project(vec![col("a"), col("b"), col("c")])?
2709            .build()?;
2710        let right_table_scan = test_table_scan_with_name("test2")?;
2711        let right = LogicalPlanBuilder::from(right_table_scan)
2712            .project(vec![col("a"), col("b"), col("c")])?
2713            .build()?;
2714        let filter = col("test.c")
2715            .gt(lit(1u32))
2716            .and(col("test.b").lt(col("test2.b")))
2717            .and(col("test2.c").gt(lit(4u32)));
2718        let plan = LogicalPlanBuilder::from(left)
2719            .join(
2720                right,
2721                JoinType::Inner,
2722                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2723                Some(filter),
2724            )?
2725            .build()?;
2726
2727        // not part of the test, just good to know:
2728        assert_snapshot!(plan,
2729        @r"
2730        Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2731          Projection: test.a, test.b, test.c
2732            TableScan: test
2733          Projection: test2.a, test2.b, test2.c
2734            TableScan: test2
2735        ",
2736        );
2737        assert_optimized_plan_equal!(
2738            plan,
2739            @r"
2740        Inner Join: test.a = test2.a Filter: test.b < test2.b
2741          Projection: test.a, test.b, test.c
2742            TableScan: test, full_filters=[test.c > UInt32(1)]
2743          Projection: test2.a, test2.b, test2.c
2744            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2745        "
2746        )
2747    }
2748
2749    /// join filter should be completely removed after pushdown
2750    #[test]
2751    fn join_filter_removed() -> Result<()> {
2752        let table_scan = test_table_scan()?;
2753        let left = LogicalPlanBuilder::from(table_scan)
2754            .project(vec![col("a"), col("b"), col("c")])?
2755            .build()?;
2756        let right_table_scan = test_table_scan_with_name("test2")?;
2757        let right = LogicalPlanBuilder::from(right_table_scan)
2758            .project(vec![col("a"), col("b"), col("c")])?
2759            .build()?;
2760        let filter = col("test.b")
2761            .gt(lit(1u32))
2762            .and(col("test2.c").gt(lit(4u32)));
2763        let plan = LogicalPlanBuilder::from(left)
2764            .join(
2765                right,
2766                JoinType::Inner,
2767                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2768                Some(filter),
2769            )?
2770            .build()?;
2771
2772        // not part of the test, just good to know:
2773        assert_snapshot!(plan,
2774        @r"
2775        Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2776          Projection: test.a, test.b, test.c
2777            TableScan: test
2778          Projection: test2.a, test2.b, test2.c
2779            TableScan: test2
2780        ",
2781        );
2782        assert_optimized_plan_equal!(
2783            plan,
2784            @r"
2785        Inner Join: test.a = test2.a
2786          Projection: test.a, test.b, test.c
2787            TableScan: test, full_filters=[test.b > UInt32(1)]
2788          Projection: test2.a, test2.b, test2.c
2789            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2790        "
2791        )
2792    }
2793
2794    /// predicate on join key in filter expression should be pushed down to both inputs
2795    #[test]
2796    fn join_filter_on_common() -> Result<()> {
2797        let table_scan = test_table_scan()?;
2798        let left = LogicalPlanBuilder::from(table_scan)
2799            .project(vec![col("a")])?
2800            .build()?;
2801        let right_table_scan = test_table_scan_with_name("test2")?;
2802        let right = LogicalPlanBuilder::from(right_table_scan)
2803            .project(vec![col("b")])?
2804            .build()?;
2805        let filter = col("test.a").gt(lit(1u32));
2806        let plan = LogicalPlanBuilder::from(left)
2807            .join(
2808                right,
2809                JoinType::Inner,
2810                (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2811                Some(filter),
2812            )?
2813            .build()?;
2814
2815        // not part of the test, just good to know:
2816        assert_snapshot!(plan,
2817        @r"
2818        Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2819          Projection: test.a
2820            TableScan: test
2821          Projection: test2.b
2822            TableScan: test2
2823        ",
2824        );
2825        assert_optimized_plan_equal!(
2826            plan,
2827            @r"
2828        Inner Join: test.a = test2.b
2829          Projection: test.a
2830            TableScan: test, full_filters=[test.a > UInt32(1)]
2831          Projection: test2.b
2832            TableScan: test2, full_filters=[test2.b > UInt32(1)]
2833        "
2834        )
2835    }
2836
2837    /// single table predicate parts of ON condition should be pushed to right input
2838    #[test]
2839    fn left_join_on_with_filter() -> Result<()> {
2840        let table_scan = test_table_scan()?;
2841        let left = LogicalPlanBuilder::from(table_scan)
2842            .project(vec![col("a"), col("b"), col("c")])?
2843            .build()?;
2844        let right_table_scan = test_table_scan_with_name("test2")?;
2845        let right = LogicalPlanBuilder::from(right_table_scan)
2846            .project(vec![col("a"), col("b"), col("c")])?
2847            .build()?;
2848        let filter = col("test.a")
2849            .gt(lit(1u32))
2850            .and(col("test.b").lt(col("test2.b")))
2851            .and(col("test2.c").gt(lit(4u32)));
2852        let plan = LogicalPlanBuilder::from(left)
2853            .join(
2854                right,
2855                JoinType::Left,
2856                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2857                Some(filter),
2858            )?
2859            .build()?;
2860
2861        // not part of the test, just good to know:
2862        assert_snapshot!(plan,
2863        @r"
2864        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2865          Projection: test.a, test.b, test.c
2866            TableScan: test
2867          Projection: test2.a, test2.b, test2.c
2868            TableScan: test2
2869        ",
2870        );
2871        assert_optimized_plan_equal!(
2872            plan,
2873            @r"
2874        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2875          Projection: test.a, test.b, test.c
2876            TableScan: test
2877          Projection: test2.a, test2.b, test2.c
2878            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2879        "
2880        )
2881    }
2882
2883    /// single table predicate parts of ON condition should be pushed to left input
2884    #[test]
2885    fn right_join_on_with_filter() -> Result<()> {
2886        let table_scan = test_table_scan()?;
2887        let left = LogicalPlanBuilder::from(table_scan)
2888            .project(vec![col("a"), col("b"), col("c")])?
2889            .build()?;
2890        let right_table_scan = test_table_scan_with_name("test2")?;
2891        let right = LogicalPlanBuilder::from(right_table_scan)
2892            .project(vec![col("a"), col("b"), col("c")])?
2893            .build()?;
2894        let filter = col("test.a")
2895            .gt(lit(1u32))
2896            .and(col("test.b").lt(col("test2.b")))
2897            .and(col("test2.c").gt(lit(4u32)));
2898        let plan = LogicalPlanBuilder::from(left)
2899            .join(
2900                right,
2901                JoinType::Right,
2902                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2903                Some(filter),
2904            )?
2905            .build()?;
2906
2907        // not part of the test, just good to know:
2908        assert_snapshot!(plan,
2909        @r"
2910        Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2911          Projection: test.a, test.b, test.c
2912            TableScan: test
2913          Projection: test2.a, test2.b, test2.c
2914            TableScan: test2
2915        ",
2916        );
2917        assert_optimized_plan_equal!(
2918            plan,
2919            @r"
2920        Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
2921          Projection: test.a, test.b, test.c
2922            TableScan: test, full_filters=[test.a > UInt32(1)]
2923          Projection: test2.a, test2.b, test2.c
2924            TableScan: test2
2925        "
2926        )
2927    }
2928
2929    /// single table predicate parts of ON condition should not be pushed
2930    #[test]
2931    fn full_join_on_with_filter() -> Result<()> {
2932        let table_scan = test_table_scan()?;
2933        let left = LogicalPlanBuilder::from(table_scan)
2934            .project(vec![col("a"), col("b"), col("c")])?
2935            .build()?;
2936        let right_table_scan = test_table_scan_with_name("test2")?;
2937        let right = LogicalPlanBuilder::from(right_table_scan)
2938            .project(vec![col("a"), col("b"), col("c")])?
2939            .build()?;
2940        let filter = col("test.a")
2941            .gt(lit(1u32))
2942            .and(col("test.b").lt(col("test2.b")))
2943            .and(col("test2.c").gt(lit(4u32)));
2944        let plan = LogicalPlanBuilder::from(left)
2945            .join(
2946                right,
2947                JoinType::Full,
2948                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2949                Some(filter),
2950            )?
2951            .build()?;
2952
2953        // not part of the test, just good to know:
2954        assert_snapshot!(plan,
2955        @r"
2956        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2957          Projection: test.a, test.b, test.c
2958            TableScan: test
2959          Projection: test2.a, test2.b, test2.c
2960            TableScan: test2
2961        ",
2962        );
2963        assert_optimized_plan_equal!(
2964            plan,
2965            @r"
2966        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2967          Projection: test.a, test.b, test.c
2968            TableScan: test
2969          Projection: test2.a, test2.b, test2.c
2970            TableScan: test2
2971        "
2972        )
2973    }
2974
2975    struct PushDownProvider {
2976        pub filter_support: TableProviderFilterPushDown,
2977    }
2978
2979    #[async_trait]
2980    impl TableSource for PushDownProvider {
2981        fn schema(&self) -> SchemaRef {
2982            Arc::new(Schema::new(vec![
2983                Field::new("a", DataType::Int32, true),
2984                Field::new("b", DataType::Int32, true),
2985            ]))
2986        }
2987
2988        fn table_type(&self) -> TableType {
2989            TableType::Base
2990        }
2991
2992        fn supports_filters_pushdown(
2993            &self,
2994            filters: &[&Expr],
2995        ) -> Result<Vec<TableProviderFilterPushDown>> {
2996            Ok((0..filters.len())
2997                .map(|_| self.filter_support.clone())
2998                .collect())
2999        }
3000
3001        fn as_any(&self) -> &dyn Any {
3002            self
3003        }
3004    }
3005
3006    fn table_scan_with_pushdown_provider_builder(
3007        filter_support: TableProviderFilterPushDown,
3008        filters: Vec<Expr>,
3009        projection: Option<Vec<usize>>,
3010    ) -> Result<LogicalPlanBuilder> {
3011        let test_provider = PushDownProvider { filter_support };
3012
3013        let table_scan = LogicalPlan::TableScan(TableScan {
3014            table_name: "test".into(),
3015            filters,
3016            projected_schema: Arc::new(DFSchema::try_from(
3017                (*test_provider.schema()).clone(),
3018            )?),
3019            projection,
3020            source: Arc::new(test_provider),
3021            fetch: None,
3022        });
3023
3024        Ok(LogicalPlanBuilder::from(table_scan))
3025    }
3026
3027    fn table_scan_with_pushdown_provider(
3028        filter_support: TableProviderFilterPushDown,
3029    ) -> Result<LogicalPlan> {
3030        table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3031            .filter(col("a").eq(lit(1i64)))?
3032            .build()
3033    }
3034
3035    #[test]
3036    fn filter_with_table_provider_exact() -> Result<()> {
3037        let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3038
3039        assert_optimized_plan_equal!(
3040            plan,
3041            @"TableScan: test, full_filters=[a = Int64(1)]"
3042        )
3043    }
3044
3045    #[test]
3046    fn filter_with_table_provider_inexact() -> Result<()> {
3047        let plan =
3048            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3049
3050        assert_optimized_plan_equal!(
3051            plan,
3052            @r"
3053        Filter: a = Int64(1)
3054          TableScan: test, partial_filters=[a = Int64(1)]
3055        "
3056        )
3057    }
3058
3059    #[test]
3060    fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3061        let plan =
3062            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3063
3064        let optimized_plan = PushDownFilter::new()
3065            .rewrite(plan, &OptimizerContext::new())
3066            .expect("failed to optimize plan")
3067            .data;
3068
3069        // Optimizing the same plan multiple times should produce the same plan
3070        // each time.
3071        assert_optimized_plan_equal!(
3072            optimized_plan,
3073            @r"
3074        Filter: a = Int64(1)
3075          TableScan: test, partial_filters=[a = Int64(1)]
3076        "
3077        )
3078    }
3079
3080    #[test]
3081    fn filter_with_table_provider_unsupported() -> Result<()> {
3082        let plan =
3083            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3084
3085        assert_optimized_plan_equal!(
3086            plan,
3087            @r"
3088        Filter: a = Int64(1)
3089          TableScan: test
3090        "
3091        )
3092    }
3093
3094    #[test]
3095    fn multi_combined_filter() -> Result<()> {
3096        let plan = table_scan_with_pushdown_provider_builder(
3097            TableProviderFilterPushDown::Inexact,
3098            vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3099            Some(vec![0]),
3100        )?
3101        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3102        .project(vec![col("a"), col("b")])?
3103        .build()?;
3104
3105        assert_optimized_plan_equal!(
3106            plan,
3107            @r"
3108        Projection: a, b
3109          Filter: a = Int64(10) AND b > Int64(11)
3110            TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3111        "
3112        )
3113    }
3114
3115    #[test]
3116    fn multi_combined_filter_exact() -> Result<()> {
3117        let plan = table_scan_with_pushdown_provider_builder(
3118            TableProviderFilterPushDown::Exact,
3119            vec![],
3120            Some(vec![0]),
3121        )?
3122        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3123        .project(vec![col("a"), col("b")])?
3124        .build()?;
3125
3126        assert_optimized_plan_equal!(
3127            plan,
3128            @r"
3129        Projection: a, b
3130          TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3131        "
3132        )
3133    }
3134
3135    #[test]
3136    fn test_filter_with_alias() -> Result<()> {
3137        // in table scan the true col name is 'test.a',
3138        // but we rename it as 'b', and use col 'b' in filter
3139        // we need rewrite filter col before push down.
3140        let table_scan = test_table_scan()?;
3141        let plan = LogicalPlanBuilder::from(table_scan)
3142            .project(vec![col("a").alias("b"), col("c")])?
3143            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3144            .build()?;
3145
3146        // filter on col b
3147        assert_snapshot!(plan,
3148        @r"
3149        Filter: b > Int64(10) AND test.c > Int64(10)
3150          Projection: test.a AS b, test.c
3151            TableScan: test
3152        ",
3153        );
3154        // rewrite filter col b to test.a
3155        assert_optimized_plan_equal!(
3156            plan,
3157            @r"
3158        Projection: test.a AS b, test.c
3159          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3160        "
3161        )
3162    }
3163
3164    #[test]
3165    fn test_filter_with_alias_2() -> Result<()> {
3166        // in table scan the true col name is 'test.a',
3167        // but we rename it as 'b', and use col 'b' in filter
3168        // we need rewrite filter col before push down.
3169        let table_scan = test_table_scan()?;
3170        let plan = LogicalPlanBuilder::from(table_scan)
3171            .project(vec![col("a").alias("b"), col("c")])?
3172            .project(vec![col("b"), col("c")])?
3173            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3174            .build()?;
3175
3176        // filter on col b
3177        assert_snapshot!(plan,
3178        @r"
3179        Filter: b > Int64(10) AND test.c > Int64(10)
3180          Projection: b, test.c
3181            Projection: test.a AS b, test.c
3182              TableScan: test
3183        ",
3184        );
3185        // rewrite filter col b to test.a
3186        assert_optimized_plan_equal!(
3187            plan,
3188            @r"
3189        Projection: b, test.c
3190          Projection: test.a AS b, test.c
3191            TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3192        "
3193        )
3194    }
3195
3196    #[test]
3197    fn test_filter_with_multi_alias() -> Result<()> {
3198        let table_scan = test_table_scan()?;
3199        let plan = LogicalPlanBuilder::from(table_scan)
3200            .project(vec![col("a").alias("b"), col("c").alias("d")])?
3201            .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3202            .build()?;
3203
3204        // filter on col b and d
3205        assert_snapshot!(plan,
3206        @r"
3207        Filter: b > Int64(10) AND d > Int64(10)
3208          Projection: test.a AS b, test.c AS d
3209            TableScan: test
3210        ",
3211        );
3212        // rewrite filter col b to test.a, col d to test.c
3213        assert_optimized_plan_equal!(
3214            plan,
3215            @r"
3216        Projection: test.a AS b, test.c AS d
3217          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3218        "
3219        )
3220    }
3221
3222    /// predicate on join key in filter expression should be pushed down to both inputs
3223    #[test]
3224    fn join_filter_with_alias() -> Result<()> {
3225        let table_scan = test_table_scan()?;
3226        let left = LogicalPlanBuilder::from(table_scan)
3227            .project(vec![col("a").alias("c")])?
3228            .build()?;
3229        let right_table_scan = test_table_scan_with_name("test2")?;
3230        let right = LogicalPlanBuilder::from(right_table_scan)
3231            .project(vec![col("b").alias("d")])?
3232            .build()?;
3233        let filter = col("c").gt(lit(1u32));
3234        let plan = LogicalPlanBuilder::from(left)
3235            .join(
3236                right,
3237                JoinType::Inner,
3238                (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3239                Some(filter),
3240            )?
3241            .build()?;
3242
3243        assert_snapshot!(plan,
3244        @r"
3245        Inner Join: c = d Filter: c > UInt32(1)
3246          Projection: test.a AS c
3247            TableScan: test
3248          Projection: test2.b AS d
3249            TableScan: test2
3250        ",
3251        );
3252        // Change filter on col `c`, 'd' to `test.a`, 'test.b'
3253        assert_optimized_plan_equal!(
3254            plan,
3255            @r"
3256        Inner Join: c = d
3257          Projection: test.a AS c
3258            TableScan: test, full_filters=[test.a > UInt32(1)]
3259          Projection: test2.b AS d
3260            TableScan: test2, full_filters=[test2.b > UInt32(1)]
3261        "
3262        )
3263    }
3264
3265    #[test]
3266    fn test_in_filter_with_alias() -> Result<()> {
3267        // in table scan the true col name is 'test.a',
3268        // but we rename it as 'b', and use col 'b' in filter
3269        // we need rewrite filter col before push down.
3270        let table_scan = test_table_scan()?;
3271        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3272        let plan = LogicalPlanBuilder::from(table_scan)
3273            .project(vec![col("a").alias("b"), col("c")])?
3274            .filter(in_list(col("b"), filter_value, false))?
3275            .build()?;
3276
3277        // filter on col b
3278        assert_snapshot!(plan,
3279        @r"
3280        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3281          Projection: test.a AS b, test.c
3282            TableScan: test
3283        ",
3284        );
3285        // rewrite filter col b to test.a
3286        assert_optimized_plan_equal!(
3287            plan,
3288            @r"
3289        Projection: test.a AS b, test.c
3290          TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3291        "
3292        )
3293    }
3294
3295    #[test]
3296    fn test_in_filter_with_alias_2() -> Result<()> {
3297        // in table scan the true col name is 'test.a',
3298        // but we rename it as 'b', and use col 'b' in filter
3299        // we need rewrite filter col before push down.
3300        let table_scan = test_table_scan()?;
3301        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3302        let plan = LogicalPlanBuilder::from(table_scan)
3303            .project(vec![col("a").alias("b"), col("c")])?
3304            .project(vec![col("b"), col("c")])?
3305            .filter(in_list(col("b"), filter_value, false))?
3306            .build()?;
3307
3308        // filter on col b
3309        assert_snapshot!(plan,
3310        @r"
3311        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3312          Projection: b, test.c
3313            Projection: test.a AS b, test.c
3314              TableScan: test
3315        ",
3316        );
3317        // rewrite filter col b to test.a
3318        assert_optimized_plan_equal!(
3319            plan,
3320            @r"
3321        Projection: b, test.c
3322          Projection: test.a AS b, test.c
3323            TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3324        "
3325        )
3326    }
3327
3328    #[test]
3329    fn test_in_subquery_with_alias() -> Result<()> {
3330        // in table scan the true col name is 'test.a',
3331        // but we rename it as 'b', and use col 'b' in subquery filter
3332        let table_scan = test_table_scan()?;
3333        let table_scan_sq = test_table_scan_with_name("sq")?;
3334        let subplan = Arc::new(
3335            LogicalPlanBuilder::from(table_scan_sq)
3336                .project(vec![col("c")])?
3337                .build()?,
3338        );
3339        let plan = LogicalPlanBuilder::from(table_scan)
3340            .project(vec![col("a").alias("b"), col("c")])?
3341            .filter(in_subquery(col("b"), subplan))?
3342            .build()?;
3343
3344        // filter on col b in subquery
3345        assert_snapshot!(plan,
3346        @r"
3347        Filter: b IN (<subquery>)
3348          Subquery:
3349            Projection: sq.c
3350              TableScan: sq
3351          Projection: test.a AS b, test.c
3352            TableScan: test
3353        ",
3354        );
3355        // rewrite filter col b to test.a
3356        assert_optimized_plan_equal!(
3357            plan,
3358            @r"
3359        Projection: test.a AS b, test.c
3360          TableScan: test, full_filters=[test.a IN (<subquery>)]
3361            Subquery:
3362              Projection: sq.c
3363                TableScan: sq
3364        "
3365        )
3366    }
3367
3368    #[test]
3369    fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3370        // SELECT a FROM (SELECT 1 AS a) b WHERE b.a = 1
3371        let plan = LogicalPlanBuilder::empty(true)
3372            .project(vec![lit(0i64).alias("a")])?
3373            .alias("b")?
3374            .project(vec![col("b.a")])?
3375            .alias("b")?
3376            .filter(col("b.a").eq(lit(1i64)))?
3377            .project(vec![col("b.a")])?
3378            .build()?;
3379
3380        assert_snapshot!(plan,
3381        @r"
3382        Projection: b.a
3383          Filter: b.a = Int64(1)
3384            SubqueryAlias: b
3385              Projection: b.a
3386                SubqueryAlias: b
3387                  Projection: Int64(0) AS a
3388                    EmptyRelation
3389        ",
3390        );
3391        // Ensure that the predicate without any columns (0 = 1) is
3392        // still there.
3393        assert_optimized_plan_equal!(
3394            plan,
3395            @r"
3396        Projection: b.a
3397          SubqueryAlias: b
3398            Projection: b.a
3399              SubqueryAlias: b
3400                Projection: Int64(0) AS a
3401                  Filter: Int64(0) = Int64(1)
3402                    EmptyRelation
3403        "
3404        )
3405    }
3406
3407    #[test]
3408    fn test_crossjoin_with_or_clause() -> Result<()> {
3409        // select * from test,test1 where (test.a = test1.a and test.b > 1) or (test.b = test1.b and test.c < 10);
3410        let table_scan = test_table_scan()?;
3411        let left = LogicalPlanBuilder::from(table_scan)
3412            .project(vec![col("a"), col("b"), col("c")])?
3413            .build()?;
3414        let right_table_scan = test_table_scan_with_name("test1")?;
3415        let right = LogicalPlanBuilder::from(right_table_scan)
3416            .project(vec![col("a").alias("d"), col("a").alias("e")])?
3417            .build()?;
3418        let filter = or(
3419            and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3420            and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3421        );
3422        let plan = LogicalPlanBuilder::from(left)
3423            .cross_join(right)?
3424            .filter(filter)?
3425            .build()?;
3426
3427        assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3428        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3429          Projection: test.a, test.b, test.c
3430            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3431          Projection: test1.a AS d, test1.a AS e
3432            TableScan: test1
3433        ")?;
3434
3435        // Originally global state which can help to avoid duplicate Filters been generated and pushed down.
3436        // Now the global state is removed. Need to double confirm that avoid duplicate Filters.
3437        let optimized_plan = PushDownFilter::new()
3438            .rewrite(plan, &OptimizerContext::new())
3439            .expect("failed to optimize plan")
3440            .data;
3441        assert_optimized_plan_equal!(
3442            optimized_plan,
3443            @r"
3444        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3445          Projection: test.a, test.b, test.c
3446            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3447          Projection: test1.a AS d, test1.a AS e
3448            TableScan: test1
3449        "
3450        )
3451    }
3452
3453    #[test]
3454    fn left_semi_join() -> Result<()> {
3455        let left = test_table_scan_with_name("test1")?;
3456        let right_table_scan = test_table_scan_with_name("test2")?;
3457        let right = LogicalPlanBuilder::from(right_table_scan)
3458            .project(vec![col("a"), col("b")])?
3459            .build()?;
3460        let plan = LogicalPlanBuilder::from(left)
3461            .join(
3462                right,
3463                JoinType::LeftSemi,
3464                (
3465                    vec![Column::from_qualified_name("test1.a")],
3466                    vec![Column::from_qualified_name("test2.a")],
3467                ),
3468                None,
3469            )?
3470            .filter(col("test2.a").lt_eq(lit(1i64)))?
3471            .build()?;
3472
3473        // not part of the test, just good to know:
3474        assert_snapshot!(plan,
3475        @r"
3476        Filter: test2.a <= Int64(1)
3477          LeftSemi Join: test1.a = test2.a
3478            TableScan: test1
3479            Projection: test2.a, test2.b
3480              TableScan: test2
3481        ",
3482        );
3483        // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side.
3484        assert_optimized_plan_equal!(
3485            plan,
3486            @r"
3487        Filter: test2.a <= Int64(1)
3488          LeftSemi Join: test1.a = test2.a
3489            TableScan: test1, full_filters=[test1.a <= Int64(1)]
3490            Projection: test2.a, test2.b
3491              TableScan: test2
3492        "
3493        )
3494    }
3495
3496    #[test]
3497    fn left_semi_join_with_filters() -> Result<()> {
3498        let left = test_table_scan_with_name("test1")?;
3499        let right_table_scan = test_table_scan_with_name("test2")?;
3500        let right = LogicalPlanBuilder::from(right_table_scan)
3501            .project(vec![col("a"), col("b")])?
3502            .build()?;
3503        let plan = LogicalPlanBuilder::from(left)
3504            .join(
3505                right,
3506                JoinType::LeftSemi,
3507                (
3508                    vec![Column::from_qualified_name("test1.a")],
3509                    vec![Column::from_qualified_name("test2.a")],
3510                ),
3511                Some(
3512                    col("test1.b")
3513                        .gt(lit(1u32))
3514                        .and(col("test2.b").gt(lit(2u32))),
3515                ),
3516            )?
3517            .build()?;
3518
3519        // not part of the test, just good to know:
3520        assert_snapshot!(plan,
3521        @r"
3522        LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3523          TableScan: test1
3524          Projection: test2.a, test2.b
3525            TableScan: test2
3526        ",
3527        );
3528        // Both side will be pushed down.
3529        assert_optimized_plan_equal!(
3530            plan,
3531            @r"
3532        LeftSemi Join: test1.a = test2.a
3533          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3534          Projection: test2.a, test2.b
3535            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3536        "
3537        )
3538    }
3539
3540    #[test]
3541    fn right_semi_join() -> Result<()> {
3542        let left = test_table_scan_with_name("test1")?;
3543        let right_table_scan = test_table_scan_with_name("test2")?;
3544        let right = LogicalPlanBuilder::from(right_table_scan)
3545            .project(vec![col("a"), col("b")])?
3546            .build()?;
3547        let plan = LogicalPlanBuilder::from(left)
3548            .join(
3549                right,
3550                JoinType::RightSemi,
3551                (
3552                    vec![Column::from_qualified_name("test1.a")],
3553                    vec![Column::from_qualified_name("test2.a")],
3554                ),
3555                None,
3556            )?
3557            .filter(col("test1.a").lt_eq(lit(1i64)))?
3558            .build()?;
3559
3560        // not part of the test, just good to know:
3561        assert_snapshot!(plan,
3562        @r"
3563        Filter: test1.a <= Int64(1)
3564          RightSemi Join: test1.a = test2.a
3565            TableScan: test1
3566            Projection: test2.a, test2.b
3567              TableScan: test2
3568        ",
3569        );
3570        // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side.
3571        assert_optimized_plan_equal!(
3572            plan,
3573            @r"
3574        Filter: test1.a <= Int64(1)
3575          RightSemi Join: test1.a = test2.a
3576            TableScan: test1
3577            Projection: test2.a, test2.b
3578              TableScan: test2, full_filters=[test2.a <= Int64(1)]
3579        "
3580        )
3581    }
3582
3583    #[test]
3584    fn right_semi_join_with_filters() -> Result<()> {
3585        let left = test_table_scan_with_name("test1")?;
3586        let right_table_scan = test_table_scan_with_name("test2")?;
3587        let right = LogicalPlanBuilder::from(right_table_scan)
3588            .project(vec![col("a"), col("b")])?
3589            .build()?;
3590        let plan = LogicalPlanBuilder::from(left)
3591            .join(
3592                right,
3593                JoinType::RightSemi,
3594                (
3595                    vec![Column::from_qualified_name("test1.a")],
3596                    vec![Column::from_qualified_name("test2.a")],
3597                ),
3598                Some(
3599                    col("test1.b")
3600                        .gt(lit(1u32))
3601                        .and(col("test2.b").gt(lit(2u32))),
3602                ),
3603            )?
3604            .build()?;
3605
3606        // not part of the test, just good to know:
3607        assert_snapshot!(plan,
3608        @r"
3609        RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3610          TableScan: test1
3611          Projection: test2.a, test2.b
3612            TableScan: test2
3613        ",
3614        );
3615        // Both side will be pushed down.
3616        assert_optimized_plan_equal!(
3617            plan,
3618            @r"
3619        RightSemi Join: test1.a = test2.a
3620          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3621          Projection: test2.a, test2.b
3622            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3623        "
3624        )
3625    }
3626
3627    #[test]
3628    fn left_anti_join() -> Result<()> {
3629        let table_scan = test_table_scan_with_name("test1")?;
3630        let left = LogicalPlanBuilder::from(table_scan)
3631            .project(vec![col("a"), col("b")])?
3632            .build()?;
3633        let right_table_scan = test_table_scan_with_name("test2")?;
3634        let right = LogicalPlanBuilder::from(right_table_scan)
3635            .project(vec![col("a"), col("b")])?
3636            .build()?;
3637        let plan = LogicalPlanBuilder::from(left)
3638            .join(
3639                right,
3640                JoinType::LeftAnti,
3641                (
3642                    vec![Column::from_qualified_name("test1.a")],
3643                    vec![Column::from_qualified_name("test2.a")],
3644                ),
3645                None,
3646            )?
3647            .filter(col("test2.a").gt(lit(2u32)))?
3648            .build()?;
3649
3650        // not part of the test, just good to know:
3651        assert_snapshot!(plan,
3652        @r"
3653        Filter: test2.a > UInt32(2)
3654          LeftAnti Join: test1.a = test2.a
3655            Projection: test1.a, test1.b
3656              TableScan: test1
3657            Projection: test2.a, test2.b
3658              TableScan: test2
3659        ",
3660        );
3661        // For left anti, filter of the right side filter can be pushed down.
3662        assert_optimized_plan_equal!(
3663            plan,
3664            @r"
3665        Filter: test2.a > UInt32(2)
3666          LeftAnti Join: test1.a = test2.a
3667            Projection: test1.a, test1.b
3668              TableScan: test1, full_filters=[test1.a > UInt32(2)]
3669            Projection: test2.a, test2.b
3670              TableScan: test2
3671        "
3672        )
3673    }
3674
3675    #[test]
3676    fn left_anti_join_with_filters() -> Result<()> {
3677        let table_scan = test_table_scan_with_name("test1")?;
3678        let left = LogicalPlanBuilder::from(table_scan)
3679            .project(vec![col("a"), col("b")])?
3680            .build()?;
3681        let right_table_scan = test_table_scan_with_name("test2")?;
3682        let right = LogicalPlanBuilder::from(right_table_scan)
3683            .project(vec![col("a"), col("b")])?
3684            .build()?;
3685        let plan = LogicalPlanBuilder::from(left)
3686            .join(
3687                right,
3688                JoinType::LeftAnti,
3689                (
3690                    vec![Column::from_qualified_name("test1.a")],
3691                    vec![Column::from_qualified_name("test2.a")],
3692                ),
3693                Some(
3694                    col("test1.b")
3695                        .gt(lit(1u32))
3696                        .and(col("test2.b").gt(lit(2u32))),
3697                ),
3698            )?
3699            .build()?;
3700
3701        // not part of the test, just good to know:
3702        assert_snapshot!(plan,
3703        @r"
3704        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3705          Projection: test1.a, test1.b
3706            TableScan: test1
3707          Projection: test2.a, test2.b
3708            TableScan: test2
3709        ",
3710        );
3711        // For left anti, filter of the right side filter can be pushed down.
3712        assert_optimized_plan_equal!(
3713            plan,
3714            @r"
3715        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3716          Projection: test1.a, test1.b
3717            TableScan: test1
3718          Projection: test2.a, test2.b
3719            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3720        "
3721        )
3722    }
3723
3724    #[test]
3725    fn right_anti_join() -> Result<()> {
3726        let table_scan = test_table_scan_with_name("test1")?;
3727        let left = LogicalPlanBuilder::from(table_scan)
3728            .project(vec![col("a"), col("b")])?
3729            .build()?;
3730        let right_table_scan = test_table_scan_with_name("test2")?;
3731        let right = LogicalPlanBuilder::from(right_table_scan)
3732            .project(vec![col("a"), col("b")])?
3733            .build()?;
3734        let plan = LogicalPlanBuilder::from(left)
3735            .join(
3736                right,
3737                JoinType::RightAnti,
3738                (
3739                    vec![Column::from_qualified_name("test1.a")],
3740                    vec![Column::from_qualified_name("test2.a")],
3741                ),
3742                None,
3743            )?
3744            .filter(col("test1.a").gt(lit(2u32)))?
3745            .build()?;
3746
3747        // not part of the test, just good to know:
3748        assert_snapshot!(plan,
3749        @r"
3750        Filter: test1.a > UInt32(2)
3751          RightAnti Join: test1.a = test2.a
3752            Projection: test1.a, test1.b
3753              TableScan: test1
3754            Projection: test2.a, test2.b
3755              TableScan: test2
3756        ",
3757        );
3758        // For right anti, filter of the left side can be pushed down.
3759        assert_optimized_plan_equal!(
3760            plan,
3761            @r"
3762        Filter: test1.a > UInt32(2)
3763          RightAnti Join: test1.a = test2.a
3764            Projection: test1.a, test1.b
3765              TableScan: test1
3766            Projection: test2.a, test2.b
3767              TableScan: test2, full_filters=[test2.a > UInt32(2)]
3768        "
3769        )
3770    }
3771
3772    #[test]
3773    fn right_anti_join_with_filters() -> Result<()> {
3774        let table_scan = test_table_scan_with_name("test1")?;
3775        let left = LogicalPlanBuilder::from(table_scan)
3776            .project(vec![col("a"), col("b")])?
3777            .build()?;
3778        let right_table_scan = test_table_scan_with_name("test2")?;
3779        let right = LogicalPlanBuilder::from(right_table_scan)
3780            .project(vec![col("a"), col("b")])?
3781            .build()?;
3782        let plan = LogicalPlanBuilder::from(left)
3783            .join(
3784                right,
3785                JoinType::RightAnti,
3786                (
3787                    vec![Column::from_qualified_name("test1.a")],
3788                    vec![Column::from_qualified_name("test2.a")],
3789                ),
3790                Some(
3791                    col("test1.b")
3792                        .gt(lit(1u32))
3793                        .and(col("test2.b").gt(lit(2u32))),
3794                ),
3795            )?
3796            .build()?;
3797
3798        // not part of the test, just good to know:
3799        assert_snapshot!(plan,
3800        @r"
3801        RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3802          Projection: test1.a, test1.b
3803            TableScan: test1
3804          Projection: test2.a, test2.b
3805            TableScan: test2
3806        ",
3807        );
3808        // For right anti, filter of the left side can be pushed down.
3809        assert_optimized_plan_equal!(
3810            plan,
3811            @r"
3812        RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3813          Projection: test1.a, test1.b
3814            TableScan: test1, full_filters=[test1.b > UInt32(1)]
3815          Projection: test2.a, test2.b
3816            TableScan: test2
3817        "
3818        )
3819    }
3820
3821    #[derive(Debug)]
3822    struct TestScalarUDF {
3823        signature: Signature,
3824    }
3825
3826    impl ScalarUDFImpl for TestScalarUDF {
3827        fn as_any(&self) -> &dyn Any {
3828            self
3829        }
3830        fn name(&self) -> &str {
3831            "TestScalarUDF"
3832        }
3833
3834        fn signature(&self) -> &Signature {
3835            &self.signature
3836        }
3837
3838        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3839            Ok(DataType::Int32)
3840        }
3841
3842        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3843            Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3844        }
3845    }
3846
3847    #[test]
3848    fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3849        // SELECT t.a, t.r FROM (SELECT a, sum(b),  TestScalarUDF()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5;
3850        let table_scan = test_table_scan_with_name("test1")?;
3851        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3852            signature: Signature::exact(vec![], Volatility::Volatile),
3853        });
3854        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3855
3856        let plan = LogicalPlanBuilder::from(table_scan)
3857            .aggregate(vec![col("a")], vec![sum(col("b"))])?
3858            .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3859            .alias("t")?
3860            .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3861            .project(vec![col("t.a"), col("t.r")])?
3862            .build()?;
3863
3864        assert_snapshot!(plan,
3865        @r"
3866        Projection: t.a, t.r
3867          Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3868            SubqueryAlias: t
3869              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3870                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3871                  TableScan: test1
3872        ",
3873        );
3874        assert_optimized_plan_equal!(
3875            plan,
3876            @r"
3877        Projection: t.a, t.r
3878          SubqueryAlias: t
3879            Filter: r > Float64(0.5)
3880              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3881                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3882                  TableScan: test1, full_filters=[test1.a > Int32(5)]
3883        "
3884        )
3885    }
3886
3887    #[test]
3888    fn test_push_down_volatile_function_in_join() -> Result<()> {
3889        // SELECT t.a, t.r FROM (SELECT test1.a AS a, TestScalarUDF() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5;
3890        let table_scan = test_table_scan_with_name("test1")?;
3891        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3892            signature: Signature::exact(vec![], Volatility::Volatile),
3893        });
3894        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3895        let left = LogicalPlanBuilder::from(table_scan).build()?;
3896        let right_table_scan = test_table_scan_with_name("test2")?;
3897        let right = LogicalPlanBuilder::from(right_table_scan).build()?;
3898        let plan = LogicalPlanBuilder::from(left)
3899            .join(
3900                right,
3901                JoinType::Inner,
3902                (
3903                    vec![Column::from_qualified_name("test1.a")],
3904                    vec![Column::from_qualified_name("test2.a")],
3905                ),
3906                None,
3907            )?
3908            .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
3909            .alias("t")?
3910            .filter(col("t.r").gt(lit(0.8)))?
3911            .project(vec![col("t.a"), col("t.r")])?
3912            .build()?;
3913
3914        assert_snapshot!(plan,
3915        @r"
3916        Projection: t.a, t.r
3917          Filter: t.r > Float64(0.8)
3918            SubqueryAlias: t
3919              Projection: test1.a AS a, TestScalarUDF() AS r
3920                Inner Join: test1.a = test2.a
3921                  TableScan: test1
3922                  TableScan: test2
3923        ",
3924        );
3925        assert_optimized_plan_equal!(
3926            plan,
3927            @r"
3928        Projection: t.a, t.r
3929          SubqueryAlias: t
3930            Filter: r > Float64(0.8)
3931              Projection: test1.a AS a, TestScalarUDF() AS r
3932                Inner Join: test1.a = test2.a
3933                  TableScan: test1
3934                  TableScan: test2
3935        "
3936        )
3937    }
3938
3939    #[test]
3940    fn test_push_down_volatile_table_scan() -> Result<()> {
3941        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1;
3942        let table_scan = test_table_scan()?;
3943        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3944            signature: Signature::exact(vec![], Volatility::Volatile),
3945        });
3946        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3947        let plan = LogicalPlanBuilder::from(table_scan)
3948            .project(vec![col("a"), col("b")])?
3949            .filter(expr.gt(lit(0.1)))?
3950            .build()?;
3951
3952        assert_snapshot!(plan,
3953        @r"
3954        Filter: TestScalarUDF() > Float64(0.1)
3955          Projection: test.a, test.b
3956            TableScan: test
3957        ",
3958        );
3959        assert_optimized_plan_equal!(
3960            plan,
3961            @r"
3962        Projection: test.a, test.b
3963          Filter: TestScalarUDF() > Float64(0.1)
3964            TableScan: test
3965        "
3966        )
3967    }
3968
3969    #[test]
3970    fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
3971        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
3972        let table_scan = test_table_scan()?;
3973        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3974            signature: Signature::exact(vec![], Volatility::Volatile),
3975        });
3976        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3977        let plan = LogicalPlanBuilder::from(table_scan)
3978            .project(vec![col("a"), col("b")])?
3979            .filter(
3980                expr.gt(lit(0.1))
3981                    .and(col("t.a").gt(lit(5)))
3982                    .and(col("t.b").gt(lit(10))),
3983            )?
3984            .build()?;
3985
3986        assert_snapshot!(plan,
3987        @r"
3988        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
3989          Projection: test.a, test.b
3990            TableScan: test
3991        ",
3992        );
3993        assert_optimized_plan_equal!(
3994            plan,
3995            @r"
3996        Projection: test.a, test.b
3997          Filter: TestScalarUDF() > Float64(0.1)
3998            TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
3999        "
4000        )
4001    }
4002
4003    #[test]
4004    fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4005        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4006        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4007            signature: Signature::exact(vec![], Volatility::Volatile),
4008        });
4009        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4010        let plan = table_scan_with_pushdown_provider_builder(
4011            TableProviderFilterPushDown::Unsupported,
4012            vec![],
4013            None,
4014        )?
4015        .project(vec![col("a"), col("b")])?
4016        .filter(
4017            expr.gt(lit(0.1))
4018                .and(col("t.a").gt(lit(5)))
4019                .and(col("t.b").gt(lit(10))),
4020        )?
4021        .build()?;
4022
4023        assert_snapshot!(plan,
4024        @r"
4025        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4026          Projection: a, b
4027            TableScan: test
4028        ",
4029        );
4030        assert_optimized_plan_equal!(
4031            plan,
4032            @r"
4033        Projection: a, b
4034          Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4035            TableScan: test
4036        "
4037        )
4038    }
4039
4040    #[test]
4041    fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4042        // Define a custom user-defined logical node
4043        #[derive(Debug, Hash, Eq, PartialEq)]
4044        struct TestUserNode {
4045            schema: DFSchemaRef,
4046        }
4047
4048        impl PartialOrd for TestUserNode {
4049            fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4050                None
4051            }
4052        }
4053
4054        impl TestUserNode {
4055            fn new() -> Self {
4056                let schema = Arc::new(
4057                    DFSchema::new_with_metadata(
4058                        vec![(None, Field::new("a", DataType::Int64, false).into())],
4059                        Default::default(),
4060                    )
4061                    .unwrap(),
4062                );
4063
4064                Self { schema }
4065            }
4066        }
4067
4068        impl UserDefinedLogicalNodeCore for TestUserNode {
4069            fn name(&self) -> &str {
4070                "test_node"
4071            }
4072
4073            fn inputs(&self) -> Vec<&LogicalPlan> {
4074                vec![]
4075            }
4076
4077            fn schema(&self) -> &DFSchemaRef {
4078                &self.schema
4079            }
4080
4081            fn expressions(&self) -> Vec<Expr> {
4082                vec![]
4083            }
4084
4085            fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4086                write!(f, "TestUserNode")
4087            }
4088
4089            fn with_exprs_and_inputs(
4090                &self,
4091                exprs: Vec<Expr>,
4092                inputs: Vec<LogicalPlan>,
4093            ) -> Result<Self> {
4094                assert!(exprs.is_empty());
4095                assert!(inputs.is_empty());
4096                Ok(Self {
4097                    schema: Arc::clone(&self.schema),
4098                })
4099            }
4100        }
4101
4102        // Create a node and build a plan with a filter
4103        let node = LogicalPlan::Extension(Extension {
4104            node: Arc::new(TestUserNode::new()),
4105        });
4106
4107        let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4108
4109        // Check the original plan format (not part of the test assertions)
4110        assert_snapshot!(plan,
4111        @r"
4112        Filter: Boolean(false)
4113          TestUserNode
4114        ",
4115        );
4116        // Check that the filter is pushed down to the user-defined node
4117        assert_optimized_plan_equal!(
4118            plan,
4119            @r"
4120        Filter: Boolean(false)
4121          TestUserNode
4122        "
4123        )
4124    }
4125}