Skip to main content

datafusion_optimizer/
push_down_filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`PushDownFilter`] applies filters as early as possible
19
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26
27use datafusion_common::tree_node::{
28    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29};
30use datafusion_common::{
31    Column, DFSchema, Result, assert_eq_or_internal_err, assert_or_internal_err,
32    internal_err, plan_err, qualified_name,
33};
34use datafusion_expr::expr::WindowFunction;
35use datafusion_expr::expr_rewriter::replace_col;
36use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
37use datafusion_expr::utils::{
38    conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
39};
40use datafusion_expr::{
41    BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, and, or,
42};
43
44use crate::optimizer::ApplyOrder;
45use crate::simplify_expressions::simplify_predicates;
46use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
47use crate::{OptimizerConfig, OptimizerRule};
48
49/// Optimizer rule for pushing (moving) filter expressions down in a plan so
50/// they are applied as early as possible.
51///
52/// # Introduction
53///
54/// The goal of this rule is to improve query performance by eliminating
55/// redundant work.
56///
57/// For example, given a plan that sorts all values where `a > 10`:
58///
59/// ```text
60///  Filter (a > 10)
61///    Sort (a, b)
62/// ```
63///
64/// A better plan is to  filter the data *before* the Sort, which sorts fewer
65/// rows and therefore does less work overall:
66///
67/// ```text
68///  Sort (a, b)
69///    Filter (a > 10)  <-- Filter is moved before the sort
70/// ```
71///
72/// However it is not always possible to push filters down. For example, given a
73/// plan that finds the top 3 values and then keeps only those that are greater
74/// than 10, if the filter is pushed below the limit it would produce a
75/// different result.
76///
77/// ```text
78///  Filter (a > 10)   <-- can not move this Filter before the limit
79///    Limit (fetch=3)
80///      Sort (a, b)
81/// ```
82///
83///
84/// More formally, a filter-commutative operation is an operation `op` that
85/// satisfies `filter(op(data)) = op(filter(data))`.
86///
87/// The filter-commutative property is plan and column-specific. A filter on `a`
88/// can be pushed through a `Aggregate(group_by = [a], agg=[sum(b))`. However, a
89/// filter on  `sum(b)` can not be pushed through the same aggregate.
90///
91/// # Handling Conjunctions
92///
93/// It is possible to only push down **part** of a filter expression if is
94/// connected with `AND`s (more formally if it is a "conjunction").
95///
96/// For example, given the following plan:
97///
98/// ```text
99/// Filter(a > 10 AND sum(b) < 5)
100///   Aggregate(group_by = [a], agg = [sum(b))
101/// ```
102///
103/// The `a > 10` is commutative with the `Aggregate` but  `sum(b) < 5` is not.
104/// Therefore it is possible to only push part of the expression, resulting in:
105///
106/// ```text
107/// Filter(sum(b) < 5)
108///   Aggregate(group_by = [a], agg = [sum(b))
109///     Filter(a > 10)
110/// ```
111///
112/// # Handling Column Aliases
113///
114/// This optimizer must sometimes handle re-writing filter expressions when they
115/// pushed, for example if there is a projection that aliases `a+1` to `"b"`:
116///
117/// ```text
118/// Filter (b > 10)
119///     Projection: [a+1 AS "b"]  <-- changes the name of `a+1` to `b`
120/// ```
121///
122/// To apply the filter prior to the `Projection`, all references to `b` must be
123/// rewritten to `a+1`:
124///
125/// ```text
126/// Projection: a AS "b"
127///     Filter: (a + 1 > 10)  <--- changed from b to a + 1
128/// ```
129/// # Implementation Notes
130///
131/// This implementation performs a single pass through the plan, "pushing" down
132/// filters. When it passes through a filter, it stores that filter, and when it
133/// reaches a plan node that does not commute with that filter, it adds the
134/// filter to that place. When it passes through a projection, it re-writes the
135/// filter's expression taking into account that projection.
136#[derive(Default, Debug)]
137pub struct PushDownFilter {}
138
139/// For a given JOIN type, determine whether each input of the join is preserved
140/// for post-join (`WHERE` clause) filters.
141///
142/// It is only correct to push filters below a join for preserved inputs.
143///
144/// # Return Value
145/// A tuple of booleans - (left_preserved, right_preserved).
146///
147/// # "Preserved" input definition
148///
149/// We say a join side is preserved if the join returns all or a subset of the rows from
150/// the relevant side, such that each row of the output table directly maps to a row of
151/// the preserved input table. If a table is not preserved, it can provide extra null rows.
152/// That is, there may be rows in the output table that don't directly map to a row in the
153/// input table.
154///
155/// For example:
156///   - In an inner join, both sides are preserved, because each row of the output
157///     maps directly to a row from each side.
158///
159///   - In a left join, the left side is preserved (we can push predicates) but
160///     the right is not, because there may be rows in the output that don't
161///     directly map to a row in the right input (due to nulls filling where there
162///     is no match on the right).
163pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
164    match join_type {
165        JoinType::Inner => (true, true),
166        JoinType::Left => (true, false),
167        JoinType::Right => (false, true),
168        JoinType::Full => (false, false),
169        // No columns from the right side of the join can be referenced in output
170        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
171        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
172        // No columns from the left side of the join can be referenced in output
173        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
174        JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
175    }
176}
177
178/// For a given JOIN type, determine whether each input of the join is preserved
179/// for the join condition (`ON` clause filters).
180///
181/// It is only correct to push filters below a join for preserved inputs.
182///
183/// # Return Value
184/// A tuple of booleans - (left_preserved, right_preserved).
185///
186/// See [`lr_is_preserved`] for a definition of "preserved".
187pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
188    match join_type {
189        JoinType::Inner => (true, true),
190        JoinType::Left => (false, true),
191        JoinType::Right => (true, false),
192        JoinType::Full => (false, false),
193        JoinType::LeftSemi | JoinType::RightSemi => (true, true),
194        JoinType::LeftAnti => (false, true),
195        JoinType::RightAnti => (true, false),
196        JoinType::LeftMark => (false, true),
197        JoinType::RightMark => (true, false),
198    }
199}
200
201/// Evaluates the columns referenced in the given expression to see if they refer
202/// only to the left or right columns
203#[derive(Debug)]
204struct ColumnChecker<'a> {
205    /// schema of left join input
206    left_schema: &'a DFSchema,
207    /// columns in left_schema, computed on demand
208    left_columns: Option<HashSet<Column>>,
209    /// schema of right join input
210    right_schema: &'a DFSchema,
211    /// columns in left_schema, computed on demand
212    right_columns: Option<HashSet<Column>>,
213}
214
215impl<'a> ColumnChecker<'a> {
216    fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
217        Self {
218            left_schema,
219            left_columns: None,
220            right_schema,
221            right_columns: None,
222        }
223    }
224
225    /// Return true if the expression references only columns from the left side of the join
226    fn is_left_only(&mut self, predicate: &Expr) -> bool {
227        if self.left_columns.is_none() {
228            self.left_columns = Some(schema_columns(self.left_schema));
229        }
230        has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
231    }
232
233    /// Return true if the expression references only columns from the right side of the join
234    fn is_right_only(&mut self, predicate: &Expr) -> bool {
235        if self.right_columns.is_none() {
236            self.right_columns = Some(schema_columns(self.right_schema));
237        }
238        has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
239    }
240}
241
242/// Returns all columns in the schema
243fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
244    schema
245        .iter()
246        .flat_map(|(qualifier, field)| {
247            [
248                Column::new(qualifier.cloned(), field.name()),
249                // we need to push down filter using unqualified column as well
250                Column::new_unqualified(field.name()),
251            ]
252        })
253        .collect::<HashSet<_>>()
254}
255
256/// Determine whether the predicate can evaluate as the join conditions
257fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
258    let mut is_evaluate = true;
259    predicate.apply(|expr| match expr {
260        Expr::Column(_)
261        | Expr::Literal(_, _)
262        | Expr::Placeholder(_)
263        | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
264        Expr::Exists { .. }
265        | Expr::InSubquery(_)
266        | Expr::ScalarSubquery(_)
267        | Expr::OuterReferenceColumn(_, _)
268        | Expr::Unnest(_) => {
269            is_evaluate = false;
270            Ok(TreeNodeRecursion::Stop)
271        }
272        Expr::Alias(_)
273        | Expr::BinaryExpr(_)
274        | Expr::Like(_)
275        | Expr::SimilarTo(_)
276        | Expr::Not(_)
277        | Expr::IsNotNull(_)
278        | Expr::IsNull(_)
279        | Expr::IsTrue(_)
280        | Expr::IsFalse(_)
281        | Expr::IsUnknown(_)
282        | Expr::IsNotTrue(_)
283        | Expr::IsNotFalse(_)
284        | Expr::IsNotUnknown(_)
285        | Expr::Negative(_)
286        | Expr::Between(_)
287        | Expr::Case(_)
288        | Expr::Cast(_)
289        | Expr::TryCast(_)
290        | Expr::InList { .. }
291        | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
292        // TODO: remove the next line after `Expr::Wildcard` is removed
293        #[expect(deprecated)]
294        Expr::AggregateFunction(_)
295        | Expr::WindowFunction(_)
296        | Expr::Wildcard { .. }
297        | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
298    })?;
299    Ok(is_evaluate)
300}
301
302/// examine OR clause to see if any useful clauses can be extracted and push down.
303/// extract at least one qual from each sub clauses of OR clause, then form the quals
304/// to new OR clause as predicate.
305///
306/// # Example
307/// ```text
308/// Filter: (a = c and a < 20) or (b = d and b > 10)
309///     join/crossjoin:
310///          TableScan: projection=[a, b]
311///          TableScan: projection=[c, d]
312/// ```
313///
314/// is optimized to
315///
316/// ```text
317/// Filter: (a = c and a < 20) or (b = d and b > 10)
318///     join/crossjoin:
319///          Filter: (a < 20) or (b > 10)
320///              TableScan: projection=[a, b]
321///          TableScan: projection=[c, d]
322/// ```
323///
324/// In general, predicates of this form:
325///
326/// ```sql
327/// (A AND B) OR (C AND D)
328/// ```
329///
330/// will be transformed to one of:
331///
332/// * `((A AND B) OR (C AND D)) AND (A OR C)`
333/// * `((A AND B) OR (C AND D)) AND ((A AND B) OR C)`
334/// * do nothing.
335fn extract_or_clauses_for_join<'a>(
336    filters: &'a [Expr],
337    schema: &'a DFSchema,
338) -> impl Iterator<Item = Expr> + 'a {
339    let schema_columns = schema_columns(schema);
340
341    // new formed OR clauses and their column references
342    filters.iter().filter_map(move |expr| {
343        if let Expr::BinaryExpr(BinaryExpr {
344            left,
345            op: Operator::Or,
346            right,
347        }) = expr
348        {
349            let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
350            let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
351
352            // If nothing can be extracted from any sub clauses, do nothing for this OR clause.
353            if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
354                return Some(or(left_expr, right_expr));
355            }
356        }
357        None
358    })
359}
360
361/// extract qual from OR sub-clause.
362///
363/// A qual is extracted if it only contains set of column references in schema_columns.
364///
365/// For AND clause, we extract from both sub-clauses, then make new AND clause by extracted
366/// clauses if both extracted; Otherwise, use the extracted clause from any sub-clauses or None.
367///
368/// For OR clause, we extract from both sub-clauses, then make new OR clause by extracted clauses if both extracted;
369/// Otherwise, return None.
370///
371/// For other clause, apply the rule above to extract clause.
372fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
373    let mut predicate = None;
374
375    match expr {
376        Expr::BinaryExpr(BinaryExpr {
377            left: l_expr,
378            op: Operator::Or,
379            right: r_expr,
380        }) => {
381            let l_expr = extract_or_clause(l_expr, schema_columns);
382            let r_expr = extract_or_clause(r_expr, schema_columns);
383
384            if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
385                predicate = Some(or(l_expr, r_expr));
386            }
387        }
388        Expr::BinaryExpr(BinaryExpr {
389            left: l_expr,
390            op: Operator::And,
391            right: r_expr,
392        }) => {
393            let l_expr = extract_or_clause(l_expr, schema_columns);
394            let r_expr = extract_or_clause(r_expr, schema_columns);
395
396            match (l_expr, r_expr) {
397                (Some(l_expr), Some(r_expr)) => {
398                    predicate = Some(and(l_expr, r_expr));
399                }
400                (Some(l_expr), None) => {
401                    predicate = Some(l_expr);
402                }
403                (None, Some(r_expr)) => {
404                    predicate = Some(r_expr);
405                }
406                (None, None) => {
407                    predicate = None;
408                }
409            }
410        }
411        _ => {
412            if has_all_column_refs(expr, schema_columns) {
413                predicate = Some(expr.clone());
414            }
415        }
416    }
417
418    predicate
419}
420
421/// push down join/cross-join
422fn push_down_all_join(
423    predicates: Vec<Expr>,
424    inferred_join_predicates: Vec<Expr>,
425    mut join: Join,
426    on_filter: Vec<Expr>,
427) -> Result<Transformed<LogicalPlan>> {
428    let is_inner_join = join.join_type == JoinType::Inner;
429    // Get pushable predicates from current optimizer state
430    let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
431
432    // The predicates can be divided to three categories:
433    // 1) can push through join to its children(left or right)
434    // 2) can be converted to join conditions if the join type is Inner
435    // 3) should be kept as filter conditions
436    let left_schema = join.left.schema();
437    let right_schema = join.right.schema();
438    let mut left_push = vec![];
439    let mut right_push = vec![];
440    let mut keep_predicates = vec![];
441    let mut join_conditions = vec![];
442    let mut checker = ColumnChecker::new(left_schema, right_schema);
443    for predicate in predicates {
444        if left_preserved && checker.is_left_only(&predicate) {
445            left_push.push(predicate);
446        } else if right_preserved && checker.is_right_only(&predicate) {
447            right_push.push(predicate);
448        } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
449            // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
450            // and convert to the join on condition
451            join_conditions.push(predicate);
452        } else {
453            keep_predicates.push(predicate);
454        }
455    }
456
457    // For infer predicates, if they can not push through join, just drop them
458    for predicate in inferred_join_predicates {
459        if left_preserved && checker.is_left_only(&predicate) {
460            left_push.push(predicate);
461        } else if right_preserved && checker.is_right_only(&predicate) {
462            right_push.push(predicate);
463        }
464    }
465
466    let mut on_filter_join_conditions = vec![];
467    let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
468
469    if !on_filter.is_empty() {
470        for on in on_filter {
471            if on_left_preserved && checker.is_left_only(&on) {
472                left_push.push(on)
473            } else if on_right_preserved && checker.is_right_only(&on) {
474                right_push.push(on)
475            } else {
476                on_filter_join_conditions.push(on)
477            }
478        }
479    }
480
481    // Extract from OR clause, generate new predicates for both side of join if possible.
482    // We only track the unpushable predicates above.
483    if left_preserved {
484        left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
485        left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
486    }
487    if right_preserved {
488        right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
489        right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
490    }
491
492    // For predicates from join filter, we should check with if a join side is preserved
493    // in term of join filtering.
494    if on_left_preserved {
495        left_push.extend(extract_or_clauses_for_join(
496            &on_filter_join_conditions,
497            left_schema,
498        ));
499    }
500    if on_right_preserved {
501        right_push.extend(extract_or_clauses_for_join(
502            &on_filter_join_conditions,
503            right_schema,
504        ));
505    }
506
507    if let Some(predicate) = conjunction(left_push) {
508        join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
509    }
510    if let Some(predicate) = conjunction(right_push) {
511        join.right =
512            Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
513    }
514
515    // Add any new join conditions as the non join predicates
516    join_conditions.extend(on_filter_join_conditions);
517    join.filter = conjunction(join_conditions);
518
519    // wrap the join on the filter whose predicates must be kept, if any
520    let plan = LogicalPlan::Join(join);
521    let plan = if let Some(predicate) = conjunction(keep_predicates) {
522        LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
523    } else {
524        plan
525    };
526    Ok(Transformed::yes(plan))
527}
528
529fn push_down_join(
530    join: Join,
531    parent_predicate: Option<&Expr>,
532) -> Result<Transformed<LogicalPlan>> {
533    // Split the parent predicate into individual conjunctive parts.
534    let predicates = parent_predicate
535        .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
536
537    // Extract conjunctions from the JOIN's ON filter, if present.
538    let on_filters = join
539        .filter
540        .as_ref()
541        .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
542
543    // Are there any new join predicates that can be inferred from the filter expressions?
544    let inferred_join_predicates =
545        infer_join_predicates(&join, &predicates, &on_filters)?;
546
547    if on_filters.is_empty()
548        && predicates.is_empty()
549        && inferred_join_predicates.is_empty()
550    {
551        return Ok(Transformed::no(LogicalPlan::Join(join)));
552    }
553
554    push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
555}
556
557/// Extracts any equi-join join predicates from the given filter expressions.
558///
559/// Parameters
560/// * `join` the join in question
561///
562/// * `predicates` the pushed down filter expression
563///
564/// * `on_filters` filters from the join ON clause that have not already been
565///   identified as join predicates
566fn infer_join_predicates(
567    join: &Join,
568    predicates: &[Expr],
569    on_filters: &[Expr],
570) -> Result<Vec<Expr>> {
571    // Only allow both side key is column.
572    let join_col_keys = join
573        .on
574        .iter()
575        .filter_map(|(l, r)| {
576            let left_col = l.try_as_col()?;
577            let right_col = r.try_as_col()?;
578            Some((left_col, right_col))
579        })
580        .collect::<Vec<_>>();
581
582    let join_type = join.join_type;
583
584    let mut inferred_predicates = InferredPredicates::new(join_type);
585
586    infer_join_predicates_from_predicates(
587        &join_col_keys,
588        predicates,
589        &mut inferred_predicates,
590    )?;
591
592    infer_join_predicates_from_on_filters(
593        &join_col_keys,
594        join_type,
595        on_filters,
596        &mut inferred_predicates,
597    )?;
598
599    Ok(inferred_predicates.predicates)
600}
601
602/// Inferred predicates collector.
603/// When the JoinType is not Inner, we need to detect whether the inferred predicate can strictly
604/// filter out NULL, otherwise ignore it. e.g.
605/// ```text
606/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL;
607/// ```
608/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate will be pushed down to
609/// the left side, resulting in the wrong result.
610struct InferredPredicates {
611    predicates: Vec<Expr>,
612    is_inner_join: bool,
613}
614
615impl InferredPredicates {
616    fn new(join_type: JoinType) -> Self {
617        Self {
618            predicates: vec![],
619            is_inner_join: matches!(join_type, JoinType::Inner),
620        }
621    }
622
623    fn try_build_predicate(
624        &mut self,
625        predicate: Expr,
626        replace_map: &HashMap<&Column, &Column>,
627    ) -> Result<()> {
628        if self.is_inner_join
629            || matches!(
630                is_restrict_null_predicate(
631                    predicate.clone(),
632                    replace_map.keys().cloned()
633                ),
634                Ok(true)
635            )
636        {
637            self.predicates.push(replace_col(predicate, replace_map)?);
638        }
639
640        Ok(())
641    }
642}
643
644/// Infer predicates from the pushed down predicates.
645///
646/// Parameters
647/// * `join_col_keys` column pairs from the join ON clause
648///
649/// * `predicates` the pushed down predicates
650///
651/// * `inferred_predicates` the inferred results
652fn infer_join_predicates_from_predicates(
653    join_col_keys: &[(&Column, &Column)],
654    predicates: &[Expr],
655    inferred_predicates: &mut InferredPredicates,
656) -> Result<()> {
657    infer_join_predicates_impl::<true, true>(
658        join_col_keys,
659        predicates,
660        inferred_predicates,
661    )
662}
663
664/// Infer predicates from the join filter.
665///
666/// Parameters
667/// * `join_col_keys` column pairs from the join ON clause
668///
669/// * `join_type` the JoinType of Join
670///
671/// * `on_filters` filters from the join ON clause that have not already been
672///   identified as join predicates
673///
674/// * `inferred_predicates` the inferred results
675fn infer_join_predicates_from_on_filters(
676    join_col_keys: &[(&Column, &Column)],
677    join_type: JoinType,
678    on_filters: &[Expr],
679    inferred_predicates: &mut InferredPredicates,
680) -> Result<()> {
681    match join_type {
682        JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
683        JoinType::Inner => infer_join_predicates_impl::<true, true>(
684            join_col_keys,
685            on_filters,
686            inferred_predicates,
687        ),
688        JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
689            infer_join_predicates_impl::<true, false>(
690                join_col_keys,
691                on_filters,
692                inferred_predicates,
693            )
694        }
695        JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
696            infer_join_predicates_impl::<false, true>(
697                join_col_keys,
698                on_filters,
699                inferred_predicates,
700            )
701        }
702    }
703}
704
705/// Infer predicates from the given predicates.
706///
707/// Parameters
708/// * `join_col_keys` column pairs from the join ON clause
709///
710/// * `input_predicates` the given predicates. It can be the pushed down predicates,
711///   or it can be the filters of the Join
712///
713/// * `inferred_predicates` the inferred results
714///
715/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate can
716///   be inferred from the left table related predicate
717///
718/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate can
719///   be inferred from the right table related predicate
720fn infer_join_predicates_impl<
721    const ENABLE_LEFT_TO_RIGHT: bool,
722    const ENABLE_RIGHT_TO_LEFT: bool,
723>(
724    join_col_keys: &[(&Column, &Column)],
725    input_predicates: &[Expr],
726    inferred_predicates: &mut InferredPredicates,
727) -> Result<()> {
728    for predicate in input_predicates {
729        let mut join_cols_to_replace = HashMap::new();
730
731        for &col in &predicate.column_refs() {
732            for (l, r) in join_col_keys.iter() {
733                if ENABLE_LEFT_TO_RIGHT && col == *l {
734                    join_cols_to_replace.insert(col, *r);
735                    break;
736                }
737                if ENABLE_RIGHT_TO_LEFT && col == *r {
738                    join_cols_to_replace.insert(col, *l);
739                    break;
740                }
741            }
742        }
743        if join_cols_to_replace.is_empty() {
744            continue;
745        }
746
747        inferred_predicates
748            .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
749    }
750    Ok(())
751}
752
753impl OptimizerRule for PushDownFilter {
754    fn name(&self) -> &str {
755        "push_down_filter"
756    }
757
758    fn apply_order(&self) -> Option<ApplyOrder> {
759        Some(ApplyOrder::TopDown)
760    }
761
762    fn supports_rewrite(&self) -> bool {
763        true
764    }
765
766    fn rewrite(
767        &self,
768        plan: LogicalPlan,
769        config: &dyn OptimizerConfig,
770    ) -> Result<Transformed<LogicalPlan>> {
771        let _ = config.options();
772        if let LogicalPlan::Join(join) = plan {
773            return push_down_join(join, None);
774        };
775
776        let plan_schema = Arc::clone(plan.schema());
777
778        let LogicalPlan::Filter(mut filter) = plan else {
779            return Ok(Transformed::no(plan));
780        };
781
782        let predicate = split_conjunction_owned(filter.predicate.clone());
783        let old_predicate_len = predicate.len();
784        let new_predicates = simplify_predicates(predicate)?;
785        if old_predicate_len != new_predicates.len() {
786            let Some(new_predicate) = conjunction(new_predicates) else {
787                // new_predicates is empty - remove the filter entirely
788                // Return the child plan without the filter
789                return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
790            };
791            filter.predicate = new_predicate;
792        }
793
794        // If the child has a fetch (limit) or skip (offset), pushing a filter
795        // below it would change semantics: the limit/offset should apply before
796        // the filter, not after.
797        if filter.input.fetch()?.is_some() || filter.input.skip()?.is_some() {
798            return Ok(Transformed::no(LogicalPlan::Filter(filter)));
799        }
800
801        match Arc::unwrap_or_clone(filter.input) {
802            LogicalPlan::Filter(child_filter) => {
803                let parents_predicates = split_conjunction_owned(filter.predicate);
804
805                // remove duplicated filters
806                let child_predicates = split_conjunction_owned(child_filter.predicate);
807                let new_predicates = parents_predicates
808                    .into_iter()
809                    .chain(child_predicates)
810                    // use IndexSet to remove dupes while preserving predicate order
811                    .collect::<IndexSet<_>>()
812                    .into_iter()
813                    .collect::<Vec<_>>();
814
815                let Some(new_predicate) = conjunction(new_predicates) else {
816                    return plan_err!("at least one expression exists");
817                };
818                let new_filter = LogicalPlan::Filter(Filter::try_new(
819                    new_predicate,
820                    child_filter.input,
821                )?);
822                self.rewrite(new_filter, config)
823            }
824            LogicalPlan::Repartition(repartition) => {
825                let new_filter =
826                    Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
827                        .map(LogicalPlan::Filter)?;
828                insert_below(LogicalPlan::Repartition(repartition), new_filter)
829            }
830            LogicalPlan::Distinct(distinct) => {
831                let new_filter =
832                    Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
833                        .map(LogicalPlan::Filter)?;
834                insert_below(LogicalPlan::Distinct(distinct), new_filter)
835            }
836            LogicalPlan::Sort(sort) => {
837                let new_filter =
838                    Filter::try_new(filter.predicate, Arc::clone(&sort.input))
839                        .map(LogicalPlan::Filter)?;
840                insert_below(LogicalPlan::Sort(sort), new_filter)
841            }
842            LogicalPlan::SubqueryAlias(subquery_alias) => {
843                let mut replace_map = HashMap::new();
844                for (i, (qualifier, field)) in
845                    subquery_alias.input.schema().iter().enumerate()
846                {
847                    let (sub_qualifier, sub_field) =
848                        subquery_alias.schema.qualified_field(i);
849                    replace_map.insert(
850                        qualified_name(sub_qualifier, sub_field.name()),
851                        Expr::Column(Column::new(qualifier.cloned(), field.name())),
852                    );
853                }
854                let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
855
856                let new_filter = LogicalPlan::Filter(Filter::try_new(
857                    new_predicate,
858                    Arc::clone(&subquery_alias.input),
859                )?);
860                insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
861            }
862            LogicalPlan::Projection(projection) => {
863                let predicates = split_conjunction_owned(filter.predicate.clone());
864                let (new_projection, keep_predicate) =
865                    rewrite_projection(predicates, projection)?;
866                if new_projection.transformed {
867                    match keep_predicate {
868                        None => Ok(new_projection),
869                        Some(keep_predicate) => new_projection.map_data(|child_plan| {
870                            Filter::try_new(keep_predicate, Arc::new(child_plan))
871                                .map(LogicalPlan::Filter)
872                        }),
873                    }
874                } else {
875                    filter.input = Arc::new(new_projection.data);
876                    Ok(Transformed::no(LogicalPlan::Filter(filter)))
877                }
878            }
879            LogicalPlan::Unnest(mut unnest) => {
880                let predicates = split_conjunction_owned(filter.predicate.clone());
881                let mut non_unnest_predicates = vec![];
882                let mut unnest_predicates = vec![];
883                let mut unnest_struct_columns = vec![];
884
885                for idx in &unnest.struct_type_columns {
886                    let (sub_qualifier, field) =
887                        unnest.input.schema().qualified_field(*idx);
888                    let field_name = field.name().clone();
889
890                    if let DataType::Struct(children) = field.data_type() {
891                        for child in children {
892                            let child_name = child.name().clone();
893                            unnest_struct_columns.push(Column::new(
894                                sub_qualifier.cloned(),
895                                format!("{field_name}.{child_name}"),
896                            ));
897                        }
898                    }
899                }
900
901                for predicate in predicates {
902                    // collect all the Expr::Column in predicate recursively
903                    let mut accum: HashSet<Column> = HashSet::new();
904                    expr_to_columns(&predicate, &mut accum)?;
905
906                    let contains_list_columns =
907                        unnest.list_type_columns.iter().any(|(_, unnest_list)| {
908                            accum.contains(&unnest_list.output_column)
909                        });
910                    let contains_struct_columns =
911                        unnest_struct_columns.iter().any(|c| accum.contains(c));
912
913                    if contains_list_columns || contains_struct_columns {
914                        unnest_predicates.push(predicate);
915                    } else {
916                        non_unnest_predicates.push(predicate);
917                    }
918                }
919
920                // Unnest predicates should not be pushed down.
921                // If no non-unnest predicates exist, early return
922                if non_unnest_predicates.is_empty() {
923                    filter.input = Arc::new(LogicalPlan::Unnest(unnest));
924                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
925                }
926
927                // Push down non-unnest filter predicate
928                // Unnest
929                //   Unnest Input (Projection)
930                // -> rewritten to
931                // Unnest
932                //   Filter
933                //     Unnest Input (Projection)
934
935                let unnest_input = std::mem::take(&mut unnest.input);
936
937                let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
938                    conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty.
939                    unnest_input,
940                )?);
941
942                // Directly assign new filter plan as the new unnest's input.
943                // The new filter plan will go through another rewrite pass since the rule itself
944                // is applied recursively to all the child from top to down
945                let unnest_plan =
946                    insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
947
948                match conjunction(unnest_predicates) {
949                    None => Ok(unnest_plan),
950                    Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
951                        Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
952                    ))),
953                }
954            }
955            LogicalPlan::Union(ref union) => {
956                let mut inputs = Vec::with_capacity(union.inputs.len());
957                for input in &union.inputs {
958                    let mut replace_map = HashMap::new();
959                    for (i, (qualifier, field)) in input.schema().iter().enumerate() {
960                        let (union_qualifier, union_field) =
961                            union.schema.qualified_field(i);
962                        replace_map.insert(
963                            qualified_name(union_qualifier, union_field.name()),
964                            Expr::Column(Column::new(qualifier.cloned(), field.name())),
965                        );
966                    }
967
968                    let push_predicate =
969                        replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
970                    inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
971                        push_predicate,
972                        Arc::clone(input),
973                    )?)))
974                }
975                Ok(Transformed::yes(LogicalPlan::Union(Union {
976                    inputs,
977                    schema: Arc::clone(&plan_schema),
978                })))
979            }
980            LogicalPlan::Aggregate(agg) => {
981                // We can push down Predicate which in groupby_expr.
982                let group_expr_columns = agg
983                    .group_expr
984                    .iter()
985                    .map(|e| {
986                        let (relation, name) = e.qualified_name();
987                        Column::new(relation, name)
988                    })
989                    .collect::<HashSet<_>>();
990
991                let predicates = split_conjunction_owned(filter.predicate);
992
993                let mut keep_predicates = vec![];
994                let mut push_predicates = vec![];
995                for expr in predicates {
996                    let cols = expr.column_refs();
997                    if cols.iter().all(|c| group_expr_columns.contains(c)) {
998                        push_predicates.push(expr);
999                    } else {
1000                        keep_predicates.push(expr);
1001                    }
1002                }
1003
1004                // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
1005                // After push, we need to replace `a+b` with Column(a)+Column(b)
1006                // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
1007                let mut replace_map = HashMap::new();
1008                for expr in &agg.group_expr {
1009                    replace_map.insert(expr.schema_name().to_string(), expr.clone());
1010                }
1011                let replaced_push_predicates = push_predicates
1012                    .into_iter()
1013                    .map(|expr| replace_cols_by_name(expr, &replace_map))
1014                    .collect::<Result<Vec<_>>>()?;
1015
1016                let agg_input = Arc::clone(&agg.input);
1017                Transformed::yes(LogicalPlan::Aggregate(agg))
1018                    .transform_data(|new_plan| {
1019                        // If we have a filter to push, we push it down to the input of the aggregate
1020                        if let Some(predicate) = conjunction(replaced_push_predicates) {
1021                            let new_filter = make_filter(predicate, agg_input)?;
1022                            insert_below(new_plan, new_filter)
1023                        } else {
1024                            Ok(Transformed::no(new_plan))
1025                        }
1026                    })?
1027                    .map_data(|child_plan| {
1028                        // if there are any remaining predicates we can't push, add them
1029                        // back as a filter
1030                        if let Some(predicate) = conjunction(keep_predicates) {
1031                            make_filter(predicate, Arc::new(child_plan))
1032                        } else {
1033                            Ok(child_plan)
1034                        }
1035                    })
1036            }
1037            // Tries to push filters based on the partition key(s) of the window function(s) used.
1038            // Example:
1039            //   Before:
1040            //     Filter: (a > 1) and (b > 1) and (c > 1)
1041            //      Window: func() PARTITION BY [a] ...
1042            //   ---
1043            //   After:
1044            //     Filter: (b > 1) and (c > 1)
1045            //      Window: func() PARTITION BY [a] ...
1046            //        Filter: (a > 1)
1047            LogicalPlan::Window(window) => {
1048                // Retrieve the set of potential partition keys where we can push filters by.
1049                // Unlike aggregations, where there is only one statement per SELECT, there can be
1050                // multiple window functions, each with potentially different partition keys.
1051                // Therefore, we need to ensure that any potential partition key returned is used in
1052                // ALL window functions. Otherwise, filters cannot be pushed by through that column.
1053                let extract_partition_keys = |func: &WindowFunction| {
1054                    func.params
1055                        .partition_by
1056                        .iter()
1057                        .map(|c| {
1058                            let (relation, name) = c.qualified_name();
1059                            Column::new(relation, name)
1060                        })
1061                        .collect::<HashSet<_>>()
1062                };
1063                let potential_partition_keys = window
1064                    .window_expr
1065                    .iter()
1066                    .map(|e| {
1067                        match e {
1068                            Expr::WindowFunction(window_func) => {
1069                                extract_partition_keys(window_func)
1070                            }
1071                            Expr::Alias(alias) => {
1072                                if let Expr::WindowFunction(window_func) =
1073                                    alias.expr.as_ref()
1074                                {
1075                                    extract_partition_keys(window_func)
1076                                } else {
1077                                    // window functions expressions are only Expr::WindowFunction
1078                                    unreachable!()
1079                                }
1080                            }
1081                            _ => {
1082                                // window functions expressions are only Expr::WindowFunction
1083                                unreachable!()
1084                            }
1085                        }
1086                    })
1087                    // performs the set intersection of the partition keys of all window functions,
1088                    // returning only the common ones
1089                    .reduce(|a, b| &a & &b)
1090                    .unwrap_or_default();
1091
1092                let predicates = split_conjunction_owned(filter.predicate);
1093                let mut keep_predicates = vec![];
1094                let mut push_predicates = vec![];
1095                for expr in predicates {
1096                    let cols = expr.column_refs();
1097                    if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1098                        push_predicates.push(expr);
1099                    } else {
1100                        keep_predicates.push(expr);
1101                    }
1102                }
1103
1104                // Unlike with aggregations, there are no cases where we have to replace, e.g.,
1105                // `a+b` with Column(a)+Column(b). This is because partition expressions are not
1106                // available as standalone columns to the user. For example, while an aggregation on
1107                // `a+b` becomes Column(a + b), in a window partition it becomes
1108                // `func() PARTITION BY [a + b] ...`. Thus, filters on expressions always remain in
1109                // place, so we can use `push_predicates` directly. This is consistent with other
1110                // optimizers, such as the one used by Postgres.
1111
1112                let window_input = Arc::clone(&window.input);
1113                Transformed::yes(LogicalPlan::Window(window))
1114                    .transform_data(|new_plan| {
1115                        // If we have a filter to push, we push it down to the input of the window
1116                        if let Some(predicate) = conjunction(push_predicates) {
1117                            let new_filter = make_filter(predicate, window_input)?;
1118                            insert_below(new_plan, new_filter)
1119                        } else {
1120                            Ok(Transformed::no(new_plan))
1121                        }
1122                    })?
1123                    .map_data(|child_plan| {
1124                        // if there are any remaining predicates we can't push, add them
1125                        // back as a filter
1126                        if let Some(predicate) = conjunction(keep_predicates) {
1127                            make_filter(predicate, Arc::new(child_plan))
1128                        } else {
1129                            Ok(child_plan)
1130                        }
1131                    })
1132            }
1133            LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1134            LogicalPlan::TableScan(scan) => {
1135                let filter_predicates = split_conjunction(&filter.predicate);
1136
1137                let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1138                    filter_predicates
1139                        .into_iter()
1140                        .partition(|pred| pred.is_volatile());
1141
1142                // Check which non-volatile filters are supported by source
1143                let supported_filters = scan
1144                    .source
1145                    .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1146                assert_eq_or_internal_err!(
1147                    non_volatile_filters.len(),
1148                    supported_filters.len(),
1149                    "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1150                    supported_filters.len(),
1151                    non_volatile_filters.len()
1152                );
1153
1154                // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type
1155                let zip = non_volatile_filters.into_iter().zip(supported_filters);
1156
1157                let new_scan_filters = zip
1158                    .clone()
1159                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1160                    .map(|(pred, _)| pred);
1161
1162                // Add new scan filters
1163                let new_scan_filters: Vec<Expr> = scan
1164                    .filters
1165                    .iter()
1166                    .chain(new_scan_filters)
1167                    .unique()
1168                    .cloned()
1169                    .collect();
1170
1171                // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
1172                let new_predicate: Vec<Expr> = zip
1173                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1174                    .map(|(pred, _)| pred)
1175                    .chain(volatile_filters)
1176                    .cloned()
1177                    .collect();
1178
1179                let new_scan = LogicalPlan::TableScan(TableScan {
1180                    filters: new_scan_filters,
1181                    ..scan
1182                });
1183
1184                Transformed::yes(new_scan).transform_data(|new_scan| {
1185                    if let Some(predicate) = conjunction(new_predicate) {
1186                        make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1187                    } else {
1188                        Ok(Transformed::no(new_scan))
1189                    }
1190                })
1191            }
1192            LogicalPlan::Extension(extension_plan) => {
1193                // This check prevents the Filter from being removed when the extension node has no children,
1194                // so we return the original Filter unchanged.
1195                if extension_plan.node.inputs().is_empty() {
1196                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1197                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1198                }
1199                let prevent_cols =
1200                    extension_plan.node.prevent_predicate_push_down_columns();
1201
1202                // determine if we can push any predicates down past the extension node
1203
1204                // each element is true for push, false to keep
1205                let predicate_push_or_keep = split_conjunction(&filter.predicate)
1206                    .iter()
1207                    .map(|expr| {
1208                        let cols = expr.column_refs();
1209                        if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1210                            Ok(false) // No push (keep)
1211                        } else {
1212                            Ok(true) // push
1213                        }
1214                    })
1215                    .collect::<Result<Vec<_>>>()?;
1216
1217                // all predicates are kept, no changes needed
1218                if predicate_push_or_keep.iter().all(|&x| !x) {
1219                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1220                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1221                }
1222
1223                // going to push some predicates down, so split the predicates
1224                let mut keep_predicates = vec![];
1225                let mut push_predicates = vec![];
1226                for (push, expr) in predicate_push_or_keep
1227                    .into_iter()
1228                    .zip(split_conjunction_owned(filter.predicate).into_iter())
1229                {
1230                    if !push {
1231                        keep_predicates.push(expr);
1232                    } else {
1233                        push_predicates.push(expr);
1234                    }
1235                }
1236
1237                let new_children = match conjunction(push_predicates) {
1238                    Some(predicate) => extension_plan
1239                        .node
1240                        .inputs()
1241                        .into_iter()
1242                        .map(|child| {
1243                            Ok(LogicalPlan::Filter(Filter::try_new(
1244                                predicate.clone(),
1245                                Arc::new(child.clone()),
1246                            )?))
1247                        })
1248                        .collect::<Result<Vec<_>>>()?,
1249                    None => extension_plan.node.inputs().into_iter().cloned().collect(),
1250                };
1251                // extension with new inputs.
1252                let child_plan = LogicalPlan::Extension(extension_plan);
1253                let new_extension =
1254                    child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1255
1256                let new_plan = match conjunction(keep_predicates) {
1257                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1258                        predicate,
1259                        Arc::new(new_extension),
1260                    )?),
1261                    None => new_extension,
1262                };
1263                Ok(Transformed::yes(new_plan))
1264            }
1265            child => {
1266                filter.input = Arc::new(child);
1267                Ok(Transformed::no(LogicalPlan::Filter(filter)))
1268            }
1269        }
1270    }
1271}
1272
1273/// Attempts to push `predicate` into a `FilterExec` below `projection
1274///
1275/// # Returns
1276/// (plan, remaining_predicate)
1277///
1278/// `plan` is a LogicalPlan for `projection` with possibly a new FilterExec below it.
1279/// `remaining_predicate` is any part of the predicate that could not be pushed down
1280///
1281/// # Args
1282/// - predicates: Split predicates like `[foo=5, bar=6]`
1283/// - projection: The target projection plan to push down the predicates
1284///
1285/// # Example
1286///
1287/// Pushing a predicate like `foo=5 AND bar=6` with an input plan like this:
1288///
1289/// ```text
1290/// Projection(foo, c+d as bar)
1291/// ```
1292///
1293/// Might result in returning `remaining_predicate` of `bar=6` and a plan like
1294///
1295/// ```text
1296/// Projection(foo, c+d as bar)
1297///  Filter(foo=5)
1298///   ...
1299/// ```
1300fn rewrite_projection(
1301    predicates: Vec<Expr>,
1302    mut projection: Projection,
1303) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1304    // A projection is filter-commutable if it do not contain volatile predicates or contain volatile
1305    // predicates that are not used in the filter. However, we should re-writes all predicate expressions.
1306    // collect projection.
1307    let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
1308        .schema
1309        .iter()
1310        .zip(projection.expr.iter())
1311        .map(|((qualifier, field), expr)| {
1312            // strip alias, as they should not be part of filters
1313            let expr = expr.clone().unalias();
1314
1315            (qualified_name(qualifier, field.name()), expr)
1316        })
1317        .partition(|(_, value)| value.is_volatile());
1318
1319    let mut push_predicates = vec![];
1320    let mut keep_predicates = vec![];
1321    for expr in predicates {
1322        if contain(&expr, &volatile_map) {
1323            keep_predicates.push(expr);
1324        } else {
1325            push_predicates.push(expr);
1326        }
1327    }
1328
1329    match conjunction(push_predicates) {
1330        Some(expr) => {
1331            // re-write all filters based on this projection
1332            // E.g. in `Filter: b\n  Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
1333            let new_filter = LogicalPlan::Filter(Filter::try_new(
1334                replace_cols_by_name(expr, &non_volatile_map)?,
1335                std::mem::take(&mut projection.input),
1336            )?);
1337
1338            projection.input = Arc::new(new_filter);
1339
1340            Ok((
1341                Transformed::yes(LogicalPlan::Projection(projection)),
1342                conjunction(keep_predicates),
1343            ))
1344        }
1345        None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
1346    }
1347}
1348
1349/// Creates a new LogicalPlan::Filter node.
1350pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1351    Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1352}
1353
1354/// Replace the existing child of the single input node with `new_child`.
1355///
1356/// Starting:
1357/// ```text
1358/// plan
1359///   child
1360/// ```
1361///
1362/// Ending:
1363/// ```text
1364/// plan
1365///   new_child
1366/// ```
1367fn insert_below(
1368    plan: LogicalPlan,
1369    new_child: LogicalPlan,
1370) -> Result<Transformed<LogicalPlan>> {
1371    let mut new_child = Some(new_child);
1372    let transformed_plan = plan.map_children(|_child| {
1373        if let Some(new_child) = new_child.take() {
1374            Ok(Transformed::yes(new_child))
1375        } else {
1376            // already took the new child
1377            internal_err!("node had more than one input")
1378        }
1379    })?;
1380
1381    // make sure we did the actual replacement
1382    assert_or_internal_err!(new_child.is_none(), "node had no inputs");
1383
1384    Ok(transformed_plan)
1385}
1386
1387impl PushDownFilter {
1388    #[expect(missing_docs)]
1389    pub fn new() -> Self {
1390        Self {}
1391    }
1392}
1393
1394/// replaces columns by its name on the projection.
1395pub fn replace_cols_by_name(
1396    e: Expr,
1397    replace_map: &HashMap<String, Expr>,
1398) -> Result<Expr> {
1399    e.transform_up(|expr| {
1400        Ok(if let Expr::Column(c) = &expr {
1401            match replace_map.get(&c.flat_name()) {
1402                Some(new_c) => Transformed::yes(new_c.clone()),
1403                None => Transformed::no(expr),
1404            }
1405        } else {
1406            Transformed::no(expr)
1407        })
1408    })
1409    .data()
1410}
1411
1412/// check whether the expression uses the columns in `check_map`.
1413fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1414    let mut is_contain = false;
1415    e.apply(|expr| {
1416        Ok(if let Expr::Column(c) = &expr {
1417            match check_map.get(&c.flat_name()) {
1418                Some(_) => {
1419                    is_contain = true;
1420                    TreeNodeRecursion::Stop
1421                }
1422                None => TreeNodeRecursion::Continue,
1423            }
1424        } else {
1425            TreeNodeRecursion::Continue
1426        })
1427    })
1428    .unwrap();
1429    is_contain
1430}
1431
1432#[cfg(test)]
1433mod tests {
1434    use std::any::Any;
1435    use std::cmp::Ordering;
1436    use std::fmt::{Debug, Formatter};
1437
1438    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1439    use async_trait::async_trait;
1440
1441    use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1442    use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1443    use datafusion_expr::logical_plan::table_scan;
1444    use datafusion_expr::{
1445        ColumnarValue, ExprFunctionExt, Extension, LogicalPlanBuilder,
1446        ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType,
1447        UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, col, in_list,
1448        in_subquery, lit,
1449    };
1450
1451    use crate::OptimizerContext;
1452    use crate::assert_optimized_plan_eq_snapshot;
1453    use crate::optimizer::Optimizer;
1454    use crate::simplify_expressions::SimplifyExpressions;
1455    use crate::test::*;
1456    use datafusion_expr::test::function_stub::sum;
1457    use insta::assert_snapshot;
1458
1459    use super::*;
1460
1461    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1462
1463    macro_rules! assert_optimized_plan_equal {
1464        (
1465            $plan:expr,
1466            @ $expected:literal $(,)?
1467        ) => {{
1468            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1469            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1470            assert_optimized_plan_eq_snapshot!(
1471                optimizer_ctx,
1472                rules,
1473                $plan,
1474                @ $expected,
1475            )
1476        }};
1477    }
1478
1479    macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1480        (
1481            $plan:expr,
1482            @ $expected:literal $(,)?
1483        ) => {{
1484            let optimizer = Optimizer::with_rules(vec![
1485                Arc::new(SimplifyExpressions::new()),
1486                Arc::new(PushDownFilter::new()),
1487            ]);
1488            let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1489            assert_snapshot!(optimized_plan, @ $expected);
1490            Ok::<(), DataFusionError>(())
1491        }};
1492    }
1493
1494    #[test]
1495    fn filter_before_projection() -> Result<()> {
1496        let table_scan = test_table_scan()?;
1497        let plan = LogicalPlanBuilder::from(table_scan)
1498            .project(vec![col("a"), col("b")])?
1499            .filter(col("a").eq(lit(1i64)))?
1500            .build()?;
1501        // filter is before projection
1502        assert_optimized_plan_equal!(
1503            plan,
1504            @r"
1505        Projection: test.a, test.b
1506          TableScan: test, full_filters=[test.a = Int64(1)]
1507        "
1508        )
1509    }
1510
1511    #[test]
1512    fn filter_after_limit() -> Result<()> {
1513        let table_scan = test_table_scan()?;
1514        let plan = LogicalPlanBuilder::from(table_scan)
1515            .project(vec![col("a"), col("b")])?
1516            .limit(0, Some(10))?
1517            .filter(col("a").eq(lit(1i64)))?
1518            .build()?;
1519        // filter is before single projection
1520        assert_optimized_plan_equal!(
1521            plan,
1522            @r"
1523        Filter: test.a = Int64(1)
1524          Limit: skip=0, fetch=10
1525            Projection: test.a, test.b
1526              TableScan: test
1527        "
1528        )
1529    }
1530
1531    #[test]
1532    fn filter_no_columns() -> Result<()> {
1533        let table_scan = test_table_scan()?;
1534        let plan = LogicalPlanBuilder::from(table_scan)
1535            .filter(lit(0i64).eq(lit(1i64)))?
1536            .build()?;
1537        assert_optimized_plan_equal!(
1538            plan,
1539            @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1540        )
1541    }
1542
1543    #[test]
1544    fn filter_jump_2_plans() -> Result<()> {
1545        let table_scan = test_table_scan()?;
1546        let plan = LogicalPlanBuilder::from(table_scan)
1547            .project(vec![col("a"), col("b"), col("c")])?
1548            .project(vec![col("c"), col("b")])?
1549            .filter(col("a").eq(lit(1i64)))?
1550            .build()?;
1551        // filter is before double projection
1552        assert_optimized_plan_equal!(
1553            plan,
1554            @r"
1555        Projection: test.c, test.b
1556          Projection: test.a, test.b, test.c
1557            TableScan: test, full_filters=[test.a = Int64(1)]
1558        "
1559        )
1560    }
1561
1562    #[test]
1563    fn filter_move_agg() -> Result<()> {
1564        let table_scan = test_table_scan()?;
1565        let plan = LogicalPlanBuilder::from(table_scan)
1566            .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1567            .filter(col("a").gt(lit(10i64)))?
1568            .build()?;
1569        // filter of key aggregation is commutative
1570        assert_optimized_plan_equal!(
1571            plan,
1572            @r"
1573        Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1574          TableScan: test, full_filters=[test.a > Int64(10)]
1575        "
1576        )
1577    }
1578
1579    /// verifies that filters with unusual column names are pushed down through aggregate operators
1580    #[test]
1581    fn filter_move_agg_special() -> Result<()> {
1582        let schema = Schema::new(vec![
1583            Field::new("$a", DataType::UInt32, false),
1584            Field::new("$b", DataType::UInt32, false),
1585            Field::new("$c", DataType::UInt32, false),
1586        ]);
1587        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1588
1589        let plan = LogicalPlanBuilder::from(table_scan)
1590            .aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])?
1591            .filter(col("$a").gt(lit(10i64)))?
1592            .build()?;
1593        // filter of key aggregation is commutative
1594        assert_optimized_plan_equal!(
1595            plan,
1596            @r"
1597        Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]]
1598          TableScan: test, full_filters=[test.$a > Int64(10)]
1599        "
1600        )
1601    }
1602
1603    #[test]
1604    fn filter_complex_group_by() -> Result<()> {
1605        let table_scan = test_table_scan()?;
1606        let plan = LogicalPlanBuilder::from(table_scan)
1607            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1608            .filter(col("b").gt(lit(10i64)))?
1609            .build()?;
1610        assert_optimized_plan_equal!(
1611            plan,
1612            @r"
1613        Filter: test.b > Int64(10)
1614          Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1615            TableScan: test
1616        "
1617        )
1618    }
1619
1620    #[test]
1621    fn push_agg_need_replace_expr() -> Result<()> {
1622        let plan = LogicalPlanBuilder::from(test_table_scan()?)
1623            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1624            .filter(col("test.b + test.a").gt(lit(10i64)))?
1625            .build()?;
1626        assert_optimized_plan_equal!(
1627            plan,
1628            @r"
1629        Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1630          TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1631        "
1632        )
1633    }
1634
1635    #[test]
1636    fn filter_keep_agg() -> Result<()> {
1637        let table_scan = test_table_scan()?;
1638        let plan = LogicalPlanBuilder::from(table_scan)
1639            .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1640            .filter(col("b").gt(lit(10i64)))?
1641            .build()?;
1642        // filter of aggregate is after aggregation since they are non-commutative
1643        assert_optimized_plan_equal!(
1644            plan,
1645            @r"
1646        Filter: b > Int64(10)
1647          Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1648            TableScan: test
1649        "
1650        )
1651    }
1652
1653    /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed
1654    #[test]
1655    fn filter_move_window() -> Result<()> {
1656        let table_scan = test_table_scan()?;
1657
1658        let window = Expr::from(WindowFunction::new(
1659            WindowFunctionDefinition::WindowUDF(
1660                datafusion_functions_window::rank::rank_udwf(),
1661            ),
1662            vec![],
1663        ))
1664        .partition_by(vec![col("a"), col("b")])
1665        .order_by(vec![col("c").sort(true, true)])
1666        .build()
1667        .unwrap();
1668
1669        let plan = LogicalPlanBuilder::from(table_scan)
1670            .window(vec![window])?
1671            .filter(col("b").gt(lit(10i64)))?
1672            .build()?;
1673
1674        assert_optimized_plan_equal!(
1675            plan,
1676            @r"
1677        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1678          TableScan: test, full_filters=[test.b > Int64(10)]
1679        "
1680        )
1681    }
1682
1683    /// verifies that filters with unusual identifier names are pushed down through window functions
1684    #[test]
1685    fn filter_window_special_identifier() -> Result<()> {
1686        let schema = Schema::new(vec![
1687            Field::new("$a", DataType::UInt32, false),
1688            Field::new("$b", DataType::UInt32, false),
1689            Field::new("$c", DataType::UInt32, false),
1690        ]);
1691        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1692
1693        let window = Expr::from(WindowFunction::new(
1694            WindowFunctionDefinition::WindowUDF(
1695                datafusion_functions_window::rank::rank_udwf(),
1696            ),
1697            vec![],
1698        ))
1699        .partition_by(vec![col("$a"), col("$b")])
1700        .order_by(vec![col("$c").sort(true, true)])
1701        .build()
1702        .unwrap();
1703
1704        let plan = LogicalPlanBuilder::from(table_scan)
1705            .window(vec![window])?
1706            .filter(col("$b").gt(lit(10i64)))?
1707            .build()?;
1708
1709        assert_optimized_plan_equal!(
1710            plan,
1711            @r"
1712        WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1713          TableScan: test, full_filters=[test.$b > Int64(10)]
1714        "
1715        )
1716    }
1717
1718    /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and
1719    /// 'b' are pushed
1720    #[test]
1721    fn filter_move_complex_window() -> Result<()> {
1722        let table_scan = test_table_scan()?;
1723
1724        let window = Expr::from(WindowFunction::new(
1725            WindowFunctionDefinition::WindowUDF(
1726                datafusion_functions_window::rank::rank_udwf(),
1727            ),
1728            vec![],
1729        ))
1730        .partition_by(vec![col("a"), col("b")])
1731        .order_by(vec![col("c").sort(true, true)])
1732        .build()
1733        .unwrap();
1734
1735        let plan = LogicalPlanBuilder::from(table_scan)
1736            .window(vec![window])?
1737            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1738            .build()?;
1739
1740        assert_optimized_plan_equal!(
1741            plan,
1742            @r"
1743        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1744          TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1745        "
1746        )
1747    }
1748
1749    /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed
1750    #[test]
1751    fn filter_move_partial_window() -> Result<()> {
1752        let table_scan = test_table_scan()?;
1753
1754        let window = Expr::from(WindowFunction::new(
1755            WindowFunctionDefinition::WindowUDF(
1756                datafusion_functions_window::rank::rank_udwf(),
1757            ),
1758            vec![],
1759        ))
1760        .partition_by(vec![col("a")])
1761        .order_by(vec![col("c").sort(true, true)])
1762        .build()
1763        .unwrap();
1764
1765        let plan = LogicalPlanBuilder::from(table_scan)
1766            .window(vec![window])?
1767            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1768            .build()?;
1769
1770        assert_optimized_plan_equal!(
1771            plan,
1772            @r"
1773        Filter: test.b = Int64(1)
1774          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1775            TableScan: test, full_filters=[test.a > Int64(10)]
1776        "
1777        )
1778    }
1779
1780    /// verifies that filters on partition expressions are not pushed, as the single expression
1781    /// column is not available to the user, unlike with aggregations
1782    #[test]
1783    fn filter_expression_keep_window() -> Result<()> {
1784        let table_scan = test_table_scan()?;
1785
1786        let window = Expr::from(WindowFunction::new(
1787            WindowFunctionDefinition::WindowUDF(
1788                datafusion_functions_window::rank::rank_udwf(),
1789            ),
1790            vec![],
1791        ))
1792        .partition_by(vec![add(col("a"), col("b"))]) // PARTITION BY a + b
1793        .order_by(vec![col("c").sort(true, true)])
1794        .build()
1795        .unwrap();
1796
1797        let plan = LogicalPlanBuilder::from(table_scan)
1798            .window(vec![window])?
1799            // unlike with aggregations, single partition column "test.a + test.b" is not available
1800            // to the plan, so we use multiple columns when filtering
1801            .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1802            .build()?;
1803
1804        assert_optimized_plan_equal!(
1805            plan,
1806            @r"
1807        Filter: test.a + test.b > Int64(10)
1808          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1809            TableScan: test
1810        "
1811        )
1812    }
1813
1814    /// verifies that filters are not pushed on order by columns (that are not used in partitioning)
1815    #[test]
1816    fn filter_order_keep_window() -> Result<()> {
1817        let table_scan = test_table_scan()?;
1818
1819        let window = Expr::from(WindowFunction::new(
1820            WindowFunctionDefinition::WindowUDF(
1821                datafusion_functions_window::rank::rank_udwf(),
1822            ),
1823            vec![],
1824        ))
1825        .partition_by(vec![col("a")])
1826        .order_by(vec![col("c").sort(true, true)])
1827        .build()
1828        .unwrap();
1829
1830        let plan = LogicalPlanBuilder::from(table_scan)
1831            .window(vec![window])?
1832            .filter(col("c").gt(lit(10i64)))?
1833            .build()?;
1834
1835        assert_optimized_plan_equal!(
1836            plan,
1837            @r"
1838        Filter: test.c > Int64(10)
1839          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1840            TableScan: test
1841        "
1842        )
1843    }
1844
1845    /// verifies that when we use multiple window functions with a common partition key, the filter
1846    /// on that key is pushed
1847    #[test]
1848    fn filter_multiple_windows_common_partitions() -> Result<()> {
1849        let table_scan = test_table_scan()?;
1850
1851        let window1 = Expr::from(WindowFunction::new(
1852            WindowFunctionDefinition::WindowUDF(
1853                datafusion_functions_window::rank::rank_udwf(),
1854            ),
1855            vec![],
1856        ))
1857        .partition_by(vec![col("a")])
1858        .order_by(vec![col("c").sort(true, true)])
1859        .build()
1860        .unwrap();
1861
1862        let window2 = Expr::from(WindowFunction::new(
1863            WindowFunctionDefinition::WindowUDF(
1864                datafusion_functions_window::rank::rank_udwf(),
1865            ),
1866            vec![],
1867        ))
1868        .partition_by(vec![col("b"), col("a")])
1869        .order_by(vec![col("c").sort(true, true)])
1870        .build()
1871        .unwrap();
1872
1873        let plan = LogicalPlanBuilder::from(table_scan)
1874            .window(vec![window1, window2])?
1875            .filter(col("a").gt(lit(10i64)))? // a appears in both window functions
1876            .build()?;
1877
1878        assert_optimized_plan_equal!(
1879            plan,
1880            @r"
1881        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1882          TableScan: test, full_filters=[test.a > Int64(10)]
1883        "
1884        )
1885    }
1886
1887    /// verifies that when we use multiple window functions with different partitions keys, the
1888    /// filter cannot be pushed
1889    #[test]
1890    fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1891        let table_scan = test_table_scan()?;
1892
1893        let window1 = Expr::from(WindowFunction::new(
1894            WindowFunctionDefinition::WindowUDF(
1895                datafusion_functions_window::rank::rank_udwf(),
1896            ),
1897            vec![],
1898        ))
1899        .partition_by(vec![col("a")])
1900        .order_by(vec![col("c").sort(true, true)])
1901        .build()
1902        .unwrap();
1903
1904        let window2 = Expr::from(WindowFunction::new(
1905            WindowFunctionDefinition::WindowUDF(
1906                datafusion_functions_window::rank::rank_udwf(),
1907            ),
1908            vec![],
1909        ))
1910        .partition_by(vec![col("b"), col("a")])
1911        .order_by(vec![col("c").sort(true, true)])
1912        .build()
1913        .unwrap();
1914
1915        let plan = LogicalPlanBuilder::from(table_scan)
1916            .window(vec![window1, window2])?
1917            .filter(col("b").gt(lit(10i64)))? // b only appears in one window function
1918            .build()?;
1919
1920        assert_optimized_plan_equal!(
1921            plan,
1922            @r"
1923        Filter: test.b > Int64(10)
1924          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1925            TableScan: test
1926        "
1927        )
1928    }
1929
1930    /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
1931    #[test]
1932    fn alias() -> Result<()> {
1933        let table_scan = test_table_scan()?;
1934        let plan = LogicalPlanBuilder::from(table_scan)
1935            .project(vec![col("a").alias("b"), col("c")])?
1936            .filter(col("b").eq(lit(1i64)))?
1937            .build()?;
1938        // filter is before projection
1939        assert_optimized_plan_equal!(
1940            plan,
1941            @r"
1942        Projection: test.a AS b, test.c
1943          TableScan: test, full_filters=[test.a = Int64(1)]
1944        "
1945        )
1946    }
1947
1948    fn add(left: Expr, right: Expr) -> Expr {
1949        Expr::BinaryExpr(BinaryExpr::new(
1950            Box::new(left),
1951            Operator::Plus,
1952            Box::new(right),
1953        ))
1954    }
1955
1956    fn multiply(left: Expr, right: Expr) -> Expr {
1957        Expr::BinaryExpr(BinaryExpr::new(
1958            Box::new(left),
1959            Operator::Multiply,
1960            Box::new(right),
1961        ))
1962    }
1963
1964    /// verifies that a filter is pushed to before a projection with a complex expression, the filter expression is correctly re-written
1965    #[test]
1966    fn complex_expression() -> Result<()> {
1967        let table_scan = test_table_scan()?;
1968        let plan = LogicalPlanBuilder::from(table_scan)
1969            .project(vec![
1970                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1971                col("c"),
1972            ])?
1973            .filter(col("b").eq(lit(1i64)))?
1974            .build()?;
1975
1976        // not part of the test, just good to know:
1977        assert_snapshot!(plan,
1978        @r"
1979        Filter: b = Int64(1)
1980          Projection: test.a * Int32(2) + test.c AS b, test.c
1981            TableScan: test
1982        ",
1983        );
1984        // filter is before projection
1985        assert_optimized_plan_equal!(
1986            plan,
1987            @r"
1988        Projection: test.a * Int32(2) + test.c AS b, test.c
1989          TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1990        "
1991        )
1992    }
1993
1994    /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written
1995    #[test]
1996    fn complex_plan() -> Result<()> {
1997        let table_scan = test_table_scan()?;
1998        let plan = LogicalPlanBuilder::from(table_scan)
1999            .project(vec![
2000                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
2001                col("c"),
2002            ])?
2003            // second projection where we rename columns, just to make it difficult
2004            .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
2005            .filter(col("a").eq(lit(1i64)))?
2006            .build()?;
2007
2008        // not part of the test, just good to know:
2009        assert_snapshot!(plan,
2010        @r"
2011        Filter: a = Int64(1)
2012          Projection: b * Int32(3) AS a, test.c
2013            Projection: test.a * Int32(2) + test.c AS b, test.c
2014              TableScan: test
2015        ",
2016        );
2017        // filter is before the projections
2018        assert_optimized_plan_equal!(
2019            plan,
2020            @r"
2021        Projection: b * Int32(3) AS a, test.c
2022          Projection: test.a * Int32(2) + test.c AS b, test.c
2023            TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
2024        "
2025        )
2026    }
2027
2028    #[derive(Debug, PartialEq, Eq, Hash)]
2029    struct NoopPlan {
2030        input: Vec<LogicalPlan>,
2031        schema: DFSchemaRef,
2032    }
2033
2034    // Manual implementation needed because of `schema` field. Comparison excludes this field.
2035    impl PartialOrd for NoopPlan {
2036        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2037            self.input
2038                .partial_cmp(&other.input)
2039                // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
2040                .filter(|cmp| *cmp != Ordering::Equal || self == other)
2041        }
2042    }
2043
2044    impl UserDefinedLogicalNodeCore for NoopPlan {
2045        fn name(&self) -> &str {
2046            "NoopPlan"
2047        }
2048
2049        fn inputs(&self) -> Vec<&LogicalPlan> {
2050            self.input.iter().collect()
2051        }
2052
2053        fn schema(&self) -> &DFSchemaRef {
2054            &self.schema
2055        }
2056
2057        fn expressions(&self) -> Vec<Expr> {
2058            self.input
2059                .iter()
2060                .flat_map(|child| child.expressions())
2061                .collect()
2062        }
2063
2064        fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
2065            HashSet::from_iter(vec!["c".to_string()])
2066        }
2067
2068        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
2069            write!(f, "NoopPlan")
2070        }
2071
2072        fn with_exprs_and_inputs(
2073            &self,
2074            _exprs: Vec<Expr>,
2075            inputs: Vec<LogicalPlan>,
2076        ) -> Result<Self> {
2077            Ok(Self {
2078                input: inputs,
2079                schema: Arc::clone(&self.schema),
2080            })
2081        }
2082
2083        fn supports_limit_pushdown(&self) -> bool {
2084            false // Disallow limit push-down by default
2085        }
2086    }
2087
2088    #[test]
2089    fn user_defined_plan() -> Result<()> {
2090        let table_scan = test_table_scan()?;
2091
2092        let custom_plan = LogicalPlan::Extension(Extension {
2093            node: Arc::new(NoopPlan {
2094                input: vec![table_scan.clone()],
2095                schema: Arc::clone(table_scan.schema()),
2096            }),
2097        });
2098        let plan = LogicalPlanBuilder::from(custom_plan)
2099            .filter(col("a").eq(lit(1i64)))?
2100            .build()?;
2101
2102        // Push filter below NoopPlan
2103        assert_optimized_plan_equal!(
2104            plan,
2105            @r"
2106        NoopPlan
2107          TableScan: test, full_filters=[test.a = Int64(1)]
2108        "
2109        )?;
2110
2111        let custom_plan = LogicalPlan::Extension(Extension {
2112            node: Arc::new(NoopPlan {
2113                input: vec![table_scan.clone()],
2114                schema: Arc::clone(table_scan.schema()),
2115            }),
2116        });
2117        let plan = LogicalPlanBuilder::from(custom_plan)
2118            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2119            .build()?;
2120
2121        // Push only predicate on `a` below NoopPlan
2122        assert_optimized_plan_equal!(
2123            plan,
2124            @r"
2125        Filter: test.c = Int64(2)
2126          NoopPlan
2127            TableScan: test, full_filters=[test.a = Int64(1)]
2128        "
2129        )?;
2130
2131        let custom_plan = LogicalPlan::Extension(Extension {
2132            node: Arc::new(NoopPlan {
2133                input: vec![table_scan.clone(), table_scan.clone()],
2134                schema: Arc::clone(table_scan.schema()),
2135            }),
2136        });
2137        let plan = LogicalPlanBuilder::from(custom_plan)
2138            .filter(col("a").eq(lit(1i64)))?
2139            .build()?;
2140
2141        // Push filter below NoopPlan for each child branch
2142        assert_optimized_plan_equal!(
2143            plan,
2144            @r"
2145        NoopPlan
2146          TableScan: test, full_filters=[test.a = Int64(1)]
2147          TableScan: test, full_filters=[test.a = Int64(1)]
2148        "
2149        )?;
2150
2151        let custom_plan = LogicalPlan::Extension(Extension {
2152            node: Arc::new(NoopPlan {
2153                input: vec![table_scan.clone(), table_scan.clone()],
2154                schema: Arc::clone(table_scan.schema()),
2155            }),
2156        });
2157        let plan = LogicalPlanBuilder::from(custom_plan)
2158            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2159            .build()?;
2160
2161        // Push only predicate on `a` below NoopPlan
2162        assert_optimized_plan_equal!(
2163            plan,
2164            @r"
2165        Filter: test.c = Int64(2)
2166          NoopPlan
2167            TableScan: test, full_filters=[test.a = Int64(1)]
2168            TableScan: test, full_filters=[test.a = Int64(1)]
2169        "
2170        )
2171    }
2172
2173    /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed
2174    /// and the other not.
2175    #[test]
2176    fn multi_filter() -> Result<()> {
2177        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2178        let table_scan = test_table_scan()?;
2179        let plan = LogicalPlanBuilder::from(table_scan)
2180            .project(vec![col("a").alias("b"), col("c")])?
2181            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2182            .filter(col("b").gt(lit(10i64)))?
2183            .filter(col("sum(test.c)").gt(lit(10i64)))?
2184            .build()?;
2185
2186        // not part of the test, just good to know:
2187        assert_snapshot!(plan,
2188        @r"
2189        Filter: sum(test.c) > Int64(10)
2190          Filter: b > Int64(10)
2191            Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2192              Projection: test.a AS b, test.c
2193                TableScan: test
2194        ",
2195        );
2196        // filter is before the projections
2197        assert_optimized_plan_equal!(
2198            plan,
2199            @r"
2200        Filter: sum(test.c) > Int64(10)
2201          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2202            Projection: test.a AS b, test.c
2203              TableScan: test, full_filters=[test.a > Int64(10)]
2204        "
2205        )
2206    }
2207
2208    /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed
2209    /// and the other not.
2210    #[test]
2211    fn split_filter() -> Result<()> {
2212        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2213        let table_scan = test_table_scan()?;
2214        let plan = LogicalPlanBuilder::from(table_scan)
2215            .project(vec![col("a").alias("b"), col("c")])?
2216            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2217            .filter(and(
2218                col("sum(test.c)").gt(lit(10i64)),
2219                and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2220            ))?
2221            .build()?;
2222
2223        // not part of the test, just good to know:
2224        assert_snapshot!(plan,
2225        @r"
2226        Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2227          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2228            Projection: test.a AS b, test.c
2229              TableScan: test
2230        ",
2231        );
2232        // filter is before the projections
2233        assert_optimized_plan_equal!(
2234            plan,
2235            @r"
2236        Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2237          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2238            Projection: test.a AS b, test.c
2239              TableScan: test, full_filters=[test.a > Int64(10)]
2240        "
2241        )
2242    }
2243
2244    /// verifies that when two limits are in place, we jump neither
2245    #[test]
2246    fn double_limit() -> Result<()> {
2247        let table_scan = test_table_scan()?;
2248        let plan = LogicalPlanBuilder::from(table_scan)
2249            .project(vec![col("a"), col("b")])?
2250            .limit(0, Some(20))?
2251            .limit(0, Some(10))?
2252            .project(vec![col("a"), col("b")])?
2253            .filter(col("a").eq(lit(1i64)))?
2254            .build()?;
2255        // filter does not just any of the limits
2256        assert_optimized_plan_equal!(
2257            plan,
2258            @r"
2259        Projection: test.a, test.b
2260          Filter: test.a = Int64(1)
2261            Limit: skip=0, fetch=10
2262              Limit: skip=0, fetch=20
2263                Projection: test.a, test.b
2264                  TableScan: test
2265        "
2266        )
2267    }
2268
2269    #[test]
2270    fn union_all() -> Result<()> {
2271        let table_scan = test_table_scan()?;
2272        let table_scan2 = test_table_scan_with_name("test2")?;
2273        let plan = LogicalPlanBuilder::from(table_scan)
2274            .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2275            .filter(col("a").eq(lit(1i64)))?
2276            .build()?;
2277        // filter appears below Union
2278        assert_optimized_plan_equal!(
2279            plan,
2280            @r"
2281        Union
2282          TableScan: test, full_filters=[test.a = Int64(1)]
2283          TableScan: test2, full_filters=[test2.a = Int64(1)]
2284        "
2285        )
2286    }
2287
2288    #[test]
2289    fn union_all_on_projection() -> Result<()> {
2290        let table_scan = test_table_scan()?;
2291        let table = LogicalPlanBuilder::from(table_scan)
2292            .project(vec![col("a").alias("b")])?
2293            .alias("test2")?;
2294
2295        let plan = table
2296            .clone()
2297            .union(table.build()?)?
2298            .filter(col("b").eq(lit(1i64)))?
2299            .build()?;
2300
2301        // filter appears below Union
2302        assert_optimized_plan_equal!(
2303            plan,
2304            @r"
2305        Union
2306          SubqueryAlias: test2
2307            Projection: test.a AS b
2308              TableScan: test, full_filters=[test.a = Int64(1)]
2309          SubqueryAlias: test2
2310            Projection: test.a AS b
2311              TableScan: test, full_filters=[test.a = Int64(1)]
2312        "
2313        )
2314    }
2315
2316    #[test]
2317    fn test_union_different_schema() -> Result<()> {
2318        let left = LogicalPlanBuilder::from(test_table_scan()?)
2319            .project(vec![col("a"), col("b"), col("c")])?
2320            .build()?;
2321
2322        let schema = Schema::new(vec![
2323            Field::new("d", DataType::UInt32, false),
2324            Field::new("e", DataType::UInt32, false),
2325            Field::new("f", DataType::UInt32, false),
2326        ]);
2327        let right = table_scan(Some("test1"), &schema, None)?
2328            .project(vec![col("d"), col("e"), col("f")])?
2329            .build()?;
2330        let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2331        let plan = LogicalPlanBuilder::from(left)
2332            .cross_join(right)?
2333            .project(vec![col("test.a"), col("test1.d")])?
2334            .filter(filter)?
2335            .build()?;
2336
2337        assert_optimized_plan_equal!(
2338            plan,
2339            @r"
2340        Projection: test.a, test1.d
2341          Cross Join: 
2342            Projection: test.a, test.b, test.c
2343              TableScan: test, full_filters=[test.a = Int32(1)]
2344            Projection: test1.d, test1.e, test1.f
2345              TableScan: test1, full_filters=[test1.d > Int32(2)]
2346        "
2347        )
2348    }
2349
2350    #[test]
2351    fn test_project_same_name_different_qualifier() -> Result<()> {
2352        let table_scan = test_table_scan()?;
2353        let left = LogicalPlanBuilder::from(table_scan)
2354            .project(vec![col("a"), col("b"), col("c")])?
2355            .build()?;
2356        let right_table_scan = test_table_scan_with_name("test1")?;
2357        let right = LogicalPlanBuilder::from(right_table_scan)
2358            .project(vec![col("a"), col("b"), col("c")])?
2359            .build()?;
2360        let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2361        let plan = LogicalPlanBuilder::from(left)
2362            .cross_join(right)?
2363            .project(vec![col("test.a"), col("test1.a")])?
2364            .filter(filter)?
2365            .build()?;
2366
2367        assert_optimized_plan_equal!(
2368            plan,
2369            @r"
2370        Projection: test.a, test1.a
2371          Cross Join: 
2372            Projection: test.a, test.b, test.c
2373              TableScan: test, full_filters=[test.a = Int32(1)]
2374            Projection: test1.a, test1.b, test1.c
2375              TableScan: test1, full_filters=[test1.a > Int32(2)]
2376        "
2377        )
2378    }
2379
2380    /// verifies that filters with the same columns are correctly placed
2381    #[test]
2382    fn filter_2_breaks_limits() -> Result<()> {
2383        let table_scan = test_table_scan()?;
2384        let plan = LogicalPlanBuilder::from(table_scan)
2385            .project(vec![col("a")])?
2386            .filter(col("a").lt_eq(lit(1i64)))?
2387            .limit(0, Some(1))?
2388            .project(vec![col("a")])?
2389            .filter(col("a").gt_eq(lit(1i64)))?
2390            .build()?;
2391        // Should be able to move both filters below the projections
2392
2393        // not part of the test
2394        assert_snapshot!(plan,
2395        @r"
2396        Filter: test.a >= Int64(1)
2397          Projection: test.a
2398            Limit: skip=0, fetch=1
2399              Filter: test.a <= Int64(1)
2400                Projection: test.a
2401                  TableScan: test
2402        ",
2403        );
2404        assert_optimized_plan_equal!(
2405            plan,
2406            @r"
2407        Projection: test.a
2408          Filter: test.a >= Int64(1)
2409            Limit: skip=0, fetch=1
2410              Projection: test.a
2411                TableScan: test, full_filters=[test.a <= Int64(1)]
2412        "
2413        )
2414    }
2415
2416    /// verifies that filters to be placed on the same depth are ANDed
2417    #[test]
2418    fn two_filters_on_same_depth() -> Result<()> {
2419        let table_scan = test_table_scan()?;
2420        let plan = LogicalPlanBuilder::from(table_scan)
2421            .limit(0, Some(1))?
2422            .filter(col("a").lt_eq(lit(1i64)))?
2423            .filter(col("a").gt_eq(lit(1i64)))?
2424            .project(vec![col("a")])?
2425            .build()?;
2426
2427        // not part of the test
2428        assert_snapshot!(plan,
2429        @r"
2430        Projection: test.a
2431          Filter: test.a >= Int64(1)
2432            Filter: test.a <= Int64(1)
2433              Limit: skip=0, fetch=1
2434                TableScan: test
2435        ",
2436        );
2437        assert_optimized_plan_equal!(
2438            plan,
2439            @r"
2440        Projection: test.a
2441          Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2442            Limit: skip=0, fetch=1
2443              TableScan: test
2444        "
2445        )
2446    }
2447
2448    /// verifies that filters on a plan with user nodes are not lost
2449    /// (ARROW-10547)
2450    #[test]
2451    fn filters_user_defined_node() -> Result<()> {
2452        let table_scan = test_table_scan()?;
2453        let plan = LogicalPlanBuilder::from(table_scan)
2454            .filter(col("a").lt_eq(lit(1i64)))?
2455            .build()?;
2456
2457        let plan = user_defined::new(plan);
2458
2459        // not part of the test
2460        assert_snapshot!(plan,
2461        @r"
2462        TestUserDefined
2463          Filter: test.a <= Int64(1)
2464            TableScan: test
2465        ",
2466        );
2467        assert_optimized_plan_equal!(
2468            plan,
2469            @r"
2470        TestUserDefined
2471          TableScan: test, full_filters=[test.a <= Int64(1)]
2472        "
2473        )
2474    }
2475
2476    /// post-on-join predicates on a column common to both sides is pushed to both sides
2477    #[test]
2478    fn filter_on_join_on_common_independent() -> Result<()> {
2479        let table_scan = test_table_scan()?;
2480        let left = LogicalPlanBuilder::from(table_scan).build()?;
2481        let right_table_scan = test_table_scan_with_name("test2")?;
2482        let right = LogicalPlanBuilder::from(right_table_scan)
2483            .project(vec![col("a")])?
2484            .build()?;
2485        let plan = LogicalPlanBuilder::from(left)
2486            .join(
2487                right,
2488                JoinType::Inner,
2489                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2490                None,
2491            )?
2492            .filter(col("test.a").lt_eq(lit(1i64)))?
2493            .build()?;
2494
2495        // not part of the test, just good to know:
2496        assert_snapshot!(plan,
2497        @r"
2498        Filter: test.a <= Int64(1)
2499          Inner Join: test.a = test2.a
2500            TableScan: test
2501            Projection: test2.a
2502              TableScan: test2
2503        ",
2504        );
2505        // filter sent to side before the join
2506        assert_optimized_plan_equal!(
2507            plan,
2508            @r"
2509        Inner Join: test.a = test2.a
2510          TableScan: test, full_filters=[test.a <= Int64(1)]
2511          Projection: test2.a
2512            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2513        "
2514        )
2515    }
2516
2517    /// post-using-join predicates on a column common to both sides is pushed to both sides
2518    #[test]
2519    fn filter_using_join_on_common_independent() -> Result<()> {
2520        let table_scan = test_table_scan()?;
2521        let left = LogicalPlanBuilder::from(table_scan).build()?;
2522        let right_table_scan = test_table_scan_with_name("test2")?;
2523        let right = LogicalPlanBuilder::from(right_table_scan)
2524            .project(vec![col("a")])?
2525            .build()?;
2526        let plan = LogicalPlanBuilder::from(left)
2527            .join_using(
2528                right,
2529                JoinType::Inner,
2530                vec![Column::from_name("a".to_string())],
2531            )?
2532            .filter(col("a").lt_eq(lit(1i64)))?
2533            .build()?;
2534
2535        // not part of the test, just good to know:
2536        assert_snapshot!(plan,
2537        @r"
2538        Filter: test.a <= Int64(1)
2539          Inner Join: Using test.a = test2.a
2540            TableScan: test
2541            Projection: test2.a
2542              TableScan: test2
2543        ",
2544        );
2545        // filter sent to side before the join
2546        assert_optimized_plan_equal!(
2547            plan,
2548            @r"
2549        Inner Join: Using test.a = test2.a
2550          TableScan: test, full_filters=[test.a <= Int64(1)]
2551          Projection: test2.a
2552            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2553        "
2554        )
2555    }
2556
2557    /// post-join predicates with columns from both sides are converted to join filters
2558    #[test]
2559    fn filter_join_on_common_dependent() -> Result<()> {
2560        let table_scan = test_table_scan()?;
2561        let left = LogicalPlanBuilder::from(table_scan)
2562            .project(vec![col("a"), col("c")])?
2563            .build()?;
2564        let right_table_scan = test_table_scan_with_name("test2")?;
2565        let right = LogicalPlanBuilder::from(right_table_scan)
2566            .project(vec![col("a"), col("b")])?
2567            .build()?;
2568        let plan = LogicalPlanBuilder::from(left)
2569            .join(
2570                right,
2571                JoinType::Inner,
2572                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2573                None,
2574            )?
2575            .filter(col("c").lt_eq(col("b")))?
2576            .build()?;
2577
2578        // not part of the test, just good to know:
2579        assert_snapshot!(plan,
2580        @r"
2581        Filter: test.c <= test2.b
2582          Inner Join: test.a = test2.a
2583            Projection: test.a, test.c
2584              TableScan: test
2585            Projection: test2.a, test2.b
2586              TableScan: test2
2587        ",
2588        );
2589        // Filter is converted to Join Filter
2590        assert_optimized_plan_equal!(
2591            plan,
2592            @r"
2593        Inner Join: test.a = test2.a Filter: test.c <= test2.b
2594          Projection: test.a, test.c
2595            TableScan: test
2596          Projection: test2.a, test2.b
2597            TableScan: test2
2598        "
2599        )
2600    }
2601
2602    /// post-join predicates with columns from one side of a join are pushed only to that side
2603    #[test]
2604    fn filter_join_on_one_side() -> Result<()> {
2605        let table_scan = test_table_scan()?;
2606        let left = LogicalPlanBuilder::from(table_scan)
2607            .project(vec![col("a"), col("b")])?
2608            .build()?;
2609        let table_scan_right = test_table_scan_with_name("test2")?;
2610        let right = LogicalPlanBuilder::from(table_scan_right)
2611            .project(vec![col("a"), col("c")])?
2612            .build()?;
2613
2614        let plan = LogicalPlanBuilder::from(left)
2615            .join(
2616                right,
2617                JoinType::Inner,
2618                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2619                None,
2620            )?
2621            .filter(col("b").lt_eq(lit(1i64)))?
2622            .build()?;
2623
2624        // not part of the test, just good to know:
2625        assert_snapshot!(plan,
2626        @r"
2627        Filter: test.b <= Int64(1)
2628          Inner Join: test.a = test2.a
2629            Projection: test.a, test.b
2630              TableScan: test
2631            Projection: test2.a, test2.c
2632              TableScan: test2
2633        ",
2634        );
2635        assert_optimized_plan_equal!(
2636            plan,
2637            @r"
2638        Inner Join: test.a = test2.a
2639          Projection: test.a, test.b
2640            TableScan: test, full_filters=[test.b <= Int64(1)]
2641          Projection: test2.a, test2.c
2642            TableScan: test2
2643        "
2644        )
2645    }
2646
2647    /// post-join predicates on the right side of a left join are not duplicated
2648    /// TODO: In this case we can sometimes convert the join to an INNER join
2649    #[test]
2650    fn filter_using_left_join() -> Result<()> {
2651        let table_scan = test_table_scan()?;
2652        let left = LogicalPlanBuilder::from(table_scan).build()?;
2653        let right_table_scan = test_table_scan_with_name("test2")?;
2654        let right = LogicalPlanBuilder::from(right_table_scan)
2655            .project(vec![col("a")])?
2656            .build()?;
2657        let plan = LogicalPlanBuilder::from(left)
2658            .join_using(
2659                right,
2660                JoinType::Left,
2661                vec![Column::from_name("a".to_string())],
2662            )?
2663            .filter(col("test2.a").lt_eq(lit(1i64)))?
2664            .build()?;
2665
2666        // not part of the test, just good to know:
2667        assert_snapshot!(plan,
2668        @r"
2669        Filter: test2.a <= Int64(1)
2670          Left Join: Using test.a = test2.a
2671            TableScan: test
2672            Projection: test2.a
2673              TableScan: test2
2674        ",
2675        );
2676        // filter not duplicated nor pushed down - i.e. noop
2677        assert_optimized_plan_equal!(
2678            plan,
2679            @r"
2680        Filter: test2.a <= Int64(1)
2681          Left Join: Using test.a = test2.a
2682            TableScan: test, full_filters=[test.a <= Int64(1)]
2683            Projection: test2.a
2684              TableScan: test2
2685        "
2686        )
2687    }
2688
2689    /// post-join predicates on the left side of a right join are not duplicated
2690    #[test]
2691    fn filter_using_right_join() -> Result<()> {
2692        let table_scan = test_table_scan()?;
2693        let left = LogicalPlanBuilder::from(table_scan).build()?;
2694        let right_table_scan = test_table_scan_with_name("test2")?;
2695        let right = LogicalPlanBuilder::from(right_table_scan)
2696            .project(vec![col("a")])?
2697            .build()?;
2698        let plan = LogicalPlanBuilder::from(left)
2699            .join_using(
2700                right,
2701                JoinType::Right,
2702                vec![Column::from_name("a".to_string())],
2703            )?
2704            .filter(col("test.a").lt_eq(lit(1i64)))?
2705            .build()?;
2706
2707        // not part of the test, just good to know:
2708        assert_snapshot!(plan,
2709        @r"
2710        Filter: test.a <= Int64(1)
2711          Right Join: Using test.a = test2.a
2712            TableScan: test
2713            Projection: test2.a
2714              TableScan: test2
2715        ",
2716        );
2717        // filter not duplicated nor pushed down - i.e. noop
2718        assert_optimized_plan_equal!(
2719            plan,
2720            @r"
2721        Filter: test.a <= Int64(1)
2722          Right Join: Using test.a = test2.a
2723            TableScan: test
2724            Projection: test2.a
2725              TableScan: test2, full_filters=[test2.a <= Int64(1)]
2726        "
2727        )
2728    }
2729
2730    /// post-left-join predicate on a column common to both sides is only pushed to the left side
2731    /// i.e. - not duplicated to the right side
2732    #[test]
2733    fn filter_using_left_join_on_common() -> Result<()> {
2734        let table_scan = test_table_scan()?;
2735        let left = LogicalPlanBuilder::from(table_scan).build()?;
2736        let right_table_scan = test_table_scan_with_name("test2")?;
2737        let right = LogicalPlanBuilder::from(right_table_scan)
2738            .project(vec![col("a")])?
2739            .build()?;
2740        let plan = LogicalPlanBuilder::from(left)
2741            .join_using(
2742                right,
2743                JoinType::Left,
2744                vec![Column::from_name("a".to_string())],
2745            )?
2746            .filter(col("a").lt_eq(lit(1i64)))?
2747            .build()?;
2748
2749        // not part of the test, just good to know:
2750        assert_snapshot!(plan,
2751        @r"
2752        Filter: test.a <= Int64(1)
2753          Left Join: Using test.a = test2.a
2754            TableScan: test
2755            Projection: test2.a
2756              TableScan: test2
2757        ",
2758        );
2759        // filter sent to left side of the join, not the right
2760        assert_optimized_plan_equal!(
2761            plan,
2762            @r"
2763        Left Join: Using test.a = test2.a
2764          TableScan: test, full_filters=[test.a <= Int64(1)]
2765          Projection: test2.a
2766            TableScan: test2
2767        "
2768        )
2769    }
2770
2771    /// post-right-join predicate on a column common to both sides is only pushed to the right side
2772    /// i.e. - not duplicated to the left side.
2773    #[test]
2774    fn filter_using_right_join_on_common() -> Result<()> {
2775        let table_scan = test_table_scan()?;
2776        let left = LogicalPlanBuilder::from(table_scan).build()?;
2777        let right_table_scan = test_table_scan_with_name("test2")?;
2778        let right = LogicalPlanBuilder::from(right_table_scan)
2779            .project(vec![col("a")])?
2780            .build()?;
2781        let plan = LogicalPlanBuilder::from(left)
2782            .join_using(
2783                right,
2784                JoinType::Right,
2785                vec![Column::from_name("a".to_string())],
2786            )?
2787            .filter(col("test2.a").lt_eq(lit(1i64)))?
2788            .build()?;
2789
2790        // not part of the test, just good to know:
2791        assert_snapshot!(plan,
2792        @r"
2793        Filter: test2.a <= Int64(1)
2794          Right Join: Using test.a = test2.a
2795            TableScan: test
2796            Projection: test2.a
2797              TableScan: test2
2798        ",
2799        );
2800        // filter sent to right side of join, not duplicated to the left
2801        assert_optimized_plan_equal!(
2802            plan,
2803            @r"
2804        Right Join: Using test.a = test2.a
2805          TableScan: test
2806          Projection: test2.a
2807            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2808        "
2809        )
2810    }
2811
2812    /// single table predicate parts of ON condition should be pushed to both inputs
2813    #[test]
2814    fn join_on_with_filter() -> Result<()> {
2815        let table_scan = test_table_scan()?;
2816        let left = LogicalPlanBuilder::from(table_scan)
2817            .project(vec![col("a"), col("b"), col("c")])?
2818            .build()?;
2819        let right_table_scan = test_table_scan_with_name("test2")?;
2820        let right = LogicalPlanBuilder::from(right_table_scan)
2821            .project(vec![col("a"), col("b"), col("c")])?
2822            .build()?;
2823        let filter = col("test.c")
2824            .gt(lit(1u32))
2825            .and(col("test.b").lt(col("test2.b")))
2826            .and(col("test2.c").gt(lit(4u32)));
2827        let plan = LogicalPlanBuilder::from(left)
2828            .join(
2829                right,
2830                JoinType::Inner,
2831                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2832                Some(filter),
2833            )?
2834            .build()?;
2835
2836        // not part of the test, just good to know:
2837        assert_snapshot!(plan,
2838        @r"
2839        Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2840          Projection: test.a, test.b, test.c
2841            TableScan: test
2842          Projection: test2.a, test2.b, test2.c
2843            TableScan: test2
2844        ",
2845        );
2846        assert_optimized_plan_equal!(
2847            plan,
2848            @r"
2849        Inner Join: test.a = test2.a Filter: test.b < test2.b
2850          Projection: test.a, test.b, test.c
2851            TableScan: test, full_filters=[test.c > UInt32(1)]
2852          Projection: test2.a, test2.b, test2.c
2853            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2854        "
2855        )
2856    }
2857
2858    /// join filter should be completely removed after pushdown
2859    #[test]
2860    fn join_filter_removed() -> Result<()> {
2861        let table_scan = test_table_scan()?;
2862        let left = LogicalPlanBuilder::from(table_scan)
2863            .project(vec![col("a"), col("b"), col("c")])?
2864            .build()?;
2865        let right_table_scan = test_table_scan_with_name("test2")?;
2866        let right = LogicalPlanBuilder::from(right_table_scan)
2867            .project(vec![col("a"), col("b"), col("c")])?
2868            .build()?;
2869        let filter = col("test.b")
2870            .gt(lit(1u32))
2871            .and(col("test2.c").gt(lit(4u32)));
2872        let plan = LogicalPlanBuilder::from(left)
2873            .join(
2874                right,
2875                JoinType::Inner,
2876                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2877                Some(filter),
2878            )?
2879            .build()?;
2880
2881        // not part of the test, just good to know:
2882        assert_snapshot!(plan,
2883        @r"
2884        Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2885          Projection: test.a, test.b, test.c
2886            TableScan: test
2887          Projection: test2.a, test2.b, test2.c
2888            TableScan: test2
2889        ",
2890        );
2891        assert_optimized_plan_equal!(
2892            plan,
2893            @r"
2894        Inner Join: test.a = test2.a
2895          Projection: test.a, test.b, test.c
2896            TableScan: test, full_filters=[test.b > UInt32(1)]
2897          Projection: test2.a, test2.b, test2.c
2898            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2899        "
2900        )
2901    }
2902
2903    /// predicate on join key in filter expression should be pushed down to both inputs
2904    #[test]
2905    fn join_filter_on_common() -> Result<()> {
2906        let table_scan = test_table_scan()?;
2907        let left = LogicalPlanBuilder::from(table_scan)
2908            .project(vec![col("a")])?
2909            .build()?;
2910        let right_table_scan = test_table_scan_with_name("test2")?;
2911        let right = LogicalPlanBuilder::from(right_table_scan)
2912            .project(vec![col("b")])?
2913            .build()?;
2914        let filter = col("test.a").gt(lit(1u32));
2915        let plan = LogicalPlanBuilder::from(left)
2916            .join(
2917                right,
2918                JoinType::Inner,
2919                (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2920                Some(filter),
2921            )?
2922            .build()?;
2923
2924        // not part of the test, just good to know:
2925        assert_snapshot!(plan,
2926        @r"
2927        Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2928          Projection: test.a
2929            TableScan: test
2930          Projection: test2.b
2931            TableScan: test2
2932        ",
2933        );
2934        assert_optimized_plan_equal!(
2935            plan,
2936            @r"
2937        Inner Join: test.a = test2.b
2938          Projection: test.a
2939            TableScan: test, full_filters=[test.a > UInt32(1)]
2940          Projection: test2.b
2941            TableScan: test2, full_filters=[test2.b > UInt32(1)]
2942        "
2943        )
2944    }
2945
2946    /// single table predicate parts of ON condition should be pushed to right input
2947    #[test]
2948    fn left_join_on_with_filter() -> Result<()> {
2949        let table_scan = test_table_scan()?;
2950        let left = LogicalPlanBuilder::from(table_scan)
2951            .project(vec![col("a"), col("b"), col("c")])?
2952            .build()?;
2953        let right_table_scan = test_table_scan_with_name("test2")?;
2954        let right = LogicalPlanBuilder::from(right_table_scan)
2955            .project(vec![col("a"), col("b"), col("c")])?
2956            .build()?;
2957        let filter = col("test.a")
2958            .gt(lit(1u32))
2959            .and(col("test.b").lt(col("test2.b")))
2960            .and(col("test2.c").gt(lit(4u32)));
2961        let plan = LogicalPlanBuilder::from(left)
2962            .join(
2963                right,
2964                JoinType::Left,
2965                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2966                Some(filter),
2967            )?
2968            .build()?;
2969
2970        // not part of the test, just good to know:
2971        assert_snapshot!(plan,
2972        @r"
2973        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2974          Projection: test.a, test.b, test.c
2975            TableScan: test
2976          Projection: test2.a, test2.b, test2.c
2977            TableScan: test2
2978        ",
2979        );
2980        assert_optimized_plan_equal!(
2981            plan,
2982            @r"
2983        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2984          Projection: test.a, test.b, test.c
2985            TableScan: test
2986          Projection: test2.a, test2.b, test2.c
2987            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2988        "
2989        )
2990    }
2991
2992    /// single table predicate parts of ON condition should be pushed to left input
2993    #[test]
2994    fn right_join_on_with_filter() -> Result<()> {
2995        let table_scan = test_table_scan()?;
2996        let left = LogicalPlanBuilder::from(table_scan)
2997            .project(vec![col("a"), col("b"), col("c")])?
2998            .build()?;
2999        let right_table_scan = test_table_scan_with_name("test2")?;
3000        let right = LogicalPlanBuilder::from(right_table_scan)
3001            .project(vec![col("a"), col("b"), col("c")])?
3002            .build()?;
3003        let filter = col("test.a")
3004            .gt(lit(1u32))
3005            .and(col("test.b").lt(col("test2.b")))
3006            .and(col("test2.c").gt(lit(4u32)));
3007        let plan = LogicalPlanBuilder::from(left)
3008            .join(
3009                right,
3010                JoinType::Right,
3011                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3012                Some(filter),
3013            )?
3014            .build()?;
3015
3016        // not part of the test, just good to know:
3017        assert_snapshot!(plan,
3018        @r"
3019        Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3020          Projection: test.a, test.b, test.c
3021            TableScan: test
3022          Projection: test2.a, test2.b, test2.c
3023            TableScan: test2
3024        ",
3025        );
3026        assert_optimized_plan_equal!(
3027            plan,
3028            @r"
3029        Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
3030          Projection: test.a, test.b, test.c
3031            TableScan: test, full_filters=[test.a > UInt32(1)]
3032          Projection: test2.a, test2.b, test2.c
3033            TableScan: test2
3034        "
3035        )
3036    }
3037
3038    /// single table predicate parts of ON condition should not be pushed
3039    #[test]
3040    fn full_join_on_with_filter() -> Result<()> {
3041        let table_scan = test_table_scan()?;
3042        let left = LogicalPlanBuilder::from(table_scan)
3043            .project(vec![col("a"), col("b"), col("c")])?
3044            .build()?;
3045        let right_table_scan = test_table_scan_with_name("test2")?;
3046        let right = LogicalPlanBuilder::from(right_table_scan)
3047            .project(vec![col("a"), col("b"), col("c")])?
3048            .build()?;
3049        let filter = col("test.a")
3050            .gt(lit(1u32))
3051            .and(col("test.b").lt(col("test2.b")))
3052            .and(col("test2.c").gt(lit(4u32)));
3053        let plan = LogicalPlanBuilder::from(left)
3054            .join(
3055                right,
3056                JoinType::Full,
3057                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3058                Some(filter),
3059            )?
3060            .build()?;
3061
3062        // not part of the test, just good to know:
3063        assert_snapshot!(plan,
3064        @r"
3065        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3066          Projection: test.a, test.b, test.c
3067            TableScan: test
3068          Projection: test2.a, test2.b, test2.c
3069            TableScan: test2
3070        ",
3071        );
3072        assert_optimized_plan_equal!(
3073            plan,
3074            @r"
3075        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3076          Projection: test.a, test.b, test.c
3077            TableScan: test
3078          Projection: test2.a, test2.b, test2.c
3079            TableScan: test2
3080        "
3081        )
3082    }
3083
3084    struct PushDownProvider {
3085        pub filter_support: TableProviderFilterPushDown,
3086    }
3087
3088    #[async_trait]
3089    impl TableSource for PushDownProvider {
3090        fn schema(&self) -> SchemaRef {
3091            Arc::new(Schema::new(vec![
3092                Field::new("a", DataType::Int32, true),
3093                Field::new("b", DataType::Int32, true),
3094            ]))
3095        }
3096
3097        fn table_type(&self) -> TableType {
3098            TableType::Base
3099        }
3100
3101        fn supports_filters_pushdown(
3102            &self,
3103            filters: &[&Expr],
3104        ) -> Result<Vec<TableProviderFilterPushDown>> {
3105            Ok((0..filters.len())
3106                .map(|_| self.filter_support.clone())
3107                .collect())
3108        }
3109
3110        fn as_any(&self) -> &dyn Any {
3111            self
3112        }
3113    }
3114
3115    fn table_scan_with_pushdown_provider_builder(
3116        filter_support: TableProviderFilterPushDown,
3117        filters: Vec<Expr>,
3118        projection: Option<Vec<usize>>,
3119    ) -> Result<LogicalPlanBuilder> {
3120        let test_provider = PushDownProvider { filter_support };
3121
3122        let table_scan = LogicalPlan::TableScan(TableScan {
3123            table_name: "test".into(),
3124            filters,
3125            projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3126            projection,
3127            source: Arc::new(test_provider),
3128            fetch: None,
3129        });
3130
3131        Ok(LogicalPlanBuilder::from(table_scan))
3132    }
3133
3134    fn table_scan_with_pushdown_provider(
3135        filter_support: TableProviderFilterPushDown,
3136    ) -> Result<LogicalPlan> {
3137        table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3138            .filter(col("a").eq(lit(1i64)))?
3139            .build()
3140    }
3141
3142    #[test]
3143    fn filter_with_table_provider_exact() -> Result<()> {
3144        let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3145
3146        assert_optimized_plan_equal!(
3147            plan,
3148            @"TableScan: test, full_filters=[a = Int64(1)]"
3149        )
3150    }
3151
3152    #[test]
3153    fn filter_with_table_provider_inexact() -> Result<()> {
3154        let plan =
3155            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3156
3157        assert_optimized_plan_equal!(
3158            plan,
3159            @r"
3160        Filter: a = Int64(1)
3161          TableScan: test, partial_filters=[a = Int64(1)]
3162        "
3163        )
3164    }
3165
3166    #[test]
3167    fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3168        let plan =
3169            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3170
3171        let optimized_plan = PushDownFilter::new()
3172            .rewrite(plan, &OptimizerContext::new())
3173            .expect("failed to optimize plan")
3174            .data;
3175
3176        // Optimizing the same plan multiple times should produce the same plan
3177        // each time.
3178        assert_optimized_plan_equal!(
3179            optimized_plan,
3180            @r"
3181        Filter: a = Int64(1)
3182          TableScan: test, partial_filters=[a = Int64(1)]
3183        "
3184        )
3185    }
3186
3187    #[test]
3188    fn filter_with_table_provider_unsupported() -> Result<()> {
3189        let plan =
3190            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3191
3192        assert_optimized_plan_equal!(
3193            plan,
3194            @r"
3195        Filter: a = Int64(1)
3196          TableScan: test
3197        "
3198        )
3199    }
3200
3201    #[test]
3202    fn multi_combined_filter() -> Result<()> {
3203        let plan = table_scan_with_pushdown_provider_builder(
3204            TableProviderFilterPushDown::Inexact,
3205            vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3206            Some(vec![0]),
3207        )?
3208        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3209        .project(vec![col("a"), col("b")])?
3210        .build()?;
3211
3212        assert_optimized_plan_equal!(
3213            plan,
3214            @r"
3215        Projection: a, b
3216          Filter: a = Int64(10) AND b > Int64(11)
3217            TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3218        "
3219        )
3220    }
3221
3222    #[test]
3223    fn multi_combined_filter_exact() -> Result<()> {
3224        let plan = table_scan_with_pushdown_provider_builder(
3225            TableProviderFilterPushDown::Exact,
3226            vec![],
3227            Some(vec![0]),
3228        )?
3229        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3230        .project(vec![col("a"), col("b")])?
3231        .build()?;
3232
3233        assert_optimized_plan_equal!(
3234            plan,
3235            @r"
3236        Projection: a, b
3237          TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3238        "
3239        )
3240    }
3241
3242    #[test]
3243    fn test_filter_with_alias() -> Result<()> {
3244        // in table scan the true col name is 'test.a',
3245        // but we rename it as 'b', and use col 'b' in filter
3246        // we need rewrite filter col before push down.
3247        let table_scan = test_table_scan()?;
3248        let plan = LogicalPlanBuilder::from(table_scan)
3249            .project(vec![col("a").alias("b"), col("c")])?
3250            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3251            .build()?;
3252
3253        // filter on col b
3254        assert_snapshot!(plan,
3255        @r"
3256        Filter: b > Int64(10) AND test.c > Int64(10)
3257          Projection: test.a AS b, test.c
3258            TableScan: test
3259        ",
3260        );
3261        // rewrite filter col b to test.a
3262        assert_optimized_plan_equal!(
3263            plan,
3264            @r"
3265        Projection: test.a AS b, test.c
3266          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3267        "
3268        )
3269    }
3270
3271    #[test]
3272    fn test_filter_with_alias_2() -> Result<()> {
3273        // in table scan the true col name is 'test.a',
3274        // but we rename it as 'b', and use col 'b' in filter
3275        // we need rewrite filter col before push down.
3276        let table_scan = test_table_scan()?;
3277        let plan = LogicalPlanBuilder::from(table_scan)
3278            .project(vec![col("a").alias("b"), col("c")])?
3279            .project(vec![col("b"), col("c")])?
3280            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3281            .build()?;
3282
3283        // filter on col b
3284        assert_snapshot!(plan,
3285        @r"
3286        Filter: b > Int64(10) AND test.c > Int64(10)
3287          Projection: b, test.c
3288            Projection: test.a AS b, test.c
3289              TableScan: test
3290        ",
3291        );
3292        // rewrite filter col b to test.a
3293        assert_optimized_plan_equal!(
3294            plan,
3295            @r"
3296        Projection: b, test.c
3297          Projection: test.a AS b, test.c
3298            TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3299        "
3300        )
3301    }
3302
3303    #[test]
3304    fn test_filter_with_multi_alias() -> Result<()> {
3305        let table_scan = test_table_scan()?;
3306        let plan = LogicalPlanBuilder::from(table_scan)
3307            .project(vec![col("a").alias("b"), col("c").alias("d")])?
3308            .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3309            .build()?;
3310
3311        // filter on col b and d
3312        assert_snapshot!(plan,
3313        @r"
3314        Filter: b > Int64(10) AND d > Int64(10)
3315          Projection: test.a AS b, test.c AS d
3316            TableScan: test
3317        ",
3318        );
3319        // rewrite filter col b to test.a, col d to test.c
3320        assert_optimized_plan_equal!(
3321            plan,
3322            @r"
3323        Projection: test.a AS b, test.c AS d
3324          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3325        "
3326        )
3327    }
3328
3329    /// predicate on join key in filter expression should be pushed down to both inputs
3330    #[test]
3331    fn join_filter_with_alias() -> Result<()> {
3332        let table_scan = test_table_scan()?;
3333        let left = LogicalPlanBuilder::from(table_scan)
3334            .project(vec![col("a").alias("c")])?
3335            .build()?;
3336        let right_table_scan = test_table_scan_with_name("test2")?;
3337        let right = LogicalPlanBuilder::from(right_table_scan)
3338            .project(vec![col("b").alias("d")])?
3339            .build()?;
3340        let filter = col("c").gt(lit(1u32));
3341        let plan = LogicalPlanBuilder::from(left)
3342            .join(
3343                right,
3344                JoinType::Inner,
3345                (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3346                Some(filter),
3347            )?
3348            .build()?;
3349
3350        assert_snapshot!(plan,
3351        @r"
3352        Inner Join: c = d Filter: c > UInt32(1)
3353          Projection: test.a AS c
3354            TableScan: test
3355          Projection: test2.b AS d
3356            TableScan: test2
3357        ",
3358        );
3359        // Change filter on col `c`, 'd' to `test.a`, 'test.b'
3360        assert_optimized_plan_equal!(
3361            plan,
3362            @r"
3363        Inner Join: c = d
3364          Projection: test.a AS c
3365            TableScan: test, full_filters=[test.a > UInt32(1)]
3366          Projection: test2.b AS d
3367            TableScan: test2, full_filters=[test2.b > UInt32(1)]
3368        "
3369        )
3370    }
3371
3372    #[test]
3373    fn test_in_filter_with_alias() -> Result<()> {
3374        // in table scan the true col name is 'test.a',
3375        // but we rename it as 'b', and use col 'b' in filter
3376        // we need rewrite filter col before push down.
3377        let table_scan = test_table_scan()?;
3378        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3379        let plan = LogicalPlanBuilder::from(table_scan)
3380            .project(vec![col("a").alias("b"), col("c")])?
3381            .filter(in_list(col("b"), filter_value, false))?
3382            .build()?;
3383
3384        // filter on col b
3385        assert_snapshot!(plan,
3386        @r"
3387        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3388          Projection: test.a AS b, test.c
3389            TableScan: test
3390        ",
3391        );
3392        // rewrite filter col b to test.a
3393        assert_optimized_plan_equal!(
3394            plan,
3395            @r"
3396        Projection: test.a AS b, test.c
3397          TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3398        "
3399        )
3400    }
3401
3402    #[test]
3403    fn test_in_filter_with_alias_2() -> Result<()> {
3404        // in table scan the true col name is 'test.a',
3405        // but we rename it as 'b', and use col 'b' in filter
3406        // we need rewrite filter col before push down.
3407        let table_scan = test_table_scan()?;
3408        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3409        let plan = LogicalPlanBuilder::from(table_scan)
3410            .project(vec![col("a").alias("b"), col("c")])?
3411            .project(vec![col("b"), col("c")])?
3412            .filter(in_list(col("b"), filter_value, false))?
3413            .build()?;
3414
3415        // filter on col b
3416        assert_snapshot!(plan,
3417        @r"
3418        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3419          Projection: b, test.c
3420            Projection: test.a AS b, test.c
3421              TableScan: test
3422        ",
3423        );
3424        // rewrite filter col b to test.a
3425        assert_optimized_plan_equal!(
3426            plan,
3427            @r"
3428        Projection: b, test.c
3429          Projection: test.a AS b, test.c
3430            TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3431        "
3432        )
3433    }
3434
3435    #[test]
3436    fn test_in_subquery_with_alias() -> Result<()> {
3437        // in table scan the true col name is 'test.a',
3438        // but we rename it as 'b', and use col 'b' in subquery filter
3439        let table_scan = test_table_scan()?;
3440        let table_scan_sq = test_table_scan_with_name("sq")?;
3441        let subplan = Arc::new(
3442            LogicalPlanBuilder::from(table_scan_sq)
3443                .project(vec![col("c")])?
3444                .build()?,
3445        );
3446        let plan = LogicalPlanBuilder::from(table_scan)
3447            .project(vec![col("a").alias("b"), col("c")])?
3448            .filter(in_subquery(col("b"), subplan))?
3449            .build()?;
3450
3451        // filter on col b in subquery
3452        assert_snapshot!(plan,
3453        @r"
3454        Filter: b IN (<subquery>)
3455          Subquery:
3456            Projection: sq.c
3457              TableScan: sq
3458          Projection: test.a AS b, test.c
3459            TableScan: test
3460        ",
3461        );
3462        // rewrite filter col b to test.a
3463        assert_optimized_plan_equal!(
3464            plan,
3465            @r"
3466        Projection: test.a AS b, test.c
3467          TableScan: test, full_filters=[test.a IN (<subquery>)]
3468            Subquery:
3469              Projection: sq.c
3470                TableScan: sq
3471        "
3472        )
3473    }
3474
3475    #[test]
3476    fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3477        // SELECT a FROM (SELECT 1 AS a) b WHERE b.a = 1
3478        let plan = LogicalPlanBuilder::empty(true)
3479            .project(vec![lit(0i64).alias("a")])?
3480            .alias("b")?
3481            .project(vec![col("b.a")])?
3482            .alias("b")?
3483            .filter(col("b.a").eq(lit(1i64)))?
3484            .project(vec![col("b.a")])?
3485            .build()?;
3486
3487        assert_snapshot!(plan,
3488        @r"
3489        Projection: b.a
3490          Filter: b.a = Int64(1)
3491            SubqueryAlias: b
3492              Projection: b.a
3493                SubqueryAlias: b
3494                  Projection: Int64(0) AS a
3495                    EmptyRelation: rows=1
3496        ",
3497        );
3498        // Ensure that the predicate without any columns (0 = 1) is
3499        // still there.
3500        assert_optimized_plan_equal!(
3501            plan,
3502            @r"
3503        Projection: b.a
3504          SubqueryAlias: b
3505            Projection: b.a
3506              SubqueryAlias: b
3507                Projection: Int64(0) AS a
3508                  Filter: Int64(0) = Int64(1)
3509                    EmptyRelation: rows=1
3510        "
3511        )
3512    }
3513
3514    #[test]
3515    fn test_crossjoin_with_or_clause() -> Result<()> {
3516        // select * from test,test1 where (test.a = test1.a and test.b > 1) or (test.b = test1.b and test.c < 10);
3517        let table_scan = test_table_scan()?;
3518        let left = LogicalPlanBuilder::from(table_scan)
3519            .project(vec![col("a"), col("b"), col("c")])?
3520            .build()?;
3521        let right_table_scan = test_table_scan_with_name("test1")?;
3522        let right = LogicalPlanBuilder::from(right_table_scan)
3523            .project(vec![col("a").alias("d"), col("a").alias("e")])?
3524            .build()?;
3525        let filter = or(
3526            and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3527            and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3528        );
3529        let plan = LogicalPlanBuilder::from(left)
3530            .cross_join(right)?
3531            .filter(filter)?
3532            .build()?;
3533
3534        assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3535        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3536          Projection: test.a, test.b, test.c
3537            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3538          Projection: test1.a AS d, test1.a AS e
3539            TableScan: test1
3540        ")?;
3541
3542        // Originally global state which can help to avoid duplicate Filters been generated and pushed down.
3543        // Now the global state is removed. Need to double confirm that avoid duplicate Filters.
3544        let optimized_plan = PushDownFilter::new()
3545            .rewrite(plan, &OptimizerContext::new())
3546            .expect("failed to optimize plan")
3547            .data;
3548        assert_optimized_plan_equal!(
3549            optimized_plan,
3550            @r"
3551        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3552          Projection: test.a, test.b, test.c
3553            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3554          Projection: test1.a AS d, test1.a AS e
3555            TableScan: test1
3556        "
3557        )
3558    }
3559
3560    #[test]
3561    fn left_semi_join() -> Result<()> {
3562        let left = test_table_scan_with_name("test1")?;
3563        let right_table_scan = test_table_scan_with_name("test2")?;
3564        let right = LogicalPlanBuilder::from(right_table_scan)
3565            .project(vec![col("a"), col("b")])?
3566            .build()?;
3567        let plan = LogicalPlanBuilder::from(left)
3568            .join(
3569                right,
3570                JoinType::LeftSemi,
3571                (
3572                    vec![Column::from_qualified_name("test1.a")],
3573                    vec![Column::from_qualified_name("test2.a")],
3574                ),
3575                None,
3576            )?
3577            .filter(col("test2.a").lt_eq(lit(1i64)))?
3578            .build()?;
3579
3580        // not part of the test, just good to know:
3581        assert_snapshot!(plan,
3582        @r"
3583        Filter: test2.a <= Int64(1)
3584          LeftSemi Join: test1.a = test2.a
3585            TableScan: test1
3586            Projection: test2.a, test2.b
3587              TableScan: test2
3588        ",
3589        );
3590        // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side.
3591        assert_optimized_plan_equal!(
3592            plan,
3593            @r"
3594        Filter: test2.a <= Int64(1)
3595          LeftSemi Join: test1.a = test2.a
3596            TableScan: test1, full_filters=[test1.a <= Int64(1)]
3597            Projection: test2.a, test2.b
3598              TableScan: test2
3599        "
3600        )
3601    }
3602
3603    #[test]
3604    fn left_semi_join_with_filters() -> Result<()> {
3605        let left = test_table_scan_with_name("test1")?;
3606        let right_table_scan = test_table_scan_with_name("test2")?;
3607        let right = LogicalPlanBuilder::from(right_table_scan)
3608            .project(vec![col("a"), col("b")])?
3609            .build()?;
3610        let plan = LogicalPlanBuilder::from(left)
3611            .join(
3612                right,
3613                JoinType::LeftSemi,
3614                (
3615                    vec![Column::from_qualified_name("test1.a")],
3616                    vec![Column::from_qualified_name("test2.a")],
3617                ),
3618                Some(
3619                    col("test1.b")
3620                        .gt(lit(1u32))
3621                        .and(col("test2.b").gt(lit(2u32))),
3622                ),
3623            )?
3624            .build()?;
3625
3626        // not part of the test, just good to know:
3627        assert_snapshot!(plan,
3628        @r"
3629        LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3630          TableScan: test1
3631          Projection: test2.a, test2.b
3632            TableScan: test2
3633        ",
3634        );
3635        // Both side will be pushed down.
3636        assert_optimized_plan_equal!(
3637            plan,
3638            @r"
3639        LeftSemi Join: test1.a = test2.a
3640          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3641          Projection: test2.a, test2.b
3642            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3643        "
3644        )
3645    }
3646
3647    #[test]
3648    fn right_semi_join() -> Result<()> {
3649        let left = test_table_scan_with_name("test1")?;
3650        let right_table_scan = test_table_scan_with_name("test2")?;
3651        let right = LogicalPlanBuilder::from(right_table_scan)
3652            .project(vec![col("a"), col("b")])?
3653            .build()?;
3654        let plan = LogicalPlanBuilder::from(left)
3655            .join(
3656                right,
3657                JoinType::RightSemi,
3658                (
3659                    vec![Column::from_qualified_name("test1.a")],
3660                    vec![Column::from_qualified_name("test2.a")],
3661                ),
3662                None,
3663            )?
3664            .filter(col("test1.a").lt_eq(lit(1i64)))?
3665            .build()?;
3666
3667        // not part of the test, just good to know:
3668        assert_snapshot!(plan,
3669        @r"
3670        Filter: test1.a <= Int64(1)
3671          RightSemi Join: test1.a = test2.a
3672            TableScan: test1
3673            Projection: test2.a, test2.b
3674              TableScan: test2
3675        ",
3676        );
3677        // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side.
3678        assert_optimized_plan_equal!(
3679            plan,
3680            @r"
3681        Filter: test1.a <= Int64(1)
3682          RightSemi Join: test1.a = test2.a
3683            TableScan: test1
3684            Projection: test2.a, test2.b
3685              TableScan: test2, full_filters=[test2.a <= Int64(1)]
3686        "
3687        )
3688    }
3689
3690    #[test]
3691    fn right_semi_join_with_filters() -> Result<()> {
3692        let left = test_table_scan_with_name("test1")?;
3693        let right_table_scan = test_table_scan_with_name("test2")?;
3694        let right = LogicalPlanBuilder::from(right_table_scan)
3695            .project(vec![col("a"), col("b")])?
3696            .build()?;
3697        let plan = LogicalPlanBuilder::from(left)
3698            .join(
3699                right,
3700                JoinType::RightSemi,
3701                (
3702                    vec![Column::from_qualified_name("test1.a")],
3703                    vec![Column::from_qualified_name("test2.a")],
3704                ),
3705                Some(
3706                    col("test1.b")
3707                        .gt(lit(1u32))
3708                        .and(col("test2.b").gt(lit(2u32))),
3709                ),
3710            )?
3711            .build()?;
3712
3713        // not part of the test, just good to know:
3714        assert_snapshot!(plan,
3715        @r"
3716        RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3717          TableScan: test1
3718          Projection: test2.a, test2.b
3719            TableScan: test2
3720        ",
3721        );
3722        // Both side will be pushed down.
3723        assert_optimized_plan_equal!(
3724            plan,
3725            @r"
3726        RightSemi Join: test1.a = test2.a
3727          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3728          Projection: test2.a, test2.b
3729            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3730        "
3731        )
3732    }
3733
3734    #[test]
3735    fn left_anti_join() -> Result<()> {
3736        let table_scan = test_table_scan_with_name("test1")?;
3737        let left = LogicalPlanBuilder::from(table_scan)
3738            .project(vec![col("a"), col("b")])?
3739            .build()?;
3740        let right_table_scan = test_table_scan_with_name("test2")?;
3741        let right = LogicalPlanBuilder::from(right_table_scan)
3742            .project(vec![col("a"), col("b")])?
3743            .build()?;
3744        let plan = LogicalPlanBuilder::from(left)
3745            .join(
3746                right,
3747                JoinType::LeftAnti,
3748                (
3749                    vec![Column::from_qualified_name("test1.a")],
3750                    vec![Column::from_qualified_name("test2.a")],
3751                ),
3752                None,
3753            )?
3754            .filter(col("test2.a").gt(lit(2u32)))?
3755            .build()?;
3756
3757        // not part of the test, just good to know:
3758        assert_snapshot!(plan,
3759        @r"
3760        Filter: test2.a > UInt32(2)
3761          LeftAnti Join: test1.a = test2.a
3762            Projection: test1.a, test1.b
3763              TableScan: test1
3764            Projection: test2.a, test2.b
3765              TableScan: test2
3766        ",
3767        );
3768        // For left anti, filter of the right side filter can be pushed down.
3769        assert_optimized_plan_equal!(
3770            plan,
3771            @r"
3772        Filter: test2.a > UInt32(2)
3773          LeftAnti Join: test1.a = test2.a
3774            Projection: test1.a, test1.b
3775              TableScan: test1, full_filters=[test1.a > UInt32(2)]
3776            Projection: test2.a, test2.b
3777              TableScan: test2
3778        "
3779        )
3780    }
3781
3782    #[test]
3783    fn left_anti_join_with_filters() -> Result<()> {
3784        let table_scan = test_table_scan_with_name("test1")?;
3785        let left = LogicalPlanBuilder::from(table_scan)
3786            .project(vec![col("a"), col("b")])?
3787            .build()?;
3788        let right_table_scan = test_table_scan_with_name("test2")?;
3789        let right = LogicalPlanBuilder::from(right_table_scan)
3790            .project(vec![col("a"), col("b")])?
3791            .build()?;
3792        let plan = LogicalPlanBuilder::from(left)
3793            .join(
3794                right,
3795                JoinType::LeftAnti,
3796                (
3797                    vec![Column::from_qualified_name("test1.a")],
3798                    vec![Column::from_qualified_name("test2.a")],
3799                ),
3800                Some(
3801                    col("test1.b")
3802                        .gt(lit(1u32))
3803                        .and(col("test2.b").gt(lit(2u32))),
3804                ),
3805            )?
3806            .build()?;
3807
3808        // not part of the test, just good to know:
3809        assert_snapshot!(plan,
3810        @r"
3811        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3812          Projection: test1.a, test1.b
3813            TableScan: test1
3814          Projection: test2.a, test2.b
3815            TableScan: test2
3816        ",
3817        );
3818        // For left anti, filter of the right side filter can be pushed down.
3819        assert_optimized_plan_equal!(
3820            plan,
3821            @r"
3822        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3823          Projection: test1.a, test1.b
3824            TableScan: test1
3825          Projection: test2.a, test2.b
3826            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3827        "
3828        )
3829    }
3830
3831    #[test]
3832    fn right_anti_join() -> Result<()> {
3833        let table_scan = test_table_scan_with_name("test1")?;
3834        let left = LogicalPlanBuilder::from(table_scan)
3835            .project(vec![col("a"), col("b")])?
3836            .build()?;
3837        let right_table_scan = test_table_scan_with_name("test2")?;
3838        let right = LogicalPlanBuilder::from(right_table_scan)
3839            .project(vec![col("a"), col("b")])?
3840            .build()?;
3841        let plan = LogicalPlanBuilder::from(left)
3842            .join(
3843                right,
3844                JoinType::RightAnti,
3845                (
3846                    vec![Column::from_qualified_name("test1.a")],
3847                    vec![Column::from_qualified_name("test2.a")],
3848                ),
3849                None,
3850            )?
3851            .filter(col("test1.a").gt(lit(2u32)))?
3852            .build()?;
3853
3854        // not part of the test, just good to know:
3855        assert_snapshot!(plan,
3856        @r"
3857        Filter: test1.a > UInt32(2)
3858          RightAnti Join: test1.a = test2.a
3859            Projection: test1.a, test1.b
3860              TableScan: test1
3861            Projection: test2.a, test2.b
3862              TableScan: test2
3863        ",
3864        );
3865        // For right anti, filter of the left side can be pushed down.
3866        assert_optimized_plan_equal!(
3867            plan,
3868            @r"
3869        Filter: test1.a > UInt32(2)
3870          RightAnti Join: test1.a = test2.a
3871            Projection: test1.a, test1.b
3872              TableScan: test1
3873            Projection: test2.a, test2.b
3874              TableScan: test2, full_filters=[test2.a > UInt32(2)]
3875        "
3876        )
3877    }
3878
3879    #[test]
3880    fn right_anti_join_with_filters() -> Result<()> {
3881        let table_scan = test_table_scan_with_name("test1")?;
3882        let left = LogicalPlanBuilder::from(table_scan)
3883            .project(vec![col("a"), col("b")])?
3884            .build()?;
3885        let right_table_scan = test_table_scan_with_name("test2")?;
3886        let right = LogicalPlanBuilder::from(right_table_scan)
3887            .project(vec![col("a"), col("b")])?
3888            .build()?;
3889        let plan = LogicalPlanBuilder::from(left)
3890            .join(
3891                right,
3892                JoinType::RightAnti,
3893                (
3894                    vec![Column::from_qualified_name("test1.a")],
3895                    vec![Column::from_qualified_name("test2.a")],
3896                ),
3897                Some(
3898                    col("test1.b")
3899                        .gt(lit(1u32))
3900                        .and(col("test2.b").gt(lit(2u32))),
3901                ),
3902            )?
3903            .build()?;
3904
3905        // not part of the test, just good to know:
3906        assert_snapshot!(plan,
3907        @r"
3908        RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3909          Projection: test1.a, test1.b
3910            TableScan: test1
3911          Projection: test2.a, test2.b
3912            TableScan: test2
3913        ",
3914        );
3915        // For right anti, filter of the left side can be pushed down.
3916        assert_optimized_plan_equal!(
3917            plan,
3918            @r"
3919        RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3920          Projection: test1.a, test1.b
3921            TableScan: test1, full_filters=[test1.b > UInt32(1)]
3922          Projection: test2.a, test2.b
3923            TableScan: test2
3924        "
3925        )
3926    }
3927
3928    #[derive(Debug, PartialEq, Eq, Hash)]
3929    struct TestScalarUDF {
3930        signature: Signature,
3931    }
3932
3933    impl ScalarUDFImpl for TestScalarUDF {
3934        fn as_any(&self) -> &dyn Any {
3935            self
3936        }
3937        fn name(&self) -> &str {
3938            "TestScalarUDF"
3939        }
3940
3941        fn signature(&self) -> &Signature {
3942            &self.signature
3943        }
3944
3945        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3946            Ok(DataType::Int32)
3947        }
3948
3949        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3950            Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3951        }
3952    }
3953
3954    #[test]
3955    fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3956        // SELECT t.a, t.r FROM (SELECT a, sum(b),  TestScalarUDF()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5;
3957        let table_scan = test_table_scan_with_name("test1")?;
3958        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3959            signature: Signature::exact(vec![], Volatility::Volatile),
3960        });
3961        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3962
3963        let plan = LogicalPlanBuilder::from(table_scan)
3964            .aggregate(vec![col("a")], vec![sum(col("b"))])?
3965            .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3966            .alias("t")?
3967            .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3968            .project(vec![col("t.a"), col("t.r")])?
3969            .build()?;
3970
3971        assert_snapshot!(plan,
3972        @r"
3973        Projection: t.a, t.r
3974          Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3975            SubqueryAlias: t
3976              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3977                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3978                  TableScan: test1
3979        ",
3980        );
3981        assert_optimized_plan_equal!(
3982            plan,
3983            @r"
3984        Projection: t.a, t.r
3985          SubqueryAlias: t
3986            Filter: r > Float64(0.5)
3987              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3988                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3989                  TableScan: test1, full_filters=[test1.a > Int32(5)]
3990        "
3991        )
3992    }
3993
3994    #[test]
3995    fn test_push_down_volatile_function_in_join() -> Result<()> {
3996        // SELECT t.a, t.r FROM (SELECT test1.a AS a, TestScalarUDF() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5;
3997        let table_scan = test_table_scan_with_name("test1")?;
3998        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3999            signature: Signature::exact(vec![], Volatility::Volatile),
4000        });
4001        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4002        let left = LogicalPlanBuilder::from(table_scan).build()?;
4003        let right_table_scan = test_table_scan_with_name("test2")?;
4004        let right = LogicalPlanBuilder::from(right_table_scan).build()?;
4005        let plan = LogicalPlanBuilder::from(left)
4006            .join(
4007                right,
4008                JoinType::Inner,
4009                (
4010                    vec![Column::from_qualified_name("test1.a")],
4011                    vec![Column::from_qualified_name("test2.a")],
4012                ),
4013                None,
4014            )?
4015            .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
4016            .alias("t")?
4017            .filter(col("t.r").gt(lit(0.8)))?
4018            .project(vec![col("t.a"), col("t.r")])?
4019            .build()?;
4020
4021        assert_snapshot!(plan,
4022        @r"
4023        Projection: t.a, t.r
4024          Filter: t.r > Float64(0.8)
4025            SubqueryAlias: t
4026              Projection: test1.a AS a, TestScalarUDF() AS r
4027                Inner Join: test1.a = test2.a
4028                  TableScan: test1
4029                  TableScan: test2
4030        ",
4031        );
4032        assert_optimized_plan_equal!(
4033            plan,
4034            @r"
4035        Projection: t.a, t.r
4036          SubqueryAlias: t
4037            Filter: r > Float64(0.8)
4038              Projection: test1.a AS a, TestScalarUDF() AS r
4039                Inner Join: test1.a = test2.a
4040                  TableScan: test1
4041                  TableScan: test2
4042        "
4043        )
4044    }
4045
4046    #[test]
4047    fn test_push_down_volatile_table_scan() -> Result<()> {
4048        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1;
4049        let table_scan = test_table_scan()?;
4050        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4051            signature: Signature::exact(vec![], Volatility::Volatile),
4052        });
4053        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4054        let plan = LogicalPlanBuilder::from(table_scan)
4055            .project(vec![col("a"), col("b")])?
4056            .filter(expr.gt(lit(0.1)))?
4057            .build()?;
4058
4059        assert_snapshot!(plan,
4060        @r"
4061        Filter: TestScalarUDF() > Float64(0.1)
4062          Projection: test.a, test.b
4063            TableScan: test
4064        ",
4065        );
4066        assert_optimized_plan_equal!(
4067            plan,
4068            @r"
4069        Projection: test.a, test.b
4070          Filter: TestScalarUDF() > Float64(0.1)
4071            TableScan: test
4072        "
4073        )
4074    }
4075
4076    #[test]
4077    fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4078        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4079        let table_scan = test_table_scan()?;
4080        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4081            signature: Signature::exact(vec![], Volatility::Volatile),
4082        });
4083        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4084        let plan = LogicalPlanBuilder::from(table_scan)
4085            .project(vec![col("a"), col("b")])?
4086            .filter(
4087                expr.gt(lit(0.1))
4088                    .and(col("t.a").gt(lit(5)))
4089                    .and(col("t.b").gt(lit(10))),
4090            )?
4091            .build()?;
4092
4093        assert_snapshot!(plan,
4094        @r"
4095        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4096          Projection: test.a, test.b
4097            TableScan: test
4098        ",
4099        );
4100        assert_optimized_plan_equal!(
4101            plan,
4102            @r"
4103        Projection: test.a, test.b
4104          Filter: TestScalarUDF() > Float64(0.1)
4105            TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4106        "
4107        )
4108    }
4109
4110    #[test]
4111    fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4112        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4113        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4114            signature: Signature::exact(vec![], Volatility::Volatile),
4115        });
4116        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4117        let plan = table_scan_with_pushdown_provider_builder(
4118            TableProviderFilterPushDown::Unsupported,
4119            vec![],
4120            None,
4121        )?
4122        .project(vec![col("a"), col("b")])?
4123        .filter(
4124            expr.gt(lit(0.1))
4125                .and(col("t.a").gt(lit(5)))
4126                .and(col("t.b").gt(lit(10))),
4127        )?
4128        .build()?;
4129
4130        assert_snapshot!(plan,
4131        @r"
4132        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4133          Projection: a, b
4134            TableScan: test
4135        ",
4136        );
4137        assert_optimized_plan_equal!(
4138            plan,
4139            @r"
4140        Projection: a, b
4141          Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4142            TableScan: test
4143        "
4144        )
4145    }
4146
4147    #[test]
4148    fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4149        // Define a custom user-defined logical node
4150        #[derive(Debug, Hash, Eq, PartialEq)]
4151        struct TestUserNode {
4152            schema: DFSchemaRef,
4153        }
4154
4155        impl PartialOrd for TestUserNode {
4156            fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4157                None
4158            }
4159        }
4160
4161        impl TestUserNode {
4162            fn new() -> Self {
4163                let schema = Arc::new(
4164                    DFSchema::new_with_metadata(
4165                        vec![(None, Field::new("a", DataType::Int64, false).into())],
4166                        Default::default(),
4167                    )
4168                    .unwrap(),
4169                );
4170
4171                Self { schema }
4172            }
4173        }
4174
4175        impl UserDefinedLogicalNodeCore for TestUserNode {
4176            fn name(&self) -> &str {
4177                "test_node"
4178            }
4179
4180            fn inputs(&self) -> Vec<&LogicalPlan> {
4181                vec![]
4182            }
4183
4184            fn schema(&self) -> &DFSchemaRef {
4185                &self.schema
4186            }
4187
4188            fn expressions(&self) -> Vec<Expr> {
4189                vec![]
4190            }
4191
4192            fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4193                write!(f, "TestUserNode")
4194            }
4195
4196            fn with_exprs_and_inputs(
4197                &self,
4198                exprs: Vec<Expr>,
4199                inputs: Vec<LogicalPlan>,
4200            ) -> Result<Self> {
4201                assert!(exprs.is_empty());
4202                assert!(inputs.is_empty());
4203                Ok(Self {
4204                    schema: Arc::clone(&self.schema),
4205                })
4206            }
4207        }
4208
4209        // Create a node and build a plan with a filter
4210        let node = LogicalPlan::Extension(Extension {
4211            node: Arc::new(TestUserNode::new()),
4212        });
4213
4214        let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4215
4216        // Check the original plan format (not part of the test assertions)
4217        assert_snapshot!(plan,
4218        @r"
4219        Filter: Boolean(false)
4220          TestUserNode
4221        ",
4222        );
4223        // Check that the filter is pushed down to the user-defined node
4224        assert_optimized_plan_equal!(
4225            plan,
4226            @r"
4227        Filter: Boolean(false)
4228          TestUserNode
4229        "
4230        )
4231    }
4232
4233    #[test]
4234    fn filter_not_pushed_down_through_table_scan_with_fetch() -> Result<()> {
4235        let scan = test_table_scan()?;
4236        let scan_with_fetch = match scan {
4237            LogicalPlan::TableScan(scan) => LogicalPlan::TableScan(TableScan {
4238                fetch: Some(10),
4239                ..scan
4240            }),
4241            _ => unreachable!(),
4242        };
4243        let plan = LogicalPlanBuilder::from(scan_with_fetch)
4244            .filter(col("a").gt(lit(10i64)))?
4245            .build()?;
4246        // Filter must NOT be pushed into the table scan when it has a fetch (limit)
4247        assert_optimized_plan_equal!(
4248            plan,
4249            @r"
4250        Filter: test.a > Int64(10)
4251          TableScan: test, fetch=10
4252        "
4253        )
4254    }
4255
4256    #[test]
4257    fn filter_push_down_through_sort_without_fetch() -> Result<()> {
4258        let table_scan = test_table_scan()?;
4259        let plan = LogicalPlanBuilder::from(table_scan)
4260            .sort(vec![col("a").sort(true, true)])?
4261            .filter(col("a").gt(lit(10i64)))?
4262            .build()?;
4263        // Filter should be pushed below the sort
4264        assert_optimized_plan_equal!(
4265            plan,
4266            @r"
4267        Sort: test.a ASC NULLS FIRST
4268          TableScan: test, full_filters=[test.a > Int64(10)]
4269        "
4270        )
4271    }
4272
4273    #[test]
4274    fn filter_not_pushed_down_through_sort_with_fetch() -> Result<()> {
4275        let table_scan = test_table_scan()?;
4276        let plan = LogicalPlanBuilder::from(table_scan)
4277            .sort_with_limit(vec![col("a").sort(true, true)], Some(5))?
4278            .filter(col("a").gt(lit(10i64)))?
4279            .build()?;
4280        // Filter must NOT be pushed below the sort when it has a fetch (limit),
4281        // because the limit should apply before the filter.
4282        assert_optimized_plan_equal!(
4283            plan,
4284            @r"
4285        Filter: test.a > Int64(10)
4286          Sort: test.a ASC NULLS FIRST, fetch=5
4287            TableScan: test
4288        "
4289        )
4290    }
4291}