Skip to main content

datafusion_optimizer/
push_down_filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`PushDownFilter`] applies filters as early as possible
19
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26
27use datafusion_common::tree_node::{
28    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29};
30use datafusion_common::{
31    Column, DFSchema, Result, assert_eq_or_internal_err, assert_or_internal_err,
32    internal_err, plan_err, qualified_name,
33};
34use datafusion_expr::expr::WindowFunction;
35use datafusion_expr::expr_rewriter::replace_col;
36use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
37use datafusion_expr::utils::{
38    conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
39};
40use datafusion_expr::{
41    BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, and, or,
42};
43
44use crate::optimizer::ApplyOrder;
45use crate::simplify_expressions::simplify_predicates;
46use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
47use crate::{OptimizerConfig, OptimizerRule};
48use datafusion_expr::ExpressionPlacement;
49
50/// Optimizer rule for pushing (moving) filter expressions down in a plan so
51/// they are applied as early as possible.
52///
53/// # Introduction
54///
55/// The goal of this rule is to improve query performance by eliminating
56/// redundant work.
57///
58/// For example, given a plan that sorts all values where `a > 10`:
59///
60/// ```text
61///  Filter (a > 10)
62///    Sort (a, b)
63/// ```
64///
65/// A better plan is to  filter the data *before* the Sort, which sorts fewer
66/// rows and therefore does less work overall:
67///
68/// ```text
69///  Sort (a, b)
70///    Filter (a > 10)  <-- Filter is moved before the sort
71/// ```
72///
73/// However it is not always possible to push filters down. For example, given a
74/// plan that finds the top 3 values and then keeps only those that are greater
75/// than 10, if the filter is pushed below the limit it would produce a
76/// different result.
77///
78/// ```text
79///  Filter (a > 10)   <-- can not move this Filter before the limit
80///    Limit (fetch=3)
81///      Sort (a, b)
82/// ```
83///
84///
85/// More formally, a filter-commutative operation is an operation `op` that
86/// satisfies `filter(op(data)) = op(filter(data))`.
87///
88/// The filter-commutative property is plan and column-specific. A filter on `a`
89/// can be pushed through a `Aggregate(group_by = [a], agg=[sum(b))`. However, a
90/// filter on  `sum(b)` can not be pushed through the same aggregate.
91///
92/// # Handling Conjunctions
93///
94/// It is possible to only push down **part** of a filter expression if is
95/// connected with `AND`s (more formally if it is a "conjunction").
96///
97/// For example, given the following plan:
98///
99/// ```text
100/// Filter(a > 10 AND sum(b) < 5)
101///   Aggregate(group_by = [a], agg = [sum(b))
102/// ```
103///
104/// The `a > 10` is commutative with the `Aggregate` but  `sum(b) < 5` is not.
105/// Therefore it is possible to only push part of the expression, resulting in:
106///
107/// ```text
108/// Filter(sum(b) < 5)
109///   Aggregate(group_by = [a], agg = [sum(b))
110///     Filter(a > 10)
111/// ```
112///
113/// # Handling Column Aliases
114///
115/// This optimizer must sometimes handle re-writing filter expressions when they
116/// pushed, for example if there is a projection that aliases `a+1` to `"b"`:
117///
118/// ```text
119/// Filter (b > 10)
120///     Projection: [a+1 AS "b"]  <-- changes the name of `a+1` to `b`
121/// ```
122///
123/// To apply the filter prior to the `Projection`, all references to `b` must be
124/// rewritten to `a+1`:
125///
126/// ```text
127/// Projection: a AS "b"
128///     Filter: (a + 1 > 10)  <--- changed from b to a + 1
129/// ```
130/// # Implementation Notes
131///
132/// This implementation performs a single pass through the plan, "pushing" down
133/// filters. When it passes through a filter, it stores that filter, and when it
134/// reaches a plan node that does not commute with that filter, it adds the
135/// filter to that place. When it passes through a projection, it re-writes the
136/// filter's expression taking into account that projection.
137#[derive(Default, Debug)]
138pub struct PushDownFilter {}
139
140/// For a given JOIN type, determine whether each input of the join is preserved
141/// for post-join (`WHERE` clause) filters.
142///
143/// It is only correct to push filters below a join for preserved inputs.
144///
145/// # Return Value
146/// A tuple of booleans - (left_preserved, right_preserved).
147///
148/// # "Preserved" input definition
149///
150/// We say a join side is preserved if the join returns all or a subset of the rows from
151/// the relevant side, such that each row of the output table directly maps to a row of
152/// the preserved input table. If a table is not preserved, it can provide extra null rows.
153/// That is, there may be rows in the output table that don't directly map to a row in the
154/// input table.
155///
156/// For example:
157///   - In an inner join, both sides are preserved, because each row of the output
158///     maps directly to a row from each side.
159///
160///   - In a left join, the left side is preserved (we can push predicates) but
161///     the right is not, because there may be rows in the output that don't
162///     directly map to a row in the right input (due to nulls filling where there
163///     is no match on the right).
164pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
165    match join_type {
166        JoinType::Inner => (true, true),
167        JoinType::Left => (true, false),
168        JoinType::Right => (false, true),
169        JoinType::Full => (false, false),
170        // No columns from the right side of the join can be referenced in output
171        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
172        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
173        // No columns from the left side of the join can be referenced in output
174        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
175        JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
176    }
177}
178
179/// For a given JOIN type, determine whether each input of the join is preserved
180/// for the join condition (`ON` clause filters).
181///
182/// It is only correct to push filters below a join for preserved inputs.
183///
184/// # Return Value
185/// A tuple of booleans - (left_preserved, right_preserved).
186///
187/// See [`lr_is_preserved`] for a definition of "preserved".
188pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
189    match join_type {
190        JoinType::Inner => (true, true),
191        JoinType::Left => (false, true),
192        JoinType::Right => (true, false),
193        JoinType::Full => (false, false),
194        JoinType::LeftSemi | JoinType::RightSemi => (true, true),
195        JoinType::LeftAnti => (false, true),
196        JoinType::RightAnti => (true, false),
197        JoinType::LeftMark => (false, true),
198        JoinType::RightMark => (true, false),
199    }
200}
201
202/// Evaluates the columns referenced in the given expression to see if they refer
203/// only to the left or right columns
204#[derive(Debug)]
205struct ColumnChecker<'a> {
206    /// schema of left join input
207    left_schema: &'a DFSchema,
208    /// columns in left_schema, computed on demand
209    left_columns: Option<HashSet<Column>>,
210    /// schema of right join input
211    right_schema: &'a DFSchema,
212    /// columns in left_schema, computed on demand
213    right_columns: Option<HashSet<Column>>,
214}
215
216impl<'a> ColumnChecker<'a> {
217    fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
218        Self {
219            left_schema,
220            left_columns: None,
221            right_schema,
222            right_columns: None,
223        }
224    }
225
226    /// Return true if the expression references only columns from the left side of the join
227    fn is_left_only(&mut self, predicate: &Expr) -> bool {
228        if self.left_columns.is_none() {
229            self.left_columns = Some(schema_columns(self.left_schema));
230        }
231        has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
232    }
233
234    /// Return true if the expression references only columns from the right side of the join
235    fn is_right_only(&mut self, predicate: &Expr) -> bool {
236        if self.right_columns.is_none() {
237            self.right_columns = Some(schema_columns(self.right_schema));
238        }
239        has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
240    }
241}
242
243/// Returns all columns in the schema
244fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
245    schema
246        .iter()
247        .flat_map(|(qualifier, field)| {
248            [
249                Column::new(qualifier.cloned(), field.name()),
250                // we need to push down filter using unqualified column as well
251                Column::new_unqualified(field.name()),
252            ]
253        })
254        .collect::<HashSet<_>>()
255}
256
257/// Determine whether the predicate can evaluate as the join conditions
258fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
259    let mut is_evaluate = true;
260    predicate.apply(|expr| match expr {
261        Expr::Column(_)
262        | Expr::Literal(_, _)
263        | Expr::Placeholder(_)
264        | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
265        Expr::Exists { .. }
266        | Expr::InSubquery(_)
267        | Expr::SetComparison(_)
268        | Expr::ScalarSubquery(_)
269        | Expr::OuterReferenceColumn(_, _)
270        | Expr::Unnest(_) => {
271            is_evaluate = false;
272            Ok(TreeNodeRecursion::Stop)
273        }
274        Expr::Alias(_)
275        | Expr::BinaryExpr(_)
276        | Expr::Like(_)
277        | Expr::SimilarTo(_)
278        | Expr::Not(_)
279        | Expr::IsNotNull(_)
280        | Expr::IsNull(_)
281        | Expr::IsTrue(_)
282        | Expr::IsFalse(_)
283        | Expr::IsUnknown(_)
284        | Expr::IsNotTrue(_)
285        | Expr::IsNotFalse(_)
286        | Expr::IsNotUnknown(_)
287        | Expr::Negative(_)
288        | Expr::Between(_)
289        | Expr::Case(_)
290        | Expr::Cast(_)
291        | Expr::TryCast(_)
292        | Expr::InList { .. }
293        | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
294        // TODO: remove the next line after `Expr::Wildcard` is removed
295        #[expect(deprecated)]
296        Expr::AggregateFunction(_)
297        | Expr::WindowFunction(_)
298        | Expr::Wildcard { .. }
299        | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
300    })?;
301    Ok(is_evaluate)
302}
303
304/// examine OR clause to see if any useful clauses can be extracted and push down.
305/// extract at least one qual from each sub clauses of OR clause, then form the quals
306/// to new OR clause as predicate.
307///
308/// # Example
309/// ```text
310/// Filter: (a = c and a < 20) or (b = d and b > 10)
311///     join/crossjoin:
312///          TableScan: projection=[a, b]
313///          TableScan: projection=[c, d]
314/// ```
315///
316/// is optimized to
317///
318/// ```text
319/// Filter: (a = c and a < 20) or (b = d and b > 10)
320///     join/crossjoin:
321///          Filter: (a < 20) or (b > 10)
322///              TableScan: projection=[a, b]
323///          TableScan: projection=[c, d]
324/// ```
325///
326/// In general, predicates of this form:
327///
328/// ```sql
329/// (A AND B) OR (C AND D)
330/// ```
331///
332/// will be transformed to one of:
333///
334/// * `((A AND B) OR (C AND D)) AND (A OR C)`
335/// * `((A AND B) OR (C AND D)) AND ((A AND B) OR C)`
336/// * do nothing.
337fn extract_or_clauses_for_join<'a>(
338    filters: &'a [Expr],
339    schema: &'a DFSchema,
340) -> impl Iterator<Item = Expr> + 'a {
341    let schema_columns = schema_columns(schema);
342
343    // new formed OR clauses and their column references
344    filters.iter().filter_map(move |expr| {
345        if let Expr::BinaryExpr(BinaryExpr {
346            left,
347            op: Operator::Or,
348            right,
349        }) = expr
350        {
351            let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
352            let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
353
354            // If nothing can be extracted from any sub clauses, do nothing for this OR clause.
355            if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
356                return Some(or(left_expr, right_expr));
357            }
358        }
359        None
360    })
361}
362
363/// extract qual from OR sub-clause.
364///
365/// A qual is extracted if it only contains set of column references in schema_columns.
366///
367/// For AND clause, we extract from both sub-clauses, then make new AND clause by extracted
368/// clauses if both extracted; Otherwise, use the extracted clause from any sub-clauses or None.
369///
370/// For OR clause, we extract from both sub-clauses, then make new OR clause by extracted clauses if both extracted;
371/// Otherwise, return None.
372///
373/// For other clause, apply the rule above to extract clause.
374fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
375    let mut predicate = None;
376
377    match expr {
378        Expr::BinaryExpr(BinaryExpr {
379            left: l_expr,
380            op: Operator::Or,
381            right: r_expr,
382        }) => {
383            let l_expr = extract_or_clause(l_expr, schema_columns);
384            let r_expr = extract_or_clause(r_expr, schema_columns);
385
386            if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
387                predicate = Some(or(l_expr, r_expr));
388            }
389        }
390        Expr::BinaryExpr(BinaryExpr {
391            left: l_expr,
392            op: Operator::And,
393            right: r_expr,
394        }) => {
395            let l_expr = extract_or_clause(l_expr, schema_columns);
396            let r_expr = extract_or_clause(r_expr, schema_columns);
397
398            match (l_expr, r_expr) {
399                (Some(l_expr), Some(r_expr)) => {
400                    predicate = Some(and(l_expr, r_expr));
401                }
402                (Some(l_expr), None) => {
403                    predicate = Some(l_expr);
404                }
405                (None, Some(r_expr)) => {
406                    predicate = Some(r_expr);
407                }
408                (None, None) => {
409                    predicate = None;
410                }
411            }
412        }
413        _ => {
414            if has_all_column_refs(expr, schema_columns) {
415                predicate = Some(expr.clone());
416            }
417        }
418    }
419
420    predicate
421}
422
423/// push down join/cross-join
424fn push_down_all_join(
425    predicates: Vec<Expr>,
426    inferred_join_predicates: Vec<Expr>,
427    mut join: Join,
428    on_filter: Vec<Expr>,
429) -> Result<Transformed<LogicalPlan>> {
430    let is_inner_join = join.join_type == JoinType::Inner;
431    // Get pushable predicates from current optimizer state
432    let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
433
434    // The predicates can be divided to three categories:
435    // 1) can push through join to its children(left or right)
436    // 2) can be converted to join conditions if the join type is Inner
437    // 3) should be kept as filter conditions
438    let left_schema = join.left.schema();
439    let right_schema = join.right.schema();
440    let mut left_push = vec![];
441    let mut right_push = vec![];
442    let mut keep_predicates = vec![];
443    let mut join_conditions = vec![];
444    let mut checker = ColumnChecker::new(left_schema, right_schema);
445    for predicate in predicates {
446        if left_preserved && checker.is_left_only(&predicate) {
447            left_push.push(predicate);
448        } else if right_preserved && checker.is_right_only(&predicate) {
449            right_push.push(predicate);
450        } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
451            // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
452            // and convert to the join on condition
453            join_conditions.push(predicate);
454        } else {
455            keep_predicates.push(predicate);
456        }
457    }
458
459    // Push predicates inferred from the join expression
460    for predicate in inferred_join_predicates {
461        if checker.is_left_only(&predicate) {
462            left_push.push(predicate);
463        } else if checker.is_right_only(&predicate) {
464            right_push.push(predicate);
465        }
466    }
467
468    let mut on_filter_join_conditions = vec![];
469    let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
470
471    if !on_filter.is_empty() {
472        for on in on_filter {
473            if on_left_preserved && checker.is_left_only(&on) {
474                left_push.push(on)
475            } else if on_right_preserved && checker.is_right_only(&on) {
476                right_push.push(on)
477            } else {
478                on_filter_join_conditions.push(on)
479            }
480        }
481    }
482
483    // Extract from OR clause, generate new predicates for both side of join if possible.
484    // We only track the unpushable predicates above.
485    if left_preserved {
486        left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
487        left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
488    }
489    if right_preserved {
490        right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
491        right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
492    }
493
494    // For predicates from join filter, we should check with if a join side is preserved
495    // in term of join filtering.
496    if on_left_preserved {
497        left_push.extend(extract_or_clauses_for_join(
498            &on_filter_join_conditions,
499            left_schema,
500        ));
501    }
502    if on_right_preserved {
503        right_push.extend(extract_or_clauses_for_join(
504            &on_filter_join_conditions,
505            right_schema,
506        ));
507    }
508
509    if let Some(predicate) = conjunction(left_push) {
510        join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
511    }
512    if let Some(predicate) = conjunction(right_push) {
513        join.right =
514            Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
515    }
516
517    // Add any new join conditions as the non join predicates
518    join_conditions.extend(on_filter_join_conditions);
519    join.filter = conjunction(join_conditions);
520
521    // wrap the join on the filter whose predicates must be kept, if any
522    let plan = LogicalPlan::Join(join);
523    let plan = if let Some(predicate) = conjunction(keep_predicates) {
524        LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
525    } else {
526        plan
527    };
528    Ok(Transformed::yes(plan))
529}
530
531fn push_down_join(
532    join: Join,
533    parent_predicate: Option<&Expr>,
534) -> Result<Transformed<LogicalPlan>> {
535    // Split the parent predicate into individual conjunctive parts.
536    let predicates = parent_predicate
537        .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
538
539    // Extract conjunctions from the JOIN's ON filter, if present.
540    let on_filters = join
541        .filter
542        .as_ref()
543        .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
544
545    // Are there any new join predicates that can be inferred from the filter expressions?
546    let inferred_join_predicates =
547        infer_join_predicates(&join, &predicates, &on_filters)?;
548
549    if on_filters.is_empty()
550        && predicates.is_empty()
551        && inferred_join_predicates.is_empty()
552    {
553        return Ok(Transformed::no(LogicalPlan::Join(join)));
554    }
555
556    push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
557}
558
559/// Extracts any equi-join join predicates from the given filter expressions.
560///
561/// Parameters
562/// * `join` the join in question
563///
564/// * `predicates` the pushed down filter expression
565///
566/// * `on_filters` filters from the join ON clause that have not already been
567///   identified as join predicates
568fn infer_join_predicates(
569    join: &Join,
570    predicates: &[Expr],
571    on_filters: &[Expr],
572) -> Result<Vec<Expr>> {
573    // Only allow both side key is column.
574    let join_col_keys = join
575        .on
576        .iter()
577        .filter_map(|(l, r)| {
578            let left_col = l.try_as_col()?;
579            let right_col = r.try_as_col()?;
580            Some((left_col, right_col))
581        })
582        .collect::<Vec<_>>();
583
584    let join_type = join.join_type;
585
586    let mut inferred_predicates = InferredPredicates::new(join_type);
587
588    infer_join_predicates_from_predicates(
589        &join_col_keys,
590        predicates,
591        &mut inferred_predicates,
592    )?;
593
594    infer_join_predicates_from_on_filters(
595        &join_col_keys,
596        join_type,
597        on_filters,
598        &mut inferred_predicates,
599    )?;
600
601    Ok(inferred_predicates.predicates)
602}
603
604/// Inferred predicates collector.
605/// When the JoinType is not Inner, we need to detect whether the inferred predicate can strictly
606/// filter out NULL, otherwise ignore it. e.g.
607/// ```text
608/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL;
609/// ```
610/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate will be pushed down to
611/// the left side, resulting in the wrong result.
612struct InferredPredicates {
613    predicates: Vec<Expr>,
614    is_inner_join: bool,
615}
616
617impl InferredPredicates {
618    fn new(join_type: JoinType) -> Self {
619        Self {
620            predicates: vec![],
621            is_inner_join: join_type == JoinType::Inner,
622        }
623    }
624
625    fn try_build_predicate(
626        &mut self,
627        predicate: Expr,
628        replace_map: &HashMap<&Column, &Column>,
629    ) -> Result<()> {
630        if self.is_inner_join
631            || matches!(
632                is_restrict_null_predicate(
633                    predicate.clone(),
634                    replace_map.keys().cloned()
635                ),
636                Ok(true)
637            )
638        {
639            self.predicates.push(replace_col(predicate, replace_map)?);
640        }
641
642        Ok(())
643    }
644}
645
646/// Infer predicates from the pushed down predicates.
647///
648/// Parameters
649/// * `join_col_keys` column pairs from the join ON clause
650///
651/// * `predicates` the pushed down predicates
652///
653/// * `inferred_predicates` the inferred results
654fn infer_join_predicates_from_predicates(
655    join_col_keys: &[(&Column, &Column)],
656    predicates: &[Expr],
657    inferred_predicates: &mut InferredPredicates,
658) -> Result<()> {
659    infer_join_predicates_impl::<true, true>(
660        join_col_keys,
661        predicates,
662        inferred_predicates,
663    )
664}
665
666/// Infer predicates from the join filter.
667///
668/// Parameters
669/// * `join_col_keys` column pairs from the join ON clause
670///
671/// * `join_type` the JoinType of Join
672///
673/// * `on_filters` filters from the join ON clause that have not already been
674///   identified as join predicates
675///
676/// * `inferred_predicates` the inferred results
677fn infer_join_predicates_from_on_filters(
678    join_col_keys: &[(&Column, &Column)],
679    join_type: JoinType,
680    on_filters: &[Expr],
681    inferred_predicates: &mut InferredPredicates,
682) -> Result<()> {
683    match join_type {
684        JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
685        JoinType::Inner => infer_join_predicates_impl::<true, true>(
686            join_col_keys,
687            on_filters,
688            inferred_predicates,
689        ),
690        JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
691            infer_join_predicates_impl::<true, false>(
692                join_col_keys,
693                on_filters,
694                inferred_predicates,
695            )
696        }
697        JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
698            infer_join_predicates_impl::<false, true>(
699                join_col_keys,
700                on_filters,
701                inferred_predicates,
702            )
703        }
704    }
705}
706
707/// Infer predicates from the given predicates.
708///
709/// Parameters
710/// * `join_col_keys` column pairs from the join ON clause
711///
712/// * `input_predicates` the given predicates. It can be the pushed down predicates,
713///   or it can be the filters of the Join
714///
715/// * `inferred_predicates` the inferred results
716///
717/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate can
718///   be inferred from the left table related predicate
719///
720/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate can
721///   be inferred from the right table related predicate
722fn infer_join_predicates_impl<
723    const ENABLE_LEFT_TO_RIGHT: bool,
724    const ENABLE_RIGHT_TO_LEFT: bool,
725>(
726    join_col_keys: &[(&Column, &Column)],
727    input_predicates: &[Expr],
728    inferred_predicates: &mut InferredPredicates,
729) -> Result<()> {
730    for predicate in input_predicates {
731        let mut join_cols_to_replace = HashMap::new();
732
733        for &col in &predicate.column_refs() {
734            for (l, r) in join_col_keys.iter() {
735                if ENABLE_LEFT_TO_RIGHT && col == *l {
736                    join_cols_to_replace.insert(col, *r);
737                    break;
738                }
739                if ENABLE_RIGHT_TO_LEFT && col == *r {
740                    join_cols_to_replace.insert(col, *l);
741                    break;
742                }
743            }
744        }
745        if join_cols_to_replace.is_empty() {
746            continue;
747        }
748
749        inferred_predicates
750            .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
751    }
752    Ok(())
753}
754
755impl OptimizerRule for PushDownFilter {
756    fn name(&self) -> &str {
757        "push_down_filter"
758    }
759
760    fn apply_order(&self) -> Option<ApplyOrder> {
761        Some(ApplyOrder::TopDown)
762    }
763
764    fn supports_rewrite(&self) -> bool {
765        true
766    }
767
768    fn rewrite(
769        &self,
770        plan: LogicalPlan,
771        config: &dyn OptimizerConfig,
772    ) -> Result<Transformed<LogicalPlan>> {
773        let _ = config.options();
774        if let LogicalPlan::Join(join) = plan {
775            return push_down_join(join, None);
776        };
777
778        let plan_schema = Arc::clone(plan.schema());
779
780        let LogicalPlan::Filter(mut filter) = plan else {
781            return Ok(Transformed::no(plan));
782        };
783
784        let predicate = split_conjunction_owned(filter.predicate.clone());
785        let old_predicate_len = predicate.len();
786        let new_predicates = simplify_predicates(predicate)?;
787        if old_predicate_len != new_predicates.len() {
788            let Some(new_predicate) = conjunction(new_predicates) else {
789                // new_predicates is empty - remove the filter entirely
790                // Return the child plan without the filter
791                return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
792            };
793            filter.predicate = new_predicate;
794        }
795
796        // If the child has a fetch (limit) or skip (offset), pushing a filter
797        // below it would change semantics: the limit/offset should apply before
798        // the filter, not after.
799        if filter.input.fetch()?.is_some() || filter.input.skip()?.is_some() {
800            return Ok(Transformed::no(LogicalPlan::Filter(filter)));
801        }
802
803        match Arc::unwrap_or_clone(filter.input) {
804            LogicalPlan::Filter(child_filter) => {
805                let parents_predicates = split_conjunction_owned(filter.predicate);
806
807                // remove duplicated filters
808                let child_predicates = split_conjunction_owned(child_filter.predicate);
809                let new_predicates = parents_predicates
810                    .into_iter()
811                    .chain(child_predicates)
812                    // use IndexSet to remove dupes while preserving predicate order
813                    .collect::<IndexSet<_>>()
814                    .into_iter()
815                    .collect::<Vec<_>>();
816
817                let Some(new_predicate) = conjunction(new_predicates) else {
818                    return plan_err!("at least one expression exists");
819                };
820                let new_filter = LogicalPlan::Filter(Filter::try_new(
821                    new_predicate,
822                    child_filter.input,
823                )?);
824                self.rewrite(new_filter, config)
825            }
826            LogicalPlan::Repartition(repartition) => {
827                let new_filter =
828                    Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
829                        .map(LogicalPlan::Filter)?;
830                insert_below(LogicalPlan::Repartition(repartition), new_filter)
831            }
832            LogicalPlan::Distinct(distinct) => {
833                let new_filter =
834                    Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
835                        .map(LogicalPlan::Filter)?;
836                insert_below(LogicalPlan::Distinct(distinct), new_filter)
837            }
838            LogicalPlan::Sort(sort) => {
839                let new_filter =
840                    Filter::try_new(filter.predicate, Arc::clone(&sort.input))
841                        .map(LogicalPlan::Filter)?;
842                insert_below(LogicalPlan::Sort(sort), new_filter)
843            }
844            LogicalPlan::SubqueryAlias(subquery_alias) => {
845                let mut replace_map = HashMap::new();
846                for (i, (qualifier, field)) in
847                    subquery_alias.input.schema().iter().enumerate()
848                {
849                    let (sub_qualifier, sub_field) =
850                        subquery_alias.schema.qualified_field(i);
851                    replace_map.insert(
852                        qualified_name(sub_qualifier, sub_field.name()),
853                        Expr::Column(Column::new(qualifier.cloned(), field.name())),
854                    );
855                }
856                let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
857
858                let new_filter = LogicalPlan::Filter(Filter::try_new(
859                    new_predicate,
860                    Arc::clone(&subquery_alias.input),
861                )?);
862                insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
863            }
864            LogicalPlan::Projection(projection) => {
865                let predicates = split_conjunction_owned(filter.predicate.clone());
866                let (new_projection, keep_predicate) =
867                    rewrite_projection(predicates, projection)?;
868                if new_projection.transformed {
869                    match keep_predicate {
870                        None => Ok(new_projection),
871                        Some(keep_predicate) => new_projection.map_data(|child_plan| {
872                            Filter::try_new(keep_predicate, Arc::new(child_plan))
873                                .map(LogicalPlan::Filter)
874                        }),
875                    }
876                } else {
877                    filter.input = Arc::new(new_projection.data);
878                    Ok(Transformed::no(LogicalPlan::Filter(filter)))
879                }
880            }
881            LogicalPlan::Unnest(mut unnest) => {
882                let predicates = split_conjunction_owned(filter.predicate.clone());
883                let mut non_unnest_predicates = vec![];
884                let mut unnest_predicates = vec![];
885                let mut unnest_struct_columns = vec![];
886
887                for idx in &unnest.struct_type_columns {
888                    let (sub_qualifier, field) =
889                        unnest.input.schema().qualified_field(*idx);
890                    let field_name = field.name().clone();
891
892                    if let DataType::Struct(children) = field.data_type() {
893                        for child in children {
894                            let child_name = child.name().clone();
895                            unnest_struct_columns.push(Column::new(
896                                sub_qualifier.cloned(),
897                                format!("{field_name}.{child_name}"),
898                            ));
899                        }
900                    }
901                }
902
903                for predicate in predicates {
904                    // collect all the Expr::Column in predicate recursively
905                    let mut accum: HashSet<Column> = HashSet::new();
906                    expr_to_columns(&predicate, &mut accum)?;
907
908                    let contains_list_columns =
909                        unnest.list_type_columns.iter().any(|(_, unnest_list)| {
910                            accum.contains(&unnest_list.output_column)
911                        });
912                    let contains_struct_columns =
913                        unnest_struct_columns.iter().any(|c| accum.contains(c));
914
915                    if contains_list_columns || contains_struct_columns {
916                        unnest_predicates.push(predicate);
917                    } else {
918                        non_unnest_predicates.push(predicate);
919                    }
920                }
921
922                // Unnest predicates should not be pushed down.
923                // If no non-unnest predicates exist, early return
924                if non_unnest_predicates.is_empty() {
925                    filter.input = Arc::new(LogicalPlan::Unnest(unnest));
926                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
927                }
928
929                // Push down non-unnest filter predicate
930                // Unnest
931                //   Unnest Input (Projection)
932                // -> rewritten to
933                // Unnest
934                //   Filter
935                //     Unnest Input (Projection)
936
937                let unnest_input = std::mem::take(&mut unnest.input);
938
939                let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
940                    conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty.
941                    unnest_input,
942                )?);
943
944                // Directly assign new filter plan as the new unnest's input.
945                // The new filter plan will go through another rewrite pass since the rule itself
946                // is applied recursively to all the child from top to down
947                let unnest_plan =
948                    insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
949
950                match conjunction(unnest_predicates) {
951                    None => Ok(unnest_plan),
952                    Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
953                        Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
954                    ))),
955                }
956            }
957            LogicalPlan::Union(ref union) => {
958                let mut inputs = Vec::with_capacity(union.inputs.len());
959                for input in &union.inputs {
960                    let mut replace_map = HashMap::new();
961                    for (i, (qualifier, field)) in input.schema().iter().enumerate() {
962                        let (union_qualifier, union_field) =
963                            union.schema.qualified_field(i);
964                        replace_map.insert(
965                            qualified_name(union_qualifier, union_field.name()),
966                            Expr::Column(Column::new(qualifier.cloned(), field.name())),
967                        );
968                    }
969
970                    let push_predicate =
971                        replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
972                    inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
973                        push_predicate,
974                        Arc::clone(input),
975                    )?)))
976                }
977                Ok(Transformed::yes(LogicalPlan::Union(Union {
978                    inputs,
979                    schema: Arc::clone(&plan_schema),
980                })))
981            }
982            LogicalPlan::Aggregate(agg) => {
983                // We can push down Predicate which in groupby_expr.
984                let group_expr_columns = agg
985                    .group_expr
986                    .iter()
987                    .map(|e| {
988                        let (relation, name) = e.qualified_name();
989                        Column::new(relation, name)
990                    })
991                    .collect::<HashSet<_>>();
992
993                let predicates = split_conjunction_owned(filter.predicate);
994
995                let mut keep_predicates = vec![];
996                let mut push_predicates = vec![];
997                for expr in predicates {
998                    let cols = expr.column_refs();
999                    if cols.iter().all(|c| group_expr_columns.contains(c)) {
1000                        push_predicates.push(expr);
1001                    } else {
1002                        keep_predicates.push(expr);
1003                    }
1004                }
1005
1006                // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
1007                // After push, we need to replace `a+b` with Column(a)+Column(b)
1008                // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
1009                let mut replace_map = HashMap::new();
1010                for expr in &agg.group_expr {
1011                    replace_map.insert(expr.schema_name().to_string(), expr.clone());
1012                }
1013                let replaced_push_predicates = push_predicates
1014                    .into_iter()
1015                    .map(|expr| replace_cols_by_name(expr, &replace_map))
1016                    .collect::<Result<Vec<_>>>()?;
1017
1018                let agg_input = Arc::clone(&agg.input);
1019                Transformed::yes(LogicalPlan::Aggregate(agg))
1020                    .transform_data(|new_plan| {
1021                        // If we have a filter to push, we push it down to the input of the aggregate
1022                        if let Some(predicate) = conjunction(replaced_push_predicates) {
1023                            let new_filter = make_filter(predicate, agg_input)?;
1024                            insert_below(new_plan, new_filter)
1025                        } else {
1026                            Ok(Transformed::no(new_plan))
1027                        }
1028                    })?
1029                    .map_data(|child_plan| {
1030                        // if there are any remaining predicates we can't push, add them
1031                        // back as a filter
1032                        if let Some(predicate) = conjunction(keep_predicates) {
1033                            make_filter(predicate, Arc::new(child_plan))
1034                        } else {
1035                            Ok(child_plan)
1036                        }
1037                    })
1038            }
1039            // Tries to push filters based on the partition key(s) of the window function(s) used.
1040            // Example:
1041            //   Before:
1042            //     Filter: (a > 1) and (b > 1) and (c > 1)
1043            //      Window: func() PARTITION BY [a] ...
1044            //   ---
1045            //   After:
1046            //     Filter: (b > 1) and (c > 1)
1047            //      Window: func() PARTITION BY [a] ...
1048            //        Filter: (a > 1)
1049            LogicalPlan::Window(window) => {
1050                // Retrieve the set of potential partition keys where we can push filters by.
1051                // Unlike aggregations, where there is only one statement per SELECT, there can be
1052                // multiple window functions, each with potentially different partition keys.
1053                // Therefore, we need to ensure that any potential partition key returned is used in
1054                // ALL window functions. Otherwise, filters cannot be pushed by through that column.
1055                let extract_partition_keys = |func: &WindowFunction| {
1056                    func.params
1057                        .partition_by
1058                        .iter()
1059                        .map(|c| {
1060                            let (relation, name) = c.qualified_name();
1061                            Column::new(relation, name)
1062                        })
1063                        .collect::<HashSet<_>>()
1064                };
1065                let potential_partition_keys = window
1066                    .window_expr
1067                    .iter()
1068                    .map(|e| {
1069                        match e {
1070                            Expr::WindowFunction(window_func) => {
1071                                extract_partition_keys(window_func)
1072                            }
1073                            Expr::Alias(alias) => {
1074                                if let Expr::WindowFunction(window_func) =
1075                                    alias.expr.as_ref()
1076                                {
1077                                    extract_partition_keys(window_func)
1078                                } else {
1079                                    // window functions expressions are only Expr::WindowFunction
1080                                    unreachable!()
1081                                }
1082                            }
1083                            _ => {
1084                                // window functions expressions are only Expr::WindowFunction
1085                                unreachable!()
1086                            }
1087                        }
1088                    })
1089                    // performs the set intersection of the partition keys of all window functions,
1090                    // returning only the common ones
1091                    .reduce(|a, b| &a & &b)
1092                    .unwrap_or_default();
1093
1094                let predicates = split_conjunction_owned(filter.predicate);
1095                let mut keep_predicates = vec![];
1096                let mut push_predicates = vec![];
1097                for expr in predicates {
1098                    let cols = expr.column_refs();
1099                    if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1100                        push_predicates.push(expr);
1101                    } else {
1102                        keep_predicates.push(expr);
1103                    }
1104                }
1105
1106                // Unlike with aggregations, there are no cases where we have to replace, e.g.,
1107                // `a+b` with Column(a)+Column(b). This is because partition expressions are not
1108                // available as standalone columns to the user. For example, while an aggregation on
1109                // `a+b` becomes Column(a + b), in a window partition it becomes
1110                // `func() PARTITION BY [a + b] ...`. Thus, filters on expressions always remain in
1111                // place, so we can use `push_predicates` directly. This is consistent with other
1112                // optimizers, such as the one used by Postgres.
1113
1114                let window_input = Arc::clone(&window.input);
1115                Transformed::yes(LogicalPlan::Window(window))
1116                    .transform_data(|new_plan| {
1117                        // If we have a filter to push, we push it down to the input of the window
1118                        if let Some(predicate) = conjunction(push_predicates) {
1119                            let new_filter = make_filter(predicate, window_input)?;
1120                            insert_below(new_plan, new_filter)
1121                        } else {
1122                            Ok(Transformed::no(new_plan))
1123                        }
1124                    })?
1125                    .map_data(|child_plan| {
1126                        // if there are any remaining predicates we can't push, add them
1127                        // back as a filter
1128                        if let Some(predicate) = conjunction(keep_predicates) {
1129                            make_filter(predicate, Arc::new(child_plan))
1130                        } else {
1131                            Ok(child_plan)
1132                        }
1133                    })
1134            }
1135            LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1136            LogicalPlan::TableScan(scan) => {
1137                let filter_predicates = split_conjunction(&filter.predicate);
1138
1139                let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1140                    filter_predicates
1141                        .into_iter()
1142                        .partition(|pred| pred.is_volatile());
1143
1144                // Check which non-volatile filters are supported by source
1145                let supported_filters = scan
1146                    .source
1147                    .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1148                assert_eq_or_internal_err!(
1149                    non_volatile_filters.len(),
1150                    supported_filters.len(),
1151                    "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1152                    supported_filters.len(),
1153                    non_volatile_filters.len()
1154                );
1155
1156                // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type
1157                let zip = non_volatile_filters.into_iter().zip(supported_filters);
1158
1159                let new_scan_filters = zip
1160                    .clone()
1161                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1162                    .map(|(pred, _)| pred);
1163
1164                // Add new scan filters
1165                let new_scan_filters: Vec<Expr> = scan
1166                    .filters
1167                    .iter()
1168                    .chain(new_scan_filters)
1169                    .unique()
1170                    .cloned()
1171                    .collect();
1172
1173                // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
1174                let new_predicate: Vec<Expr> = zip
1175                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1176                    .map(|(pred, _)| pred)
1177                    .chain(volatile_filters)
1178                    .cloned()
1179                    .collect();
1180
1181                let new_scan = LogicalPlan::TableScan(TableScan {
1182                    filters: new_scan_filters,
1183                    ..scan
1184                });
1185
1186                Transformed::yes(new_scan).transform_data(|new_scan| {
1187                    if let Some(predicate) = conjunction(new_predicate) {
1188                        make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1189                    } else {
1190                        Ok(Transformed::no(new_scan))
1191                    }
1192                })
1193            }
1194            LogicalPlan::Extension(extension_plan) => {
1195                // This check prevents the Filter from being removed when the extension node has no children,
1196                // so we return the original Filter unchanged.
1197                if extension_plan.node.inputs().is_empty() {
1198                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1199                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1200                }
1201                let prevent_cols =
1202                    extension_plan.node.prevent_predicate_push_down_columns();
1203
1204                // determine if we can push any predicates down past the extension node
1205
1206                // each element is true for push, false to keep
1207                let predicate_push_or_keep = split_conjunction(&filter.predicate)
1208                    .iter()
1209                    .map(|expr| {
1210                        let cols = expr.column_refs();
1211                        if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1212                            Ok(false) // No push (keep)
1213                        } else {
1214                            Ok(true) // push
1215                        }
1216                    })
1217                    .collect::<Result<Vec<_>>>()?;
1218
1219                // all predicates are kept, no changes needed
1220                if predicate_push_or_keep.iter().all(|&x| !x) {
1221                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1222                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1223                }
1224
1225                // going to push some predicates down, so split the predicates
1226                let mut keep_predicates = vec![];
1227                let mut push_predicates = vec![];
1228                for (push, expr) in predicate_push_or_keep
1229                    .into_iter()
1230                    .zip(split_conjunction_owned(filter.predicate).into_iter())
1231                {
1232                    if !push {
1233                        keep_predicates.push(expr);
1234                    } else {
1235                        push_predicates.push(expr);
1236                    }
1237                }
1238
1239                let new_children = match conjunction(push_predicates) {
1240                    Some(predicate) => extension_plan
1241                        .node
1242                        .inputs()
1243                        .into_iter()
1244                        .map(|child| {
1245                            Ok(LogicalPlan::Filter(Filter::try_new(
1246                                predicate.clone(),
1247                                Arc::new(child.clone()),
1248                            )?))
1249                        })
1250                        .collect::<Result<Vec<_>>>()?,
1251                    None => extension_plan.node.inputs().into_iter().cloned().collect(),
1252                };
1253                // extension with new inputs.
1254                let child_plan = LogicalPlan::Extension(extension_plan);
1255                let new_extension =
1256                    child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1257
1258                let new_plan = match conjunction(keep_predicates) {
1259                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1260                        predicate,
1261                        Arc::new(new_extension),
1262                    )?),
1263                    None => new_extension,
1264                };
1265                Ok(Transformed::yes(new_plan))
1266            }
1267            child => {
1268                filter.input = Arc::new(child);
1269                Ok(Transformed::no(LogicalPlan::Filter(filter)))
1270            }
1271        }
1272    }
1273}
1274
1275/// Attempts to push `predicate` into a `FilterExec` below `projection
1276///
1277/// # Returns
1278/// (plan, remaining_predicate)
1279///
1280/// `plan` is a LogicalPlan for `projection` with possibly a new FilterExec below it.
1281/// `remaining_predicate` is any part of the predicate that could not be pushed down
1282///
1283/// # Args
1284/// - predicates: Split predicates like `[foo=5, bar=6]`
1285/// - projection: The target projection plan to push down the predicates
1286///
1287/// # Example
1288///
1289/// Pushing a predicate like `foo=5 AND bar=6` with an input plan like this:
1290///
1291/// ```text
1292/// Projection(foo, c+d as bar)
1293/// ```
1294///
1295/// Might result in returning `remaining_predicate` of `bar=6` and a plan like
1296///
1297/// ```text
1298/// Projection(foo, c+d as bar)
1299///  Filter(foo=5)
1300///   ...
1301/// ```
1302fn rewrite_projection(
1303    predicates: Vec<Expr>,
1304    mut projection: Projection,
1305) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1306    // Partition projection expressions into non-pushable vs pushable.
1307    // Non-pushable expressions are volatile (must not be duplicated) or
1308    // MoveTowardsLeafNodes (cheap expressions like get_field where re-inlining
1309    // into a filter causes optimizer instability — ExtractLeafExpressions will
1310    // undo the push-down, creating an infinite loop that runs until the
1311    // iteration limit is hit).
1312    let (non_pushable_map, pushable_map): (HashMap<_, _>, HashMap<_, _>) = projection
1313        .schema
1314        .iter()
1315        .zip(projection.expr.iter())
1316        .map(|((qualifier, field), expr)| {
1317            // strip alias, as they should not be part of filters
1318            let expr = expr.clone().unalias();
1319
1320            (qualified_name(qualifier, field.name()), expr)
1321        })
1322        .partition(|(_, value)| {
1323            value.is_volatile()
1324                || value.placement() == ExpressionPlacement::MoveTowardsLeafNodes
1325        });
1326
1327    let mut push_predicates = vec![];
1328    let mut keep_predicates = vec![];
1329    for expr in predicates {
1330        if contain(&expr, &non_pushable_map) {
1331            keep_predicates.push(expr);
1332        } else {
1333            push_predicates.push(expr);
1334        }
1335    }
1336
1337    match conjunction(push_predicates) {
1338        Some(expr) => {
1339            // re-write all filters based on this projection
1340            // E.g. in `Filter: b\n  Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
1341            let new_filter = LogicalPlan::Filter(Filter::try_new(
1342                replace_cols_by_name(expr, &pushable_map)?,
1343                std::mem::take(&mut projection.input),
1344            )?);
1345
1346            projection.input = Arc::new(new_filter);
1347
1348            Ok((
1349                Transformed::yes(LogicalPlan::Projection(projection)),
1350                conjunction(keep_predicates),
1351            ))
1352        }
1353        None => Ok((
1354            Transformed::no(LogicalPlan::Projection(projection)),
1355            conjunction(keep_predicates),
1356        )),
1357    }
1358}
1359
1360/// Creates a new LogicalPlan::Filter node.
1361pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1362    Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1363}
1364
1365/// Replace the existing child of the single input node with `new_child`.
1366///
1367/// Starting:
1368/// ```text
1369/// plan
1370///   child
1371/// ```
1372///
1373/// Ending:
1374/// ```text
1375/// plan
1376///   new_child
1377/// ```
1378fn insert_below(
1379    plan: LogicalPlan,
1380    new_child: LogicalPlan,
1381) -> Result<Transformed<LogicalPlan>> {
1382    let mut new_child = Some(new_child);
1383    let transformed_plan = plan.map_children(|_child| {
1384        if let Some(new_child) = new_child.take() {
1385            Ok(Transformed::yes(new_child))
1386        } else {
1387            // already took the new child
1388            internal_err!("node had more than one input")
1389        }
1390    })?;
1391
1392    // make sure we did the actual replacement
1393    assert_or_internal_err!(new_child.is_none(), "node had no inputs");
1394
1395    Ok(transformed_plan)
1396}
1397
1398impl PushDownFilter {
1399    #[expect(missing_docs)]
1400    pub fn new() -> Self {
1401        Self {}
1402    }
1403}
1404
1405/// replaces columns by its name on the projection.
1406pub fn replace_cols_by_name(
1407    e: Expr,
1408    replace_map: &HashMap<String, Expr>,
1409) -> Result<Expr> {
1410    e.transform_up(|expr| {
1411        Ok(if let Expr::Column(c) = &expr {
1412            match replace_map.get(&c.flat_name()) {
1413                Some(new_c) => Transformed::yes(new_c.clone()),
1414                None => Transformed::no(expr),
1415            }
1416        } else {
1417            Transformed::no(expr)
1418        })
1419    })
1420    .data()
1421}
1422
1423/// check whether the expression uses the columns in `check_map`.
1424fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1425    let mut is_contain = false;
1426    e.apply(|expr| {
1427        Ok(if let Expr::Column(c) = &expr {
1428            match check_map.get(&c.flat_name()) {
1429                Some(_) => {
1430                    is_contain = true;
1431                    TreeNodeRecursion::Stop
1432                }
1433                None => TreeNodeRecursion::Continue,
1434            }
1435        } else {
1436            TreeNodeRecursion::Continue
1437        })
1438    })
1439    .unwrap();
1440    is_contain
1441}
1442
1443#[cfg(test)]
1444mod tests {
1445    use std::any::Any;
1446    use std::cmp::Ordering;
1447    use std::fmt::{Debug, Formatter};
1448
1449    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1450    use async_trait::async_trait;
1451
1452    use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1453    use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1454    use datafusion_expr::logical_plan::table_scan;
1455    use datafusion_expr::{
1456        ColumnarValue, ExprFunctionExt, Extension, LogicalPlanBuilder,
1457        ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType,
1458        UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, col, in_list,
1459        in_subquery, lit,
1460    };
1461
1462    use crate::OptimizerContext;
1463    use crate::assert_optimized_plan_eq_snapshot;
1464    use crate::optimizer::Optimizer;
1465    use crate::simplify_expressions::SimplifyExpressions;
1466    use crate::test::udfs::leaf_udf_expr;
1467    use crate::test::*;
1468    use datafusion_expr::test::function_stub::sum;
1469    use insta::assert_snapshot;
1470
1471    use super::*;
1472
1473    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1474
1475    macro_rules! assert_optimized_plan_equal {
1476        (
1477            $plan:expr,
1478            @ $expected:literal $(,)?
1479        ) => {{
1480            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1481            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1482            assert_optimized_plan_eq_snapshot!(
1483                optimizer_ctx,
1484                rules,
1485                $plan,
1486                @ $expected,
1487            )
1488        }};
1489    }
1490
1491    macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1492        (
1493            $plan:expr,
1494            @ $expected:literal $(,)?
1495        ) => {{
1496            let optimizer = Optimizer::with_rules(vec![
1497                Arc::new(SimplifyExpressions::new()),
1498                Arc::new(PushDownFilter::new()),
1499            ]);
1500            let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1501            assert_snapshot!(optimized_plan, @ $expected);
1502            Ok::<(), DataFusionError>(())
1503        }};
1504    }
1505
1506    #[test]
1507    fn filter_before_projection() -> Result<()> {
1508        let table_scan = test_table_scan()?;
1509        let plan = LogicalPlanBuilder::from(table_scan)
1510            .project(vec![col("a"), col("b")])?
1511            .filter(col("a").eq(lit(1i64)))?
1512            .build()?;
1513        // filter is before projection
1514        assert_optimized_plan_equal!(
1515            plan,
1516            @r"
1517        Projection: test.a, test.b
1518          TableScan: test, full_filters=[test.a = Int64(1)]
1519        "
1520        )
1521    }
1522
1523    #[test]
1524    fn filter_after_limit() -> Result<()> {
1525        let table_scan = test_table_scan()?;
1526        let plan = LogicalPlanBuilder::from(table_scan)
1527            .project(vec![col("a"), col("b")])?
1528            .limit(0, Some(10))?
1529            .filter(col("a").eq(lit(1i64)))?
1530            .build()?;
1531        // filter is before single projection
1532        assert_optimized_plan_equal!(
1533            plan,
1534            @r"
1535        Filter: test.a = Int64(1)
1536          Limit: skip=0, fetch=10
1537            Projection: test.a, test.b
1538              TableScan: test
1539        "
1540        )
1541    }
1542
1543    #[test]
1544    fn filter_no_columns() -> Result<()> {
1545        let table_scan = test_table_scan()?;
1546        let plan = LogicalPlanBuilder::from(table_scan)
1547            .filter(lit(0i64).eq(lit(1i64)))?
1548            .build()?;
1549        assert_optimized_plan_equal!(
1550            plan,
1551            @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1552        )
1553    }
1554
1555    #[test]
1556    fn filter_jump_2_plans() -> Result<()> {
1557        let table_scan = test_table_scan()?;
1558        let plan = LogicalPlanBuilder::from(table_scan)
1559            .project(vec![col("a"), col("b"), col("c")])?
1560            .project(vec![col("c"), col("b")])?
1561            .filter(col("a").eq(lit(1i64)))?
1562            .build()?;
1563        // filter is before double projection
1564        assert_optimized_plan_equal!(
1565            plan,
1566            @r"
1567        Projection: test.c, test.b
1568          Projection: test.a, test.b, test.c
1569            TableScan: test, full_filters=[test.a = Int64(1)]
1570        "
1571        )
1572    }
1573
1574    #[test]
1575    fn filter_move_agg() -> Result<()> {
1576        let table_scan = test_table_scan()?;
1577        let plan = LogicalPlanBuilder::from(table_scan)
1578            .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1579            .filter(col("a").gt(lit(10i64)))?
1580            .build()?;
1581        // filter of key aggregation is commutative
1582        assert_optimized_plan_equal!(
1583            plan,
1584            @r"
1585        Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1586          TableScan: test, full_filters=[test.a > Int64(10)]
1587        "
1588        )
1589    }
1590
1591    /// verifies that filters with unusual column names are pushed down through aggregate operators
1592    #[test]
1593    fn filter_move_agg_special() -> Result<()> {
1594        let schema = Schema::new(vec![
1595            Field::new("$a", DataType::UInt32, false),
1596            Field::new("$b", DataType::UInt32, false),
1597            Field::new("$c", DataType::UInt32, false),
1598        ]);
1599        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1600
1601        let plan = LogicalPlanBuilder::from(table_scan)
1602            .aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])?
1603            .filter(col("$a").gt(lit(10i64)))?
1604            .build()?;
1605        // filter of key aggregation is commutative
1606        assert_optimized_plan_equal!(
1607            plan,
1608            @r"
1609        Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]]
1610          TableScan: test, full_filters=[test.$a > Int64(10)]
1611        "
1612        )
1613    }
1614
1615    #[test]
1616    fn filter_complex_group_by() -> Result<()> {
1617        let table_scan = test_table_scan()?;
1618        let plan = LogicalPlanBuilder::from(table_scan)
1619            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1620            .filter(col("b").gt(lit(10i64)))?
1621            .build()?;
1622        assert_optimized_plan_equal!(
1623            plan,
1624            @r"
1625        Filter: test.b > Int64(10)
1626          Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1627            TableScan: test
1628        "
1629        )
1630    }
1631
1632    #[test]
1633    fn push_agg_need_replace_expr() -> Result<()> {
1634        let plan = LogicalPlanBuilder::from(test_table_scan()?)
1635            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1636            .filter(col("test.b + test.a").gt(lit(10i64)))?
1637            .build()?;
1638        assert_optimized_plan_equal!(
1639            plan,
1640            @r"
1641        Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1642          TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1643        "
1644        )
1645    }
1646
1647    #[test]
1648    fn filter_keep_agg() -> Result<()> {
1649        let table_scan = test_table_scan()?;
1650        let plan = LogicalPlanBuilder::from(table_scan)
1651            .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1652            .filter(col("b").gt(lit(10i64)))?
1653            .build()?;
1654        // filter of aggregate is after aggregation since they are non-commutative
1655        assert_optimized_plan_equal!(
1656            plan,
1657            @r"
1658        Filter: b > Int64(10)
1659          Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1660            TableScan: test
1661        "
1662        )
1663    }
1664
1665    /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed
1666    #[test]
1667    fn filter_move_window() -> Result<()> {
1668        let table_scan = test_table_scan()?;
1669
1670        let window = Expr::from(WindowFunction::new(
1671            WindowFunctionDefinition::WindowUDF(
1672                datafusion_functions_window::rank::rank_udwf(),
1673            ),
1674            vec![],
1675        ))
1676        .partition_by(vec![col("a"), col("b")])
1677        .order_by(vec![col("c").sort(true, true)])
1678        .build()
1679        .unwrap();
1680
1681        let plan = LogicalPlanBuilder::from(table_scan)
1682            .window(vec![window])?
1683            .filter(col("b").gt(lit(10i64)))?
1684            .build()?;
1685
1686        assert_optimized_plan_equal!(
1687            plan,
1688            @r"
1689        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1690          TableScan: test, full_filters=[test.b > Int64(10)]
1691        "
1692        )
1693    }
1694
1695    /// verifies that filters with unusual identifier names are pushed down through window functions
1696    #[test]
1697    fn filter_window_special_identifier() -> Result<()> {
1698        let schema = Schema::new(vec![
1699            Field::new("$a", DataType::UInt32, false),
1700            Field::new("$b", DataType::UInt32, false),
1701            Field::new("$c", DataType::UInt32, false),
1702        ]);
1703        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1704
1705        let window = Expr::from(WindowFunction::new(
1706            WindowFunctionDefinition::WindowUDF(
1707                datafusion_functions_window::rank::rank_udwf(),
1708            ),
1709            vec![],
1710        ))
1711        .partition_by(vec![col("$a"), col("$b")])
1712        .order_by(vec![col("$c").sort(true, true)])
1713        .build()
1714        .unwrap();
1715
1716        let plan = LogicalPlanBuilder::from(table_scan)
1717            .window(vec![window])?
1718            .filter(col("$b").gt(lit(10i64)))?
1719            .build()?;
1720
1721        assert_optimized_plan_equal!(
1722            plan,
1723            @r"
1724        WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1725          TableScan: test, full_filters=[test.$b > Int64(10)]
1726        "
1727        )
1728    }
1729
1730    /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and
1731    /// 'b' are pushed
1732    #[test]
1733    fn filter_move_complex_window() -> Result<()> {
1734        let table_scan = test_table_scan()?;
1735
1736        let window = Expr::from(WindowFunction::new(
1737            WindowFunctionDefinition::WindowUDF(
1738                datafusion_functions_window::rank::rank_udwf(),
1739            ),
1740            vec![],
1741        ))
1742        .partition_by(vec![col("a"), col("b")])
1743        .order_by(vec![col("c").sort(true, true)])
1744        .build()
1745        .unwrap();
1746
1747        let plan = LogicalPlanBuilder::from(table_scan)
1748            .window(vec![window])?
1749            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1750            .build()?;
1751
1752        assert_optimized_plan_equal!(
1753            plan,
1754            @r"
1755        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1756          TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1757        "
1758        )
1759    }
1760
1761    /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed
1762    #[test]
1763    fn filter_move_partial_window() -> Result<()> {
1764        let table_scan = test_table_scan()?;
1765
1766        let window = Expr::from(WindowFunction::new(
1767            WindowFunctionDefinition::WindowUDF(
1768                datafusion_functions_window::rank::rank_udwf(),
1769            ),
1770            vec![],
1771        ))
1772        .partition_by(vec![col("a")])
1773        .order_by(vec![col("c").sort(true, true)])
1774        .build()
1775        .unwrap();
1776
1777        let plan = LogicalPlanBuilder::from(table_scan)
1778            .window(vec![window])?
1779            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1780            .build()?;
1781
1782        assert_optimized_plan_equal!(
1783            plan,
1784            @r"
1785        Filter: test.b = Int64(1)
1786          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1787            TableScan: test, full_filters=[test.a > Int64(10)]
1788        "
1789        )
1790    }
1791
1792    /// verifies that filters on partition expressions are not pushed, as the single expression
1793    /// column is not available to the user, unlike with aggregations
1794    #[test]
1795    fn filter_expression_keep_window() -> Result<()> {
1796        let table_scan = test_table_scan()?;
1797
1798        let window = Expr::from(WindowFunction::new(
1799            WindowFunctionDefinition::WindowUDF(
1800                datafusion_functions_window::rank::rank_udwf(),
1801            ),
1802            vec![],
1803        ))
1804        .partition_by(vec![add(col("a"), col("b"))]) // PARTITION BY a + b
1805        .order_by(vec![col("c").sort(true, true)])
1806        .build()
1807        .unwrap();
1808
1809        let plan = LogicalPlanBuilder::from(table_scan)
1810            .window(vec![window])?
1811            // unlike with aggregations, single partition column "test.a + test.b" is not available
1812            // to the plan, so we use multiple columns when filtering
1813            .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1814            .build()?;
1815
1816        assert_optimized_plan_equal!(
1817            plan,
1818            @r"
1819        Filter: test.a + test.b > Int64(10)
1820          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1821            TableScan: test
1822        "
1823        )
1824    }
1825
1826    /// verifies that filters are not pushed on order by columns (that are not used in partitioning)
1827    #[test]
1828    fn filter_order_keep_window() -> Result<()> {
1829        let table_scan = test_table_scan()?;
1830
1831        let window = Expr::from(WindowFunction::new(
1832            WindowFunctionDefinition::WindowUDF(
1833                datafusion_functions_window::rank::rank_udwf(),
1834            ),
1835            vec![],
1836        ))
1837        .partition_by(vec![col("a")])
1838        .order_by(vec![col("c").sort(true, true)])
1839        .build()
1840        .unwrap();
1841
1842        let plan = LogicalPlanBuilder::from(table_scan)
1843            .window(vec![window])?
1844            .filter(col("c").gt(lit(10i64)))?
1845            .build()?;
1846
1847        assert_optimized_plan_equal!(
1848            plan,
1849            @r"
1850        Filter: test.c > Int64(10)
1851          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1852            TableScan: test
1853        "
1854        )
1855    }
1856
1857    /// verifies that when we use multiple window functions with a common partition key, the filter
1858    /// on that key is pushed
1859    #[test]
1860    fn filter_multiple_windows_common_partitions() -> Result<()> {
1861        let table_scan = test_table_scan()?;
1862
1863        let window1 = Expr::from(WindowFunction::new(
1864            WindowFunctionDefinition::WindowUDF(
1865                datafusion_functions_window::rank::rank_udwf(),
1866            ),
1867            vec![],
1868        ))
1869        .partition_by(vec![col("a")])
1870        .order_by(vec![col("c").sort(true, true)])
1871        .build()
1872        .unwrap();
1873
1874        let window2 = Expr::from(WindowFunction::new(
1875            WindowFunctionDefinition::WindowUDF(
1876                datafusion_functions_window::rank::rank_udwf(),
1877            ),
1878            vec![],
1879        ))
1880        .partition_by(vec![col("b"), col("a")])
1881        .order_by(vec![col("c").sort(true, true)])
1882        .build()
1883        .unwrap();
1884
1885        let plan = LogicalPlanBuilder::from(table_scan)
1886            .window(vec![window1, window2])?
1887            .filter(col("a").gt(lit(10i64)))? // a appears in both window functions
1888            .build()?;
1889
1890        assert_optimized_plan_equal!(
1891            plan,
1892            @r"
1893        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1894          TableScan: test, full_filters=[test.a > Int64(10)]
1895        "
1896        )
1897    }
1898
1899    /// verifies that when we use multiple window functions with different partitions keys, the
1900    /// filter cannot be pushed
1901    #[test]
1902    fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1903        let table_scan = test_table_scan()?;
1904
1905        let window1 = Expr::from(WindowFunction::new(
1906            WindowFunctionDefinition::WindowUDF(
1907                datafusion_functions_window::rank::rank_udwf(),
1908            ),
1909            vec![],
1910        ))
1911        .partition_by(vec![col("a")])
1912        .order_by(vec![col("c").sort(true, true)])
1913        .build()
1914        .unwrap();
1915
1916        let window2 = Expr::from(WindowFunction::new(
1917            WindowFunctionDefinition::WindowUDF(
1918                datafusion_functions_window::rank::rank_udwf(),
1919            ),
1920            vec![],
1921        ))
1922        .partition_by(vec![col("b"), col("a")])
1923        .order_by(vec![col("c").sort(true, true)])
1924        .build()
1925        .unwrap();
1926
1927        let plan = LogicalPlanBuilder::from(table_scan)
1928            .window(vec![window1, window2])?
1929            .filter(col("b").gt(lit(10i64)))? // b only appears in one window function
1930            .build()?;
1931
1932        assert_optimized_plan_equal!(
1933            plan,
1934            @r"
1935        Filter: test.b > Int64(10)
1936          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1937            TableScan: test
1938        "
1939        )
1940    }
1941
1942    /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
1943    #[test]
1944    fn alias() -> Result<()> {
1945        let table_scan = test_table_scan()?;
1946        let plan = LogicalPlanBuilder::from(table_scan)
1947            .project(vec![col("a").alias("b"), col("c")])?
1948            .filter(col("b").eq(lit(1i64)))?
1949            .build()?;
1950        // filter is before projection
1951        assert_optimized_plan_equal!(
1952            plan,
1953            @r"
1954        Projection: test.a AS b, test.c
1955          TableScan: test, full_filters=[test.a = Int64(1)]
1956        "
1957        )
1958    }
1959
1960    fn add(left: Expr, right: Expr) -> Expr {
1961        Expr::BinaryExpr(BinaryExpr::new(
1962            Box::new(left),
1963            Operator::Plus,
1964            Box::new(right),
1965        ))
1966    }
1967
1968    fn multiply(left: Expr, right: Expr) -> Expr {
1969        Expr::BinaryExpr(BinaryExpr::new(
1970            Box::new(left),
1971            Operator::Multiply,
1972            Box::new(right),
1973        ))
1974    }
1975
1976    /// verifies that a filter is pushed to before a projection with a complex expression, the filter expression is correctly re-written
1977    #[test]
1978    fn complex_expression() -> Result<()> {
1979        let table_scan = test_table_scan()?;
1980        let plan = LogicalPlanBuilder::from(table_scan)
1981            .project(vec![
1982                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1983                col("c"),
1984            ])?
1985            .filter(col("b").eq(lit(1i64)))?
1986            .build()?;
1987
1988        // not part of the test, just good to know:
1989        assert_snapshot!(plan,
1990        @r"
1991        Filter: b = Int64(1)
1992          Projection: test.a * Int32(2) + test.c AS b, test.c
1993            TableScan: test
1994        ",
1995        );
1996        // filter is before projection
1997        assert_optimized_plan_equal!(
1998            plan,
1999            @r"
2000        Projection: test.a * Int32(2) + test.c AS b, test.c
2001          TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
2002        "
2003        )
2004    }
2005
2006    /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written
2007    #[test]
2008    fn complex_plan() -> Result<()> {
2009        let table_scan = test_table_scan()?;
2010        let plan = LogicalPlanBuilder::from(table_scan)
2011            .project(vec![
2012                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
2013                col("c"),
2014            ])?
2015            // second projection where we rename columns, just to make it difficult
2016            .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
2017            .filter(col("a").eq(lit(1i64)))?
2018            .build()?;
2019
2020        // not part of the test, just good to know:
2021        assert_snapshot!(plan,
2022        @r"
2023        Filter: a = Int64(1)
2024          Projection: b * Int32(3) AS a, test.c
2025            Projection: test.a * Int32(2) + test.c AS b, test.c
2026              TableScan: test
2027        ",
2028        );
2029        // filter is before the projections
2030        assert_optimized_plan_equal!(
2031            plan,
2032            @r"
2033        Projection: b * Int32(3) AS a, test.c
2034          Projection: test.a * Int32(2) + test.c AS b, test.c
2035            TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
2036        "
2037        )
2038    }
2039
2040    #[derive(Debug, PartialEq, Eq, Hash)]
2041    struct NoopPlan {
2042        input: Vec<LogicalPlan>,
2043        schema: DFSchemaRef,
2044    }
2045
2046    // Manual implementation needed because of `schema` field. Comparison excludes this field.
2047    impl PartialOrd for NoopPlan {
2048        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2049            self.input
2050                .partial_cmp(&other.input)
2051                // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
2052                .filter(|cmp| *cmp != Ordering::Equal || self == other)
2053        }
2054    }
2055
2056    impl UserDefinedLogicalNodeCore for NoopPlan {
2057        fn name(&self) -> &str {
2058            "NoopPlan"
2059        }
2060
2061        fn inputs(&self) -> Vec<&LogicalPlan> {
2062            self.input.iter().collect()
2063        }
2064
2065        fn schema(&self) -> &DFSchemaRef {
2066            &self.schema
2067        }
2068
2069        fn expressions(&self) -> Vec<Expr> {
2070            self.input
2071                .iter()
2072                .flat_map(|child| child.expressions())
2073                .collect()
2074        }
2075
2076        fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
2077            HashSet::from_iter(vec!["c".to_string()])
2078        }
2079
2080        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
2081            write!(f, "NoopPlan")
2082        }
2083
2084        fn with_exprs_and_inputs(
2085            &self,
2086            _exprs: Vec<Expr>,
2087            inputs: Vec<LogicalPlan>,
2088        ) -> Result<Self> {
2089            Ok(Self {
2090                input: inputs,
2091                schema: Arc::clone(&self.schema),
2092            })
2093        }
2094
2095        fn supports_limit_pushdown(&self) -> bool {
2096            false // Disallow limit push-down by default
2097        }
2098    }
2099
2100    #[test]
2101    fn user_defined_plan() -> Result<()> {
2102        let table_scan = test_table_scan()?;
2103
2104        let custom_plan = LogicalPlan::Extension(Extension {
2105            node: Arc::new(NoopPlan {
2106                input: vec![table_scan.clone()],
2107                schema: Arc::clone(table_scan.schema()),
2108            }),
2109        });
2110        let plan = LogicalPlanBuilder::from(custom_plan)
2111            .filter(col("a").eq(lit(1i64)))?
2112            .build()?;
2113
2114        // Push filter below NoopPlan
2115        assert_optimized_plan_equal!(
2116            plan,
2117            @r"
2118        NoopPlan
2119          TableScan: test, full_filters=[test.a = Int64(1)]
2120        "
2121        )?;
2122
2123        let custom_plan = LogicalPlan::Extension(Extension {
2124            node: Arc::new(NoopPlan {
2125                input: vec![table_scan.clone()],
2126                schema: Arc::clone(table_scan.schema()),
2127            }),
2128        });
2129        let plan = LogicalPlanBuilder::from(custom_plan)
2130            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2131            .build()?;
2132
2133        // Push only predicate on `a` below NoopPlan
2134        assert_optimized_plan_equal!(
2135            plan,
2136            @r"
2137        Filter: test.c = Int64(2)
2138          NoopPlan
2139            TableScan: test, full_filters=[test.a = Int64(1)]
2140        "
2141        )?;
2142
2143        let custom_plan = LogicalPlan::Extension(Extension {
2144            node: Arc::new(NoopPlan {
2145                input: vec![table_scan.clone(), table_scan.clone()],
2146                schema: Arc::clone(table_scan.schema()),
2147            }),
2148        });
2149        let plan = LogicalPlanBuilder::from(custom_plan)
2150            .filter(col("a").eq(lit(1i64)))?
2151            .build()?;
2152
2153        // Push filter below NoopPlan for each child branch
2154        assert_optimized_plan_equal!(
2155            plan,
2156            @r"
2157        NoopPlan
2158          TableScan: test, full_filters=[test.a = Int64(1)]
2159          TableScan: test, full_filters=[test.a = Int64(1)]
2160        "
2161        )?;
2162
2163        let custom_plan = LogicalPlan::Extension(Extension {
2164            node: Arc::new(NoopPlan {
2165                input: vec![table_scan.clone(), table_scan.clone()],
2166                schema: Arc::clone(table_scan.schema()),
2167            }),
2168        });
2169        let plan = LogicalPlanBuilder::from(custom_plan)
2170            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2171            .build()?;
2172
2173        // Push only predicate on `a` below NoopPlan
2174        assert_optimized_plan_equal!(
2175            plan,
2176            @r"
2177        Filter: test.c = Int64(2)
2178          NoopPlan
2179            TableScan: test, full_filters=[test.a = Int64(1)]
2180            TableScan: test, full_filters=[test.a = Int64(1)]
2181        "
2182        )
2183    }
2184
2185    /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed
2186    /// and the other not.
2187    #[test]
2188    fn multi_filter() -> Result<()> {
2189        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2190        let table_scan = test_table_scan()?;
2191        let plan = LogicalPlanBuilder::from(table_scan)
2192            .project(vec![col("a").alias("b"), col("c")])?
2193            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2194            .filter(col("b").gt(lit(10i64)))?
2195            .filter(col("sum(test.c)").gt(lit(10i64)))?
2196            .build()?;
2197
2198        // not part of the test, just good to know:
2199        assert_snapshot!(plan,
2200        @r"
2201        Filter: sum(test.c) > Int64(10)
2202          Filter: b > Int64(10)
2203            Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2204              Projection: test.a AS b, test.c
2205                TableScan: test
2206        ",
2207        );
2208        // filter is before the projections
2209        assert_optimized_plan_equal!(
2210            plan,
2211            @r"
2212        Filter: sum(test.c) > Int64(10)
2213          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2214            Projection: test.a AS b, test.c
2215              TableScan: test, full_filters=[test.a > Int64(10)]
2216        "
2217        )
2218    }
2219
2220    /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed
2221    /// and the other not.
2222    #[test]
2223    fn split_filter() -> Result<()> {
2224        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2225        let table_scan = test_table_scan()?;
2226        let plan = LogicalPlanBuilder::from(table_scan)
2227            .project(vec![col("a").alias("b"), col("c")])?
2228            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2229            .filter(and(
2230                col("sum(test.c)").gt(lit(10i64)),
2231                and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2232            ))?
2233            .build()?;
2234
2235        // not part of the test, just good to know:
2236        assert_snapshot!(plan,
2237        @r"
2238        Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2239          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2240            Projection: test.a AS b, test.c
2241              TableScan: test
2242        ",
2243        );
2244        // filter is before the projections
2245        assert_optimized_plan_equal!(
2246            plan,
2247            @r"
2248        Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2249          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2250            Projection: test.a AS b, test.c
2251              TableScan: test, full_filters=[test.a > Int64(10)]
2252        "
2253        )
2254    }
2255
2256    /// verifies that when two limits are in place, we jump neither
2257    #[test]
2258    fn double_limit() -> Result<()> {
2259        let table_scan = test_table_scan()?;
2260        let plan = LogicalPlanBuilder::from(table_scan)
2261            .project(vec![col("a"), col("b")])?
2262            .limit(0, Some(20))?
2263            .limit(0, Some(10))?
2264            .project(vec![col("a"), col("b")])?
2265            .filter(col("a").eq(lit(1i64)))?
2266            .build()?;
2267        // filter does not just any of the limits
2268        assert_optimized_plan_equal!(
2269            plan,
2270            @r"
2271        Projection: test.a, test.b
2272          Filter: test.a = Int64(1)
2273            Limit: skip=0, fetch=10
2274              Limit: skip=0, fetch=20
2275                Projection: test.a, test.b
2276                  TableScan: test
2277        "
2278        )
2279    }
2280
2281    #[test]
2282    fn union_all() -> Result<()> {
2283        let table_scan = test_table_scan()?;
2284        let table_scan2 = test_table_scan_with_name("test2")?;
2285        let plan = LogicalPlanBuilder::from(table_scan)
2286            .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2287            .filter(col("a").eq(lit(1i64)))?
2288            .build()?;
2289        // filter appears below Union
2290        assert_optimized_plan_equal!(
2291            plan,
2292            @r"
2293        Union
2294          TableScan: test, full_filters=[test.a = Int64(1)]
2295          TableScan: test2, full_filters=[test2.a = Int64(1)]
2296        "
2297        )
2298    }
2299
2300    #[test]
2301    fn union_all_on_projection() -> Result<()> {
2302        let table_scan = test_table_scan()?;
2303        let table = LogicalPlanBuilder::from(table_scan)
2304            .project(vec![col("a").alias("b")])?
2305            .alias("test2")?;
2306
2307        let plan = table
2308            .clone()
2309            .union(table.build()?)?
2310            .filter(col("b").eq(lit(1i64)))?
2311            .build()?;
2312
2313        // filter appears below Union
2314        assert_optimized_plan_equal!(
2315            plan,
2316            @r"
2317        Union
2318          SubqueryAlias: test2
2319            Projection: test.a AS b
2320              TableScan: test, full_filters=[test.a = Int64(1)]
2321          SubqueryAlias: test2
2322            Projection: test.a AS b
2323              TableScan: test, full_filters=[test.a = Int64(1)]
2324        "
2325        )
2326    }
2327
2328    #[test]
2329    fn test_union_different_schema() -> Result<()> {
2330        let left = LogicalPlanBuilder::from(test_table_scan()?)
2331            .project(vec![col("a"), col("b"), col("c")])?
2332            .build()?;
2333
2334        let schema = Schema::new(vec![
2335            Field::new("d", DataType::UInt32, false),
2336            Field::new("e", DataType::UInt32, false),
2337            Field::new("f", DataType::UInt32, false),
2338        ]);
2339        let right = table_scan(Some("test1"), &schema, None)?
2340            .project(vec![col("d"), col("e"), col("f")])?
2341            .build()?;
2342        let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2343        let plan = LogicalPlanBuilder::from(left)
2344            .cross_join(right)?
2345            .project(vec![col("test.a"), col("test1.d")])?
2346            .filter(filter)?
2347            .build()?;
2348
2349        assert_optimized_plan_equal!(
2350            plan,
2351            @r"
2352        Projection: test.a, test1.d
2353          Cross Join:
2354            Projection: test.a, test.b, test.c
2355              TableScan: test, full_filters=[test.a = Int32(1)]
2356            Projection: test1.d, test1.e, test1.f
2357              TableScan: test1, full_filters=[test1.d > Int32(2)]
2358        "
2359        )
2360    }
2361
2362    #[test]
2363    fn test_project_same_name_different_qualifier() -> Result<()> {
2364        let table_scan = test_table_scan()?;
2365        let left = LogicalPlanBuilder::from(table_scan)
2366            .project(vec![col("a"), col("b"), col("c")])?
2367            .build()?;
2368        let right_table_scan = test_table_scan_with_name("test1")?;
2369        let right = LogicalPlanBuilder::from(right_table_scan)
2370            .project(vec![col("a"), col("b"), col("c")])?
2371            .build()?;
2372        let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2373        let plan = LogicalPlanBuilder::from(left)
2374            .cross_join(right)?
2375            .project(vec![col("test.a"), col("test1.a")])?
2376            .filter(filter)?
2377            .build()?;
2378
2379        assert_optimized_plan_equal!(
2380            plan,
2381            @r"
2382        Projection: test.a, test1.a
2383          Cross Join:
2384            Projection: test.a, test.b, test.c
2385              TableScan: test, full_filters=[test.a = Int32(1)]
2386            Projection: test1.a, test1.b, test1.c
2387              TableScan: test1, full_filters=[test1.a > Int32(2)]
2388        "
2389        )
2390    }
2391
2392    /// verifies that filters with the same columns are correctly placed
2393    #[test]
2394    fn filter_2_breaks_limits() -> Result<()> {
2395        let table_scan = test_table_scan()?;
2396        let plan = LogicalPlanBuilder::from(table_scan)
2397            .project(vec![col("a")])?
2398            .filter(col("a").lt_eq(lit(1i64)))?
2399            .limit(0, Some(1))?
2400            .project(vec![col("a")])?
2401            .filter(col("a").gt_eq(lit(1i64)))?
2402            .build()?;
2403        // Should be able to move both filters below the projections
2404
2405        // not part of the test
2406        assert_snapshot!(plan,
2407        @r"
2408        Filter: test.a >= Int64(1)
2409          Projection: test.a
2410            Limit: skip=0, fetch=1
2411              Filter: test.a <= Int64(1)
2412                Projection: test.a
2413                  TableScan: test
2414        ",
2415        );
2416        assert_optimized_plan_equal!(
2417            plan,
2418            @r"
2419        Projection: test.a
2420          Filter: test.a >= Int64(1)
2421            Limit: skip=0, fetch=1
2422              Projection: test.a
2423                TableScan: test, full_filters=[test.a <= Int64(1)]
2424        "
2425        )
2426    }
2427
2428    /// verifies that filters to be placed on the same depth are ANDed
2429    #[test]
2430    fn two_filters_on_same_depth() -> Result<()> {
2431        let table_scan = test_table_scan()?;
2432        let plan = LogicalPlanBuilder::from(table_scan)
2433            .limit(0, Some(1))?
2434            .filter(col("a").lt_eq(lit(1i64)))?
2435            .filter(col("a").gt_eq(lit(1i64)))?
2436            .project(vec![col("a")])?
2437            .build()?;
2438
2439        // not part of the test
2440        assert_snapshot!(plan,
2441        @r"
2442        Projection: test.a
2443          Filter: test.a >= Int64(1)
2444            Filter: test.a <= Int64(1)
2445              Limit: skip=0, fetch=1
2446                TableScan: test
2447        ",
2448        );
2449        assert_optimized_plan_equal!(
2450            plan,
2451            @r"
2452        Projection: test.a
2453          Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2454            Limit: skip=0, fetch=1
2455              TableScan: test
2456        "
2457        )
2458    }
2459
2460    /// verifies that filters on a plan with user nodes are not lost
2461    /// (ARROW-10547)
2462    #[test]
2463    fn filters_user_defined_node() -> Result<()> {
2464        let table_scan = test_table_scan()?;
2465        let plan = LogicalPlanBuilder::from(table_scan)
2466            .filter(col("a").lt_eq(lit(1i64)))?
2467            .build()?;
2468
2469        let plan = user_defined::new(plan);
2470
2471        // not part of the test
2472        assert_snapshot!(plan,
2473        @r"
2474        TestUserDefined
2475          Filter: test.a <= Int64(1)
2476            TableScan: test
2477        ",
2478        );
2479        assert_optimized_plan_equal!(
2480            plan,
2481            @r"
2482        TestUserDefined
2483          TableScan: test, full_filters=[test.a <= Int64(1)]
2484        "
2485        )
2486    }
2487
2488    /// post-on-join predicates on a column common to both sides is pushed to both sides
2489    #[test]
2490    fn filter_on_join_on_common_independent() -> Result<()> {
2491        let table_scan = test_table_scan()?;
2492        let left = LogicalPlanBuilder::from(table_scan).build()?;
2493        let right_table_scan = test_table_scan_with_name("test2")?;
2494        let right = LogicalPlanBuilder::from(right_table_scan)
2495            .project(vec![col("a")])?
2496            .build()?;
2497        let plan = LogicalPlanBuilder::from(left)
2498            .join(
2499                right,
2500                JoinType::Inner,
2501                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2502                None,
2503            )?
2504            .filter(col("test.a").lt_eq(lit(1i64)))?
2505            .build()?;
2506
2507        // not part of the test, just good to know:
2508        assert_snapshot!(plan,
2509        @r"
2510        Filter: test.a <= Int64(1)
2511          Inner Join: test.a = test2.a
2512            TableScan: test
2513            Projection: test2.a
2514              TableScan: test2
2515        ",
2516        );
2517        // filter sent to side before the join
2518        assert_optimized_plan_equal!(
2519            plan,
2520            @r"
2521        Inner Join: test.a = test2.a
2522          TableScan: test, full_filters=[test.a <= Int64(1)]
2523          Projection: test2.a
2524            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2525        "
2526        )
2527    }
2528
2529    /// post-using-join predicates on a column common to both sides is pushed to both sides
2530    #[test]
2531    fn filter_using_join_on_common_independent() -> Result<()> {
2532        let table_scan = test_table_scan()?;
2533        let left = LogicalPlanBuilder::from(table_scan).build()?;
2534        let right_table_scan = test_table_scan_with_name("test2")?;
2535        let right = LogicalPlanBuilder::from(right_table_scan)
2536            .project(vec![col("a")])?
2537            .build()?;
2538        let plan = LogicalPlanBuilder::from(left)
2539            .join_using(
2540                right,
2541                JoinType::Inner,
2542                vec![Column::from_name("a".to_string())],
2543            )?
2544            .filter(col("a").lt_eq(lit(1i64)))?
2545            .build()?;
2546
2547        // not part of the test, just good to know:
2548        assert_snapshot!(plan,
2549        @r"
2550        Filter: test.a <= Int64(1)
2551          Inner Join: Using test.a = test2.a
2552            TableScan: test
2553            Projection: test2.a
2554              TableScan: test2
2555        ",
2556        );
2557        // filter sent to side before the join
2558        assert_optimized_plan_equal!(
2559            plan,
2560            @r"
2561        Inner Join: Using test.a = test2.a
2562          TableScan: test, full_filters=[test.a <= Int64(1)]
2563          Projection: test2.a
2564            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2565        "
2566        )
2567    }
2568
2569    /// post-join predicates with columns from both sides are converted to join filters
2570    #[test]
2571    fn filter_join_on_common_dependent() -> Result<()> {
2572        let table_scan = test_table_scan()?;
2573        let left = LogicalPlanBuilder::from(table_scan)
2574            .project(vec![col("a"), col("c")])?
2575            .build()?;
2576        let right_table_scan = test_table_scan_with_name("test2")?;
2577        let right = LogicalPlanBuilder::from(right_table_scan)
2578            .project(vec![col("a"), col("b")])?
2579            .build()?;
2580        let plan = LogicalPlanBuilder::from(left)
2581            .join(
2582                right,
2583                JoinType::Inner,
2584                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2585                None,
2586            )?
2587            .filter(col("c").lt_eq(col("b")))?
2588            .build()?;
2589
2590        // not part of the test, just good to know:
2591        assert_snapshot!(plan,
2592        @r"
2593        Filter: test.c <= test2.b
2594          Inner Join: test.a = test2.a
2595            Projection: test.a, test.c
2596              TableScan: test
2597            Projection: test2.a, test2.b
2598              TableScan: test2
2599        ",
2600        );
2601        // Filter is converted to Join Filter
2602        assert_optimized_plan_equal!(
2603            plan,
2604            @r"
2605        Inner Join: test.a = test2.a Filter: test.c <= test2.b
2606          Projection: test.a, test.c
2607            TableScan: test
2608          Projection: test2.a, test2.b
2609            TableScan: test2
2610        "
2611        )
2612    }
2613
2614    /// post-join predicates with columns from one side of a join are pushed only to that side
2615    #[test]
2616    fn filter_join_on_one_side() -> Result<()> {
2617        let table_scan = test_table_scan()?;
2618        let left = LogicalPlanBuilder::from(table_scan)
2619            .project(vec![col("a"), col("b")])?
2620            .build()?;
2621        let table_scan_right = test_table_scan_with_name("test2")?;
2622        let right = LogicalPlanBuilder::from(table_scan_right)
2623            .project(vec![col("a"), col("c")])?
2624            .build()?;
2625
2626        let plan = LogicalPlanBuilder::from(left)
2627            .join(
2628                right,
2629                JoinType::Inner,
2630                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2631                None,
2632            )?
2633            .filter(col("b").lt_eq(lit(1i64)))?
2634            .build()?;
2635
2636        // not part of the test, just good to know:
2637        assert_snapshot!(plan,
2638        @r"
2639        Filter: test.b <= Int64(1)
2640          Inner Join: test.a = test2.a
2641            Projection: test.a, test.b
2642              TableScan: test
2643            Projection: test2.a, test2.c
2644              TableScan: test2
2645        ",
2646        );
2647        assert_optimized_plan_equal!(
2648            plan,
2649            @r"
2650        Inner Join: test.a = test2.a
2651          Projection: test.a, test.b
2652            TableScan: test, full_filters=[test.b <= Int64(1)]
2653          Projection: test2.a, test2.c
2654            TableScan: test2
2655        "
2656        )
2657    }
2658
2659    /// post-join predicates on the right side of a left join are not duplicated
2660    /// TODO: In this case we can sometimes convert the join to an INNER join
2661    #[test]
2662    fn filter_using_left_join() -> Result<()> {
2663        let table_scan = test_table_scan()?;
2664        let left = LogicalPlanBuilder::from(table_scan).build()?;
2665        let right_table_scan = test_table_scan_with_name("test2")?;
2666        let right = LogicalPlanBuilder::from(right_table_scan)
2667            .project(vec![col("a")])?
2668            .build()?;
2669        let plan = LogicalPlanBuilder::from(left)
2670            .join_using(
2671                right,
2672                JoinType::Left,
2673                vec![Column::from_name("a".to_string())],
2674            )?
2675            .filter(col("test2.a").lt_eq(lit(1i64)))?
2676            .build()?;
2677
2678        // not part of the test, just good to know:
2679        assert_snapshot!(plan,
2680        @r"
2681        Filter: test2.a <= Int64(1)
2682          Left Join: Using test.a = test2.a
2683            TableScan: test
2684            Projection: test2.a
2685              TableScan: test2
2686        ",
2687        );
2688        // filter not duplicated nor pushed down - i.e. noop
2689        assert_optimized_plan_equal!(
2690            plan,
2691            @r"
2692        Filter: test2.a <= Int64(1)
2693          Left Join: Using test.a = test2.a
2694            TableScan: test, full_filters=[test.a <= Int64(1)]
2695            Projection: test2.a
2696              TableScan: test2
2697        "
2698        )
2699    }
2700
2701    /// post-join predicates on the left side of a right join are not duplicated
2702    #[test]
2703    fn filter_using_right_join() -> Result<()> {
2704        let table_scan = test_table_scan()?;
2705        let left = LogicalPlanBuilder::from(table_scan).build()?;
2706        let right_table_scan = test_table_scan_with_name("test2")?;
2707        let right = LogicalPlanBuilder::from(right_table_scan)
2708            .project(vec![col("a")])?
2709            .build()?;
2710        let plan = LogicalPlanBuilder::from(left)
2711            .join_using(
2712                right,
2713                JoinType::Right,
2714                vec![Column::from_name("a".to_string())],
2715            )?
2716            .filter(col("test.a").lt_eq(lit(1i64)))?
2717            .build()?;
2718
2719        // not part of the test, just good to know:
2720        assert_snapshot!(plan,
2721        @r"
2722        Filter: test.a <= Int64(1)
2723          Right Join: Using test.a = test2.a
2724            TableScan: test
2725            Projection: test2.a
2726              TableScan: test2
2727        ",
2728        );
2729        // filter not duplicated nor pushed down - i.e. noop
2730        assert_optimized_plan_equal!(
2731            plan,
2732            @r"
2733        Filter: test.a <= Int64(1)
2734          Right Join: Using test.a = test2.a
2735            TableScan: test
2736            Projection: test2.a
2737              TableScan: test2, full_filters=[test2.a <= Int64(1)]
2738        "
2739        )
2740    }
2741
2742    /// post-left-join predicate on a column common to both sides is pushed to both sides
2743    #[test]
2744    fn filter_using_left_join_on_common() -> Result<()> {
2745        let table_scan = test_table_scan()?;
2746        let left = LogicalPlanBuilder::from(table_scan).build()?;
2747        let right_table_scan = test_table_scan_with_name("test2")?;
2748        let right = LogicalPlanBuilder::from(right_table_scan)
2749            .project(vec![col("a")])?
2750            .build()?;
2751        let plan = LogicalPlanBuilder::from(left)
2752            .join_using(
2753                right,
2754                JoinType::Left,
2755                vec![Column::from_name("a".to_string())],
2756            )?
2757            .filter(col("a").lt_eq(lit(1i64)))?
2758            .build()?;
2759
2760        // not part of the test, just good to know:
2761        assert_snapshot!(plan,
2762        @r"
2763        Filter: test.a <= Int64(1)
2764          Left Join: Using test.a = test2.a
2765            TableScan: test
2766            Projection: test2.a
2767              TableScan: test2
2768        ",
2769        );
2770        // filter sent to left side of the join and to the right
2771        assert_optimized_plan_equal!(
2772            plan,
2773            @r"
2774        Left Join: Using test.a = test2.a
2775          TableScan: test, full_filters=[test.a <= Int64(1)]
2776          Projection: test2.a
2777            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2778        "
2779        )
2780    }
2781
2782    /// post-right-join predicate on a column common to both sides is pushed to both sides
2783    #[test]
2784    fn filter_using_right_join_on_common() -> Result<()> {
2785        let table_scan = test_table_scan()?;
2786        let left = LogicalPlanBuilder::from(table_scan).build()?;
2787        let right_table_scan = test_table_scan_with_name("test2")?;
2788        let right = LogicalPlanBuilder::from(right_table_scan)
2789            .project(vec![col("a")])?
2790            .build()?;
2791        let plan = LogicalPlanBuilder::from(left)
2792            .join_using(
2793                right,
2794                JoinType::Right,
2795                vec![Column::from_name("a".to_string())],
2796            )?
2797            .filter(col("test2.a").lt_eq(lit(1i64)))?
2798            .build()?;
2799
2800        // not part of the test, just good to know:
2801        assert_snapshot!(plan,
2802        @r"
2803        Filter: test2.a <= Int64(1)
2804          Right Join: Using test.a = test2.a
2805            TableScan: test
2806            Projection: test2.a
2807              TableScan: test2
2808        ",
2809        );
2810        // filter sent to right side of join, sent to the left as well
2811        assert_optimized_plan_equal!(
2812            plan,
2813            @r"
2814        Right Join: Using test.a = test2.a
2815          TableScan: test, full_filters=[test.a <= Int64(1)]
2816          Projection: test2.a
2817            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2818        "
2819        )
2820    }
2821
2822    /// single table predicate parts of ON condition should be pushed to both inputs
2823    #[test]
2824    fn join_on_with_filter() -> Result<()> {
2825        let table_scan = test_table_scan()?;
2826        let left = LogicalPlanBuilder::from(table_scan)
2827            .project(vec![col("a"), col("b"), col("c")])?
2828            .build()?;
2829        let right_table_scan = test_table_scan_with_name("test2")?;
2830        let right = LogicalPlanBuilder::from(right_table_scan)
2831            .project(vec![col("a"), col("b"), col("c")])?
2832            .build()?;
2833        let filter = col("test.c")
2834            .gt(lit(1u32))
2835            .and(col("test.b").lt(col("test2.b")))
2836            .and(col("test2.c").gt(lit(4u32)));
2837        let plan = LogicalPlanBuilder::from(left)
2838            .join(
2839                right,
2840                JoinType::Inner,
2841                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2842                Some(filter),
2843            )?
2844            .build()?;
2845
2846        // not part of the test, just good to know:
2847        assert_snapshot!(plan,
2848        @r"
2849        Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2850          Projection: test.a, test.b, test.c
2851            TableScan: test
2852          Projection: test2.a, test2.b, test2.c
2853            TableScan: test2
2854        ",
2855        );
2856        assert_optimized_plan_equal!(
2857            plan,
2858            @r"
2859        Inner Join: test.a = test2.a Filter: test.b < test2.b
2860          Projection: test.a, test.b, test.c
2861            TableScan: test, full_filters=[test.c > UInt32(1)]
2862          Projection: test2.a, test2.b, test2.c
2863            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2864        "
2865        )
2866    }
2867
2868    /// join filter should be completely removed after pushdown
2869    #[test]
2870    fn join_filter_removed() -> Result<()> {
2871        let table_scan = test_table_scan()?;
2872        let left = LogicalPlanBuilder::from(table_scan)
2873            .project(vec![col("a"), col("b"), col("c")])?
2874            .build()?;
2875        let right_table_scan = test_table_scan_with_name("test2")?;
2876        let right = LogicalPlanBuilder::from(right_table_scan)
2877            .project(vec![col("a"), col("b"), col("c")])?
2878            .build()?;
2879        let filter = col("test.b")
2880            .gt(lit(1u32))
2881            .and(col("test2.c").gt(lit(4u32)));
2882        let plan = LogicalPlanBuilder::from(left)
2883            .join(
2884                right,
2885                JoinType::Inner,
2886                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2887                Some(filter),
2888            )?
2889            .build()?;
2890
2891        // not part of the test, just good to know:
2892        assert_snapshot!(plan,
2893        @r"
2894        Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2895          Projection: test.a, test.b, test.c
2896            TableScan: test
2897          Projection: test2.a, test2.b, test2.c
2898            TableScan: test2
2899        ",
2900        );
2901        assert_optimized_plan_equal!(
2902            plan,
2903            @r"
2904        Inner Join: test.a = test2.a
2905          Projection: test.a, test.b, test.c
2906            TableScan: test, full_filters=[test.b > UInt32(1)]
2907          Projection: test2.a, test2.b, test2.c
2908            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2909        "
2910        )
2911    }
2912
2913    /// predicate on join key in filter expression should be pushed down to both inputs
2914    #[test]
2915    fn join_filter_on_common() -> Result<()> {
2916        let table_scan = test_table_scan()?;
2917        let left = LogicalPlanBuilder::from(table_scan)
2918            .project(vec![col("a")])?
2919            .build()?;
2920        let right_table_scan = test_table_scan_with_name("test2")?;
2921        let right = LogicalPlanBuilder::from(right_table_scan)
2922            .project(vec![col("b")])?
2923            .build()?;
2924        let filter = col("test.a").gt(lit(1u32));
2925        let plan = LogicalPlanBuilder::from(left)
2926            .join(
2927                right,
2928                JoinType::Inner,
2929                (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2930                Some(filter),
2931            )?
2932            .build()?;
2933
2934        // not part of the test, just good to know:
2935        assert_snapshot!(plan,
2936        @r"
2937        Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2938          Projection: test.a
2939            TableScan: test
2940          Projection: test2.b
2941            TableScan: test2
2942        ",
2943        );
2944        assert_optimized_plan_equal!(
2945            plan,
2946            @r"
2947        Inner Join: test.a = test2.b
2948          Projection: test.a
2949            TableScan: test, full_filters=[test.a > UInt32(1)]
2950          Projection: test2.b
2951            TableScan: test2, full_filters=[test2.b > UInt32(1)]
2952        "
2953        )
2954    }
2955
2956    /// single table predicate parts of ON condition should be pushed to right input
2957    #[test]
2958    fn left_join_on_with_filter() -> Result<()> {
2959        let table_scan = test_table_scan()?;
2960        let left = LogicalPlanBuilder::from(table_scan)
2961            .project(vec![col("a"), col("b"), col("c")])?
2962            .build()?;
2963        let right_table_scan = test_table_scan_with_name("test2")?;
2964        let right = LogicalPlanBuilder::from(right_table_scan)
2965            .project(vec![col("a"), col("b"), col("c")])?
2966            .build()?;
2967        let filter = col("test.a")
2968            .gt(lit(1u32))
2969            .and(col("test.b").lt(col("test2.b")))
2970            .and(col("test2.c").gt(lit(4u32)));
2971        let plan = LogicalPlanBuilder::from(left)
2972            .join(
2973                right,
2974                JoinType::Left,
2975                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2976                Some(filter),
2977            )?
2978            .build()?;
2979
2980        // not part of the test, just good to know:
2981        assert_snapshot!(plan,
2982        @r"
2983        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2984          Projection: test.a, test.b, test.c
2985            TableScan: test
2986          Projection: test2.a, test2.b, test2.c
2987            TableScan: test2
2988        ",
2989        );
2990        assert_optimized_plan_equal!(
2991            plan,
2992            @r"
2993        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2994          Projection: test.a, test.b, test.c
2995            TableScan: test
2996          Projection: test2.a, test2.b, test2.c
2997            TableScan: test2, full_filters=[test2.a > UInt32(1), test2.c > UInt32(4)]
2998        "
2999        )
3000    }
3001
3002    /// single table predicate parts of ON condition should be pushed to left input
3003    #[test]
3004    fn right_join_on_with_filter() -> Result<()> {
3005        let table_scan = test_table_scan()?;
3006        let left = LogicalPlanBuilder::from(table_scan)
3007            .project(vec![col("a"), col("b"), col("c")])?
3008            .build()?;
3009        let right_table_scan = test_table_scan_with_name("test2")?;
3010        let right = LogicalPlanBuilder::from(right_table_scan)
3011            .project(vec![col("a"), col("b"), col("c")])?
3012            .build()?;
3013        let filter = col("test.a")
3014            .gt(lit(1u32))
3015            .and(col("test.b").lt(col("test2.b")))
3016            .and(col("test2.c").gt(lit(4u32)));
3017        let plan = LogicalPlanBuilder::from(left)
3018            .join(
3019                right,
3020                JoinType::Right,
3021                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3022                Some(filter),
3023            )?
3024            .build()?;
3025
3026        // not part of the test, just good to know:
3027        assert_snapshot!(plan,
3028        @r"
3029        Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3030          Projection: test.a, test.b, test.c
3031            TableScan: test
3032          Projection: test2.a, test2.b, test2.c
3033            TableScan: test2
3034        ",
3035        );
3036        assert_optimized_plan_equal!(
3037            plan,
3038            @r"
3039        Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
3040          Projection: test.a, test.b, test.c
3041            TableScan: test, full_filters=[test.a > UInt32(1)]
3042          Projection: test2.a, test2.b, test2.c
3043            TableScan: test2
3044        "
3045        )
3046    }
3047
3048    /// single table predicate parts of ON condition should not be pushed
3049    #[test]
3050    fn full_join_on_with_filter() -> Result<()> {
3051        let table_scan = test_table_scan()?;
3052        let left = LogicalPlanBuilder::from(table_scan)
3053            .project(vec![col("a"), col("b"), col("c")])?
3054            .build()?;
3055        let right_table_scan = test_table_scan_with_name("test2")?;
3056        let right = LogicalPlanBuilder::from(right_table_scan)
3057            .project(vec![col("a"), col("b"), col("c")])?
3058            .build()?;
3059        let filter = col("test.a")
3060            .gt(lit(1u32))
3061            .and(col("test.b").lt(col("test2.b")))
3062            .and(col("test2.c").gt(lit(4u32)));
3063        let plan = LogicalPlanBuilder::from(left)
3064            .join(
3065                right,
3066                JoinType::Full,
3067                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3068                Some(filter),
3069            )?
3070            .build()?;
3071
3072        // not part of the test, just good to know:
3073        assert_snapshot!(plan,
3074        @r"
3075        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3076          Projection: test.a, test.b, test.c
3077            TableScan: test
3078          Projection: test2.a, test2.b, test2.c
3079            TableScan: test2
3080        ",
3081        );
3082        assert_optimized_plan_equal!(
3083            plan,
3084            @r"
3085        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3086          Projection: test.a, test.b, test.c
3087            TableScan: test
3088          Projection: test2.a, test2.b, test2.c
3089            TableScan: test2
3090        "
3091        )
3092    }
3093
3094    struct PushDownProvider {
3095        pub filter_support: TableProviderFilterPushDown,
3096    }
3097
3098    #[async_trait]
3099    impl TableSource for PushDownProvider {
3100        fn schema(&self) -> SchemaRef {
3101            Arc::new(Schema::new(vec![
3102                Field::new("a", DataType::Int32, true),
3103                Field::new("b", DataType::Int32, true),
3104            ]))
3105        }
3106
3107        fn table_type(&self) -> TableType {
3108            TableType::Base
3109        }
3110
3111        fn supports_filters_pushdown(
3112            &self,
3113            filters: &[&Expr],
3114        ) -> Result<Vec<TableProviderFilterPushDown>> {
3115            Ok((0..filters.len())
3116                .map(|_| self.filter_support.clone())
3117                .collect())
3118        }
3119
3120        fn as_any(&self) -> &dyn Any {
3121            self
3122        }
3123    }
3124
3125    fn table_scan_with_pushdown_provider_builder(
3126        filter_support: TableProviderFilterPushDown,
3127        filters: Vec<Expr>,
3128        projection: Option<Vec<usize>>,
3129    ) -> Result<LogicalPlanBuilder> {
3130        let test_provider = PushDownProvider { filter_support };
3131
3132        let table_scan = LogicalPlan::TableScan(TableScan {
3133            table_name: "test".into(),
3134            filters,
3135            projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3136            projection,
3137            source: Arc::new(test_provider),
3138            fetch: None,
3139        });
3140
3141        Ok(LogicalPlanBuilder::from(table_scan))
3142    }
3143
3144    fn table_scan_with_pushdown_provider(
3145        filter_support: TableProviderFilterPushDown,
3146    ) -> Result<LogicalPlan> {
3147        table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3148            .filter(col("a").eq(lit(1i64)))?
3149            .build()
3150    }
3151
3152    #[test]
3153    fn filter_with_table_provider_exact() -> Result<()> {
3154        let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3155
3156        assert_optimized_plan_equal!(
3157            plan,
3158            @"TableScan: test, full_filters=[a = Int64(1)]"
3159        )
3160    }
3161
3162    #[test]
3163    fn filter_with_table_provider_inexact() -> Result<()> {
3164        let plan =
3165            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3166
3167        assert_optimized_plan_equal!(
3168            plan,
3169            @r"
3170        Filter: a = Int64(1)
3171          TableScan: test, partial_filters=[a = Int64(1)]
3172        "
3173        )
3174    }
3175
3176    #[test]
3177    fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3178        let plan =
3179            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3180
3181        let optimized_plan = PushDownFilter::new()
3182            .rewrite(plan, &OptimizerContext::new())
3183            .expect("failed to optimize plan")
3184            .data;
3185
3186        // Optimizing the same plan multiple times should produce the same plan
3187        // each time.
3188        assert_optimized_plan_equal!(
3189            optimized_plan,
3190            @r"
3191        Filter: a = Int64(1)
3192          TableScan: test, partial_filters=[a = Int64(1)]
3193        "
3194        )
3195    }
3196
3197    #[test]
3198    fn filter_with_table_provider_unsupported() -> Result<()> {
3199        let plan =
3200            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3201
3202        assert_optimized_plan_equal!(
3203            plan,
3204            @r"
3205        Filter: a = Int64(1)
3206          TableScan: test
3207        "
3208        )
3209    }
3210
3211    #[test]
3212    fn multi_combined_filter() -> Result<()> {
3213        let plan = table_scan_with_pushdown_provider_builder(
3214            TableProviderFilterPushDown::Inexact,
3215            vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3216            Some(vec![0]),
3217        )?
3218        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3219        .project(vec![col("a"), col("b")])?
3220        .build()?;
3221
3222        assert_optimized_plan_equal!(
3223            plan,
3224            @r"
3225        Projection: a, b
3226          Filter: a = Int64(10) AND b > Int64(11)
3227            TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3228        "
3229        )
3230    }
3231
3232    #[test]
3233    fn multi_combined_filter_exact() -> Result<()> {
3234        let plan = table_scan_with_pushdown_provider_builder(
3235            TableProviderFilterPushDown::Exact,
3236            vec![],
3237            Some(vec![0]),
3238        )?
3239        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3240        .project(vec![col("a"), col("b")])?
3241        .build()?;
3242
3243        assert_optimized_plan_equal!(
3244            plan,
3245            @r"
3246        Projection: a, b
3247          TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3248        "
3249        )
3250    }
3251
3252    #[test]
3253    fn test_filter_with_alias() -> Result<()> {
3254        // in table scan the true col name is 'test.a',
3255        // but we rename it as 'b', and use col 'b' in filter
3256        // we need rewrite filter col before push down.
3257        let table_scan = test_table_scan()?;
3258        let plan = LogicalPlanBuilder::from(table_scan)
3259            .project(vec![col("a").alias("b"), col("c")])?
3260            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3261            .build()?;
3262
3263        // filter on col b
3264        assert_snapshot!(plan,
3265        @r"
3266        Filter: b > Int64(10) AND test.c > Int64(10)
3267          Projection: test.a AS b, test.c
3268            TableScan: test
3269        ",
3270        );
3271        // rewrite filter col b to test.a
3272        assert_optimized_plan_equal!(
3273            plan,
3274            @r"
3275        Projection: test.a AS b, test.c
3276          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3277        "
3278        )
3279    }
3280
3281    #[test]
3282    fn test_filter_with_alias_2() -> Result<()> {
3283        // in table scan the true col name is 'test.a',
3284        // but we rename it as 'b', and use col 'b' in filter
3285        // we need rewrite filter col before push down.
3286        let table_scan = test_table_scan()?;
3287        let plan = LogicalPlanBuilder::from(table_scan)
3288            .project(vec![col("a").alias("b"), col("c")])?
3289            .project(vec![col("b"), col("c")])?
3290            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3291            .build()?;
3292
3293        // filter on col b
3294        assert_snapshot!(plan,
3295        @r"
3296        Filter: b > Int64(10) AND test.c > Int64(10)
3297          Projection: b, test.c
3298            Projection: test.a AS b, test.c
3299              TableScan: test
3300        ",
3301        );
3302        // rewrite filter col b to test.a
3303        assert_optimized_plan_equal!(
3304            plan,
3305            @r"
3306        Projection: b, test.c
3307          Projection: test.a AS b, test.c
3308            TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3309        "
3310        )
3311    }
3312
3313    #[test]
3314    fn test_filter_with_multi_alias() -> Result<()> {
3315        let table_scan = test_table_scan()?;
3316        let plan = LogicalPlanBuilder::from(table_scan)
3317            .project(vec![col("a").alias("b"), col("c").alias("d")])?
3318            .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3319            .build()?;
3320
3321        // filter on col b and d
3322        assert_snapshot!(plan,
3323        @r"
3324        Filter: b > Int64(10) AND d > Int64(10)
3325          Projection: test.a AS b, test.c AS d
3326            TableScan: test
3327        ",
3328        );
3329        // rewrite filter col b to test.a, col d to test.c
3330        assert_optimized_plan_equal!(
3331            plan,
3332            @r"
3333        Projection: test.a AS b, test.c AS d
3334          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3335        "
3336        )
3337    }
3338
3339    /// predicate on join key in filter expression should be pushed down to both inputs
3340    #[test]
3341    fn join_filter_with_alias() -> Result<()> {
3342        let table_scan = test_table_scan()?;
3343        let left = LogicalPlanBuilder::from(table_scan)
3344            .project(vec![col("a").alias("c")])?
3345            .build()?;
3346        let right_table_scan = test_table_scan_with_name("test2")?;
3347        let right = LogicalPlanBuilder::from(right_table_scan)
3348            .project(vec![col("b").alias("d")])?
3349            .build()?;
3350        let filter = col("c").gt(lit(1u32));
3351        let plan = LogicalPlanBuilder::from(left)
3352            .join(
3353                right,
3354                JoinType::Inner,
3355                (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3356                Some(filter),
3357            )?
3358            .build()?;
3359
3360        assert_snapshot!(plan,
3361        @r"
3362        Inner Join: c = d Filter: c > UInt32(1)
3363          Projection: test.a AS c
3364            TableScan: test
3365          Projection: test2.b AS d
3366            TableScan: test2
3367        ",
3368        );
3369        // Change filter on col `c`, 'd' to `test.a`, 'test.b'
3370        assert_optimized_plan_equal!(
3371            plan,
3372            @r"
3373        Inner Join: c = d
3374          Projection: test.a AS c
3375            TableScan: test, full_filters=[test.a > UInt32(1)]
3376          Projection: test2.b AS d
3377            TableScan: test2, full_filters=[test2.b > UInt32(1)]
3378        "
3379        )
3380    }
3381
3382    #[test]
3383    fn test_in_filter_with_alias() -> Result<()> {
3384        // in table scan the true col name is 'test.a',
3385        // but we rename it as 'b', and use col 'b' in filter
3386        // we need rewrite filter col before push down.
3387        let table_scan = test_table_scan()?;
3388        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3389        let plan = LogicalPlanBuilder::from(table_scan)
3390            .project(vec![col("a").alias("b"), col("c")])?
3391            .filter(in_list(col("b"), filter_value, false))?
3392            .build()?;
3393
3394        // filter on col b
3395        assert_snapshot!(plan,
3396        @r"
3397        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3398          Projection: test.a AS b, test.c
3399            TableScan: test
3400        ",
3401        );
3402        // rewrite filter col b to test.a
3403        assert_optimized_plan_equal!(
3404            plan,
3405            @r"
3406        Projection: test.a AS b, test.c
3407          TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3408        "
3409        )
3410    }
3411
3412    #[test]
3413    fn test_in_filter_with_alias_2() -> Result<()> {
3414        // in table scan the true col name is 'test.a',
3415        // but we rename it as 'b', and use col 'b' in filter
3416        // we need rewrite filter col before push down.
3417        let table_scan = test_table_scan()?;
3418        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3419        let plan = LogicalPlanBuilder::from(table_scan)
3420            .project(vec![col("a").alias("b"), col("c")])?
3421            .project(vec![col("b"), col("c")])?
3422            .filter(in_list(col("b"), filter_value, false))?
3423            .build()?;
3424
3425        // filter on col b
3426        assert_snapshot!(plan,
3427        @r"
3428        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3429          Projection: b, test.c
3430            Projection: test.a AS b, test.c
3431              TableScan: test
3432        ",
3433        );
3434        // rewrite filter col b to test.a
3435        assert_optimized_plan_equal!(
3436            plan,
3437            @r"
3438        Projection: b, test.c
3439          Projection: test.a AS b, test.c
3440            TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3441        "
3442        )
3443    }
3444
3445    #[test]
3446    fn test_in_subquery_with_alias() -> Result<()> {
3447        // in table scan the true col name is 'test.a',
3448        // but we rename it as 'b', and use col 'b' in subquery filter
3449        let table_scan = test_table_scan()?;
3450        let table_scan_sq = test_table_scan_with_name("sq")?;
3451        let subplan = Arc::new(
3452            LogicalPlanBuilder::from(table_scan_sq)
3453                .project(vec![col("c")])?
3454                .build()?,
3455        );
3456        let plan = LogicalPlanBuilder::from(table_scan)
3457            .project(vec![col("a").alias("b"), col("c")])?
3458            .filter(in_subquery(col("b"), subplan))?
3459            .build()?;
3460
3461        // filter on col b in subquery
3462        assert_snapshot!(plan,
3463        @r"
3464        Filter: b IN (<subquery>)
3465          Subquery:
3466            Projection: sq.c
3467              TableScan: sq
3468          Projection: test.a AS b, test.c
3469            TableScan: test
3470        ",
3471        );
3472        // rewrite filter col b to test.a
3473        assert_optimized_plan_equal!(
3474            plan,
3475            @r"
3476        Projection: test.a AS b, test.c
3477          TableScan: test, full_filters=[test.a IN (<subquery>)]
3478            Subquery:
3479              Projection: sq.c
3480                TableScan: sq
3481        "
3482        )
3483    }
3484
3485    #[test]
3486    fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3487        // SELECT a FROM (SELECT 1 AS a) b WHERE b.a = 1
3488        let plan = LogicalPlanBuilder::empty(true)
3489            .project(vec![lit(0i64).alias("a")])?
3490            .alias("b")?
3491            .project(vec![col("b.a")])?
3492            .alias("b")?
3493            .filter(col("b.a").eq(lit(1i64)))?
3494            .project(vec![col("b.a")])?
3495            .build()?;
3496
3497        assert_snapshot!(plan,
3498        @r"
3499        Projection: b.a
3500          Filter: b.a = Int64(1)
3501            SubqueryAlias: b
3502              Projection: b.a
3503                SubqueryAlias: b
3504                  Projection: Int64(0) AS a
3505                    EmptyRelation: rows=1
3506        ",
3507        );
3508        // Ensure that the predicate without any columns (0 = 1) is
3509        // still there.
3510        assert_optimized_plan_equal!(
3511            plan,
3512            @r"
3513        Projection: b.a
3514          SubqueryAlias: b
3515            Projection: b.a
3516              SubqueryAlias: b
3517                Projection: Int64(0) AS a
3518                  Filter: Int64(0) = Int64(1)
3519                    EmptyRelation: rows=1
3520        "
3521        )
3522    }
3523
3524    #[test]
3525    fn test_crossjoin_with_or_clause() -> Result<()> {
3526        // select * from test,test1 where (test.a = test1.a and test.b > 1) or (test.b = test1.b and test.c < 10);
3527        let table_scan = test_table_scan()?;
3528        let left = LogicalPlanBuilder::from(table_scan)
3529            .project(vec![col("a"), col("b"), col("c")])?
3530            .build()?;
3531        let right_table_scan = test_table_scan_with_name("test1")?;
3532        let right = LogicalPlanBuilder::from(right_table_scan)
3533            .project(vec![col("a").alias("d"), col("a").alias("e")])?
3534            .build()?;
3535        let filter = or(
3536            and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3537            and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3538        );
3539        let plan = LogicalPlanBuilder::from(left)
3540            .cross_join(right)?
3541            .filter(filter)?
3542            .build()?;
3543
3544        assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3545        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3546          Projection: test.a, test.b, test.c
3547            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3548          Projection: test1.a AS d, test1.a AS e
3549            TableScan: test1
3550        ")?;
3551
3552        // Originally global state which can help to avoid duplicate Filters been generated and pushed down.
3553        // Now the global state is removed. Need to double confirm that avoid duplicate Filters.
3554        let optimized_plan = PushDownFilter::new()
3555            .rewrite(plan, &OptimizerContext::new())
3556            .expect("failed to optimize plan")
3557            .data;
3558        assert_optimized_plan_equal!(
3559            optimized_plan,
3560            @r"
3561        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3562          Projection: test.a, test.b, test.c
3563            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3564          Projection: test1.a AS d, test1.a AS e
3565            TableScan: test1
3566        "
3567        )
3568    }
3569
3570    #[test]
3571    fn left_semi_join() -> Result<()> {
3572        let left = test_table_scan_with_name("test1")?;
3573        let right_table_scan = test_table_scan_with_name("test2")?;
3574        let right = LogicalPlanBuilder::from(right_table_scan)
3575            .project(vec![col("a"), col("b")])?
3576            .build()?;
3577        let plan = LogicalPlanBuilder::from(left)
3578            .join(
3579                right,
3580                JoinType::LeftSemi,
3581                (
3582                    vec![Column::from_qualified_name("test1.a")],
3583                    vec![Column::from_qualified_name("test2.a")],
3584                ),
3585                None,
3586            )?
3587            .filter(col("test2.a").lt_eq(lit(1i64)))?
3588            .build()?;
3589
3590        // not part of the test, just good to know:
3591        assert_snapshot!(plan,
3592        @r"
3593        Filter: test2.a <= Int64(1)
3594          LeftSemi Join: test1.a = test2.a
3595            TableScan: test1
3596            Projection: test2.a, test2.b
3597              TableScan: test2
3598        ",
3599        );
3600        // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side.
3601        assert_optimized_plan_equal!(
3602            plan,
3603            @r"
3604        Filter: test2.a <= Int64(1)
3605          LeftSemi Join: test1.a = test2.a
3606            TableScan: test1, full_filters=[test1.a <= Int64(1)]
3607            Projection: test2.a, test2.b
3608              TableScan: test2
3609        "
3610        )
3611    }
3612
3613    #[test]
3614    fn left_semi_join_with_filters() -> Result<()> {
3615        let left = test_table_scan_with_name("test1")?;
3616        let right_table_scan = test_table_scan_with_name("test2")?;
3617        let right = LogicalPlanBuilder::from(right_table_scan)
3618            .project(vec![col("a"), col("b")])?
3619            .build()?;
3620        let plan = LogicalPlanBuilder::from(left)
3621            .join(
3622                right,
3623                JoinType::LeftSemi,
3624                (
3625                    vec![Column::from_qualified_name("test1.a")],
3626                    vec![Column::from_qualified_name("test2.a")],
3627                ),
3628                Some(
3629                    col("test1.b")
3630                        .gt(lit(1u32))
3631                        .and(col("test2.b").gt(lit(2u32))),
3632                ),
3633            )?
3634            .build()?;
3635
3636        // not part of the test, just good to know:
3637        assert_snapshot!(plan,
3638        @r"
3639        LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3640          TableScan: test1
3641          Projection: test2.a, test2.b
3642            TableScan: test2
3643        ",
3644        );
3645        // Both side will be pushed down.
3646        assert_optimized_plan_equal!(
3647            plan,
3648            @r"
3649        LeftSemi Join: test1.a = test2.a
3650          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3651          Projection: test2.a, test2.b
3652            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3653        "
3654        )
3655    }
3656
3657    #[test]
3658    fn right_semi_join() -> Result<()> {
3659        let left = test_table_scan_with_name("test1")?;
3660        let right_table_scan = test_table_scan_with_name("test2")?;
3661        let right = LogicalPlanBuilder::from(right_table_scan)
3662            .project(vec![col("a"), col("b")])?
3663            .build()?;
3664        let plan = LogicalPlanBuilder::from(left)
3665            .join(
3666                right,
3667                JoinType::RightSemi,
3668                (
3669                    vec![Column::from_qualified_name("test1.a")],
3670                    vec![Column::from_qualified_name("test2.a")],
3671                ),
3672                None,
3673            )?
3674            .filter(col("test1.a").lt_eq(lit(1i64)))?
3675            .build()?;
3676
3677        // not part of the test, just good to know:
3678        assert_snapshot!(plan,
3679        @r"
3680        Filter: test1.a <= Int64(1)
3681          RightSemi Join: test1.a = test2.a
3682            TableScan: test1
3683            Projection: test2.a, test2.b
3684              TableScan: test2
3685        ",
3686        );
3687        // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side.
3688        assert_optimized_plan_equal!(
3689            plan,
3690            @r"
3691        Filter: test1.a <= Int64(1)
3692          RightSemi Join: test1.a = test2.a
3693            TableScan: test1
3694            Projection: test2.a, test2.b
3695              TableScan: test2, full_filters=[test2.a <= Int64(1)]
3696        "
3697        )
3698    }
3699
3700    #[test]
3701    fn right_semi_join_with_filters() -> Result<()> {
3702        let left = test_table_scan_with_name("test1")?;
3703        let right_table_scan = test_table_scan_with_name("test2")?;
3704        let right = LogicalPlanBuilder::from(right_table_scan)
3705            .project(vec![col("a"), col("b")])?
3706            .build()?;
3707        let plan = LogicalPlanBuilder::from(left)
3708            .join(
3709                right,
3710                JoinType::RightSemi,
3711                (
3712                    vec![Column::from_qualified_name("test1.a")],
3713                    vec![Column::from_qualified_name("test2.a")],
3714                ),
3715                Some(
3716                    col("test1.b")
3717                        .gt(lit(1u32))
3718                        .and(col("test2.b").gt(lit(2u32))),
3719                ),
3720            )?
3721            .build()?;
3722
3723        // not part of the test, just good to know:
3724        assert_snapshot!(plan,
3725        @r"
3726        RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3727          TableScan: test1
3728          Projection: test2.a, test2.b
3729            TableScan: test2
3730        ",
3731        );
3732        // Both side will be pushed down.
3733        assert_optimized_plan_equal!(
3734            plan,
3735            @r"
3736        RightSemi Join: test1.a = test2.a
3737          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3738          Projection: test2.a, test2.b
3739            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3740        "
3741        )
3742    }
3743
3744    #[test]
3745    fn left_anti_join() -> Result<()> {
3746        let table_scan = test_table_scan_with_name("test1")?;
3747        let left = LogicalPlanBuilder::from(table_scan)
3748            .project(vec![col("a"), col("b")])?
3749            .build()?;
3750        let right_table_scan = test_table_scan_with_name("test2")?;
3751        let right = LogicalPlanBuilder::from(right_table_scan)
3752            .project(vec![col("a"), col("b")])?
3753            .build()?;
3754        let plan = LogicalPlanBuilder::from(left)
3755            .join(
3756                right,
3757                JoinType::LeftAnti,
3758                (
3759                    vec![Column::from_qualified_name("test1.a")],
3760                    vec![Column::from_qualified_name("test2.a")],
3761                ),
3762                None,
3763            )?
3764            .filter(col("test2.a").gt(lit(2u32)))?
3765            .build()?;
3766
3767        // not part of the test, just good to know:
3768        assert_snapshot!(plan,
3769        @r"
3770        Filter: test2.a > UInt32(2)
3771          LeftAnti Join: test1.a = test2.a
3772            Projection: test1.a, test1.b
3773              TableScan: test1
3774            Projection: test2.a, test2.b
3775              TableScan: test2
3776        ",
3777        );
3778        // For left anti, filter of the right side filter can be pushed down.
3779        assert_optimized_plan_equal!(
3780            plan,
3781            @r"
3782        Filter: test2.a > UInt32(2)
3783          LeftAnti Join: test1.a = test2.a
3784            Projection: test1.a, test1.b
3785              TableScan: test1, full_filters=[test1.a > UInt32(2)]
3786            Projection: test2.a, test2.b
3787              TableScan: test2
3788        "
3789        )
3790    }
3791
3792    #[test]
3793    fn left_anti_join_with_filters() -> Result<()> {
3794        let table_scan = test_table_scan_with_name("test1")?;
3795        let left = LogicalPlanBuilder::from(table_scan)
3796            .project(vec![col("a"), col("b")])?
3797            .build()?;
3798        let right_table_scan = test_table_scan_with_name("test2")?;
3799        let right = LogicalPlanBuilder::from(right_table_scan)
3800            .project(vec![col("a"), col("b")])?
3801            .build()?;
3802        let plan = LogicalPlanBuilder::from(left)
3803            .join(
3804                right,
3805                JoinType::LeftAnti,
3806                (
3807                    vec![Column::from_qualified_name("test1.a")],
3808                    vec![Column::from_qualified_name("test2.a")],
3809                ),
3810                Some(
3811                    col("test1.b")
3812                        .gt(lit(1u32))
3813                        .and(col("test2.b").gt(lit(2u32))),
3814                ),
3815            )?
3816            .build()?;
3817
3818        // not part of the test, just good to know:
3819        assert_snapshot!(plan,
3820        @r"
3821        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3822          Projection: test1.a, test1.b
3823            TableScan: test1
3824          Projection: test2.a, test2.b
3825            TableScan: test2
3826        ",
3827        );
3828        // For left anti, filter of the right side filter can be pushed down.
3829        assert_optimized_plan_equal!(
3830            plan,
3831            @r"
3832        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3833          Projection: test1.a, test1.b
3834            TableScan: test1
3835          Projection: test2.a, test2.b
3836            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3837        "
3838        )
3839    }
3840
3841    #[test]
3842    fn right_anti_join() -> Result<()> {
3843        let table_scan = test_table_scan_with_name("test1")?;
3844        let left = LogicalPlanBuilder::from(table_scan)
3845            .project(vec![col("a"), col("b")])?
3846            .build()?;
3847        let right_table_scan = test_table_scan_with_name("test2")?;
3848        let right = LogicalPlanBuilder::from(right_table_scan)
3849            .project(vec![col("a"), col("b")])?
3850            .build()?;
3851        let plan = LogicalPlanBuilder::from(left)
3852            .join(
3853                right,
3854                JoinType::RightAnti,
3855                (
3856                    vec![Column::from_qualified_name("test1.a")],
3857                    vec![Column::from_qualified_name("test2.a")],
3858                ),
3859                None,
3860            )?
3861            .filter(col("test1.a").gt(lit(2u32)))?
3862            .build()?;
3863
3864        // not part of the test, just good to know:
3865        assert_snapshot!(plan,
3866        @r"
3867        Filter: test1.a > UInt32(2)
3868          RightAnti Join: test1.a = test2.a
3869            Projection: test1.a, test1.b
3870              TableScan: test1
3871            Projection: test2.a, test2.b
3872              TableScan: test2
3873        ",
3874        );
3875        // For right anti, filter of the left side can be pushed down.
3876        assert_optimized_plan_equal!(
3877            plan,
3878            @r"
3879        Filter: test1.a > UInt32(2)
3880          RightAnti Join: test1.a = test2.a
3881            Projection: test1.a, test1.b
3882              TableScan: test1
3883            Projection: test2.a, test2.b
3884              TableScan: test2, full_filters=[test2.a > UInt32(2)]
3885        "
3886        )
3887    }
3888
3889    #[test]
3890    fn right_anti_join_with_filters() -> Result<()> {
3891        let table_scan = test_table_scan_with_name("test1")?;
3892        let left = LogicalPlanBuilder::from(table_scan)
3893            .project(vec![col("a"), col("b")])?
3894            .build()?;
3895        let right_table_scan = test_table_scan_with_name("test2")?;
3896        let right = LogicalPlanBuilder::from(right_table_scan)
3897            .project(vec![col("a"), col("b")])?
3898            .build()?;
3899        let plan = LogicalPlanBuilder::from(left)
3900            .join(
3901                right,
3902                JoinType::RightAnti,
3903                (
3904                    vec![Column::from_qualified_name("test1.a")],
3905                    vec![Column::from_qualified_name("test2.a")],
3906                ),
3907                Some(
3908                    col("test1.b")
3909                        .gt(lit(1u32))
3910                        .and(col("test2.b").gt(lit(2u32))),
3911                ),
3912            )?
3913            .build()?;
3914
3915        // not part of the test, just good to know:
3916        assert_snapshot!(plan,
3917        @r"
3918        RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3919          Projection: test1.a, test1.b
3920            TableScan: test1
3921          Projection: test2.a, test2.b
3922            TableScan: test2
3923        ",
3924        );
3925        // For right anti, filter of the left side can be pushed down.
3926        assert_optimized_plan_equal!(
3927            plan,
3928            @r"
3929        RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3930          Projection: test1.a, test1.b
3931            TableScan: test1, full_filters=[test1.b > UInt32(1)]
3932          Projection: test2.a, test2.b
3933            TableScan: test2
3934        "
3935        )
3936    }
3937
3938    #[derive(Debug, PartialEq, Eq, Hash)]
3939    struct TestScalarUDF {
3940        signature: Signature,
3941    }
3942
3943    impl ScalarUDFImpl for TestScalarUDF {
3944        fn as_any(&self) -> &dyn Any {
3945            self
3946        }
3947        fn name(&self) -> &str {
3948            "TestScalarUDF"
3949        }
3950
3951        fn signature(&self) -> &Signature {
3952            &self.signature
3953        }
3954
3955        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3956            Ok(DataType::Int32)
3957        }
3958
3959        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3960            Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3961        }
3962    }
3963
3964    #[test]
3965    fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3966        // SELECT t.a, t.r FROM (SELECT a, sum(b),  TestScalarUDF()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5;
3967        let table_scan = test_table_scan_with_name("test1")?;
3968        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3969            signature: Signature::exact(vec![], Volatility::Volatile),
3970        });
3971        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3972
3973        let plan = LogicalPlanBuilder::from(table_scan)
3974            .aggregate(vec![col("a")], vec![sum(col("b"))])?
3975            .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3976            .alias("t")?
3977            .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3978            .project(vec![col("t.a"), col("t.r")])?
3979            .build()?;
3980
3981        assert_snapshot!(plan,
3982        @r"
3983        Projection: t.a, t.r
3984          Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3985            SubqueryAlias: t
3986              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3987                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3988                  TableScan: test1
3989        ",
3990        );
3991        assert_optimized_plan_equal!(
3992            plan,
3993            @r"
3994        Projection: t.a, t.r
3995          SubqueryAlias: t
3996            Filter: r > Float64(0.5)
3997              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3998                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3999                  TableScan: test1, full_filters=[test1.a > Int32(5)]
4000        "
4001        )
4002    }
4003
4004    #[test]
4005    fn test_push_down_volatile_function_in_join() -> Result<()> {
4006        // SELECT t.a, t.r FROM (SELECT test1.a AS a, TestScalarUDF() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5;
4007        let table_scan = test_table_scan_with_name("test1")?;
4008        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4009            signature: Signature::exact(vec![], Volatility::Volatile),
4010        });
4011        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4012        let left = LogicalPlanBuilder::from(table_scan).build()?;
4013        let right_table_scan = test_table_scan_with_name("test2")?;
4014        let right = LogicalPlanBuilder::from(right_table_scan).build()?;
4015        let plan = LogicalPlanBuilder::from(left)
4016            .join(
4017                right,
4018                JoinType::Inner,
4019                (
4020                    vec![Column::from_qualified_name("test1.a")],
4021                    vec![Column::from_qualified_name("test2.a")],
4022                ),
4023                None,
4024            )?
4025            .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
4026            .alias("t")?
4027            .filter(col("t.r").gt(lit(0.8)))?
4028            .project(vec![col("t.a"), col("t.r")])?
4029            .build()?;
4030
4031        assert_snapshot!(plan,
4032        @r"
4033        Projection: t.a, t.r
4034          Filter: t.r > Float64(0.8)
4035            SubqueryAlias: t
4036              Projection: test1.a AS a, TestScalarUDF() AS r
4037                Inner Join: test1.a = test2.a
4038                  TableScan: test1
4039                  TableScan: test2
4040        ",
4041        );
4042        assert_optimized_plan_equal!(
4043            plan,
4044            @r"
4045        Projection: t.a, t.r
4046          SubqueryAlias: t
4047            Filter: r > Float64(0.8)
4048              Projection: test1.a AS a, TestScalarUDF() AS r
4049                Inner Join: test1.a = test2.a
4050                  TableScan: test1
4051                  TableScan: test2
4052        "
4053        )
4054    }
4055
4056    #[test]
4057    fn test_push_down_volatile_table_scan() -> Result<()> {
4058        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1;
4059        let table_scan = test_table_scan()?;
4060        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4061            signature: Signature::exact(vec![], Volatility::Volatile),
4062        });
4063        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4064        let plan = LogicalPlanBuilder::from(table_scan)
4065            .project(vec![col("a"), col("b")])?
4066            .filter(expr.gt(lit(0.1)))?
4067            .build()?;
4068
4069        assert_snapshot!(plan,
4070        @r"
4071        Filter: TestScalarUDF() > Float64(0.1)
4072          Projection: test.a, test.b
4073            TableScan: test
4074        ",
4075        );
4076        assert_optimized_plan_equal!(
4077            plan,
4078            @r"
4079        Projection: test.a, test.b
4080          Filter: TestScalarUDF() > Float64(0.1)
4081            TableScan: test
4082        "
4083        )
4084    }
4085
4086    #[test]
4087    fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4088        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4089        let table_scan = test_table_scan()?;
4090        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4091            signature: Signature::exact(vec![], Volatility::Volatile),
4092        });
4093        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4094        let plan = LogicalPlanBuilder::from(table_scan)
4095            .project(vec![col("a"), col("b")])?
4096            .filter(
4097                expr.gt(lit(0.1))
4098                    .and(col("t.a").gt(lit(5)))
4099                    .and(col("t.b").gt(lit(10))),
4100            )?
4101            .build()?;
4102
4103        assert_snapshot!(plan,
4104        @r"
4105        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4106          Projection: test.a, test.b
4107            TableScan: test
4108        ",
4109        );
4110        assert_optimized_plan_equal!(
4111            plan,
4112            @r"
4113        Projection: test.a, test.b
4114          Filter: TestScalarUDF() > Float64(0.1)
4115            TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4116        "
4117        )
4118    }
4119
4120    #[test]
4121    fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4122        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4123        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4124            signature: Signature::exact(vec![], Volatility::Volatile),
4125        });
4126        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4127        let plan = table_scan_with_pushdown_provider_builder(
4128            TableProviderFilterPushDown::Unsupported,
4129            vec![],
4130            None,
4131        )?
4132        .project(vec![col("a"), col("b")])?
4133        .filter(
4134            expr.gt(lit(0.1))
4135                .and(col("t.a").gt(lit(5)))
4136                .and(col("t.b").gt(lit(10))),
4137        )?
4138        .build()?;
4139
4140        assert_snapshot!(plan,
4141        @r"
4142        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4143          Projection: a, b
4144            TableScan: test
4145        ",
4146        );
4147        assert_optimized_plan_equal!(
4148            plan,
4149            @r"
4150        Projection: a, b
4151          Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4152            TableScan: test
4153        "
4154        )
4155    }
4156
4157    #[test]
4158    fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4159        // Define a custom user-defined logical node
4160        #[derive(Debug, Hash, Eq, PartialEq)]
4161        struct TestUserNode {
4162            schema: DFSchemaRef,
4163        }
4164
4165        impl PartialOrd for TestUserNode {
4166            fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4167                None
4168            }
4169        }
4170
4171        impl TestUserNode {
4172            fn new() -> Self {
4173                let schema = Arc::new(
4174                    DFSchema::new_with_metadata(
4175                        vec![(None, Field::new("a", DataType::Int64, false).into())],
4176                        Default::default(),
4177                    )
4178                    .unwrap(),
4179                );
4180
4181                Self { schema }
4182            }
4183        }
4184
4185        impl UserDefinedLogicalNodeCore for TestUserNode {
4186            fn name(&self) -> &str {
4187                "test_node"
4188            }
4189
4190            fn inputs(&self) -> Vec<&LogicalPlan> {
4191                vec![]
4192            }
4193
4194            fn schema(&self) -> &DFSchemaRef {
4195                &self.schema
4196            }
4197
4198            fn expressions(&self) -> Vec<Expr> {
4199                vec![]
4200            }
4201
4202            fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4203                write!(f, "TestUserNode")
4204            }
4205
4206            fn with_exprs_and_inputs(
4207                &self,
4208                exprs: Vec<Expr>,
4209                inputs: Vec<LogicalPlan>,
4210            ) -> Result<Self> {
4211                assert!(exprs.is_empty());
4212                assert!(inputs.is_empty());
4213                Ok(Self {
4214                    schema: Arc::clone(&self.schema),
4215                })
4216            }
4217        }
4218
4219        // Create a node and build a plan with a filter
4220        let node = LogicalPlan::Extension(Extension {
4221            node: Arc::new(TestUserNode::new()),
4222        });
4223
4224        let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4225
4226        // Check the original plan format (not part of the test assertions)
4227        assert_snapshot!(plan,
4228        @r"
4229        Filter: Boolean(false)
4230          TestUserNode
4231        ",
4232        );
4233        // Check that the filter is pushed down to the user-defined node
4234        assert_optimized_plan_equal!(
4235            plan,
4236            @r"
4237        Filter: Boolean(false)
4238          TestUserNode
4239        "
4240        )
4241    }
4242
4243    /// Test that filters are NOT pushed through MoveTowardsLeafNodes projections.
4244    /// These are cheap expressions (like get_field) where re-inlining into a filter
4245    /// has no benefit and causes optimizer instability — ExtractLeafExpressions will
4246    /// undo the push-down, creating an infinite loop that runs until the iteration
4247    /// limit is hit.
4248    #[test]
4249    fn filter_not_pushed_through_move_towards_leaves_projection() -> Result<()> {
4250        let table_scan = test_table_scan()?;
4251
4252        // Create a projection with a MoveTowardsLeafNodes expression
4253        let proj = LogicalPlanBuilder::from(table_scan)
4254            .project(vec![
4255                leaf_udf_expr(col("a")).alias("val"),
4256                col("b"),
4257                col("c"),
4258            ])?
4259            .build()?;
4260
4261        // Put a filter on the MoveTowardsLeafNodes column
4262        let plan = LogicalPlanBuilder::from(proj)
4263            .filter(col("val").gt(lit(150i64)))?
4264            .build()?;
4265
4266        // Filter should NOT be pushed through — val maps to a MoveTowardsLeafNodes expr
4267        assert_optimized_plan_equal!(
4268            plan,
4269            @r"
4270        Filter: val > Int64(150)
4271          Projection: leaf_udf(test.a) AS val, test.b, test.c
4272            TableScan: test
4273        "
4274        )
4275    }
4276
4277    /// Test mixed predicates: Column predicate pushed, MoveTowardsLeafNodes kept.
4278    #[test]
4279    fn filter_mixed_predicates_partial_push() -> Result<()> {
4280        let table_scan = test_table_scan()?;
4281
4282        // Create a projection with both MoveTowardsLeafNodes and Column expressions
4283        let proj = LogicalPlanBuilder::from(table_scan)
4284            .project(vec![
4285                leaf_udf_expr(col("a")).alias("val"),
4286                col("b"),
4287                col("c"),
4288            ])?
4289            .build()?;
4290
4291        // Filter with both: val > 150 (MoveTowardsLeafNodes) AND b > 5 (Column)
4292        let plan = LogicalPlanBuilder::from(proj)
4293            .filter(col("val").gt(lit(150i64)).and(col("b").gt(lit(5i64))))?
4294            .build()?;
4295
4296        // val > 150 should be kept above, b > 5 should be pushed through
4297        assert_optimized_plan_equal!(
4298            plan,
4299            @r"
4300        Filter: val > Int64(150)
4301          Projection: leaf_udf(test.a) AS val, test.b, test.c
4302            TableScan: test, full_filters=[test.b > Int64(5)]
4303        "
4304        )
4305    }
4306
4307    #[test]
4308    fn filter_not_pushed_down_through_table_scan_with_fetch() -> Result<()> {
4309        let scan = test_table_scan()?;
4310        let scan_with_fetch = match scan {
4311            LogicalPlan::TableScan(scan) => LogicalPlan::TableScan(TableScan {
4312                fetch: Some(10),
4313                ..scan
4314            }),
4315            _ => unreachable!(),
4316        };
4317        let plan = LogicalPlanBuilder::from(scan_with_fetch)
4318            .filter(col("a").gt(lit(10i64)))?
4319            .build()?;
4320        // Filter must NOT be pushed into the table scan when it has a fetch (limit)
4321        assert_optimized_plan_equal!(
4322            plan,
4323            @r"
4324        Filter: test.a > Int64(10)
4325          TableScan: test, fetch=10
4326        "
4327        )
4328    }
4329
4330    #[test]
4331    fn filter_push_down_through_sort_without_fetch() -> Result<()> {
4332        let table_scan = test_table_scan()?;
4333        let plan = LogicalPlanBuilder::from(table_scan)
4334            .sort(vec![col("a").sort(true, true)])?
4335            .filter(col("a").gt(lit(10i64)))?
4336            .build()?;
4337        // Filter should be pushed below the sort
4338        assert_optimized_plan_equal!(
4339            plan,
4340            @r"
4341        Sort: test.a ASC NULLS FIRST
4342          TableScan: test, full_filters=[test.a > Int64(10)]
4343        "
4344        )
4345    }
4346
4347    #[test]
4348    fn filter_not_pushed_down_through_sort_with_fetch() -> Result<()> {
4349        let table_scan = test_table_scan()?;
4350        let plan = LogicalPlanBuilder::from(table_scan)
4351            .sort_with_limit(vec![col("a").sort(true, true)], Some(5))?
4352            .filter(col("a").gt(lit(10i64)))?
4353            .build()?;
4354        // Filter must NOT be pushed below the sort when it has a fetch (limit),
4355        // because the limit should apply before the filter.
4356        assert_optimized_plan_equal!(
4357            plan,
4358            @r"
4359        Filter: test.a > Int64(10)
4360          Sort: test.a ASC NULLS FIRST, fetch=5
4361            TableScan: test
4362        "
4363        )
4364    }
4365}