Skip to main content

datafusion_optimizer/
push_down_filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`PushDownFilter`] applies filters as early as possible
19
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26
27use datafusion_common::tree_node::{
28    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29};
30use datafusion_common::{
31    Column, DFSchema, Result, assert_eq_or_internal_err, assert_or_internal_err,
32    internal_err, plan_err, qualified_name,
33};
34use datafusion_expr::expr::WindowFunction;
35use datafusion_expr::expr_rewriter::replace_col;
36use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
37use datafusion_expr::utils::{
38    conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
39};
40use datafusion_expr::{
41    BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, and, or,
42};
43
44use crate::optimizer::ApplyOrder;
45use crate::simplify_expressions::simplify_predicates;
46use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
47use crate::{OptimizerConfig, OptimizerRule};
48use datafusion_expr::ExpressionPlacement;
49
50/// Optimizer rule for pushing (moving) filter expressions down in a plan so
51/// they are applied as early as possible.
52///
53/// # Introduction
54///
55/// The goal of this rule is to improve query performance by eliminating
56/// redundant work.
57///
58/// For example, given a plan that sorts all values where `a > 10`:
59///
60/// ```text
61///  Filter (a > 10)
62///    Sort (a, b)
63/// ```
64///
65/// A better plan is to  filter the data *before* the Sort, which sorts fewer
66/// rows and therefore does less work overall:
67///
68/// ```text
69///  Sort (a, b)
70///    Filter (a > 10)  <-- Filter is moved before the sort
71/// ```
72///
73/// However it is not always possible to push filters down. For example, given a
74/// plan that finds the top 3 values and then keeps only those that are greater
75/// than 10, if the filter is pushed below the limit it would produce a
76/// different result.
77///
78/// ```text
79///  Filter (a > 10)   <-- can not move this Filter before the limit
80///    Limit (fetch=3)
81///      Sort (a, b)
82/// ```
83///
84///
85/// More formally, a filter-commutative operation is an operation `op` that
86/// satisfies `filter(op(data)) = op(filter(data))`.
87///
88/// The filter-commutative property is plan and column-specific. A filter on `a`
89/// can be pushed through a `Aggregate(group_by = [a], agg=[sum(b))`. However, a
90/// filter on  `sum(b)` can not be pushed through the same aggregate.
91///
92/// # Handling Conjunctions
93///
94/// It is possible to only push down **part** of a filter expression if is
95/// connected with `AND`s (more formally if it is a "conjunction").
96///
97/// For example, given the following plan:
98///
99/// ```text
100/// Filter(a > 10 AND sum(b) < 5)
101///   Aggregate(group_by = [a], agg = [sum(b))
102/// ```
103///
104/// The `a > 10` is commutative with the `Aggregate` but  `sum(b) < 5` is not.
105/// Therefore it is possible to only push part of the expression, resulting in:
106///
107/// ```text
108/// Filter(sum(b) < 5)
109///   Aggregate(group_by = [a], agg = [sum(b))
110///     Filter(a > 10)
111/// ```
112///
113/// # Handling Column Aliases
114///
115/// This optimizer must sometimes handle re-writing filter expressions when they
116/// pushed, for example if there is a projection that aliases `a+1` to `"b"`:
117///
118/// ```text
119/// Filter (b > 10)
120///     Projection: [a+1 AS "b"]  <-- changes the name of `a+1` to `b`
121/// ```
122///
123/// To apply the filter prior to the `Projection`, all references to `b` must be
124/// rewritten to `a+1`:
125///
126/// ```text
127/// Projection: a AS "b"
128///     Filter: (a + 1 > 10)  <--- changed from b to a + 1
129/// ```
130/// # Implementation Notes
131///
132/// This implementation performs a single pass through the plan, "pushing" down
133/// filters. When it passes through a filter, it stores that filter, and when it
134/// reaches a plan node that does not commute with that filter, it adds the
135/// filter to that place. When it passes through a projection, it re-writes the
136/// filter's expression taking into account that projection.
137#[derive(Default, Debug)]
138pub struct PushDownFilter {}
139
140/// For a given JOIN type, determine whether each input of the join is preserved
141/// for post-join (`WHERE` clause) filters.
142///
143/// It is only correct to push filters below a join for preserved inputs.
144///
145/// # Return Value
146/// A tuple of booleans - (left_preserved, right_preserved).
147///
148/// # "Preserved" input definition
149///
150/// We say a join side is preserved if the join returns all or a subset of the rows from
151/// the relevant side, such that each row of the output table directly maps to a row of
152/// the preserved input table. If a table is not preserved, it can provide extra null rows.
153/// That is, there may be rows in the output table that don't directly map to a row in the
154/// input table.
155///
156/// For example:
157///   - In an inner join, both sides are preserved, because each row of the output
158///     maps directly to a row from each side.
159///
160///   - In a left join, the left side is preserved (we can push predicates) but
161///     the right is not, because there may be rows in the output that don't
162///     directly map to a row in the right input (due to nulls filling where there
163///     is no match on the right).
164pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
165    match join_type {
166        JoinType::Inner => (true, true),
167        JoinType::Left => (true, false),
168        JoinType::Right => (false, true),
169        JoinType::Full => (false, false),
170        // No columns from the right side of the join can be referenced in output
171        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
172        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
173        // No columns from the left side of the join can be referenced in output
174        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
175        JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
176    }
177}
178
179/// For a given JOIN type, determine whether each input of the join is preserved
180/// for the join condition (`ON` clause filters).
181///
182/// It is only correct to push filters below a join for preserved inputs.
183///
184/// # Return Value
185/// A tuple of booleans - (left_preserved, right_preserved).
186///
187/// See [`lr_is_preserved`] for a definition of "preserved".
188pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
189    match join_type {
190        JoinType::Inner => (true, true),
191        JoinType::Left => (false, true),
192        JoinType::Right => (true, false),
193        JoinType::Full => (false, false),
194        JoinType::LeftSemi | JoinType::RightSemi => (true, true),
195        JoinType::LeftAnti => (false, true),
196        JoinType::RightAnti => (true, false),
197        JoinType::LeftMark => (false, true),
198        JoinType::RightMark => (true, false),
199    }
200}
201
202/// Evaluates the columns referenced in the given expression to see if they refer
203/// only to the left or right columns
204#[derive(Debug)]
205struct ColumnChecker<'a> {
206    /// schema of left join input
207    left_schema: &'a DFSchema,
208    /// columns in left_schema, computed on demand
209    left_columns: Option<HashSet<Column>>,
210    /// schema of right join input
211    right_schema: &'a DFSchema,
212    /// columns in left_schema, computed on demand
213    right_columns: Option<HashSet<Column>>,
214}
215
216impl<'a> ColumnChecker<'a> {
217    fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
218        Self {
219            left_schema,
220            left_columns: None,
221            right_schema,
222            right_columns: None,
223        }
224    }
225
226    /// Return true if the expression references only columns from the left side of the join
227    fn is_left_only(&mut self, predicate: &Expr) -> bool {
228        if self.left_columns.is_none() {
229            self.left_columns = Some(schema_columns(self.left_schema));
230        }
231        has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
232    }
233
234    /// Return true if the expression references only columns from the right side of the join
235    fn is_right_only(&mut self, predicate: &Expr) -> bool {
236        if self.right_columns.is_none() {
237            self.right_columns = Some(schema_columns(self.right_schema));
238        }
239        has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
240    }
241}
242
243/// Returns all columns in the schema
244fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
245    schema
246        .iter()
247        .flat_map(|(qualifier, field)| {
248            [
249                Column::new(qualifier.cloned(), field.name()),
250                // we need to push down filter using unqualified column as well
251                Column::new_unqualified(field.name()),
252            ]
253        })
254        .collect::<HashSet<_>>()
255}
256
257/// Determine whether the predicate can evaluate as the join conditions
258fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
259    let mut is_evaluate = true;
260    predicate.apply(|expr| match expr {
261        Expr::Column(_)
262        | Expr::Literal(_, _)
263        | Expr::Placeholder(_)
264        | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
265        Expr::Exists { .. }
266        | Expr::InSubquery(_)
267        | Expr::SetComparison(_)
268        | Expr::ScalarSubquery(_)
269        | Expr::OuterReferenceColumn(_, _)
270        | Expr::Unnest(_) => {
271            is_evaluate = false;
272            Ok(TreeNodeRecursion::Stop)
273        }
274        Expr::Alias(_)
275        | Expr::BinaryExpr(_)
276        | Expr::Like(_)
277        | Expr::SimilarTo(_)
278        | Expr::Not(_)
279        | Expr::IsNotNull(_)
280        | Expr::IsNull(_)
281        | Expr::IsTrue(_)
282        | Expr::IsFalse(_)
283        | Expr::IsUnknown(_)
284        | Expr::IsNotTrue(_)
285        | Expr::IsNotFalse(_)
286        | Expr::IsNotUnknown(_)
287        | Expr::Negative(_)
288        | Expr::Between(_)
289        | Expr::Case(_)
290        | Expr::Cast(_)
291        | Expr::TryCast(_)
292        | Expr::InList { .. }
293        | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
294        // TODO: remove the next line after `Expr::Wildcard` is removed
295        #[expect(deprecated)]
296        Expr::AggregateFunction(_)
297        | Expr::WindowFunction(_)
298        | Expr::Wildcard { .. }
299        | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
300    })?;
301    Ok(is_evaluate)
302}
303
304/// examine OR clause to see if any useful clauses can be extracted and push down.
305/// extract at least one qual from each sub clauses of OR clause, then form the quals
306/// to new OR clause as predicate.
307///
308/// # Example
309/// ```text
310/// Filter: (a = c and a < 20) or (b = d and b > 10)
311///     join/crossjoin:
312///          TableScan: projection=[a, b]
313///          TableScan: projection=[c, d]
314/// ```
315///
316/// is optimized to
317///
318/// ```text
319/// Filter: (a = c and a < 20) or (b = d and b > 10)
320///     join/crossjoin:
321///          Filter: (a < 20) or (b > 10)
322///              TableScan: projection=[a, b]
323///          TableScan: projection=[c, d]
324/// ```
325///
326/// In general, predicates of this form:
327///
328/// ```sql
329/// (A AND B) OR (C AND D)
330/// ```
331///
332/// will be transformed to one of:
333///
334/// * `((A AND B) OR (C AND D)) AND (A OR C)`
335/// * `((A AND B) OR (C AND D)) AND ((A AND B) OR C)`
336/// * do nothing.
337fn extract_or_clauses_for_join<'a>(
338    filters: &'a [Expr],
339    schema: &'a DFSchema,
340) -> impl Iterator<Item = Expr> + 'a {
341    let schema_columns = schema_columns(schema);
342
343    // new formed OR clauses and their column references
344    filters.iter().filter_map(move |expr| {
345        if let Expr::BinaryExpr(BinaryExpr {
346            left,
347            op: Operator::Or,
348            right,
349        }) = expr
350        {
351            let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
352            let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
353
354            // If nothing can be extracted from any sub clauses, do nothing for this OR clause.
355            if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
356                return Some(or(left_expr, right_expr));
357            }
358        }
359        None
360    })
361}
362
363/// extract qual from OR sub-clause.
364///
365/// A qual is extracted if it only contains set of column references in schema_columns.
366///
367/// For AND clause, we extract from both sub-clauses, then make new AND clause by extracted
368/// clauses if both extracted; Otherwise, use the extracted clause from any sub-clauses or None.
369///
370/// For OR clause, we extract from both sub-clauses, then make new OR clause by extracted clauses if both extracted;
371/// Otherwise, return None.
372///
373/// For other clause, apply the rule above to extract clause.
374fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
375    let mut predicate = None;
376
377    match expr {
378        Expr::BinaryExpr(BinaryExpr {
379            left: l_expr,
380            op: Operator::Or,
381            right: r_expr,
382        }) => {
383            let l_expr = extract_or_clause(l_expr, schema_columns);
384            let r_expr = extract_or_clause(r_expr, schema_columns);
385
386            if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
387                predicate = Some(or(l_expr, r_expr));
388            }
389        }
390        Expr::BinaryExpr(BinaryExpr {
391            left: l_expr,
392            op: Operator::And,
393            right: r_expr,
394        }) => {
395            let l_expr = extract_or_clause(l_expr, schema_columns);
396            let r_expr = extract_or_clause(r_expr, schema_columns);
397
398            match (l_expr, r_expr) {
399                (Some(l_expr), Some(r_expr)) => {
400                    predicate = Some(and(l_expr, r_expr));
401                }
402                (Some(l_expr), None) => {
403                    predicate = Some(l_expr);
404                }
405                (None, Some(r_expr)) => {
406                    predicate = Some(r_expr);
407                }
408                (None, None) => {
409                    predicate = None;
410                }
411            }
412        }
413        _ => {
414            if has_all_column_refs(expr, schema_columns) {
415                predicate = Some(expr.clone());
416            }
417        }
418    }
419
420    predicate
421}
422
423/// push down join/cross-join
424fn push_down_all_join(
425    predicates: Vec<Expr>,
426    inferred_join_predicates: Vec<Expr>,
427    mut join: Join,
428    on_filter: Vec<Expr>,
429) -> Result<Transformed<LogicalPlan>> {
430    let is_inner_join = join.join_type == JoinType::Inner;
431    // Get pushable predicates from current optimizer state
432    let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
433
434    // The predicates can be divided to three categories:
435    // 1) can push through join to its children(left or right)
436    // 2) can be converted to join conditions if the join type is Inner
437    // 3) should be kept as filter conditions
438    let left_schema = join.left.schema();
439    let right_schema = join.right.schema();
440    let mut left_push = vec![];
441    let mut right_push = vec![];
442    let mut keep_predicates = vec![];
443    let mut join_conditions = vec![];
444    let mut checker = ColumnChecker::new(left_schema, right_schema);
445    for predicate in predicates {
446        if left_preserved && checker.is_left_only(&predicate) {
447            left_push.push(predicate);
448        } else if right_preserved && checker.is_right_only(&predicate) {
449            right_push.push(predicate);
450        } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
451            // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
452            // and convert to the join on condition
453            join_conditions.push(predicate);
454        } else {
455            keep_predicates.push(predicate);
456        }
457    }
458
459    // Push predicates inferred from the join expression
460    for predicate in inferred_join_predicates {
461        if checker.is_left_only(&predicate) {
462            left_push.push(predicate);
463        } else if checker.is_right_only(&predicate) {
464            right_push.push(predicate);
465        }
466    }
467
468    let mut on_filter_join_conditions = vec![];
469    let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
470
471    if !on_filter.is_empty() {
472        for on in on_filter {
473            if on_left_preserved && checker.is_left_only(&on) {
474                left_push.push(on)
475            } else if on_right_preserved && checker.is_right_only(&on) {
476                right_push.push(on)
477            } else {
478                on_filter_join_conditions.push(on)
479            }
480        }
481    }
482
483    // Extract from OR clause, generate new predicates for both side of join if possible.
484    // We only track the unpushable predicates above.
485    if left_preserved {
486        left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
487        left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
488    }
489    if right_preserved {
490        right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
491        right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
492    }
493
494    // For predicates from join filter, we should check with if a join side is preserved
495    // in term of join filtering.
496    if on_left_preserved {
497        left_push.extend(extract_or_clauses_for_join(
498            &on_filter_join_conditions,
499            left_schema,
500        ));
501    }
502    if on_right_preserved {
503        right_push.extend(extract_or_clauses_for_join(
504            &on_filter_join_conditions,
505            right_schema,
506        ));
507    }
508
509    if let Some(predicate) = conjunction(left_push) {
510        join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
511    }
512    if let Some(predicate) = conjunction(right_push) {
513        join.right =
514            Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
515    }
516
517    // Add any new join conditions as the non join predicates
518    join_conditions.extend(on_filter_join_conditions);
519    join.filter = conjunction(join_conditions);
520
521    // wrap the join on the filter whose predicates must be kept, if any
522    let plan = LogicalPlan::Join(join);
523    let plan = if let Some(predicate) = conjunction(keep_predicates) {
524        LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
525    } else {
526        plan
527    };
528    Ok(Transformed::yes(plan))
529}
530
531fn push_down_join(
532    join: Join,
533    parent_predicate: Option<&Expr>,
534) -> Result<Transformed<LogicalPlan>> {
535    // Split the parent predicate into individual conjunctive parts.
536    let predicates = parent_predicate
537        .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
538
539    // Extract conjunctions from the JOIN's ON filter, if present.
540    let on_filters = join
541        .filter
542        .as_ref()
543        .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
544
545    // Are there any new join predicates that can be inferred from the filter expressions?
546    let inferred_join_predicates =
547        infer_join_predicates(&join, &predicates, &on_filters)?;
548
549    if on_filters.is_empty()
550        && predicates.is_empty()
551        && inferred_join_predicates.is_empty()
552    {
553        return Ok(Transformed::no(LogicalPlan::Join(join)));
554    }
555
556    push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
557}
558
559/// Extracts any equi-join join predicates from the given filter expressions.
560///
561/// Parameters
562/// * `join` the join in question
563///
564/// * `predicates` the pushed down filter expression
565///
566/// * `on_filters` filters from the join ON clause that have not already been
567///   identified as join predicates
568fn infer_join_predicates(
569    join: &Join,
570    predicates: &[Expr],
571    on_filters: &[Expr],
572) -> Result<Vec<Expr>> {
573    // Only allow both side key is column.
574    let join_col_keys = join
575        .on
576        .iter()
577        .filter_map(|(l, r)| {
578            let left_col = l.try_as_col()?;
579            let right_col = r.try_as_col()?;
580            Some((left_col, right_col))
581        })
582        .collect::<Vec<_>>();
583
584    let join_type = join.join_type;
585
586    let mut inferred_predicates = InferredPredicates::new(join_type);
587
588    infer_join_predicates_from_predicates(
589        &join_col_keys,
590        predicates,
591        &mut inferred_predicates,
592    )?;
593
594    infer_join_predicates_from_on_filters(
595        &join_col_keys,
596        join_type,
597        on_filters,
598        &mut inferred_predicates,
599    )?;
600
601    Ok(inferred_predicates.predicates)
602}
603
604/// Inferred predicates collector.
605/// When the JoinType is not Inner, we need to detect whether the inferred predicate can strictly
606/// filter out NULL, otherwise ignore it. e.g.
607/// ```text
608/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL;
609/// ```
610/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate will be pushed down to
611/// the left side, resulting in the wrong result.
612struct InferredPredicates {
613    predicates: Vec<Expr>,
614    is_inner_join: bool,
615}
616
617impl InferredPredicates {
618    fn new(join_type: JoinType) -> Self {
619        Self {
620            predicates: vec![],
621            is_inner_join: join_type == JoinType::Inner,
622        }
623    }
624
625    fn try_build_predicate(
626        &mut self,
627        predicate: Expr,
628        replace_map: &HashMap<&Column, &Column>,
629    ) -> Result<()> {
630        if self.is_inner_join
631            || matches!(
632                is_restrict_null_predicate(
633                    predicate.clone(),
634                    replace_map.keys().cloned()
635                ),
636                Ok(true)
637            )
638        {
639            self.predicates.push(replace_col(predicate, replace_map)?);
640        }
641
642        Ok(())
643    }
644}
645
646/// Infer predicates from the pushed down predicates.
647///
648/// Parameters
649/// * `join_col_keys` column pairs from the join ON clause
650///
651/// * `predicates` the pushed down predicates
652///
653/// * `inferred_predicates` the inferred results
654fn infer_join_predicates_from_predicates(
655    join_col_keys: &[(&Column, &Column)],
656    predicates: &[Expr],
657    inferred_predicates: &mut InferredPredicates,
658) -> Result<()> {
659    infer_join_predicates_impl::<true, true>(
660        join_col_keys,
661        predicates,
662        inferred_predicates,
663    )
664}
665
666/// Infer predicates from the join filter.
667///
668/// Parameters
669/// * `join_col_keys` column pairs from the join ON clause
670///
671/// * `join_type` the JoinType of Join
672///
673/// * `on_filters` filters from the join ON clause that have not already been
674///   identified as join predicates
675///
676/// * `inferred_predicates` the inferred results
677fn infer_join_predicates_from_on_filters(
678    join_col_keys: &[(&Column, &Column)],
679    join_type: JoinType,
680    on_filters: &[Expr],
681    inferred_predicates: &mut InferredPredicates,
682) -> Result<()> {
683    match join_type {
684        JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
685        JoinType::Inner => infer_join_predicates_impl::<true, true>(
686            join_col_keys,
687            on_filters,
688            inferred_predicates,
689        ),
690        JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
691            infer_join_predicates_impl::<true, false>(
692                join_col_keys,
693                on_filters,
694                inferred_predicates,
695            )
696        }
697        JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
698            infer_join_predicates_impl::<false, true>(
699                join_col_keys,
700                on_filters,
701                inferred_predicates,
702            )
703        }
704    }
705}
706
707/// Infer predicates from the given predicates.
708///
709/// Parameters
710/// * `join_col_keys` column pairs from the join ON clause
711///
712/// * `input_predicates` the given predicates. It can be the pushed down predicates,
713///   or it can be the filters of the Join
714///
715/// * `inferred_predicates` the inferred results
716///
717/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate can
718///   be inferred from the left table related predicate
719///
720/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate can
721///   be inferred from the right table related predicate
722fn infer_join_predicates_impl<
723    const ENABLE_LEFT_TO_RIGHT: bool,
724    const ENABLE_RIGHT_TO_LEFT: bool,
725>(
726    join_col_keys: &[(&Column, &Column)],
727    input_predicates: &[Expr],
728    inferred_predicates: &mut InferredPredicates,
729) -> Result<()> {
730    for predicate in input_predicates {
731        let mut join_cols_to_replace = HashMap::new();
732
733        for &col in &predicate.column_refs() {
734            for (l, r) in join_col_keys.iter() {
735                if ENABLE_LEFT_TO_RIGHT && col == *l {
736                    join_cols_to_replace.insert(col, *r);
737                    break;
738                }
739                if ENABLE_RIGHT_TO_LEFT && col == *r {
740                    join_cols_to_replace.insert(col, *l);
741                    break;
742                }
743            }
744        }
745        if join_cols_to_replace.is_empty() {
746            continue;
747        }
748
749        inferred_predicates
750            .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
751    }
752    Ok(())
753}
754
755impl OptimizerRule for PushDownFilter {
756    fn name(&self) -> &str {
757        "push_down_filter"
758    }
759
760    fn apply_order(&self) -> Option<ApplyOrder> {
761        Some(ApplyOrder::TopDown)
762    }
763
764    fn supports_rewrite(&self) -> bool {
765        true
766    }
767
768    fn rewrite(
769        &self,
770        plan: LogicalPlan,
771        config: &dyn OptimizerConfig,
772    ) -> Result<Transformed<LogicalPlan>> {
773        let _ = config.options();
774        if let LogicalPlan::Join(join) = plan {
775            return push_down_join(join, None);
776        };
777
778        let plan_schema = Arc::clone(plan.schema());
779
780        let LogicalPlan::Filter(mut filter) = plan else {
781            return Ok(Transformed::no(plan));
782        };
783
784        let predicate = split_conjunction_owned(filter.predicate.clone());
785        let old_predicate_len = predicate.len();
786        let new_predicates = simplify_predicates(predicate)?;
787        if old_predicate_len != new_predicates.len() {
788            let Some(new_predicate) = conjunction(new_predicates) else {
789                // new_predicates is empty - remove the filter entirely
790                // Return the child plan without the filter
791                return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
792            };
793            filter.predicate = new_predicate;
794        }
795
796        match Arc::unwrap_or_clone(filter.input) {
797            LogicalPlan::Filter(child_filter) => {
798                let parents_predicates = split_conjunction_owned(filter.predicate);
799
800                // remove duplicated filters
801                let child_predicates = split_conjunction_owned(child_filter.predicate);
802                let new_predicates = parents_predicates
803                    .into_iter()
804                    .chain(child_predicates)
805                    // use IndexSet to remove dupes while preserving predicate order
806                    .collect::<IndexSet<_>>()
807                    .into_iter()
808                    .collect::<Vec<_>>();
809
810                let Some(new_predicate) = conjunction(new_predicates) else {
811                    return plan_err!("at least one expression exists");
812                };
813                let new_filter = LogicalPlan::Filter(Filter::try_new(
814                    new_predicate,
815                    child_filter.input,
816                )?);
817                self.rewrite(new_filter, config)
818            }
819            LogicalPlan::Repartition(repartition) => {
820                let new_filter =
821                    Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
822                        .map(LogicalPlan::Filter)?;
823                insert_below(LogicalPlan::Repartition(repartition), new_filter)
824            }
825            LogicalPlan::Distinct(distinct) => {
826                let new_filter =
827                    Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
828                        .map(LogicalPlan::Filter)?;
829                insert_below(LogicalPlan::Distinct(distinct), new_filter)
830            }
831            LogicalPlan::Sort(sort) => {
832                let new_filter =
833                    Filter::try_new(filter.predicate, Arc::clone(&sort.input))
834                        .map(LogicalPlan::Filter)?;
835                insert_below(LogicalPlan::Sort(sort), new_filter)
836            }
837            LogicalPlan::SubqueryAlias(subquery_alias) => {
838                let mut replace_map = HashMap::new();
839                for (i, (qualifier, field)) in
840                    subquery_alias.input.schema().iter().enumerate()
841                {
842                    let (sub_qualifier, sub_field) =
843                        subquery_alias.schema.qualified_field(i);
844                    replace_map.insert(
845                        qualified_name(sub_qualifier, sub_field.name()),
846                        Expr::Column(Column::new(qualifier.cloned(), field.name())),
847                    );
848                }
849                let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
850
851                let new_filter = LogicalPlan::Filter(Filter::try_new(
852                    new_predicate,
853                    Arc::clone(&subquery_alias.input),
854                )?);
855                insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
856            }
857            LogicalPlan::Projection(projection) => {
858                let predicates = split_conjunction_owned(filter.predicate.clone());
859                let (new_projection, keep_predicate) =
860                    rewrite_projection(predicates, projection)?;
861                if new_projection.transformed {
862                    match keep_predicate {
863                        None => Ok(new_projection),
864                        Some(keep_predicate) => new_projection.map_data(|child_plan| {
865                            Filter::try_new(keep_predicate, Arc::new(child_plan))
866                                .map(LogicalPlan::Filter)
867                        }),
868                    }
869                } else {
870                    filter.input = Arc::new(new_projection.data);
871                    Ok(Transformed::no(LogicalPlan::Filter(filter)))
872                }
873            }
874            LogicalPlan::Unnest(mut unnest) => {
875                let predicates = split_conjunction_owned(filter.predicate.clone());
876                let mut non_unnest_predicates = vec![];
877                let mut unnest_predicates = vec![];
878                let mut unnest_struct_columns = vec![];
879
880                for idx in &unnest.struct_type_columns {
881                    let (sub_qualifier, field) =
882                        unnest.input.schema().qualified_field(*idx);
883                    let field_name = field.name().clone();
884
885                    if let DataType::Struct(children) = field.data_type() {
886                        for child in children {
887                            let child_name = child.name().clone();
888                            unnest_struct_columns.push(Column::new(
889                                sub_qualifier.cloned(),
890                                format!("{field_name}.{child_name}"),
891                            ));
892                        }
893                    }
894                }
895
896                for predicate in predicates {
897                    // collect all the Expr::Column in predicate recursively
898                    let mut accum: HashSet<Column> = HashSet::new();
899                    expr_to_columns(&predicate, &mut accum)?;
900
901                    let contains_list_columns =
902                        unnest.list_type_columns.iter().any(|(_, unnest_list)| {
903                            accum.contains(&unnest_list.output_column)
904                        });
905                    let contains_struct_columns =
906                        unnest_struct_columns.iter().any(|c| accum.contains(c));
907
908                    if contains_list_columns || contains_struct_columns {
909                        unnest_predicates.push(predicate);
910                    } else {
911                        non_unnest_predicates.push(predicate);
912                    }
913                }
914
915                // Unnest predicates should not be pushed down.
916                // If no non-unnest predicates exist, early return
917                if non_unnest_predicates.is_empty() {
918                    filter.input = Arc::new(LogicalPlan::Unnest(unnest));
919                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
920                }
921
922                // Push down non-unnest filter predicate
923                // Unnest
924                //   Unnest Input (Projection)
925                // -> rewritten to
926                // Unnest
927                //   Filter
928                //     Unnest Input (Projection)
929
930                let unnest_input = std::mem::take(&mut unnest.input);
931
932                let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
933                    conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty.
934                    unnest_input,
935                )?);
936
937                // Directly assign new filter plan as the new unnest's input.
938                // The new filter plan will go through another rewrite pass since the rule itself
939                // is applied recursively to all the child from top to down
940                let unnest_plan =
941                    insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
942
943                match conjunction(unnest_predicates) {
944                    None => Ok(unnest_plan),
945                    Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
946                        Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
947                    ))),
948                }
949            }
950            LogicalPlan::Union(ref union) => {
951                let mut inputs = Vec::with_capacity(union.inputs.len());
952                for input in &union.inputs {
953                    let mut replace_map = HashMap::new();
954                    for (i, (qualifier, field)) in input.schema().iter().enumerate() {
955                        let (union_qualifier, union_field) =
956                            union.schema.qualified_field(i);
957                        replace_map.insert(
958                            qualified_name(union_qualifier, union_field.name()),
959                            Expr::Column(Column::new(qualifier.cloned(), field.name())),
960                        );
961                    }
962
963                    let push_predicate =
964                        replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
965                    inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
966                        push_predicate,
967                        Arc::clone(input),
968                    )?)))
969                }
970                Ok(Transformed::yes(LogicalPlan::Union(Union {
971                    inputs,
972                    schema: Arc::clone(&plan_schema),
973                })))
974            }
975            LogicalPlan::Aggregate(agg) => {
976                // We can push down Predicate which in groupby_expr.
977                let group_expr_columns = agg
978                    .group_expr
979                    .iter()
980                    .map(|e| {
981                        let (relation, name) = e.qualified_name();
982                        Column::new(relation, name)
983                    })
984                    .collect::<HashSet<_>>();
985
986                let predicates = split_conjunction_owned(filter.predicate);
987
988                let mut keep_predicates = vec![];
989                let mut push_predicates = vec![];
990                for expr in predicates {
991                    let cols = expr.column_refs();
992                    if cols.iter().all(|c| group_expr_columns.contains(c)) {
993                        push_predicates.push(expr);
994                    } else {
995                        keep_predicates.push(expr);
996                    }
997                }
998
999                // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
1000                // After push, we need to replace `a+b` with Column(a)+Column(b)
1001                // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
1002                let mut replace_map = HashMap::new();
1003                for expr in &agg.group_expr {
1004                    replace_map.insert(expr.schema_name().to_string(), expr.clone());
1005                }
1006                let replaced_push_predicates = push_predicates
1007                    .into_iter()
1008                    .map(|expr| replace_cols_by_name(expr, &replace_map))
1009                    .collect::<Result<Vec<_>>>()?;
1010
1011                let agg_input = Arc::clone(&agg.input);
1012                Transformed::yes(LogicalPlan::Aggregate(agg))
1013                    .transform_data(|new_plan| {
1014                        // If we have a filter to push, we push it down to the input of the aggregate
1015                        if let Some(predicate) = conjunction(replaced_push_predicates) {
1016                            let new_filter = make_filter(predicate, agg_input)?;
1017                            insert_below(new_plan, new_filter)
1018                        } else {
1019                            Ok(Transformed::no(new_plan))
1020                        }
1021                    })?
1022                    .map_data(|child_plan| {
1023                        // if there are any remaining predicates we can't push, add them
1024                        // back as a filter
1025                        if let Some(predicate) = conjunction(keep_predicates) {
1026                            make_filter(predicate, Arc::new(child_plan))
1027                        } else {
1028                            Ok(child_plan)
1029                        }
1030                    })
1031            }
1032            // Tries to push filters based on the partition key(s) of the window function(s) used.
1033            // Example:
1034            //   Before:
1035            //     Filter: (a > 1) and (b > 1) and (c > 1)
1036            //      Window: func() PARTITION BY [a] ...
1037            //   ---
1038            //   After:
1039            //     Filter: (b > 1) and (c > 1)
1040            //      Window: func() PARTITION BY [a] ...
1041            //        Filter: (a > 1)
1042            LogicalPlan::Window(window) => {
1043                // Retrieve the set of potential partition keys where we can push filters by.
1044                // Unlike aggregations, where there is only one statement per SELECT, there can be
1045                // multiple window functions, each with potentially different partition keys.
1046                // Therefore, we need to ensure that any potential partition key returned is used in
1047                // ALL window functions. Otherwise, filters cannot be pushed by through that column.
1048                let extract_partition_keys = |func: &WindowFunction| {
1049                    func.params
1050                        .partition_by
1051                        .iter()
1052                        .map(|c| {
1053                            let (relation, name) = c.qualified_name();
1054                            Column::new(relation, name)
1055                        })
1056                        .collect::<HashSet<_>>()
1057                };
1058                let potential_partition_keys = window
1059                    .window_expr
1060                    .iter()
1061                    .map(|e| {
1062                        match e {
1063                            Expr::WindowFunction(window_func) => {
1064                                extract_partition_keys(window_func)
1065                            }
1066                            Expr::Alias(alias) => {
1067                                if let Expr::WindowFunction(window_func) =
1068                                    alias.expr.as_ref()
1069                                {
1070                                    extract_partition_keys(window_func)
1071                                } else {
1072                                    // window functions expressions are only Expr::WindowFunction
1073                                    unreachable!()
1074                                }
1075                            }
1076                            _ => {
1077                                // window functions expressions are only Expr::WindowFunction
1078                                unreachable!()
1079                            }
1080                        }
1081                    })
1082                    // performs the set intersection of the partition keys of all window functions,
1083                    // returning only the common ones
1084                    .reduce(|a, b| &a & &b)
1085                    .unwrap_or_default();
1086
1087                let predicates = split_conjunction_owned(filter.predicate);
1088                let mut keep_predicates = vec![];
1089                let mut push_predicates = vec![];
1090                for expr in predicates {
1091                    let cols = expr.column_refs();
1092                    if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1093                        push_predicates.push(expr);
1094                    } else {
1095                        keep_predicates.push(expr);
1096                    }
1097                }
1098
1099                // Unlike with aggregations, there are no cases where we have to replace, e.g.,
1100                // `a+b` with Column(a)+Column(b). This is because partition expressions are not
1101                // available as standalone columns to the user. For example, while an aggregation on
1102                // `a+b` becomes Column(a + b), in a window partition it becomes
1103                // `func() PARTITION BY [a + b] ...`. Thus, filters on expressions always remain in
1104                // place, so we can use `push_predicates` directly. This is consistent with other
1105                // optimizers, such as the one used by Postgres.
1106
1107                let window_input = Arc::clone(&window.input);
1108                Transformed::yes(LogicalPlan::Window(window))
1109                    .transform_data(|new_plan| {
1110                        // If we have a filter to push, we push it down to the input of the window
1111                        if let Some(predicate) = conjunction(push_predicates) {
1112                            let new_filter = make_filter(predicate, window_input)?;
1113                            insert_below(new_plan, new_filter)
1114                        } else {
1115                            Ok(Transformed::no(new_plan))
1116                        }
1117                    })?
1118                    .map_data(|child_plan| {
1119                        // if there are any remaining predicates we can't push, add them
1120                        // back as a filter
1121                        if let Some(predicate) = conjunction(keep_predicates) {
1122                            make_filter(predicate, Arc::new(child_plan))
1123                        } else {
1124                            Ok(child_plan)
1125                        }
1126                    })
1127            }
1128            LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1129            LogicalPlan::TableScan(scan) => {
1130                let filter_predicates = split_conjunction(&filter.predicate);
1131
1132                let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1133                    filter_predicates
1134                        .into_iter()
1135                        .partition(|pred| pred.is_volatile());
1136
1137                // Check which non-volatile filters are supported by source
1138                let supported_filters = scan
1139                    .source
1140                    .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1141                assert_eq_or_internal_err!(
1142                    non_volatile_filters.len(),
1143                    supported_filters.len(),
1144                    "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1145                    supported_filters.len(),
1146                    non_volatile_filters.len()
1147                );
1148
1149                // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type
1150                let zip = non_volatile_filters.into_iter().zip(supported_filters);
1151
1152                let new_scan_filters = zip
1153                    .clone()
1154                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1155                    .map(|(pred, _)| pred);
1156
1157                // Add new scan filters
1158                let new_scan_filters: Vec<Expr> = scan
1159                    .filters
1160                    .iter()
1161                    .chain(new_scan_filters)
1162                    .unique()
1163                    .cloned()
1164                    .collect();
1165
1166                // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
1167                let new_predicate: Vec<Expr> = zip
1168                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1169                    .map(|(pred, _)| pred)
1170                    .chain(volatile_filters)
1171                    .cloned()
1172                    .collect();
1173
1174                let new_scan = LogicalPlan::TableScan(TableScan {
1175                    filters: new_scan_filters,
1176                    ..scan
1177                });
1178
1179                Transformed::yes(new_scan).transform_data(|new_scan| {
1180                    if let Some(predicate) = conjunction(new_predicate) {
1181                        make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1182                    } else {
1183                        Ok(Transformed::no(new_scan))
1184                    }
1185                })
1186            }
1187            LogicalPlan::Extension(extension_plan) => {
1188                // This check prevents the Filter from being removed when the extension node has no children,
1189                // so we return the original Filter unchanged.
1190                if extension_plan.node.inputs().is_empty() {
1191                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1192                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1193                }
1194                let prevent_cols =
1195                    extension_plan.node.prevent_predicate_push_down_columns();
1196
1197                // determine if we can push any predicates down past the extension node
1198
1199                // each element is true for push, false to keep
1200                let predicate_push_or_keep = split_conjunction(&filter.predicate)
1201                    .iter()
1202                    .map(|expr| {
1203                        let cols = expr.column_refs();
1204                        if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1205                            Ok(false) // No push (keep)
1206                        } else {
1207                            Ok(true) // push
1208                        }
1209                    })
1210                    .collect::<Result<Vec<_>>>()?;
1211
1212                // all predicates are kept, no changes needed
1213                if predicate_push_or_keep.iter().all(|&x| !x) {
1214                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1215                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1216                }
1217
1218                // going to push some predicates down, so split the predicates
1219                let mut keep_predicates = vec![];
1220                let mut push_predicates = vec![];
1221                for (push, expr) in predicate_push_or_keep
1222                    .into_iter()
1223                    .zip(split_conjunction_owned(filter.predicate).into_iter())
1224                {
1225                    if !push {
1226                        keep_predicates.push(expr);
1227                    } else {
1228                        push_predicates.push(expr);
1229                    }
1230                }
1231
1232                let new_children = match conjunction(push_predicates) {
1233                    Some(predicate) => extension_plan
1234                        .node
1235                        .inputs()
1236                        .into_iter()
1237                        .map(|child| {
1238                            Ok(LogicalPlan::Filter(Filter::try_new(
1239                                predicate.clone(),
1240                                Arc::new(child.clone()),
1241                            )?))
1242                        })
1243                        .collect::<Result<Vec<_>>>()?,
1244                    None => extension_plan.node.inputs().into_iter().cloned().collect(),
1245                };
1246                // extension with new inputs.
1247                let child_plan = LogicalPlan::Extension(extension_plan);
1248                let new_extension =
1249                    child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1250
1251                let new_plan = match conjunction(keep_predicates) {
1252                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1253                        predicate,
1254                        Arc::new(new_extension),
1255                    )?),
1256                    None => new_extension,
1257                };
1258                Ok(Transformed::yes(new_plan))
1259            }
1260            child => {
1261                filter.input = Arc::new(child);
1262                Ok(Transformed::no(LogicalPlan::Filter(filter)))
1263            }
1264        }
1265    }
1266}
1267
1268/// Attempts to push `predicate` into a `FilterExec` below `projection
1269///
1270/// # Returns
1271/// (plan, remaining_predicate)
1272///
1273/// `plan` is a LogicalPlan for `projection` with possibly a new FilterExec below it.
1274/// `remaining_predicate` is any part of the predicate that could not be pushed down
1275///
1276/// # Args
1277/// - predicates: Split predicates like `[foo=5, bar=6]`
1278/// - projection: The target projection plan to push down the predicates
1279///
1280/// # Example
1281///
1282/// Pushing a predicate like `foo=5 AND bar=6` with an input plan like this:
1283///
1284/// ```text
1285/// Projection(foo, c+d as bar)
1286/// ```
1287///
1288/// Might result in returning `remaining_predicate` of `bar=6` and a plan like
1289///
1290/// ```text
1291/// Projection(foo, c+d as bar)
1292///  Filter(foo=5)
1293///   ...
1294/// ```
1295fn rewrite_projection(
1296    predicates: Vec<Expr>,
1297    mut projection: Projection,
1298) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1299    // Partition projection expressions into non-pushable vs pushable.
1300    // Non-pushable expressions are volatile (must not be duplicated) or
1301    // MoveTowardsLeafNodes (cheap expressions like get_field where re-inlining
1302    // into a filter causes optimizer instability — ExtractLeafExpressions will
1303    // undo the push-down, creating an infinite loop that runs until the
1304    // iteration limit is hit).
1305    let (non_pushable_map, pushable_map): (HashMap<_, _>, HashMap<_, _>) = projection
1306        .schema
1307        .iter()
1308        .zip(projection.expr.iter())
1309        .map(|((qualifier, field), expr)| {
1310            // strip alias, as they should not be part of filters
1311            let expr = expr.clone().unalias();
1312
1313            (qualified_name(qualifier, field.name()), expr)
1314        })
1315        .partition(|(_, value)| {
1316            value.is_volatile()
1317                || value.placement() == ExpressionPlacement::MoveTowardsLeafNodes
1318        });
1319
1320    let mut push_predicates = vec![];
1321    let mut keep_predicates = vec![];
1322    for expr in predicates {
1323        if contain(&expr, &non_pushable_map) {
1324            keep_predicates.push(expr);
1325        } else {
1326            push_predicates.push(expr);
1327        }
1328    }
1329
1330    match conjunction(push_predicates) {
1331        Some(expr) => {
1332            // re-write all filters based on this projection
1333            // E.g. in `Filter: b\n  Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
1334            let new_filter = LogicalPlan::Filter(Filter::try_new(
1335                replace_cols_by_name(expr, &pushable_map)?,
1336                std::mem::take(&mut projection.input),
1337            )?);
1338
1339            projection.input = Arc::new(new_filter);
1340
1341            Ok((
1342                Transformed::yes(LogicalPlan::Projection(projection)),
1343                conjunction(keep_predicates),
1344            ))
1345        }
1346        None => Ok((
1347            Transformed::no(LogicalPlan::Projection(projection)),
1348            conjunction(keep_predicates),
1349        )),
1350    }
1351}
1352
1353/// Creates a new LogicalPlan::Filter node.
1354pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1355    Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1356}
1357
1358/// Replace the existing child of the single input node with `new_child`.
1359///
1360/// Starting:
1361/// ```text
1362/// plan
1363///   child
1364/// ```
1365///
1366/// Ending:
1367/// ```text
1368/// plan
1369///   new_child
1370/// ```
1371fn insert_below(
1372    plan: LogicalPlan,
1373    new_child: LogicalPlan,
1374) -> Result<Transformed<LogicalPlan>> {
1375    let mut new_child = Some(new_child);
1376    let transformed_plan = plan.map_children(|_child| {
1377        if let Some(new_child) = new_child.take() {
1378            Ok(Transformed::yes(new_child))
1379        } else {
1380            // already took the new child
1381            internal_err!("node had more than one input")
1382        }
1383    })?;
1384
1385    // make sure we did the actual replacement
1386    assert_or_internal_err!(new_child.is_none(), "node had no inputs");
1387
1388    Ok(transformed_plan)
1389}
1390
1391impl PushDownFilter {
1392    #[expect(missing_docs)]
1393    pub fn new() -> Self {
1394        Self {}
1395    }
1396}
1397
1398/// replaces columns by its name on the projection.
1399pub fn replace_cols_by_name(
1400    e: Expr,
1401    replace_map: &HashMap<String, Expr>,
1402) -> Result<Expr> {
1403    e.transform_up(|expr| {
1404        Ok(if let Expr::Column(c) = &expr {
1405            match replace_map.get(&c.flat_name()) {
1406                Some(new_c) => Transformed::yes(new_c.clone()),
1407                None => Transformed::no(expr),
1408            }
1409        } else {
1410            Transformed::no(expr)
1411        })
1412    })
1413    .data()
1414}
1415
1416/// check whether the expression uses the columns in `check_map`.
1417fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1418    let mut is_contain = false;
1419    e.apply(|expr| {
1420        Ok(if let Expr::Column(c) = &expr {
1421            match check_map.get(&c.flat_name()) {
1422                Some(_) => {
1423                    is_contain = true;
1424                    TreeNodeRecursion::Stop
1425                }
1426                None => TreeNodeRecursion::Continue,
1427            }
1428        } else {
1429            TreeNodeRecursion::Continue
1430        })
1431    })
1432    .unwrap();
1433    is_contain
1434}
1435
1436#[cfg(test)]
1437mod tests {
1438    use std::any::Any;
1439    use std::cmp::Ordering;
1440    use std::fmt::{Debug, Formatter};
1441
1442    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1443    use async_trait::async_trait;
1444
1445    use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1446    use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1447    use datafusion_expr::logical_plan::table_scan;
1448    use datafusion_expr::{
1449        ColumnarValue, ExprFunctionExt, Extension, LogicalPlanBuilder,
1450        ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType,
1451        UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, col, in_list,
1452        in_subquery, lit,
1453    };
1454
1455    use crate::OptimizerContext;
1456    use crate::assert_optimized_plan_eq_snapshot;
1457    use crate::optimizer::Optimizer;
1458    use crate::simplify_expressions::SimplifyExpressions;
1459    use crate::test::udfs::leaf_udf_expr;
1460    use crate::test::*;
1461    use datafusion_expr::test::function_stub::sum;
1462    use insta::assert_snapshot;
1463
1464    use super::*;
1465
1466    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1467
1468    macro_rules! assert_optimized_plan_equal {
1469        (
1470            $plan:expr,
1471            @ $expected:literal $(,)?
1472        ) => {{
1473            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1474            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1475            assert_optimized_plan_eq_snapshot!(
1476                optimizer_ctx,
1477                rules,
1478                $plan,
1479                @ $expected,
1480            )
1481        }};
1482    }
1483
1484    macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1485        (
1486            $plan:expr,
1487            @ $expected:literal $(,)?
1488        ) => {{
1489            let optimizer = Optimizer::with_rules(vec![
1490                Arc::new(SimplifyExpressions::new()),
1491                Arc::new(PushDownFilter::new()),
1492            ]);
1493            let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1494            assert_snapshot!(optimized_plan, @ $expected);
1495            Ok::<(), DataFusionError>(())
1496        }};
1497    }
1498
1499    #[test]
1500    fn filter_before_projection() -> Result<()> {
1501        let table_scan = test_table_scan()?;
1502        let plan = LogicalPlanBuilder::from(table_scan)
1503            .project(vec![col("a"), col("b")])?
1504            .filter(col("a").eq(lit(1i64)))?
1505            .build()?;
1506        // filter is before projection
1507        assert_optimized_plan_equal!(
1508            plan,
1509            @r"
1510        Projection: test.a, test.b
1511          TableScan: test, full_filters=[test.a = Int64(1)]
1512        "
1513        )
1514    }
1515
1516    #[test]
1517    fn filter_after_limit() -> Result<()> {
1518        let table_scan = test_table_scan()?;
1519        let plan = LogicalPlanBuilder::from(table_scan)
1520            .project(vec![col("a"), col("b")])?
1521            .limit(0, Some(10))?
1522            .filter(col("a").eq(lit(1i64)))?
1523            .build()?;
1524        // filter is before single projection
1525        assert_optimized_plan_equal!(
1526            plan,
1527            @r"
1528        Filter: test.a = Int64(1)
1529          Limit: skip=0, fetch=10
1530            Projection: test.a, test.b
1531              TableScan: test
1532        "
1533        )
1534    }
1535
1536    #[test]
1537    fn filter_no_columns() -> Result<()> {
1538        let table_scan = test_table_scan()?;
1539        let plan = LogicalPlanBuilder::from(table_scan)
1540            .filter(lit(0i64).eq(lit(1i64)))?
1541            .build()?;
1542        assert_optimized_plan_equal!(
1543            plan,
1544            @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1545        )
1546    }
1547
1548    #[test]
1549    fn filter_jump_2_plans() -> Result<()> {
1550        let table_scan = test_table_scan()?;
1551        let plan = LogicalPlanBuilder::from(table_scan)
1552            .project(vec![col("a"), col("b"), col("c")])?
1553            .project(vec![col("c"), col("b")])?
1554            .filter(col("a").eq(lit(1i64)))?
1555            .build()?;
1556        // filter is before double projection
1557        assert_optimized_plan_equal!(
1558            plan,
1559            @r"
1560        Projection: test.c, test.b
1561          Projection: test.a, test.b, test.c
1562            TableScan: test, full_filters=[test.a = Int64(1)]
1563        "
1564        )
1565    }
1566
1567    #[test]
1568    fn filter_move_agg() -> Result<()> {
1569        let table_scan = test_table_scan()?;
1570        let plan = LogicalPlanBuilder::from(table_scan)
1571            .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1572            .filter(col("a").gt(lit(10i64)))?
1573            .build()?;
1574        // filter of key aggregation is commutative
1575        assert_optimized_plan_equal!(
1576            plan,
1577            @r"
1578        Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1579          TableScan: test, full_filters=[test.a > Int64(10)]
1580        "
1581        )
1582    }
1583
1584    /// verifies that filters with unusual column names are pushed down through aggregate operators
1585    #[test]
1586    fn filter_move_agg_special() -> Result<()> {
1587        let schema = Schema::new(vec![
1588            Field::new("$a", DataType::UInt32, false),
1589            Field::new("$b", DataType::UInt32, false),
1590            Field::new("$c", DataType::UInt32, false),
1591        ]);
1592        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1593
1594        let plan = LogicalPlanBuilder::from(table_scan)
1595            .aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])?
1596            .filter(col("$a").gt(lit(10i64)))?
1597            .build()?;
1598        // filter of key aggregation is commutative
1599        assert_optimized_plan_equal!(
1600            plan,
1601            @r"
1602        Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]]
1603          TableScan: test, full_filters=[test.$a > Int64(10)]
1604        "
1605        )
1606    }
1607
1608    #[test]
1609    fn filter_complex_group_by() -> Result<()> {
1610        let table_scan = test_table_scan()?;
1611        let plan = LogicalPlanBuilder::from(table_scan)
1612            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1613            .filter(col("b").gt(lit(10i64)))?
1614            .build()?;
1615        assert_optimized_plan_equal!(
1616            plan,
1617            @r"
1618        Filter: test.b > Int64(10)
1619          Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1620            TableScan: test
1621        "
1622        )
1623    }
1624
1625    #[test]
1626    fn push_agg_need_replace_expr() -> Result<()> {
1627        let plan = LogicalPlanBuilder::from(test_table_scan()?)
1628            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1629            .filter(col("test.b + test.a").gt(lit(10i64)))?
1630            .build()?;
1631        assert_optimized_plan_equal!(
1632            plan,
1633            @r"
1634        Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1635          TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1636        "
1637        )
1638    }
1639
1640    #[test]
1641    fn filter_keep_agg() -> Result<()> {
1642        let table_scan = test_table_scan()?;
1643        let plan = LogicalPlanBuilder::from(table_scan)
1644            .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1645            .filter(col("b").gt(lit(10i64)))?
1646            .build()?;
1647        // filter of aggregate is after aggregation since they are non-commutative
1648        assert_optimized_plan_equal!(
1649            plan,
1650            @r"
1651        Filter: b > Int64(10)
1652          Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1653            TableScan: test
1654        "
1655        )
1656    }
1657
1658    /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed
1659    #[test]
1660    fn filter_move_window() -> Result<()> {
1661        let table_scan = test_table_scan()?;
1662
1663        let window = Expr::from(WindowFunction::new(
1664            WindowFunctionDefinition::WindowUDF(
1665                datafusion_functions_window::rank::rank_udwf(),
1666            ),
1667            vec![],
1668        ))
1669        .partition_by(vec![col("a"), col("b")])
1670        .order_by(vec![col("c").sort(true, true)])
1671        .build()
1672        .unwrap();
1673
1674        let plan = LogicalPlanBuilder::from(table_scan)
1675            .window(vec![window])?
1676            .filter(col("b").gt(lit(10i64)))?
1677            .build()?;
1678
1679        assert_optimized_plan_equal!(
1680            plan,
1681            @r"
1682        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1683          TableScan: test, full_filters=[test.b > Int64(10)]
1684        "
1685        )
1686    }
1687
1688    /// verifies that filters with unusual identifier names are pushed down through window functions
1689    #[test]
1690    fn filter_window_special_identifier() -> Result<()> {
1691        let schema = Schema::new(vec![
1692            Field::new("$a", DataType::UInt32, false),
1693            Field::new("$b", DataType::UInt32, false),
1694            Field::new("$c", DataType::UInt32, false),
1695        ]);
1696        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1697
1698        let window = Expr::from(WindowFunction::new(
1699            WindowFunctionDefinition::WindowUDF(
1700                datafusion_functions_window::rank::rank_udwf(),
1701            ),
1702            vec![],
1703        ))
1704        .partition_by(vec![col("$a"), col("$b")])
1705        .order_by(vec![col("$c").sort(true, true)])
1706        .build()
1707        .unwrap();
1708
1709        let plan = LogicalPlanBuilder::from(table_scan)
1710            .window(vec![window])?
1711            .filter(col("$b").gt(lit(10i64)))?
1712            .build()?;
1713
1714        assert_optimized_plan_equal!(
1715            plan,
1716            @r"
1717        WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1718          TableScan: test, full_filters=[test.$b > Int64(10)]
1719        "
1720        )
1721    }
1722
1723    /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and
1724    /// 'b' are pushed
1725    #[test]
1726    fn filter_move_complex_window() -> Result<()> {
1727        let table_scan = test_table_scan()?;
1728
1729        let window = Expr::from(WindowFunction::new(
1730            WindowFunctionDefinition::WindowUDF(
1731                datafusion_functions_window::rank::rank_udwf(),
1732            ),
1733            vec![],
1734        ))
1735        .partition_by(vec![col("a"), col("b")])
1736        .order_by(vec![col("c").sort(true, true)])
1737        .build()
1738        .unwrap();
1739
1740        let plan = LogicalPlanBuilder::from(table_scan)
1741            .window(vec![window])?
1742            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1743            .build()?;
1744
1745        assert_optimized_plan_equal!(
1746            plan,
1747            @r"
1748        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1749          TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1750        "
1751        )
1752    }
1753
1754    /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed
1755    #[test]
1756    fn filter_move_partial_window() -> Result<()> {
1757        let table_scan = test_table_scan()?;
1758
1759        let window = Expr::from(WindowFunction::new(
1760            WindowFunctionDefinition::WindowUDF(
1761                datafusion_functions_window::rank::rank_udwf(),
1762            ),
1763            vec![],
1764        ))
1765        .partition_by(vec![col("a")])
1766        .order_by(vec![col("c").sort(true, true)])
1767        .build()
1768        .unwrap();
1769
1770        let plan = LogicalPlanBuilder::from(table_scan)
1771            .window(vec![window])?
1772            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1773            .build()?;
1774
1775        assert_optimized_plan_equal!(
1776            plan,
1777            @r"
1778        Filter: test.b = Int64(1)
1779          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1780            TableScan: test, full_filters=[test.a > Int64(10)]
1781        "
1782        )
1783    }
1784
1785    /// verifies that filters on partition expressions are not pushed, as the single expression
1786    /// column is not available to the user, unlike with aggregations
1787    #[test]
1788    fn filter_expression_keep_window() -> Result<()> {
1789        let table_scan = test_table_scan()?;
1790
1791        let window = Expr::from(WindowFunction::new(
1792            WindowFunctionDefinition::WindowUDF(
1793                datafusion_functions_window::rank::rank_udwf(),
1794            ),
1795            vec![],
1796        ))
1797        .partition_by(vec![add(col("a"), col("b"))]) // PARTITION BY a + b
1798        .order_by(vec![col("c").sort(true, true)])
1799        .build()
1800        .unwrap();
1801
1802        let plan = LogicalPlanBuilder::from(table_scan)
1803            .window(vec![window])?
1804            // unlike with aggregations, single partition column "test.a + test.b" is not available
1805            // to the plan, so we use multiple columns when filtering
1806            .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1807            .build()?;
1808
1809        assert_optimized_plan_equal!(
1810            plan,
1811            @r"
1812        Filter: test.a + test.b > Int64(10)
1813          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1814            TableScan: test
1815        "
1816        )
1817    }
1818
1819    /// verifies that filters are not pushed on order by columns (that are not used in partitioning)
1820    #[test]
1821    fn filter_order_keep_window() -> Result<()> {
1822        let table_scan = test_table_scan()?;
1823
1824        let window = Expr::from(WindowFunction::new(
1825            WindowFunctionDefinition::WindowUDF(
1826                datafusion_functions_window::rank::rank_udwf(),
1827            ),
1828            vec![],
1829        ))
1830        .partition_by(vec![col("a")])
1831        .order_by(vec![col("c").sort(true, true)])
1832        .build()
1833        .unwrap();
1834
1835        let plan = LogicalPlanBuilder::from(table_scan)
1836            .window(vec![window])?
1837            .filter(col("c").gt(lit(10i64)))?
1838            .build()?;
1839
1840        assert_optimized_plan_equal!(
1841            plan,
1842            @r"
1843        Filter: test.c > Int64(10)
1844          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1845            TableScan: test
1846        "
1847        )
1848    }
1849
1850    /// verifies that when we use multiple window functions with a common partition key, the filter
1851    /// on that key is pushed
1852    #[test]
1853    fn filter_multiple_windows_common_partitions() -> Result<()> {
1854        let table_scan = test_table_scan()?;
1855
1856        let window1 = Expr::from(WindowFunction::new(
1857            WindowFunctionDefinition::WindowUDF(
1858                datafusion_functions_window::rank::rank_udwf(),
1859            ),
1860            vec![],
1861        ))
1862        .partition_by(vec![col("a")])
1863        .order_by(vec![col("c").sort(true, true)])
1864        .build()
1865        .unwrap();
1866
1867        let window2 = Expr::from(WindowFunction::new(
1868            WindowFunctionDefinition::WindowUDF(
1869                datafusion_functions_window::rank::rank_udwf(),
1870            ),
1871            vec![],
1872        ))
1873        .partition_by(vec![col("b"), col("a")])
1874        .order_by(vec![col("c").sort(true, true)])
1875        .build()
1876        .unwrap();
1877
1878        let plan = LogicalPlanBuilder::from(table_scan)
1879            .window(vec![window1, window2])?
1880            .filter(col("a").gt(lit(10i64)))? // a appears in both window functions
1881            .build()?;
1882
1883        assert_optimized_plan_equal!(
1884            plan,
1885            @r"
1886        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1887          TableScan: test, full_filters=[test.a > Int64(10)]
1888        "
1889        )
1890    }
1891
1892    /// verifies that when we use multiple window functions with different partitions keys, the
1893    /// filter cannot be pushed
1894    #[test]
1895    fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1896        let table_scan = test_table_scan()?;
1897
1898        let window1 = Expr::from(WindowFunction::new(
1899            WindowFunctionDefinition::WindowUDF(
1900                datafusion_functions_window::rank::rank_udwf(),
1901            ),
1902            vec![],
1903        ))
1904        .partition_by(vec![col("a")])
1905        .order_by(vec![col("c").sort(true, true)])
1906        .build()
1907        .unwrap();
1908
1909        let window2 = Expr::from(WindowFunction::new(
1910            WindowFunctionDefinition::WindowUDF(
1911                datafusion_functions_window::rank::rank_udwf(),
1912            ),
1913            vec![],
1914        ))
1915        .partition_by(vec![col("b"), col("a")])
1916        .order_by(vec![col("c").sort(true, true)])
1917        .build()
1918        .unwrap();
1919
1920        let plan = LogicalPlanBuilder::from(table_scan)
1921            .window(vec![window1, window2])?
1922            .filter(col("b").gt(lit(10i64)))? // b only appears in one window function
1923            .build()?;
1924
1925        assert_optimized_plan_equal!(
1926            plan,
1927            @r"
1928        Filter: test.b > Int64(10)
1929          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1930            TableScan: test
1931        "
1932        )
1933    }
1934
1935    /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
1936    #[test]
1937    fn alias() -> Result<()> {
1938        let table_scan = test_table_scan()?;
1939        let plan = LogicalPlanBuilder::from(table_scan)
1940            .project(vec![col("a").alias("b"), col("c")])?
1941            .filter(col("b").eq(lit(1i64)))?
1942            .build()?;
1943        // filter is before projection
1944        assert_optimized_plan_equal!(
1945            plan,
1946            @r"
1947        Projection: test.a AS b, test.c
1948          TableScan: test, full_filters=[test.a = Int64(1)]
1949        "
1950        )
1951    }
1952
1953    fn add(left: Expr, right: Expr) -> Expr {
1954        Expr::BinaryExpr(BinaryExpr::new(
1955            Box::new(left),
1956            Operator::Plus,
1957            Box::new(right),
1958        ))
1959    }
1960
1961    fn multiply(left: Expr, right: Expr) -> Expr {
1962        Expr::BinaryExpr(BinaryExpr::new(
1963            Box::new(left),
1964            Operator::Multiply,
1965            Box::new(right),
1966        ))
1967    }
1968
1969    /// verifies that a filter is pushed to before a projection with a complex expression, the filter expression is correctly re-written
1970    #[test]
1971    fn complex_expression() -> Result<()> {
1972        let table_scan = test_table_scan()?;
1973        let plan = LogicalPlanBuilder::from(table_scan)
1974            .project(vec![
1975                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1976                col("c"),
1977            ])?
1978            .filter(col("b").eq(lit(1i64)))?
1979            .build()?;
1980
1981        // not part of the test, just good to know:
1982        assert_snapshot!(plan,
1983        @r"
1984        Filter: b = Int64(1)
1985          Projection: test.a * Int32(2) + test.c AS b, test.c
1986            TableScan: test
1987        ",
1988        );
1989        // filter is before projection
1990        assert_optimized_plan_equal!(
1991            plan,
1992            @r"
1993        Projection: test.a * Int32(2) + test.c AS b, test.c
1994          TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1995        "
1996        )
1997    }
1998
1999    /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written
2000    #[test]
2001    fn complex_plan() -> Result<()> {
2002        let table_scan = test_table_scan()?;
2003        let plan = LogicalPlanBuilder::from(table_scan)
2004            .project(vec![
2005                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
2006                col("c"),
2007            ])?
2008            // second projection where we rename columns, just to make it difficult
2009            .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
2010            .filter(col("a").eq(lit(1i64)))?
2011            .build()?;
2012
2013        // not part of the test, just good to know:
2014        assert_snapshot!(plan,
2015        @r"
2016        Filter: a = Int64(1)
2017          Projection: b * Int32(3) AS a, test.c
2018            Projection: test.a * Int32(2) + test.c AS b, test.c
2019              TableScan: test
2020        ",
2021        );
2022        // filter is before the projections
2023        assert_optimized_plan_equal!(
2024            plan,
2025            @r"
2026        Projection: b * Int32(3) AS a, test.c
2027          Projection: test.a * Int32(2) + test.c AS b, test.c
2028            TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
2029        "
2030        )
2031    }
2032
2033    #[derive(Debug, PartialEq, Eq, Hash)]
2034    struct NoopPlan {
2035        input: Vec<LogicalPlan>,
2036        schema: DFSchemaRef,
2037    }
2038
2039    // Manual implementation needed because of `schema` field. Comparison excludes this field.
2040    impl PartialOrd for NoopPlan {
2041        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2042            self.input
2043                .partial_cmp(&other.input)
2044                // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
2045                .filter(|cmp| *cmp != Ordering::Equal || self == other)
2046        }
2047    }
2048
2049    impl UserDefinedLogicalNodeCore for NoopPlan {
2050        fn name(&self) -> &str {
2051            "NoopPlan"
2052        }
2053
2054        fn inputs(&self) -> Vec<&LogicalPlan> {
2055            self.input.iter().collect()
2056        }
2057
2058        fn schema(&self) -> &DFSchemaRef {
2059            &self.schema
2060        }
2061
2062        fn expressions(&self) -> Vec<Expr> {
2063            self.input
2064                .iter()
2065                .flat_map(|child| child.expressions())
2066                .collect()
2067        }
2068
2069        fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
2070            HashSet::from_iter(vec!["c".to_string()])
2071        }
2072
2073        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
2074            write!(f, "NoopPlan")
2075        }
2076
2077        fn with_exprs_and_inputs(
2078            &self,
2079            _exprs: Vec<Expr>,
2080            inputs: Vec<LogicalPlan>,
2081        ) -> Result<Self> {
2082            Ok(Self {
2083                input: inputs,
2084                schema: Arc::clone(&self.schema),
2085            })
2086        }
2087
2088        fn supports_limit_pushdown(&self) -> bool {
2089            false // Disallow limit push-down by default
2090        }
2091    }
2092
2093    #[test]
2094    fn user_defined_plan() -> Result<()> {
2095        let table_scan = test_table_scan()?;
2096
2097        let custom_plan = LogicalPlan::Extension(Extension {
2098            node: Arc::new(NoopPlan {
2099                input: vec![table_scan.clone()],
2100                schema: Arc::clone(table_scan.schema()),
2101            }),
2102        });
2103        let plan = LogicalPlanBuilder::from(custom_plan)
2104            .filter(col("a").eq(lit(1i64)))?
2105            .build()?;
2106
2107        // Push filter below NoopPlan
2108        assert_optimized_plan_equal!(
2109            plan,
2110            @r"
2111        NoopPlan
2112          TableScan: test, full_filters=[test.a = Int64(1)]
2113        "
2114        )?;
2115
2116        let custom_plan = LogicalPlan::Extension(Extension {
2117            node: Arc::new(NoopPlan {
2118                input: vec![table_scan.clone()],
2119                schema: Arc::clone(table_scan.schema()),
2120            }),
2121        });
2122        let plan = LogicalPlanBuilder::from(custom_plan)
2123            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2124            .build()?;
2125
2126        // Push only predicate on `a` below NoopPlan
2127        assert_optimized_plan_equal!(
2128            plan,
2129            @r"
2130        Filter: test.c = Int64(2)
2131          NoopPlan
2132            TableScan: test, full_filters=[test.a = Int64(1)]
2133        "
2134        )?;
2135
2136        let custom_plan = LogicalPlan::Extension(Extension {
2137            node: Arc::new(NoopPlan {
2138                input: vec![table_scan.clone(), table_scan.clone()],
2139                schema: Arc::clone(table_scan.schema()),
2140            }),
2141        });
2142        let plan = LogicalPlanBuilder::from(custom_plan)
2143            .filter(col("a").eq(lit(1i64)))?
2144            .build()?;
2145
2146        // Push filter below NoopPlan for each child branch
2147        assert_optimized_plan_equal!(
2148            plan,
2149            @r"
2150        NoopPlan
2151          TableScan: test, full_filters=[test.a = Int64(1)]
2152          TableScan: test, full_filters=[test.a = Int64(1)]
2153        "
2154        )?;
2155
2156        let custom_plan = LogicalPlan::Extension(Extension {
2157            node: Arc::new(NoopPlan {
2158                input: vec![table_scan.clone(), table_scan.clone()],
2159                schema: Arc::clone(table_scan.schema()),
2160            }),
2161        });
2162        let plan = LogicalPlanBuilder::from(custom_plan)
2163            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2164            .build()?;
2165
2166        // Push only predicate on `a` below NoopPlan
2167        assert_optimized_plan_equal!(
2168            plan,
2169            @r"
2170        Filter: test.c = Int64(2)
2171          NoopPlan
2172            TableScan: test, full_filters=[test.a = Int64(1)]
2173            TableScan: test, full_filters=[test.a = Int64(1)]
2174        "
2175        )
2176    }
2177
2178    /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed
2179    /// and the other not.
2180    #[test]
2181    fn multi_filter() -> Result<()> {
2182        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2183        let table_scan = test_table_scan()?;
2184        let plan = LogicalPlanBuilder::from(table_scan)
2185            .project(vec![col("a").alias("b"), col("c")])?
2186            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2187            .filter(col("b").gt(lit(10i64)))?
2188            .filter(col("sum(test.c)").gt(lit(10i64)))?
2189            .build()?;
2190
2191        // not part of the test, just good to know:
2192        assert_snapshot!(plan,
2193        @r"
2194        Filter: sum(test.c) > Int64(10)
2195          Filter: b > Int64(10)
2196            Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2197              Projection: test.a AS b, test.c
2198                TableScan: test
2199        ",
2200        );
2201        // filter is before the projections
2202        assert_optimized_plan_equal!(
2203            plan,
2204            @r"
2205        Filter: sum(test.c) > Int64(10)
2206          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2207            Projection: test.a AS b, test.c
2208              TableScan: test, full_filters=[test.a > Int64(10)]
2209        "
2210        )
2211    }
2212
2213    /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed
2214    /// and the other not.
2215    #[test]
2216    fn split_filter() -> Result<()> {
2217        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2218        let table_scan = test_table_scan()?;
2219        let plan = LogicalPlanBuilder::from(table_scan)
2220            .project(vec![col("a").alias("b"), col("c")])?
2221            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2222            .filter(and(
2223                col("sum(test.c)").gt(lit(10i64)),
2224                and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2225            ))?
2226            .build()?;
2227
2228        // not part of the test, just good to know:
2229        assert_snapshot!(plan,
2230        @r"
2231        Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2232          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2233            Projection: test.a AS b, test.c
2234              TableScan: test
2235        ",
2236        );
2237        // filter is before the projections
2238        assert_optimized_plan_equal!(
2239            plan,
2240            @r"
2241        Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2242          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2243            Projection: test.a AS b, test.c
2244              TableScan: test, full_filters=[test.a > Int64(10)]
2245        "
2246        )
2247    }
2248
2249    /// verifies that when two limits are in place, we jump neither
2250    #[test]
2251    fn double_limit() -> Result<()> {
2252        let table_scan = test_table_scan()?;
2253        let plan = LogicalPlanBuilder::from(table_scan)
2254            .project(vec![col("a"), col("b")])?
2255            .limit(0, Some(20))?
2256            .limit(0, Some(10))?
2257            .project(vec![col("a"), col("b")])?
2258            .filter(col("a").eq(lit(1i64)))?
2259            .build()?;
2260        // filter does not just any of the limits
2261        assert_optimized_plan_equal!(
2262            plan,
2263            @r"
2264        Projection: test.a, test.b
2265          Filter: test.a = Int64(1)
2266            Limit: skip=0, fetch=10
2267              Limit: skip=0, fetch=20
2268                Projection: test.a, test.b
2269                  TableScan: test
2270        "
2271        )
2272    }
2273
2274    #[test]
2275    fn union_all() -> Result<()> {
2276        let table_scan = test_table_scan()?;
2277        let table_scan2 = test_table_scan_with_name("test2")?;
2278        let plan = LogicalPlanBuilder::from(table_scan)
2279            .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2280            .filter(col("a").eq(lit(1i64)))?
2281            .build()?;
2282        // filter appears below Union
2283        assert_optimized_plan_equal!(
2284            plan,
2285            @r"
2286        Union
2287          TableScan: test, full_filters=[test.a = Int64(1)]
2288          TableScan: test2, full_filters=[test2.a = Int64(1)]
2289        "
2290        )
2291    }
2292
2293    #[test]
2294    fn union_all_on_projection() -> Result<()> {
2295        let table_scan = test_table_scan()?;
2296        let table = LogicalPlanBuilder::from(table_scan)
2297            .project(vec![col("a").alias("b")])?
2298            .alias("test2")?;
2299
2300        let plan = table
2301            .clone()
2302            .union(table.build()?)?
2303            .filter(col("b").eq(lit(1i64)))?
2304            .build()?;
2305
2306        // filter appears below Union
2307        assert_optimized_plan_equal!(
2308            plan,
2309            @r"
2310        Union
2311          SubqueryAlias: test2
2312            Projection: test.a AS b
2313              TableScan: test, full_filters=[test.a = Int64(1)]
2314          SubqueryAlias: test2
2315            Projection: test.a AS b
2316              TableScan: test, full_filters=[test.a = Int64(1)]
2317        "
2318        )
2319    }
2320
2321    #[test]
2322    fn test_union_different_schema() -> Result<()> {
2323        let left = LogicalPlanBuilder::from(test_table_scan()?)
2324            .project(vec![col("a"), col("b"), col("c")])?
2325            .build()?;
2326
2327        let schema = Schema::new(vec![
2328            Field::new("d", DataType::UInt32, false),
2329            Field::new("e", DataType::UInt32, false),
2330            Field::new("f", DataType::UInt32, false),
2331        ]);
2332        let right = table_scan(Some("test1"), &schema, None)?
2333            .project(vec![col("d"), col("e"), col("f")])?
2334            .build()?;
2335        let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2336        let plan = LogicalPlanBuilder::from(left)
2337            .cross_join(right)?
2338            .project(vec![col("test.a"), col("test1.d")])?
2339            .filter(filter)?
2340            .build()?;
2341
2342        assert_optimized_plan_equal!(
2343            plan,
2344            @r"
2345        Projection: test.a, test1.d
2346          Cross Join:
2347            Projection: test.a, test.b, test.c
2348              TableScan: test, full_filters=[test.a = Int32(1)]
2349            Projection: test1.d, test1.e, test1.f
2350              TableScan: test1, full_filters=[test1.d > Int32(2)]
2351        "
2352        )
2353    }
2354
2355    #[test]
2356    fn test_project_same_name_different_qualifier() -> Result<()> {
2357        let table_scan = test_table_scan()?;
2358        let left = LogicalPlanBuilder::from(table_scan)
2359            .project(vec![col("a"), col("b"), col("c")])?
2360            .build()?;
2361        let right_table_scan = test_table_scan_with_name("test1")?;
2362        let right = LogicalPlanBuilder::from(right_table_scan)
2363            .project(vec![col("a"), col("b"), col("c")])?
2364            .build()?;
2365        let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2366        let plan = LogicalPlanBuilder::from(left)
2367            .cross_join(right)?
2368            .project(vec![col("test.a"), col("test1.a")])?
2369            .filter(filter)?
2370            .build()?;
2371
2372        assert_optimized_plan_equal!(
2373            plan,
2374            @r"
2375        Projection: test.a, test1.a
2376          Cross Join:
2377            Projection: test.a, test.b, test.c
2378              TableScan: test, full_filters=[test.a = Int32(1)]
2379            Projection: test1.a, test1.b, test1.c
2380              TableScan: test1, full_filters=[test1.a > Int32(2)]
2381        "
2382        )
2383    }
2384
2385    /// verifies that filters with the same columns are correctly placed
2386    #[test]
2387    fn filter_2_breaks_limits() -> Result<()> {
2388        let table_scan = test_table_scan()?;
2389        let plan = LogicalPlanBuilder::from(table_scan)
2390            .project(vec![col("a")])?
2391            .filter(col("a").lt_eq(lit(1i64)))?
2392            .limit(0, Some(1))?
2393            .project(vec![col("a")])?
2394            .filter(col("a").gt_eq(lit(1i64)))?
2395            .build()?;
2396        // Should be able to move both filters below the projections
2397
2398        // not part of the test
2399        assert_snapshot!(plan,
2400        @r"
2401        Filter: test.a >= Int64(1)
2402          Projection: test.a
2403            Limit: skip=0, fetch=1
2404              Filter: test.a <= Int64(1)
2405                Projection: test.a
2406                  TableScan: test
2407        ",
2408        );
2409        assert_optimized_plan_equal!(
2410            plan,
2411            @r"
2412        Projection: test.a
2413          Filter: test.a >= Int64(1)
2414            Limit: skip=0, fetch=1
2415              Projection: test.a
2416                TableScan: test, full_filters=[test.a <= Int64(1)]
2417        "
2418        )
2419    }
2420
2421    /// verifies that filters to be placed on the same depth are ANDed
2422    #[test]
2423    fn two_filters_on_same_depth() -> Result<()> {
2424        let table_scan = test_table_scan()?;
2425        let plan = LogicalPlanBuilder::from(table_scan)
2426            .limit(0, Some(1))?
2427            .filter(col("a").lt_eq(lit(1i64)))?
2428            .filter(col("a").gt_eq(lit(1i64)))?
2429            .project(vec![col("a")])?
2430            .build()?;
2431
2432        // not part of the test
2433        assert_snapshot!(plan,
2434        @r"
2435        Projection: test.a
2436          Filter: test.a >= Int64(1)
2437            Filter: test.a <= Int64(1)
2438              Limit: skip=0, fetch=1
2439                TableScan: test
2440        ",
2441        );
2442        assert_optimized_plan_equal!(
2443            plan,
2444            @r"
2445        Projection: test.a
2446          Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2447            Limit: skip=0, fetch=1
2448              TableScan: test
2449        "
2450        )
2451    }
2452
2453    /// verifies that filters on a plan with user nodes are not lost
2454    /// (ARROW-10547)
2455    #[test]
2456    fn filters_user_defined_node() -> Result<()> {
2457        let table_scan = test_table_scan()?;
2458        let plan = LogicalPlanBuilder::from(table_scan)
2459            .filter(col("a").lt_eq(lit(1i64)))?
2460            .build()?;
2461
2462        let plan = user_defined::new(plan);
2463
2464        // not part of the test
2465        assert_snapshot!(plan,
2466        @r"
2467        TestUserDefined
2468          Filter: test.a <= Int64(1)
2469            TableScan: test
2470        ",
2471        );
2472        assert_optimized_plan_equal!(
2473            plan,
2474            @r"
2475        TestUserDefined
2476          TableScan: test, full_filters=[test.a <= Int64(1)]
2477        "
2478        )
2479    }
2480
2481    /// post-on-join predicates on a column common to both sides is pushed to both sides
2482    #[test]
2483    fn filter_on_join_on_common_independent() -> Result<()> {
2484        let table_scan = test_table_scan()?;
2485        let left = LogicalPlanBuilder::from(table_scan).build()?;
2486        let right_table_scan = test_table_scan_with_name("test2")?;
2487        let right = LogicalPlanBuilder::from(right_table_scan)
2488            .project(vec![col("a")])?
2489            .build()?;
2490        let plan = LogicalPlanBuilder::from(left)
2491            .join(
2492                right,
2493                JoinType::Inner,
2494                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2495                None,
2496            )?
2497            .filter(col("test.a").lt_eq(lit(1i64)))?
2498            .build()?;
2499
2500        // not part of the test, just good to know:
2501        assert_snapshot!(plan,
2502        @r"
2503        Filter: test.a <= Int64(1)
2504          Inner Join: test.a = test2.a
2505            TableScan: test
2506            Projection: test2.a
2507              TableScan: test2
2508        ",
2509        );
2510        // filter sent to side before the join
2511        assert_optimized_plan_equal!(
2512            plan,
2513            @r"
2514        Inner Join: test.a = test2.a
2515          TableScan: test, full_filters=[test.a <= Int64(1)]
2516          Projection: test2.a
2517            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2518        "
2519        )
2520    }
2521
2522    /// post-using-join predicates on a column common to both sides is pushed to both sides
2523    #[test]
2524    fn filter_using_join_on_common_independent() -> Result<()> {
2525        let table_scan = test_table_scan()?;
2526        let left = LogicalPlanBuilder::from(table_scan).build()?;
2527        let right_table_scan = test_table_scan_with_name("test2")?;
2528        let right = LogicalPlanBuilder::from(right_table_scan)
2529            .project(vec![col("a")])?
2530            .build()?;
2531        let plan = LogicalPlanBuilder::from(left)
2532            .join_using(
2533                right,
2534                JoinType::Inner,
2535                vec![Column::from_name("a".to_string())],
2536            )?
2537            .filter(col("a").lt_eq(lit(1i64)))?
2538            .build()?;
2539
2540        // not part of the test, just good to know:
2541        assert_snapshot!(plan,
2542        @r"
2543        Filter: test.a <= Int64(1)
2544          Inner Join: Using test.a = test2.a
2545            TableScan: test
2546            Projection: test2.a
2547              TableScan: test2
2548        ",
2549        );
2550        // filter sent to side before the join
2551        assert_optimized_plan_equal!(
2552            plan,
2553            @r"
2554        Inner Join: Using test.a = test2.a
2555          TableScan: test, full_filters=[test.a <= Int64(1)]
2556          Projection: test2.a
2557            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2558        "
2559        )
2560    }
2561
2562    /// post-join predicates with columns from both sides are converted to join filters
2563    #[test]
2564    fn filter_join_on_common_dependent() -> Result<()> {
2565        let table_scan = test_table_scan()?;
2566        let left = LogicalPlanBuilder::from(table_scan)
2567            .project(vec![col("a"), col("c")])?
2568            .build()?;
2569        let right_table_scan = test_table_scan_with_name("test2")?;
2570        let right = LogicalPlanBuilder::from(right_table_scan)
2571            .project(vec![col("a"), col("b")])?
2572            .build()?;
2573        let plan = LogicalPlanBuilder::from(left)
2574            .join(
2575                right,
2576                JoinType::Inner,
2577                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2578                None,
2579            )?
2580            .filter(col("c").lt_eq(col("b")))?
2581            .build()?;
2582
2583        // not part of the test, just good to know:
2584        assert_snapshot!(plan,
2585        @r"
2586        Filter: test.c <= test2.b
2587          Inner Join: test.a = test2.a
2588            Projection: test.a, test.c
2589              TableScan: test
2590            Projection: test2.a, test2.b
2591              TableScan: test2
2592        ",
2593        );
2594        // Filter is converted to Join Filter
2595        assert_optimized_plan_equal!(
2596            plan,
2597            @r"
2598        Inner Join: test.a = test2.a Filter: test.c <= test2.b
2599          Projection: test.a, test.c
2600            TableScan: test
2601          Projection: test2.a, test2.b
2602            TableScan: test2
2603        "
2604        )
2605    }
2606
2607    /// post-join predicates with columns from one side of a join are pushed only to that side
2608    #[test]
2609    fn filter_join_on_one_side() -> Result<()> {
2610        let table_scan = test_table_scan()?;
2611        let left = LogicalPlanBuilder::from(table_scan)
2612            .project(vec![col("a"), col("b")])?
2613            .build()?;
2614        let table_scan_right = test_table_scan_with_name("test2")?;
2615        let right = LogicalPlanBuilder::from(table_scan_right)
2616            .project(vec![col("a"), col("c")])?
2617            .build()?;
2618
2619        let plan = LogicalPlanBuilder::from(left)
2620            .join(
2621                right,
2622                JoinType::Inner,
2623                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2624                None,
2625            )?
2626            .filter(col("b").lt_eq(lit(1i64)))?
2627            .build()?;
2628
2629        // not part of the test, just good to know:
2630        assert_snapshot!(plan,
2631        @r"
2632        Filter: test.b <= Int64(1)
2633          Inner Join: test.a = test2.a
2634            Projection: test.a, test.b
2635              TableScan: test
2636            Projection: test2.a, test2.c
2637              TableScan: test2
2638        ",
2639        );
2640        assert_optimized_plan_equal!(
2641            plan,
2642            @r"
2643        Inner Join: test.a = test2.a
2644          Projection: test.a, test.b
2645            TableScan: test, full_filters=[test.b <= Int64(1)]
2646          Projection: test2.a, test2.c
2647            TableScan: test2
2648        "
2649        )
2650    }
2651
2652    /// post-join predicates on the right side of a left join are not duplicated
2653    /// TODO: In this case we can sometimes convert the join to an INNER join
2654    #[test]
2655    fn filter_using_left_join() -> Result<()> {
2656        let table_scan = test_table_scan()?;
2657        let left = LogicalPlanBuilder::from(table_scan).build()?;
2658        let right_table_scan = test_table_scan_with_name("test2")?;
2659        let right = LogicalPlanBuilder::from(right_table_scan)
2660            .project(vec![col("a")])?
2661            .build()?;
2662        let plan = LogicalPlanBuilder::from(left)
2663            .join_using(
2664                right,
2665                JoinType::Left,
2666                vec![Column::from_name("a".to_string())],
2667            )?
2668            .filter(col("test2.a").lt_eq(lit(1i64)))?
2669            .build()?;
2670
2671        // not part of the test, just good to know:
2672        assert_snapshot!(plan,
2673        @r"
2674        Filter: test2.a <= Int64(1)
2675          Left Join: Using test.a = test2.a
2676            TableScan: test
2677            Projection: test2.a
2678              TableScan: test2
2679        ",
2680        );
2681        // filter not duplicated nor pushed down - i.e. noop
2682        assert_optimized_plan_equal!(
2683            plan,
2684            @r"
2685        Filter: test2.a <= Int64(1)
2686          Left Join: Using test.a = test2.a
2687            TableScan: test, full_filters=[test.a <= Int64(1)]
2688            Projection: test2.a
2689              TableScan: test2
2690        "
2691        )
2692    }
2693
2694    /// post-join predicates on the left side of a right join are not duplicated
2695    #[test]
2696    fn filter_using_right_join() -> Result<()> {
2697        let table_scan = test_table_scan()?;
2698        let left = LogicalPlanBuilder::from(table_scan).build()?;
2699        let right_table_scan = test_table_scan_with_name("test2")?;
2700        let right = LogicalPlanBuilder::from(right_table_scan)
2701            .project(vec![col("a")])?
2702            .build()?;
2703        let plan = LogicalPlanBuilder::from(left)
2704            .join_using(
2705                right,
2706                JoinType::Right,
2707                vec![Column::from_name("a".to_string())],
2708            )?
2709            .filter(col("test.a").lt_eq(lit(1i64)))?
2710            .build()?;
2711
2712        // not part of the test, just good to know:
2713        assert_snapshot!(plan,
2714        @r"
2715        Filter: test.a <= Int64(1)
2716          Right Join: Using test.a = test2.a
2717            TableScan: test
2718            Projection: test2.a
2719              TableScan: test2
2720        ",
2721        );
2722        // filter not duplicated nor pushed down - i.e. noop
2723        assert_optimized_plan_equal!(
2724            plan,
2725            @r"
2726        Filter: test.a <= Int64(1)
2727          Right Join: Using test.a = test2.a
2728            TableScan: test
2729            Projection: test2.a
2730              TableScan: test2, full_filters=[test2.a <= Int64(1)]
2731        "
2732        )
2733    }
2734
2735    /// post-left-join predicate on a column common to both sides is pushed to both sides
2736    #[test]
2737    fn filter_using_left_join_on_common() -> Result<()> {
2738        let table_scan = test_table_scan()?;
2739        let left = LogicalPlanBuilder::from(table_scan).build()?;
2740        let right_table_scan = test_table_scan_with_name("test2")?;
2741        let right = LogicalPlanBuilder::from(right_table_scan)
2742            .project(vec![col("a")])?
2743            .build()?;
2744        let plan = LogicalPlanBuilder::from(left)
2745            .join_using(
2746                right,
2747                JoinType::Left,
2748                vec![Column::from_name("a".to_string())],
2749            )?
2750            .filter(col("a").lt_eq(lit(1i64)))?
2751            .build()?;
2752
2753        // not part of the test, just good to know:
2754        assert_snapshot!(plan,
2755        @r"
2756        Filter: test.a <= Int64(1)
2757          Left Join: Using test.a = test2.a
2758            TableScan: test
2759            Projection: test2.a
2760              TableScan: test2
2761        ",
2762        );
2763        // filter sent to left side of the join and to the right
2764        assert_optimized_plan_equal!(
2765            plan,
2766            @r"
2767        Left Join: Using test.a = test2.a
2768          TableScan: test, full_filters=[test.a <= Int64(1)]
2769          Projection: test2.a
2770            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2771        "
2772        )
2773    }
2774
2775    /// post-right-join predicate on a column common to both sides is pushed to both sides
2776    #[test]
2777    fn filter_using_right_join_on_common() -> Result<()> {
2778        let table_scan = test_table_scan()?;
2779        let left = LogicalPlanBuilder::from(table_scan).build()?;
2780        let right_table_scan = test_table_scan_with_name("test2")?;
2781        let right = LogicalPlanBuilder::from(right_table_scan)
2782            .project(vec![col("a")])?
2783            .build()?;
2784        let plan = LogicalPlanBuilder::from(left)
2785            .join_using(
2786                right,
2787                JoinType::Right,
2788                vec![Column::from_name("a".to_string())],
2789            )?
2790            .filter(col("test2.a").lt_eq(lit(1i64)))?
2791            .build()?;
2792
2793        // not part of the test, just good to know:
2794        assert_snapshot!(plan,
2795        @r"
2796        Filter: test2.a <= Int64(1)
2797          Right Join: Using test.a = test2.a
2798            TableScan: test
2799            Projection: test2.a
2800              TableScan: test2
2801        ",
2802        );
2803        // filter sent to right side of join, sent to the left as well
2804        assert_optimized_plan_equal!(
2805            plan,
2806            @r"
2807        Right Join: Using test.a = test2.a
2808          TableScan: test, full_filters=[test.a <= Int64(1)]
2809          Projection: test2.a
2810            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2811        "
2812        )
2813    }
2814
2815    /// single table predicate parts of ON condition should be pushed to both inputs
2816    #[test]
2817    fn join_on_with_filter() -> Result<()> {
2818        let table_scan = test_table_scan()?;
2819        let left = LogicalPlanBuilder::from(table_scan)
2820            .project(vec![col("a"), col("b"), col("c")])?
2821            .build()?;
2822        let right_table_scan = test_table_scan_with_name("test2")?;
2823        let right = LogicalPlanBuilder::from(right_table_scan)
2824            .project(vec![col("a"), col("b"), col("c")])?
2825            .build()?;
2826        let filter = col("test.c")
2827            .gt(lit(1u32))
2828            .and(col("test.b").lt(col("test2.b")))
2829            .and(col("test2.c").gt(lit(4u32)));
2830        let plan = LogicalPlanBuilder::from(left)
2831            .join(
2832                right,
2833                JoinType::Inner,
2834                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2835                Some(filter),
2836            )?
2837            .build()?;
2838
2839        // not part of the test, just good to know:
2840        assert_snapshot!(plan,
2841        @r"
2842        Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2843          Projection: test.a, test.b, test.c
2844            TableScan: test
2845          Projection: test2.a, test2.b, test2.c
2846            TableScan: test2
2847        ",
2848        );
2849        assert_optimized_plan_equal!(
2850            plan,
2851            @r"
2852        Inner Join: test.a = test2.a Filter: test.b < test2.b
2853          Projection: test.a, test.b, test.c
2854            TableScan: test, full_filters=[test.c > UInt32(1)]
2855          Projection: test2.a, test2.b, test2.c
2856            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2857        "
2858        )
2859    }
2860
2861    /// join filter should be completely removed after pushdown
2862    #[test]
2863    fn join_filter_removed() -> Result<()> {
2864        let table_scan = test_table_scan()?;
2865        let left = LogicalPlanBuilder::from(table_scan)
2866            .project(vec![col("a"), col("b"), col("c")])?
2867            .build()?;
2868        let right_table_scan = test_table_scan_with_name("test2")?;
2869        let right = LogicalPlanBuilder::from(right_table_scan)
2870            .project(vec![col("a"), col("b"), col("c")])?
2871            .build()?;
2872        let filter = col("test.b")
2873            .gt(lit(1u32))
2874            .and(col("test2.c").gt(lit(4u32)));
2875        let plan = LogicalPlanBuilder::from(left)
2876            .join(
2877                right,
2878                JoinType::Inner,
2879                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2880                Some(filter),
2881            )?
2882            .build()?;
2883
2884        // not part of the test, just good to know:
2885        assert_snapshot!(plan,
2886        @r"
2887        Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2888          Projection: test.a, test.b, test.c
2889            TableScan: test
2890          Projection: test2.a, test2.b, test2.c
2891            TableScan: test2
2892        ",
2893        );
2894        assert_optimized_plan_equal!(
2895            plan,
2896            @r"
2897        Inner Join: test.a = test2.a
2898          Projection: test.a, test.b, test.c
2899            TableScan: test, full_filters=[test.b > UInt32(1)]
2900          Projection: test2.a, test2.b, test2.c
2901            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2902        "
2903        )
2904    }
2905
2906    /// predicate on join key in filter expression should be pushed down to both inputs
2907    #[test]
2908    fn join_filter_on_common() -> Result<()> {
2909        let table_scan = test_table_scan()?;
2910        let left = LogicalPlanBuilder::from(table_scan)
2911            .project(vec![col("a")])?
2912            .build()?;
2913        let right_table_scan = test_table_scan_with_name("test2")?;
2914        let right = LogicalPlanBuilder::from(right_table_scan)
2915            .project(vec![col("b")])?
2916            .build()?;
2917        let filter = col("test.a").gt(lit(1u32));
2918        let plan = LogicalPlanBuilder::from(left)
2919            .join(
2920                right,
2921                JoinType::Inner,
2922                (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2923                Some(filter),
2924            )?
2925            .build()?;
2926
2927        // not part of the test, just good to know:
2928        assert_snapshot!(plan,
2929        @r"
2930        Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2931          Projection: test.a
2932            TableScan: test
2933          Projection: test2.b
2934            TableScan: test2
2935        ",
2936        );
2937        assert_optimized_plan_equal!(
2938            plan,
2939            @r"
2940        Inner Join: test.a = test2.b
2941          Projection: test.a
2942            TableScan: test, full_filters=[test.a > UInt32(1)]
2943          Projection: test2.b
2944            TableScan: test2, full_filters=[test2.b > UInt32(1)]
2945        "
2946        )
2947    }
2948
2949    /// single table predicate parts of ON condition should be pushed to right input
2950    #[test]
2951    fn left_join_on_with_filter() -> Result<()> {
2952        let table_scan = test_table_scan()?;
2953        let left = LogicalPlanBuilder::from(table_scan)
2954            .project(vec![col("a"), col("b"), col("c")])?
2955            .build()?;
2956        let right_table_scan = test_table_scan_with_name("test2")?;
2957        let right = LogicalPlanBuilder::from(right_table_scan)
2958            .project(vec![col("a"), col("b"), col("c")])?
2959            .build()?;
2960        let filter = col("test.a")
2961            .gt(lit(1u32))
2962            .and(col("test.b").lt(col("test2.b")))
2963            .and(col("test2.c").gt(lit(4u32)));
2964        let plan = LogicalPlanBuilder::from(left)
2965            .join(
2966                right,
2967                JoinType::Left,
2968                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2969                Some(filter),
2970            )?
2971            .build()?;
2972
2973        // not part of the test, just good to know:
2974        assert_snapshot!(plan,
2975        @r"
2976        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2977          Projection: test.a, test.b, test.c
2978            TableScan: test
2979          Projection: test2.a, test2.b, test2.c
2980            TableScan: test2
2981        ",
2982        );
2983        assert_optimized_plan_equal!(
2984            plan,
2985            @r"
2986        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2987          Projection: test.a, test.b, test.c
2988            TableScan: test
2989          Projection: test2.a, test2.b, test2.c
2990            TableScan: test2, full_filters=[test2.a > UInt32(1), test2.c > UInt32(4)]
2991        "
2992        )
2993    }
2994
2995    /// single table predicate parts of ON condition should be pushed to left input
2996    #[test]
2997    fn right_join_on_with_filter() -> Result<()> {
2998        let table_scan = test_table_scan()?;
2999        let left = LogicalPlanBuilder::from(table_scan)
3000            .project(vec![col("a"), col("b"), col("c")])?
3001            .build()?;
3002        let right_table_scan = test_table_scan_with_name("test2")?;
3003        let right = LogicalPlanBuilder::from(right_table_scan)
3004            .project(vec![col("a"), col("b"), col("c")])?
3005            .build()?;
3006        let filter = col("test.a")
3007            .gt(lit(1u32))
3008            .and(col("test.b").lt(col("test2.b")))
3009            .and(col("test2.c").gt(lit(4u32)));
3010        let plan = LogicalPlanBuilder::from(left)
3011            .join(
3012                right,
3013                JoinType::Right,
3014                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3015                Some(filter),
3016            )?
3017            .build()?;
3018
3019        // not part of the test, just good to know:
3020        assert_snapshot!(plan,
3021        @r"
3022        Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3023          Projection: test.a, test.b, test.c
3024            TableScan: test
3025          Projection: test2.a, test2.b, test2.c
3026            TableScan: test2
3027        ",
3028        );
3029        assert_optimized_plan_equal!(
3030            plan,
3031            @r"
3032        Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
3033          Projection: test.a, test.b, test.c
3034            TableScan: test, full_filters=[test.a > UInt32(1)]
3035          Projection: test2.a, test2.b, test2.c
3036            TableScan: test2
3037        "
3038        )
3039    }
3040
3041    /// single table predicate parts of ON condition should not be pushed
3042    #[test]
3043    fn full_join_on_with_filter() -> Result<()> {
3044        let table_scan = test_table_scan()?;
3045        let left = LogicalPlanBuilder::from(table_scan)
3046            .project(vec![col("a"), col("b"), col("c")])?
3047            .build()?;
3048        let right_table_scan = test_table_scan_with_name("test2")?;
3049        let right = LogicalPlanBuilder::from(right_table_scan)
3050            .project(vec![col("a"), col("b"), col("c")])?
3051            .build()?;
3052        let filter = col("test.a")
3053            .gt(lit(1u32))
3054            .and(col("test.b").lt(col("test2.b")))
3055            .and(col("test2.c").gt(lit(4u32)));
3056        let plan = LogicalPlanBuilder::from(left)
3057            .join(
3058                right,
3059                JoinType::Full,
3060                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3061                Some(filter),
3062            )?
3063            .build()?;
3064
3065        // not part of the test, just good to know:
3066        assert_snapshot!(plan,
3067        @r"
3068        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3069          Projection: test.a, test.b, test.c
3070            TableScan: test
3071          Projection: test2.a, test2.b, test2.c
3072            TableScan: test2
3073        ",
3074        );
3075        assert_optimized_plan_equal!(
3076            plan,
3077            @r"
3078        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3079          Projection: test.a, test.b, test.c
3080            TableScan: test
3081          Projection: test2.a, test2.b, test2.c
3082            TableScan: test2
3083        "
3084        )
3085    }
3086
3087    struct PushDownProvider {
3088        pub filter_support: TableProviderFilterPushDown,
3089    }
3090
3091    #[async_trait]
3092    impl TableSource for PushDownProvider {
3093        fn schema(&self) -> SchemaRef {
3094            Arc::new(Schema::new(vec![
3095                Field::new("a", DataType::Int32, true),
3096                Field::new("b", DataType::Int32, true),
3097            ]))
3098        }
3099
3100        fn table_type(&self) -> TableType {
3101            TableType::Base
3102        }
3103
3104        fn supports_filters_pushdown(
3105            &self,
3106            filters: &[&Expr],
3107        ) -> Result<Vec<TableProviderFilterPushDown>> {
3108            Ok((0..filters.len())
3109                .map(|_| self.filter_support.clone())
3110                .collect())
3111        }
3112
3113        fn as_any(&self) -> &dyn Any {
3114            self
3115        }
3116    }
3117
3118    fn table_scan_with_pushdown_provider_builder(
3119        filter_support: TableProviderFilterPushDown,
3120        filters: Vec<Expr>,
3121        projection: Option<Vec<usize>>,
3122    ) -> Result<LogicalPlanBuilder> {
3123        let test_provider = PushDownProvider { filter_support };
3124
3125        let table_scan = LogicalPlan::TableScan(TableScan {
3126            table_name: "test".into(),
3127            filters,
3128            projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3129            projection,
3130            source: Arc::new(test_provider),
3131            fetch: None,
3132        });
3133
3134        Ok(LogicalPlanBuilder::from(table_scan))
3135    }
3136
3137    fn table_scan_with_pushdown_provider(
3138        filter_support: TableProviderFilterPushDown,
3139    ) -> Result<LogicalPlan> {
3140        table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3141            .filter(col("a").eq(lit(1i64)))?
3142            .build()
3143    }
3144
3145    #[test]
3146    fn filter_with_table_provider_exact() -> Result<()> {
3147        let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3148
3149        assert_optimized_plan_equal!(
3150            plan,
3151            @"TableScan: test, full_filters=[a = Int64(1)]"
3152        )
3153    }
3154
3155    #[test]
3156    fn filter_with_table_provider_inexact() -> Result<()> {
3157        let plan =
3158            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3159
3160        assert_optimized_plan_equal!(
3161            plan,
3162            @r"
3163        Filter: a = Int64(1)
3164          TableScan: test, partial_filters=[a = Int64(1)]
3165        "
3166        )
3167    }
3168
3169    #[test]
3170    fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3171        let plan =
3172            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3173
3174        let optimized_plan = PushDownFilter::new()
3175            .rewrite(plan, &OptimizerContext::new())
3176            .expect("failed to optimize plan")
3177            .data;
3178
3179        // Optimizing the same plan multiple times should produce the same plan
3180        // each time.
3181        assert_optimized_plan_equal!(
3182            optimized_plan,
3183            @r"
3184        Filter: a = Int64(1)
3185          TableScan: test, partial_filters=[a = Int64(1)]
3186        "
3187        )
3188    }
3189
3190    #[test]
3191    fn filter_with_table_provider_unsupported() -> Result<()> {
3192        let plan =
3193            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3194
3195        assert_optimized_plan_equal!(
3196            plan,
3197            @r"
3198        Filter: a = Int64(1)
3199          TableScan: test
3200        "
3201        )
3202    }
3203
3204    #[test]
3205    fn multi_combined_filter() -> Result<()> {
3206        let plan = table_scan_with_pushdown_provider_builder(
3207            TableProviderFilterPushDown::Inexact,
3208            vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3209            Some(vec![0]),
3210        )?
3211        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3212        .project(vec![col("a"), col("b")])?
3213        .build()?;
3214
3215        assert_optimized_plan_equal!(
3216            plan,
3217            @r"
3218        Projection: a, b
3219          Filter: a = Int64(10) AND b > Int64(11)
3220            TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3221        "
3222        )
3223    }
3224
3225    #[test]
3226    fn multi_combined_filter_exact() -> Result<()> {
3227        let plan = table_scan_with_pushdown_provider_builder(
3228            TableProviderFilterPushDown::Exact,
3229            vec![],
3230            Some(vec![0]),
3231        )?
3232        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3233        .project(vec![col("a"), col("b")])?
3234        .build()?;
3235
3236        assert_optimized_plan_equal!(
3237            plan,
3238            @r"
3239        Projection: a, b
3240          TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3241        "
3242        )
3243    }
3244
3245    #[test]
3246    fn test_filter_with_alias() -> Result<()> {
3247        // in table scan the true col name is 'test.a',
3248        // but we rename it as 'b', and use col 'b' in filter
3249        // we need rewrite filter col before push down.
3250        let table_scan = test_table_scan()?;
3251        let plan = LogicalPlanBuilder::from(table_scan)
3252            .project(vec![col("a").alias("b"), col("c")])?
3253            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3254            .build()?;
3255
3256        // filter on col b
3257        assert_snapshot!(plan,
3258        @r"
3259        Filter: b > Int64(10) AND test.c > Int64(10)
3260          Projection: test.a AS b, test.c
3261            TableScan: test
3262        ",
3263        );
3264        // rewrite filter col b to test.a
3265        assert_optimized_plan_equal!(
3266            plan,
3267            @r"
3268        Projection: test.a AS b, test.c
3269          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3270        "
3271        )
3272    }
3273
3274    #[test]
3275    fn test_filter_with_alias_2() -> Result<()> {
3276        // in table scan the true col name is 'test.a',
3277        // but we rename it as 'b', and use col 'b' in filter
3278        // we need rewrite filter col before push down.
3279        let table_scan = test_table_scan()?;
3280        let plan = LogicalPlanBuilder::from(table_scan)
3281            .project(vec![col("a").alias("b"), col("c")])?
3282            .project(vec![col("b"), col("c")])?
3283            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3284            .build()?;
3285
3286        // filter on col b
3287        assert_snapshot!(plan,
3288        @r"
3289        Filter: b > Int64(10) AND test.c > Int64(10)
3290          Projection: b, test.c
3291            Projection: test.a AS b, test.c
3292              TableScan: test
3293        ",
3294        );
3295        // rewrite filter col b to test.a
3296        assert_optimized_plan_equal!(
3297            plan,
3298            @r"
3299        Projection: b, test.c
3300          Projection: test.a AS b, test.c
3301            TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3302        "
3303        )
3304    }
3305
3306    #[test]
3307    fn test_filter_with_multi_alias() -> Result<()> {
3308        let table_scan = test_table_scan()?;
3309        let plan = LogicalPlanBuilder::from(table_scan)
3310            .project(vec![col("a").alias("b"), col("c").alias("d")])?
3311            .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3312            .build()?;
3313
3314        // filter on col b and d
3315        assert_snapshot!(plan,
3316        @r"
3317        Filter: b > Int64(10) AND d > Int64(10)
3318          Projection: test.a AS b, test.c AS d
3319            TableScan: test
3320        ",
3321        );
3322        // rewrite filter col b to test.a, col d to test.c
3323        assert_optimized_plan_equal!(
3324            plan,
3325            @r"
3326        Projection: test.a AS b, test.c AS d
3327          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3328        "
3329        )
3330    }
3331
3332    /// predicate on join key in filter expression should be pushed down to both inputs
3333    #[test]
3334    fn join_filter_with_alias() -> Result<()> {
3335        let table_scan = test_table_scan()?;
3336        let left = LogicalPlanBuilder::from(table_scan)
3337            .project(vec![col("a").alias("c")])?
3338            .build()?;
3339        let right_table_scan = test_table_scan_with_name("test2")?;
3340        let right = LogicalPlanBuilder::from(right_table_scan)
3341            .project(vec![col("b").alias("d")])?
3342            .build()?;
3343        let filter = col("c").gt(lit(1u32));
3344        let plan = LogicalPlanBuilder::from(left)
3345            .join(
3346                right,
3347                JoinType::Inner,
3348                (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3349                Some(filter),
3350            )?
3351            .build()?;
3352
3353        assert_snapshot!(plan,
3354        @r"
3355        Inner Join: c = d Filter: c > UInt32(1)
3356          Projection: test.a AS c
3357            TableScan: test
3358          Projection: test2.b AS d
3359            TableScan: test2
3360        ",
3361        );
3362        // Change filter on col `c`, 'd' to `test.a`, 'test.b'
3363        assert_optimized_plan_equal!(
3364            plan,
3365            @r"
3366        Inner Join: c = d
3367          Projection: test.a AS c
3368            TableScan: test, full_filters=[test.a > UInt32(1)]
3369          Projection: test2.b AS d
3370            TableScan: test2, full_filters=[test2.b > UInt32(1)]
3371        "
3372        )
3373    }
3374
3375    #[test]
3376    fn test_in_filter_with_alias() -> Result<()> {
3377        // in table scan the true col name is 'test.a',
3378        // but we rename it as 'b', and use col 'b' in filter
3379        // we need rewrite filter col before push down.
3380        let table_scan = test_table_scan()?;
3381        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3382        let plan = LogicalPlanBuilder::from(table_scan)
3383            .project(vec![col("a").alias("b"), col("c")])?
3384            .filter(in_list(col("b"), filter_value, false))?
3385            .build()?;
3386
3387        // filter on col b
3388        assert_snapshot!(plan,
3389        @r"
3390        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3391          Projection: test.a AS b, test.c
3392            TableScan: test
3393        ",
3394        );
3395        // rewrite filter col b to test.a
3396        assert_optimized_plan_equal!(
3397            plan,
3398            @r"
3399        Projection: test.a AS b, test.c
3400          TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3401        "
3402        )
3403    }
3404
3405    #[test]
3406    fn test_in_filter_with_alias_2() -> Result<()> {
3407        // in table scan the true col name is 'test.a',
3408        // but we rename it as 'b', and use col 'b' in filter
3409        // we need rewrite filter col before push down.
3410        let table_scan = test_table_scan()?;
3411        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3412        let plan = LogicalPlanBuilder::from(table_scan)
3413            .project(vec![col("a").alias("b"), col("c")])?
3414            .project(vec![col("b"), col("c")])?
3415            .filter(in_list(col("b"), filter_value, false))?
3416            .build()?;
3417
3418        // filter on col b
3419        assert_snapshot!(plan,
3420        @r"
3421        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3422          Projection: b, test.c
3423            Projection: test.a AS b, test.c
3424              TableScan: test
3425        ",
3426        );
3427        // rewrite filter col b to test.a
3428        assert_optimized_plan_equal!(
3429            plan,
3430            @r"
3431        Projection: b, test.c
3432          Projection: test.a AS b, test.c
3433            TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3434        "
3435        )
3436    }
3437
3438    #[test]
3439    fn test_in_subquery_with_alias() -> Result<()> {
3440        // in table scan the true col name is 'test.a',
3441        // but we rename it as 'b', and use col 'b' in subquery filter
3442        let table_scan = test_table_scan()?;
3443        let table_scan_sq = test_table_scan_with_name("sq")?;
3444        let subplan = Arc::new(
3445            LogicalPlanBuilder::from(table_scan_sq)
3446                .project(vec![col("c")])?
3447                .build()?,
3448        );
3449        let plan = LogicalPlanBuilder::from(table_scan)
3450            .project(vec![col("a").alias("b"), col("c")])?
3451            .filter(in_subquery(col("b"), subplan))?
3452            .build()?;
3453
3454        // filter on col b in subquery
3455        assert_snapshot!(plan,
3456        @r"
3457        Filter: b IN (<subquery>)
3458          Subquery:
3459            Projection: sq.c
3460              TableScan: sq
3461          Projection: test.a AS b, test.c
3462            TableScan: test
3463        ",
3464        );
3465        // rewrite filter col b to test.a
3466        assert_optimized_plan_equal!(
3467            plan,
3468            @r"
3469        Projection: test.a AS b, test.c
3470          TableScan: test, full_filters=[test.a IN (<subquery>)]
3471            Subquery:
3472              Projection: sq.c
3473                TableScan: sq
3474        "
3475        )
3476    }
3477
3478    #[test]
3479    fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3480        // SELECT a FROM (SELECT 1 AS a) b WHERE b.a = 1
3481        let plan = LogicalPlanBuilder::empty(true)
3482            .project(vec![lit(0i64).alias("a")])?
3483            .alias("b")?
3484            .project(vec![col("b.a")])?
3485            .alias("b")?
3486            .filter(col("b.a").eq(lit(1i64)))?
3487            .project(vec![col("b.a")])?
3488            .build()?;
3489
3490        assert_snapshot!(plan,
3491        @r"
3492        Projection: b.a
3493          Filter: b.a = Int64(1)
3494            SubqueryAlias: b
3495              Projection: b.a
3496                SubqueryAlias: b
3497                  Projection: Int64(0) AS a
3498                    EmptyRelation: rows=1
3499        ",
3500        );
3501        // Ensure that the predicate without any columns (0 = 1) is
3502        // still there.
3503        assert_optimized_plan_equal!(
3504            plan,
3505            @r"
3506        Projection: b.a
3507          SubqueryAlias: b
3508            Projection: b.a
3509              SubqueryAlias: b
3510                Projection: Int64(0) AS a
3511                  Filter: Int64(0) = Int64(1)
3512                    EmptyRelation: rows=1
3513        "
3514        )
3515    }
3516
3517    #[test]
3518    fn test_crossjoin_with_or_clause() -> Result<()> {
3519        // select * from test,test1 where (test.a = test1.a and test.b > 1) or (test.b = test1.b and test.c < 10);
3520        let table_scan = test_table_scan()?;
3521        let left = LogicalPlanBuilder::from(table_scan)
3522            .project(vec![col("a"), col("b"), col("c")])?
3523            .build()?;
3524        let right_table_scan = test_table_scan_with_name("test1")?;
3525        let right = LogicalPlanBuilder::from(right_table_scan)
3526            .project(vec![col("a").alias("d"), col("a").alias("e")])?
3527            .build()?;
3528        let filter = or(
3529            and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3530            and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3531        );
3532        let plan = LogicalPlanBuilder::from(left)
3533            .cross_join(right)?
3534            .filter(filter)?
3535            .build()?;
3536
3537        assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3538        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3539          Projection: test.a, test.b, test.c
3540            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3541          Projection: test1.a AS d, test1.a AS e
3542            TableScan: test1
3543        ")?;
3544
3545        // Originally global state which can help to avoid duplicate Filters been generated and pushed down.
3546        // Now the global state is removed. Need to double confirm that avoid duplicate Filters.
3547        let optimized_plan = PushDownFilter::new()
3548            .rewrite(plan, &OptimizerContext::new())
3549            .expect("failed to optimize plan")
3550            .data;
3551        assert_optimized_plan_equal!(
3552            optimized_plan,
3553            @r"
3554        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3555          Projection: test.a, test.b, test.c
3556            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3557          Projection: test1.a AS d, test1.a AS e
3558            TableScan: test1
3559        "
3560        )
3561    }
3562
3563    #[test]
3564    fn left_semi_join() -> Result<()> {
3565        let left = test_table_scan_with_name("test1")?;
3566        let right_table_scan = test_table_scan_with_name("test2")?;
3567        let right = LogicalPlanBuilder::from(right_table_scan)
3568            .project(vec![col("a"), col("b")])?
3569            .build()?;
3570        let plan = LogicalPlanBuilder::from(left)
3571            .join(
3572                right,
3573                JoinType::LeftSemi,
3574                (
3575                    vec![Column::from_qualified_name("test1.a")],
3576                    vec![Column::from_qualified_name("test2.a")],
3577                ),
3578                None,
3579            )?
3580            .filter(col("test2.a").lt_eq(lit(1i64)))?
3581            .build()?;
3582
3583        // not part of the test, just good to know:
3584        assert_snapshot!(plan,
3585        @r"
3586        Filter: test2.a <= Int64(1)
3587          LeftSemi Join: test1.a = test2.a
3588            TableScan: test1
3589            Projection: test2.a, test2.b
3590              TableScan: test2
3591        ",
3592        );
3593        // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side.
3594        assert_optimized_plan_equal!(
3595            plan,
3596            @r"
3597        Filter: test2.a <= Int64(1)
3598          LeftSemi Join: test1.a = test2.a
3599            TableScan: test1, full_filters=[test1.a <= Int64(1)]
3600            Projection: test2.a, test2.b
3601              TableScan: test2
3602        "
3603        )
3604    }
3605
3606    #[test]
3607    fn left_semi_join_with_filters() -> Result<()> {
3608        let left = test_table_scan_with_name("test1")?;
3609        let right_table_scan = test_table_scan_with_name("test2")?;
3610        let right = LogicalPlanBuilder::from(right_table_scan)
3611            .project(vec![col("a"), col("b")])?
3612            .build()?;
3613        let plan = LogicalPlanBuilder::from(left)
3614            .join(
3615                right,
3616                JoinType::LeftSemi,
3617                (
3618                    vec![Column::from_qualified_name("test1.a")],
3619                    vec![Column::from_qualified_name("test2.a")],
3620                ),
3621                Some(
3622                    col("test1.b")
3623                        .gt(lit(1u32))
3624                        .and(col("test2.b").gt(lit(2u32))),
3625                ),
3626            )?
3627            .build()?;
3628
3629        // not part of the test, just good to know:
3630        assert_snapshot!(plan,
3631        @r"
3632        LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3633          TableScan: test1
3634          Projection: test2.a, test2.b
3635            TableScan: test2
3636        ",
3637        );
3638        // Both side will be pushed down.
3639        assert_optimized_plan_equal!(
3640            plan,
3641            @r"
3642        LeftSemi Join: test1.a = test2.a
3643          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3644          Projection: test2.a, test2.b
3645            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3646        "
3647        )
3648    }
3649
3650    #[test]
3651    fn right_semi_join() -> Result<()> {
3652        let left = test_table_scan_with_name("test1")?;
3653        let right_table_scan = test_table_scan_with_name("test2")?;
3654        let right = LogicalPlanBuilder::from(right_table_scan)
3655            .project(vec![col("a"), col("b")])?
3656            .build()?;
3657        let plan = LogicalPlanBuilder::from(left)
3658            .join(
3659                right,
3660                JoinType::RightSemi,
3661                (
3662                    vec![Column::from_qualified_name("test1.a")],
3663                    vec![Column::from_qualified_name("test2.a")],
3664                ),
3665                None,
3666            )?
3667            .filter(col("test1.a").lt_eq(lit(1i64)))?
3668            .build()?;
3669
3670        // not part of the test, just good to know:
3671        assert_snapshot!(plan,
3672        @r"
3673        Filter: test1.a <= Int64(1)
3674          RightSemi Join: test1.a = test2.a
3675            TableScan: test1
3676            Projection: test2.a, test2.b
3677              TableScan: test2
3678        ",
3679        );
3680        // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side.
3681        assert_optimized_plan_equal!(
3682            plan,
3683            @r"
3684        Filter: test1.a <= Int64(1)
3685          RightSemi Join: test1.a = test2.a
3686            TableScan: test1
3687            Projection: test2.a, test2.b
3688              TableScan: test2, full_filters=[test2.a <= Int64(1)]
3689        "
3690        )
3691    }
3692
3693    #[test]
3694    fn right_semi_join_with_filters() -> Result<()> {
3695        let left = test_table_scan_with_name("test1")?;
3696        let right_table_scan = test_table_scan_with_name("test2")?;
3697        let right = LogicalPlanBuilder::from(right_table_scan)
3698            .project(vec![col("a"), col("b")])?
3699            .build()?;
3700        let plan = LogicalPlanBuilder::from(left)
3701            .join(
3702                right,
3703                JoinType::RightSemi,
3704                (
3705                    vec![Column::from_qualified_name("test1.a")],
3706                    vec![Column::from_qualified_name("test2.a")],
3707                ),
3708                Some(
3709                    col("test1.b")
3710                        .gt(lit(1u32))
3711                        .and(col("test2.b").gt(lit(2u32))),
3712                ),
3713            )?
3714            .build()?;
3715
3716        // not part of the test, just good to know:
3717        assert_snapshot!(plan,
3718        @r"
3719        RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3720          TableScan: test1
3721          Projection: test2.a, test2.b
3722            TableScan: test2
3723        ",
3724        );
3725        // Both side will be pushed down.
3726        assert_optimized_plan_equal!(
3727            plan,
3728            @r"
3729        RightSemi Join: test1.a = test2.a
3730          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3731          Projection: test2.a, test2.b
3732            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3733        "
3734        )
3735    }
3736
3737    #[test]
3738    fn left_anti_join() -> Result<()> {
3739        let table_scan = test_table_scan_with_name("test1")?;
3740        let left = LogicalPlanBuilder::from(table_scan)
3741            .project(vec![col("a"), col("b")])?
3742            .build()?;
3743        let right_table_scan = test_table_scan_with_name("test2")?;
3744        let right = LogicalPlanBuilder::from(right_table_scan)
3745            .project(vec![col("a"), col("b")])?
3746            .build()?;
3747        let plan = LogicalPlanBuilder::from(left)
3748            .join(
3749                right,
3750                JoinType::LeftAnti,
3751                (
3752                    vec![Column::from_qualified_name("test1.a")],
3753                    vec![Column::from_qualified_name("test2.a")],
3754                ),
3755                None,
3756            )?
3757            .filter(col("test2.a").gt(lit(2u32)))?
3758            .build()?;
3759
3760        // not part of the test, just good to know:
3761        assert_snapshot!(plan,
3762        @r"
3763        Filter: test2.a > UInt32(2)
3764          LeftAnti Join: test1.a = test2.a
3765            Projection: test1.a, test1.b
3766              TableScan: test1
3767            Projection: test2.a, test2.b
3768              TableScan: test2
3769        ",
3770        );
3771        // For left anti, filter of the right side filter can be pushed down.
3772        assert_optimized_plan_equal!(
3773            plan,
3774            @r"
3775        Filter: test2.a > UInt32(2)
3776          LeftAnti Join: test1.a = test2.a
3777            Projection: test1.a, test1.b
3778              TableScan: test1, full_filters=[test1.a > UInt32(2)]
3779            Projection: test2.a, test2.b
3780              TableScan: test2
3781        "
3782        )
3783    }
3784
3785    #[test]
3786    fn left_anti_join_with_filters() -> Result<()> {
3787        let table_scan = test_table_scan_with_name("test1")?;
3788        let left = LogicalPlanBuilder::from(table_scan)
3789            .project(vec![col("a"), col("b")])?
3790            .build()?;
3791        let right_table_scan = test_table_scan_with_name("test2")?;
3792        let right = LogicalPlanBuilder::from(right_table_scan)
3793            .project(vec![col("a"), col("b")])?
3794            .build()?;
3795        let plan = LogicalPlanBuilder::from(left)
3796            .join(
3797                right,
3798                JoinType::LeftAnti,
3799                (
3800                    vec![Column::from_qualified_name("test1.a")],
3801                    vec![Column::from_qualified_name("test2.a")],
3802                ),
3803                Some(
3804                    col("test1.b")
3805                        .gt(lit(1u32))
3806                        .and(col("test2.b").gt(lit(2u32))),
3807                ),
3808            )?
3809            .build()?;
3810
3811        // not part of the test, just good to know:
3812        assert_snapshot!(plan,
3813        @r"
3814        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3815          Projection: test1.a, test1.b
3816            TableScan: test1
3817          Projection: test2.a, test2.b
3818            TableScan: test2
3819        ",
3820        );
3821        // For left anti, filter of the right side filter can be pushed down.
3822        assert_optimized_plan_equal!(
3823            plan,
3824            @r"
3825        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3826          Projection: test1.a, test1.b
3827            TableScan: test1
3828          Projection: test2.a, test2.b
3829            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3830        "
3831        )
3832    }
3833
3834    #[test]
3835    fn right_anti_join() -> Result<()> {
3836        let table_scan = test_table_scan_with_name("test1")?;
3837        let left = LogicalPlanBuilder::from(table_scan)
3838            .project(vec![col("a"), col("b")])?
3839            .build()?;
3840        let right_table_scan = test_table_scan_with_name("test2")?;
3841        let right = LogicalPlanBuilder::from(right_table_scan)
3842            .project(vec![col("a"), col("b")])?
3843            .build()?;
3844        let plan = LogicalPlanBuilder::from(left)
3845            .join(
3846                right,
3847                JoinType::RightAnti,
3848                (
3849                    vec![Column::from_qualified_name("test1.a")],
3850                    vec![Column::from_qualified_name("test2.a")],
3851                ),
3852                None,
3853            )?
3854            .filter(col("test1.a").gt(lit(2u32)))?
3855            .build()?;
3856
3857        // not part of the test, just good to know:
3858        assert_snapshot!(plan,
3859        @r"
3860        Filter: test1.a > UInt32(2)
3861          RightAnti Join: test1.a = test2.a
3862            Projection: test1.a, test1.b
3863              TableScan: test1
3864            Projection: test2.a, test2.b
3865              TableScan: test2
3866        ",
3867        );
3868        // For right anti, filter of the left side can be pushed down.
3869        assert_optimized_plan_equal!(
3870            plan,
3871            @r"
3872        Filter: test1.a > UInt32(2)
3873          RightAnti Join: test1.a = test2.a
3874            Projection: test1.a, test1.b
3875              TableScan: test1
3876            Projection: test2.a, test2.b
3877              TableScan: test2, full_filters=[test2.a > UInt32(2)]
3878        "
3879        )
3880    }
3881
3882    #[test]
3883    fn right_anti_join_with_filters() -> Result<()> {
3884        let table_scan = test_table_scan_with_name("test1")?;
3885        let left = LogicalPlanBuilder::from(table_scan)
3886            .project(vec![col("a"), col("b")])?
3887            .build()?;
3888        let right_table_scan = test_table_scan_with_name("test2")?;
3889        let right = LogicalPlanBuilder::from(right_table_scan)
3890            .project(vec![col("a"), col("b")])?
3891            .build()?;
3892        let plan = LogicalPlanBuilder::from(left)
3893            .join(
3894                right,
3895                JoinType::RightAnti,
3896                (
3897                    vec![Column::from_qualified_name("test1.a")],
3898                    vec![Column::from_qualified_name("test2.a")],
3899                ),
3900                Some(
3901                    col("test1.b")
3902                        .gt(lit(1u32))
3903                        .and(col("test2.b").gt(lit(2u32))),
3904                ),
3905            )?
3906            .build()?;
3907
3908        // not part of the test, just good to know:
3909        assert_snapshot!(plan,
3910        @r"
3911        RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3912          Projection: test1.a, test1.b
3913            TableScan: test1
3914          Projection: test2.a, test2.b
3915            TableScan: test2
3916        ",
3917        );
3918        // For right anti, filter of the left side can be pushed down.
3919        assert_optimized_plan_equal!(
3920            plan,
3921            @r"
3922        RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3923          Projection: test1.a, test1.b
3924            TableScan: test1, full_filters=[test1.b > UInt32(1)]
3925          Projection: test2.a, test2.b
3926            TableScan: test2
3927        "
3928        )
3929    }
3930
3931    #[derive(Debug, PartialEq, Eq, Hash)]
3932    struct TestScalarUDF {
3933        signature: Signature,
3934    }
3935
3936    impl ScalarUDFImpl for TestScalarUDF {
3937        fn as_any(&self) -> &dyn Any {
3938            self
3939        }
3940        fn name(&self) -> &str {
3941            "TestScalarUDF"
3942        }
3943
3944        fn signature(&self) -> &Signature {
3945            &self.signature
3946        }
3947
3948        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3949            Ok(DataType::Int32)
3950        }
3951
3952        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3953            Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3954        }
3955    }
3956
3957    #[test]
3958    fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3959        // SELECT t.a, t.r FROM (SELECT a, sum(b),  TestScalarUDF()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5;
3960        let table_scan = test_table_scan_with_name("test1")?;
3961        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3962            signature: Signature::exact(vec![], Volatility::Volatile),
3963        });
3964        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3965
3966        let plan = LogicalPlanBuilder::from(table_scan)
3967            .aggregate(vec![col("a")], vec![sum(col("b"))])?
3968            .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3969            .alias("t")?
3970            .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3971            .project(vec![col("t.a"), col("t.r")])?
3972            .build()?;
3973
3974        assert_snapshot!(plan,
3975        @r"
3976        Projection: t.a, t.r
3977          Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3978            SubqueryAlias: t
3979              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3980                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3981                  TableScan: test1
3982        ",
3983        );
3984        assert_optimized_plan_equal!(
3985            plan,
3986            @r"
3987        Projection: t.a, t.r
3988          SubqueryAlias: t
3989            Filter: r > Float64(0.5)
3990              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3991                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3992                  TableScan: test1, full_filters=[test1.a > Int32(5)]
3993        "
3994        )
3995    }
3996
3997    #[test]
3998    fn test_push_down_volatile_function_in_join() -> Result<()> {
3999        // SELECT t.a, t.r FROM (SELECT test1.a AS a, TestScalarUDF() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5;
4000        let table_scan = test_table_scan_with_name("test1")?;
4001        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4002            signature: Signature::exact(vec![], Volatility::Volatile),
4003        });
4004        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4005        let left = LogicalPlanBuilder::from(table_scan).build()?;
4006        let right_table_scan = test_table_scan_with_name("test2")?;
4007        let right = LogicalPlanBuilder::from(right_table_scan).build()?;
4008        let plan = LogicalPlanBuilder::from(left)
4009            .join(
4010                right,
4011                JoinType::Inner,
4012                (
4013                    vec![Column::from_qualified_name("test1.a")],
4014                    vec![Column::from_qualified_name("test2.a")],
4015                ),
4016                None,
4017            )?
4018            .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
4019            .alias("t")?
4020            .filter(col("t.r").gt(lit(0.8)))?
4021            .project(vec![col("t.a"), col("t.r")])?
4022            .build()?;
4023
4024        assert_snapshot!(plan,
4025        @r"
4026        Projection: t.a, t.r
4027          Filter: t.r > Float64(0.8)
4028            SubqueryAlias: t
4029              Projection: test1.a AS a, TestScalarUDF() AS r
4030                Inner Join: test1.a = test2.a
4031                  TableScan: test1
4032                  TableScan: test2
4033        ",
4034        );
4035        assert_optimized_plan_equal!(
4036            plan,
4037            @r"
4038        Projection: t.a, t.r
4039          SubqueryAlias: t
4040            Filter: r > Float64(0.8)
4041              Projection: test1.a AS a, TestScalarUDF() AS r
4042                Inner Join: test1.a = test2.a
4043                  TableScan: test1
4044                  TableScan: test2
4045        "
4046        )
4047    }
4048
4049    #[test]
4050    fn test_push_down_volatile_table_scan() -> Result<()> {
4051        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1;
4052        let table_scan = test_table_scan()?;
4053        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4054            signature: Signature::exact(vec![], Volatility::Volatile),
4055        });
4056        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4057        let plan = LogicalPlanBuilder::from(table_scan)
4058            .project(vec![col("a"), col("b")])?
4059            .filter(expr.gt(lit(0.1)))?
4060            .build()?;
4061
4062        assert_snapshot!(plan,
4063        @r"
4064        Filter: TestScalarUDF() > Float64(0.1)
4065          Projection: test.a, test.b
4066            TableScan: test
4067        ",
4068        );
4069        assert_optimized_plan_equal!(
4070            plan,
4071            @r"
4072        Projection: test.a, test.b
4073          Filter: TestScalarUDF() > Float64(0.1)
4074            TableScan: test
4075        "
4076        )
4077    }
4078
4079    #[test]
4080    fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4081        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4082        let table_scan = test_table_scan()?;
4083        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4084            signature: Signature::exact(vec![], Volatility::Volatile),
4085        });
4086        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4087        let plan = LogicalPlanBuilder::from(table_scan)
4088            .project(vec![col("a"), col("b")])?
4089            .filter(
4090                expr.gt(lit(0.1))
4091                    .and(col("t.a").gt(lit(5)))
4092                    .and(col("t.b").gt(lit(10))),
4093            )?
4094            .build()?;
4095
4096        assert_snapshot!(plan,
4097        @r"
4098        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4099          Projection: test.a, test.b
4100            TableScan: test
4101        ",
4102        );
4103        assert_optimized_plan_equal!(
4104            plan,
4105            @r"
4106        Projection: test.a, test.b
4107          Filter: TestScalarUDF() > Float64(0.1)
4108            TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4109        "
4110        )
4111    }
4112
4113    #[test]
4114    fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4115        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4116        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4117            signature: Signature::exact(vec![], Volatility::Volatile),
4118        });
4119        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4120        let plan = table_scan_with_pushdown_provider_builder(
4121            TableProviderFilterPushDown::Unsupported,
4122            vec![],
4123            None,
4124        )?
4125        .project(vec![col("a"), col("b")])?
4126        .filter(
4127            expr.gt(lit(0.1))
4128                .and(col("t.a").gt(lit(5)))
4129                .and(col("t.b").gt(lit(10))),
4130        )?
4131        .build()?;
4132
4133        assert_snapshot!(plan,
4134        @r"
4135        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4136          Projection: a, b
4137            TableScan: test
4138        ",
4139        );
4140        assert_optimized_plan_equal!(
4141            plan,
4142            @r"
4143        Projection: a, b
4144          Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4145            TableScan: test
4146        "
4147        )
4148    }
4149
4150    #[test]
4151    fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4152        // Define a custom user-defined logical node
4153        #[derive(Debug, Hash, Eq, PartialEq)]
4154        struct TestUserNode {
4155            schema: DFSchemaRef,
4156        }
4157
4158        impl PartialOrd for TestUserNode {
4159            fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4160                None
4161            }
4162        }
4163
4164        impl TestUserNode {
4165            fn new() -> Self {
4166                let schema = Arc::new(
4167                    DFSchema::new_with_metadata(
4168                        vec![(None, Field::new("a", DataType::Int64, false).into())],
4169                        Default::default(),
4170                    )
4171                    .unwrap(),
4172                );
4173
4174                Self { schema }
4175            }
4176        }
4177
4178        impl UserDefinedLogicalNodeCore for TestUserNode {
4179            fn name(&self) -> &str {
4180                "test_node"
4181            }
4182
4183            fn inputs(&self) -> Vec<&LogicalPlan> {
4184                vec![]
4185            }
4186
4187            fn schema(&self) -> &DFSchemaRef {
4188                &self.schema
4189            }
4190
4191            fn expressions(&self) -> Vec<Expr> {
4192                vec![]
4193            }
4194
4195            fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4196                write!(f, "TestUserNode")
4197            }
4198
4199            fn with_exprs_and_inputs(
4200                &self,
4201                exprs: Vec<Expr>,
4202                inputs: Vec<LogicalPlan>,
4203            ) -> Result<Self> {
4204                assert!(exprs.is_empty());
4205                assert!(inputs.is_empty());
4206                Ok(Self {
4207                    schema: Arc::clone(&self.schema),
4208                })
4209            }
4210        }
4211
4212        // Create a node and build a plan with a filter
4213        let node = LogicalPlan::Extension(Extension {
4214            node: Arc::new(TestUserNode::new()),
4215        });
4216
4217        let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4218
4219        // Check the original plan format (not part of the test assertions)
4220        assert_snapshot!(plan,
4221        @r"
4222        Filter: Boolean(false)
4223          TestUserNode
4224        ",
4225        );
4226        // Check that the filter is pushed down to the user-defined node
4227        assert_optimized_plan_equal!(
4228            plan,
4229            @r"
4230        Filter: Boolean(false)
4231          TestUserNode
4232        "
4233        )
4234    }
4235
4236    /// Test that filters are NOT pushed through MoveTowardsLeafNodes projections.
4237    /// These are cheap expressions (like get_field) where re-inlining into a filter
4238    /// has no benefit and causes optimizer instability — ExtractLeafExpressions will
4239    /// undo the push-down, creating an infinite loop that runs until the iteration
4240    /// limit is hit.
4241    #[test]
4242    fn filter_not_pushed_through_move_towards_leaves_projection() -> Result<()> {
4243        let table_scan = test_table_scan()?;
4244
4245        // Create a projection with a MoveTowardsLeafNodes expression
4246        let proj = LogicalPlanBuilder::from(table_scan)
4247            .project(vec![
4248                leaf_udf_expr(col("a")).alias("val"),
4249                col("b"),
4250                col("c"),
4251            ])?
4252            .build()?;
4253
4254        // Put a filter on the MoveTowardsLeafNodes column
4255        let plan = LogicalPlanBuilder::from(proj)
4256            .filter(col("val").gt(lit(150i64)))?
4257            .build()?;
4258
4259        // Filter should NOT be pushed through — val maps to a MoveTowardsLeafNodes expr
4260        assert_optimized_plan_equal!(
4261            plan,
4262            @r"
4263        Filter: val > Int64(150)
4264          Projection: leaf_udf(test.a) AS val, test.b, test.c
4265            TableScan: test
4266        "
4267        )
4268    }
4269
4270    /// Test mixed predicates: Column predicate pushed, MoveTowardsLeafNodes kept.
4271    #[test]
4272    fn filter_mixed_predicates_partial_push() -> Result<()> {
4273        let table_scan = test_table_scan()?;
4274
4275        // Create a projection with both MoveTowardsLeafNodes and Column expressions
4276        let proj = LogicalPlanBuilder::from(table_scan)
4277            .project(vec![
4278                leaf_udf_expr(col("a")).alias("val"),
4279                col("b"),
4280                col("c"),
4281            ])?
4282            .build()?;
4283
4284        // Filter with both: val > 150 (MoveTowardsLeafNodes) AND b > 5 (Column)
4285        let plan = LogicalPlanBuilder::from(proj)
4286            .filter(col("val").gt(lit(150i64)).and(col("b").gt(lit(5i64))))?
4287            .build()?;
4288
4289        // val > 150 should be kept above, b > 5 should be pushed through
4290        assert_optimized_plan_equal!(
4291            plan,
4292            @r"
4293        Filter: val > Int64(150)
4294          Projection: leaf_udf(test.a) AS val, test.b, test.c
4295            TableScan: test, full_filters=[test.b > Int64(5)]
4296        "
4297        )
4298    }
4299}