datafusion_optimizer/
push_down_filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`PushDownFilter`] applies filters as early as possible
19
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use indexmap::IndexSet;
24use itertools::Itertools;
25
26use datafusion_common::tree_node::{
27    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
28};
29use datafusion_common::{
30    internal_err, plan_err, qualified_name, Column, DFSchema, Result,
31};
32use datafusion_expr::expr::WindowFunction;
33use datafusion_expr::expr_rewriter::replace_col;
34use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
35use datafusion_expr::utils::{
36    conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
37};
38use datafusion_expr::{
39    and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown,
40};
41
42use crate::optimizer::ApplyOrder;
43use crate::simplify_expressions::simplify_predicates;
44use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
45use crate::{OptimizerConfig, OptimizerRule};
46
47/// Optimizer rule for pushing (moving) filter expressions down in a plan so
48/// they are applied as early as possible.
49///
50/// # Introduction
51///
52/// The goal of this rule is to improve query performance by eliminating
53/// redundant work.
54///
55/// For example, given a plan that sorts all values where `a > 10`:
56///
57/// ```text
58///  Filter (a > 10)
59///    Sort (a, b)
60/// ```
61///
62/// A better plan is to  filter the data *before* the Sort, which sorts fewer
63/// rows and therefore does less work overall:
64///
65/// ```text
66///  Sort (a, b)
67///    Filter (a > 10)  <-- Filter is moved before the sort
68/// ```
69///
70/// However it is not always possible to push filters down. For example, given a
71/// plan that finds the top 3 values and then keeps only those that are greater
72/// than 10, if the filter is pushed below the limit it would produce a
73/// different result.
74///
75/// ```text
76///  Filter (a > 10)   <-- can not move this Filter before the limit
77///    Limit (fetch=3)
78///      Sort (a, b)
79/// ```
80///
81///
82/// More formally, a filter-commutative operation is an operation `op` that
83/// satisfies `filter(op(data)) = op(filter(data))`.
84///
85/// The filter-commutative property is plan and column-specific. A filter on `a`
86/// can be pushed through a `Aggregate(group_by = [a], agg=[sum(b))`. However, a
87/// filter on  `sum(b)` can not be pushed through the same aggregate.
88///
89/// # Handling Conjunctions
90///
91/// It is possible to only push down **part** of a filter expression if is
92/// connected with `AND`s (more formally if it is a "conjunction").
93///
94/// For example, given the following plan:
95///
96/// ```text
97/// Filter(a > 10 AND sum(b) < 5)
98///   Aggregate(group_by = [a], agg = [sum(b))
99/// ```
100///
101/// The `a > 10` is commutative with the `Aggregate` but  `sum(b) < 5` is not.
102/// Therefore it is possible to only push part of the expression, resulting in:
103///
104/// ```text
105/// Filter(sum(b) < 5)
106///   Aggregate(group_by = [a], agg = [sum(b))
107///     Filter(a > 10)
108/// ```
109///
110/// # Handling Column Aliases
111///
112/// This optimizer must sometimes handle re-writing filter expressions when they
113/// pushed, for example if there is a projection that aliases `a+1` to `"b"`:
114///
115/// ```text
116/// Filter (b > 10)
117///     Projection: [a+1 AS "b"]  <-- changes the name of `a+1` to `b`
118/// ```
119///
120/// To apply the filter prior to the `Projection`, all references to `b` must be
121/// rewritten to `a+1`:
122///
123/// ```text
124/// Projection: a AS "b"
125///     Filter: (a + 1 > 10)  <--- changed from b to a + 1
126/// ```
127/// # Implementation Notes
128///
129/// This implementation performs a single pass through the plan, "pushing" down
130/// filters. When it passes through a filter, it stores that filter, and when it
131/// reaches a plan node that does not commute with that filter, it adds the
132/// filter to that place. When it passes through a projection, it re-writes the
133/// filter's expression taking into account that projection.
134#[derive(Default, Debug)]
135pub struct PushDownFilter {}
136
137/// For a given JOIN type, determine whether each input of the join is preserved
138/// for post-join (`WHERE` clause) filters.
139///
140/// It is only correct to push filters below a join for preserved inputs.
141///
142/// # Return Value
143/// A tuple of booleans - (left_preserved, right_preserved).
144///
145/// # "Preserved" input definition
146///
147/// We say a join side is preserved if the join returns all or a subset of the rows from
148/// the relevant side, such that each row of the output table directly maps to a row of
149/// the preserved input table. If a table is not preserved, it can provide extra null rows.
150/// That is, there may be rows in the output table that don't directly map to a row in the
151/// input table.
152///
153/// For example:
154///   - In an inner join, both sides are preserved, because each row of the output
155///     maps directly to a row from each side.
156///
157///   - In a left join, the left side is preserved (we can push predicates) but
158///     the right is not, because there may be rows in the output that don't
159///     directly map to a row in the right input (due to nulls filling where there
160///     is no match on the right).
161pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
162    match join_type {
163        JoinType::Inner => (true, true),
164        JoinType::Left => (true, false),
165        JoinType::Right => (false, true),
166        JoinType::Full => (false, false),
167        // No columns from the right side of the join can be referenced in output
168        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
169        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
170        // No columns from the left side of the join can be referenced in output
171        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
172        JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
173    }
174}
175
176/// For a given JOIN type, determine whether each input of the join is preserved
177/// for the join condition (`ON` clause filters).
178///
179/// It is only correct to push filters below a join for preserved inputs.
180///
181/// # Return Value
182/// A tuple of booleans - (left_preserved, right_preserved).
183///
184/// See [`lr_is_preserved`] for a definition of "preserved".
185pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
186    match join_type {
187        JoinType::Inner => (true, true),
188        JoinType::Left => (false, true),
189        JoinType::Right => (true, false),
190        JoinType::Full => (false, false),
191        JoinType::LeftSemi | JoinType::RightSemi => (true, true),
192        JoinType::LeftAnti => (false, true),
193        JoinType::RightAnti => (true, false),
194        JoinType::LeftMark => (false, true),
195        JoinType::RightMark => (true, false),
196    }
197}
198
199/// Evaluates the columns referenced in the given expression to see if they refer
200/// only to the left or right columns
201#[derive(Debug)]
202struct ColumnChecker<'a> {
203    /// schema of left join input
204    left_schema: &'a DFSchema,
205    /// columns in left_schema, computed on demand
206    left_columns: Option<HashSet<Column>>,
207    /// schema of right join input
208    right_schema: &'a DFSchema,
209    /// columns in left_schema, computed on demand
210    right_columns: Option<HashSet<Column>>,
211}
212
213impl<'a> ColumnChecker<'a> {
214    fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
215        Self {
216            left_schema,
217            left_columns: None,
218            right_schema,
219            right_columns: None,
220        }
221    }
222
223    /// Return true if the expression references only columns from the left side of the join
224    fn is_left_only(&mut self, predicate: &Expr) -> bool {
225        if self.left_columns.is_none() {
226            self.left_columns = Some(schema_columns(self.left_schema));
227        }
228        has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
229    }
230
231    /// Return true if the expression references only columns from the right side of the join
232    fn is_right_only(&mut self, predicate: &Expr) -> bool {
233        if self.right_columns.is_none() {
234            self.right_columns = Some(schema_columns(self.right_schema));
235        }
236        has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
237    }
238}
239
240/// Returns all columns in the schema
241fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
242    schema
243        .iter()
244        .flat_map(|(qualifier, field)| {
245            [
246                Column::new(qualifier.cloned(), field.name()),
247                // we need to push down filter using unqualified column as well
248                Column::new_unqualified(field.name()),
249            ]
250        })
251        .collect::<HashSet<_>>()
252}
253
254/// Determine whether the predicate can evaluate as the join conditions
255fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
256    let mut is_evaluate = true;
257    predicate.apply(|expr| match expr {
258        Expr::Column(_)
259        | Expr::Literal(_, _)
260        | Expr::Placeholder(_)
261        | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
262        Expr::Exists { .. }
263        | Expr::InSubquery(_)
264        | Expr::ScalarSubquery(_)
265        | Expr::OuterReferenceColumn(_, _)
266        | Expr::Unnest(_) => {
267            is_evaluate = false;
268            Ok(TreeNodeRecursion::Stop)
269        }
270        Expr::Alias(_)
271        | Expr::BinaryExpr(_)
272        | Expr::Like(_)
273        | Expr::SimilarTo(_)
274        | Expr::Not(_)
275        | Expr::IsNotNull(_)
276        | Expr::IsNull(_)
277        | Expr::IsTrue(_)
278        | Expr::IsFalse(_)
279        | Expr::IsUnknown(_)
280        | Expr::IsNotTrue(_)
281        | Expr::IsNotFalse(_)
282        | Expr::IsNotUnknown(_)
283        | Expr::Negative(_)
284        | Expr::Between(_)
285        | Expr::Case(_)
286        | Expr::Cast(_)
287        | Expr::TryCast(_)
288        | Expr::InList { .. }
289        | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
290        // TODO: remove the next line after `Expr::Wildcard` is removed
291        #[expect(deprecated)]
292        Expr::AggregateFunction(_)
293        | Expr::WindowFunction(_)
294        | Expr::Wildcard { .. }
295        | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
296    })?;
297    Ok(is_evaluate)
298}
299
300/// examine OR clause to see if any useful clauses can be extracted and push down.
301/// extract at least one qual from each sub clauses of OR clause, then form the quals
302/// to new OR clause as predicate.
303///
304/// # Example
305/// ```text
306/// Filter: (a = c and a < 20) or (b = d and b > 10)
307///     join/crossjoin:
308///          TableScan: projection=[a, b]
309///          TableScan: projection=[c, d]
310/// ```
311///
312/// is optimized to
313///
314/// ```text
315/// Filter: (a = c and a < 20) or (b = d and b > 10)
316///     join/crossjoin:
317///          Filter: (a < 20) or (b > 10)
318///              TableScan: projection=[a, b]
319///          TableScan: projection=[c, d]
320/// ```
321///
322/// In general, predicates of this form:
323///
324/// ```sql
325/// (A AND B) OR (C AND D)
326/// ```
327///
328/// will be transformed to one of:
329///
330/// * `((A AND B) OR (C AND D)) AND (A OR C)`
331/// * `((A AND B) OR (C AND D)) AND ((A AND B) OR C)`
332/// * do nothing.
333fn extract_or_clauses_for_join<'a>(
334    filters: &'a [Expr],
335    schema: &'a DFSchema,
336) -> impl Iterator<Item = Expr> + 'a {
337    let schema_columns = schema_columns(schema);
338
339    // new formed OR clauses and their column references
340    filters.iter().filter_map(move |expr| {
341        if let Expr::BinaryExpr(BinaryExpr {
342            left,
343            op: Operator::Or,
344            right,
345        }) = expr
346        {
347            let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
348            let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
349
350            // If nothing can be extracted from any sub clauses, do nothing for this OR clause.
351            if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
352                return Some(or(left_expr, right_expr));
353            }
354        }
355        None
356    })
357}
358
359/// extract qual from OR sub-clause.
360///
361/// A qual is extracted if it only contains set of column references in schema_columns.
362///
363/// For AND clause, we extract from both sub-clauses, then make new AND clause by extracted
364/// clauses if both extracted; Otherwise, use the extracted clause from any sub-clauses or None.
365///
366/// For OR clause, we extract from both sub-clauses, then make new OR clause by extracted clauses if both extracted;
367/// Otherwise, return None.
368///
369/// For other clause, apply the rule above to extract clause.
370fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
371    let mut predicate = None;
372
373    match expr {
374        Expr::BinaryExpr(BinaryExpr {
375            left: l_expr,
376            op: Operator::Or,
377            right: r_expr,
378        }) => {
379            let l_expr = extract_or_clause(l_expr, schema_columns);
380            let r_expr = extract_or_clause(r_expr, schema_columns);
381
382            if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
383                predicate = Some(or(l_expr, r_expr));
384            }
385        }
386        Expr::BinaryExpr(BinaryExpr {
387            left: l_expr,
388            op: Operator::And,
389            right: r_expr,
390        }) => {
391            let l_expr = extract_or_clause(l_expr, schema_columns);
392            let r_expr = extract_or_clause(r_expr, schema_columns);
393
394            match (l_expr, r_expr) {
395                (Some(l_expr), Some(r_expr)) => {
396                    predicate = Some(and(l_expr, r_expr));
397                }
398                (Some(l_expr), None) => {
399                    predicate = Some(l_expr);
400                }
401                (None, Some(r_expr)) => {
402                    predicate = Some(r_expr);
403                }
404                (None, None) => {
405                    predicate = None;
406                }
407            }
408        }
409        _ => {
410            if has_all_column_refs(expr, schema_columns) {
411                predicate = Some(expr.clone());
412            }
413        }
414    }
415
416    predicate
417}
418
419/// push down join/cross-join
420fn push_down_all_join(
421    predicates: Vec<Expr>,
422    inferred_join_predicates: Vec<Expr>,
423    mut join: Join,
424    on_filter: Vec<Expr>,
425) -> Result<Transformed<LogicalPlan>> {
426    let is_inner_join = join.join_type == JoinType::Inner;
427    // Get pushable predicates from current optimizer state
428    let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
429
430    // The predicates can be divided to three categories:
431    // 1) can push through join to its children(left or right)
432    // 2) can be converted to join conditions if the join type is Inner
433    // 3) should be kept as filter conditions
434    let left_schema = join.left.schema();
435    let right_schema = join.right.schema();
436    let mut left_push = vec![];
437    let mut right_push = vec![];
438    let mut keep_predicates = vec![];
439    let mut join_conditions = vec![];
440    let mut checker = ColumnChecker::new(left_schema, right_schema);
441    for predicate in predicates {
442        if left_preserved && checker.is_left_only(&predicate) {
443            left_push.push(predicate);
444        } else if right_preserved && checker.is_right_only(&predicate) {
445            right_push.push(predicate);
446        } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
447            // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
448            // and convert to the join on condition
449            join_conditions.push(predicate);
450        } else {
451            keep_predicates.push(predicate);
452        }
453    }
454
455    // For infer predicates, if they can not push through join, just drop them
456    for predicate in inferred_join_predicates {
457        if left_preserved && checker.is_left_only(&predicate) {
458            left_push.push(predicate);
459        } else if right_preserved && checker.is_right_only(&predicate) {
460            right_push.push(predicate);
461        }
462    }
463
464    let mut on_filter_join_conditions = vec![];
465    let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
466
467    if !on_filter.is_empty() {
468        for on in on_filter {
469            if on_left_preserved && checker.is_left_only(&on) {
470                left_push.push(on)
471            } else if on_right_preserved && checker.is_right_only(&on) {
472                right_push.push(on)
473            } else {
474                on_filter_join_conditions.push(on)
475            }
476        }
477    }
478
479    // Extract from OR clause, generate new predicates for both side of join if possible.
480    // We only track the unpushable predicates above.
481    if left_preserved {
482        left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
483        left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
484    }
485    if right_preserved {
486        right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
487        right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
488    }
489
490    // For predicates from join filter, we should check with if a join side is preserved
491    // in term of join filtering.
492    if on_left_preserved {
493        left_push.extend(extract_or_clauses_for_join(
494            &on_filter_join_conditions,
495            left_schema,
496        ));
497    }
498    if on_right_preserved {
499        right_push.extend(extract_or_clauses_for_join(
500            &on_filter_join_conditions,
501            right_schema,
502        ));
503    }
504
505    if let Some(predicate) = conjunction(left_push) {
506        join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
507    }
508    if let Some(predicate) = conjunction(right_push) {
509        join.right =
510            Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
511    }
512
513    // Add any new join conditions as the non join predicates
514    join_conditions.extend(on_filter_join_conditions);
515    join.filter = conjunction(join_conditions);
516
517    // wrap the join on the filter whose predicates must be kept, if any
518    let plan = LogicalPlan::Join(join);
519    let plan = if let Some(predicate) = conjunction(keep_predicates) {
520        LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
521    } else {
522        plan
523    };
524    Ok(Transformed::yes(plan))
525}
526
527fn push_down_join(
528    join: Join,
529    parent_predicate: Option<&Expr>,
530) -> Result<Transformed<LogicalPlan>> {
531    // Split the parent predicate into individual conjunctive parts.
532    let predicates = parent_predicate
533        .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
534
535    // Extract conjunctions from the JOIN's ON filter, if present.
536    let on_filters = join
537        .filter
538        .as_ref()
539        .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
540
541    // Are there any new join predicates that can be inferred from the filter expressions?
542    let inferred_join_predicates =
543        infer_join_predicates(&join, &predicates, &on_filters)?;
544
545    if on_filters.is_empty()
546        && predicates.is_empty()
547        && inferred_join_predicates.is_empty()
548    {
549        return Ok(Transformed::no(LogicalPlan::Join(join)));
550    }
551
552    push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
553}
554
555/// Extracts any equi-join join predicates from the given filter expressions.
556///
557/// Parameters
558/// * `join` the join in question
559///
560/// * `predicates` the pushed down filter expression
561///
562/// * `on_filters` filters from the join ON clause that have not already been
563///   identified as join predicates
564///
565fn infer_join_predicates(
566    join: &Join,
567    predicates: &[Expr],
568    on_filters: &[Expr],
569) -> Result<Vec<Expr>> {
570    // Only allow both side key is column.
571    let join_col_keys = join
572        .on
573        .iter()
574        .filter_map(|(l, r)| {
575            let left_col = l.try_as_col()?;
576            let right_col = r.try_as_col()?;
577            Some((left_col, right_col))
578        })
579        .collect::<Vec<_>>();
580
581    let join_type = join.join_type;
582
583    let mut inferred_predicates = InferredPredicates::new(join_type);
584
585    infer_join_predicates_from_predicates(
586        &join_col_keys,
587        predicates,
588        &mut inferred_predicates,
589    )?;
590
591    infer_join_predicates_from_on_filters(
592        &join_col_keys,
593        join_type,
594        on_filters,
595        &mut inferred_predicates,
596    )?;
597
598    Ok(inferred_predicates.predicates)
599}
600
601/// Inferred predicates collector.
602/// When the JoinType is not Inner, we need to detect whether the inferred predicate can strictly
603/// filter out NULL, otherwise ignore it. e.g.
604/// ```text
605/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL;
606/// ```
607/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate will be pushed down to
608/// the left side, resulting in the wrong result.
609struct InferredPredicates {
610    predicates: Vec<Expr>,
611    is_inner_join: bool,
612}
613
614impl InferredPredicates {
615    fn new(join_type: JoinType) -> Self {
616        Self {
617            predicates: vec![],
618            is_inner_join: matches!(join_type, JoinType::Inner),
619        }
620    }
621
622    fn try_build_predicate(
623        &mut self,
624        predicate: Expr,
625        replace_map: &HashMap<&Column, &Column>,
626    ) -> Result<()> {
627        if self.is_inner_join
628            || matches!(
629                is_restrict_null_predicate(
630                    predicate.clone(),
631                    replace_map.keys().cloned()
632                ),
633                Ok(true)
634            )
635        {
636            self.predicates.push(replace_col(predicate, replace_map)?);
637        }
638
639        Ok(())
640    }
641}
642
643/// Infer predicates from the pushed down predicates.
644///
645/// Parameters
646/// * `join_col_keys` column pairs from the join ON clause
647///
648/// * `predicates` the pushed down predicates
649///
650/// * `inferred_predicates` the inferred results
651///
652fn infer_join_predicates_from_predicates(
653    join_col_keys: &[(&Column, &Column)],
654    predicates: &[Expr],
655    inferred_predicates: &mut InferredPredicates,
656) -> Result<()> {
657    infer_join_predicates_impl::<true, true>(
658        join_col_keys,
659        predicates,
660        inferred_predicates,
661    )
662}
663
664/// Infer predicates from the join filter.
665///
666/// Parameters
667/// * `join_col_keys` column pairs from the join ON clause
668///
669/// * `join_type` the JoinType of Join
670///
671/// * `on_filters` filters from the join ON clause that have not already been
672///   identified as join predicates
673///
674/// * `inferred_predicates` the inferred results
675///
676fn infer_join_predicates_from_on_filters(
677    join_col_keys: &[(&Column, &Column)],
678    join_type: JoinType,
679    on_filters: &[Expr],
680    inferred_predicates: &mut InferredPredicates,
681) -> Result<()> {
682    match join_type {
683        JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
684        JoinType::Inner => infer_join_predicates_impl::<true, true>(
685            join_col_keys,
686            on_filters,
687            inferred_predicates,
688        ),
689        JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
690            infer_join_predicates_impl::<true, false>(
691                join_col_keys,
692                on_filters,
693                inferred_predicates,
694            )
695        }
696        JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
697            infer_join_predicates_impl::<false, true>(
698                join_col_keys,
699                on_filters,
700                inferred_predicates,
701            )
702        }
703    }
704}
705
706/// Infer predicates from the given predicates.
707///
708/// Parameters
709/// * `join_col_keys` column pairs from the join ON clause
710///
711/// * `input_predicates` the given predicates. It can be the pushed down predicates,
712///   or it can be the filters of the Join
713///
714/// * `inferred_predicates` the inferred results
715///
716/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate can
717///   be inferred from the left table related predicate
718///
719/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate can
720///   be inferred from the right table related predicate
721///
722fn infer_join_predicates_impl<
723    const ENABLE_LEFT_TO_RIGHT: bool,
724    const ENABLE_RIGHT_TO_LEFT: bool,
725>(
726    join_col_keys: &[(&Column, &Column)],
727    input_predicates: &[Expr],
728    inferred_predicates: &mut InferredPredicates,
729) -> Result<()> {
730    for predicate in input_predicates {
731        let mut join_cols_to_replace = HashMap::new();
732
733        for &col in &predicate.column_refs() {
734            for (l, r) in join_col_keys.iter() {
735                if ENABLE_LEFT_TO_RIGHT && col == *l {
736                    join_cols_to_replace.insert(col, *r);
737                    break;
738                }
739                if ENABLE_RIGHT_TO_LEFT && col == *r {
740                    join_cols_to_replace.insert(col, *l);
741                    break;
742                }
743            }
744        }
745        if join_cols_to_replace.is_empty() {
746            continue;
747        }
748
749        inferred_predicates
750            .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
751    }
752    Ok(())
753}
754
755impl OptimizerRule for PushDownFilter {
756    fn name(&self) -> &str {
757        "push_down_filter"
758    }
759
760    fn apply_order(&self) -> Option<ApplyOrder> {
761        Some(ApplyOrder::TopDown)
762    }
763
764    fn supports_rewrite(&self) -> bool {
765        true
766    }
767
768    fn rewrite(
769        &self,
770        plan: LogicalPlan,
771        _config: &dyn OptimizerConfig,
772    ) -> Result<Transformed<LogicalPlan>> {
773        if let LogicalPlan::Join(join) = plan {
774            return push_down_join(join, None);
775        };
776
777        let plan_schema = Arc::clone(plan.schema());
778
779        let LogicalPlan::Filter(mut filter) = plan else {
780            return Ok(Transformed::no(plan));
781        };
782
783        let predicate = split_conjunction_owned(filter.predicate.clone());
784        let old_predicate_len = predicate.len();
785        let new_predicates = simplify_predicates(predicate)?;
786        if old_predicate_len != new_predicates.len() {
787            let Some(new_predicate) = conjunction(new_predicates) else {
788                // new_predicates is empty - remove the filter entirely
789                // Return the child plan without the filter
790                return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
791            };
792            filter.predicate = new_predicate;
793        }
794
795        match Arc::unwrap_or_clone(filter.input) {
796            LogicalPlan::Filter(child_filter) => {
797                let parents_predicates = split_conjunction_owned(filter.predicate);
798
799                // remove duplicated filters
800                let child_predicates = split_conjunction_owned(child_filter.predicate);
801                let new_predicates = parents_predicates
802                    .into_iter()
803                    .chain(child_predicates)
804                    // use IndexSet to remove dupes while preserving predicate order
805                    .collect::<IndexSet<_>>()
806                    .into_iter()
807                    .collect::<Vec<_>>();
808
809                let Some(new_predicate) = conjunction(new_predicates) else {
810                    return plan_err!("at least one expression exists");
811                };
812                let new_filter = LogicalPlan::Filter(Filter::try_new(
813                    new_predicate,
814                    child_filter.input,
815                )?);
816                #[allow(clippy::used_underscore_binding)]
817                self.rewrite(new_filter, _config)
818            }
819            LogicalPlan::Repartition(repartition) => {
820                let new_filter =
821                    Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
822                        .map(LogicalPlan::Filter)?;
823                insert_below(LogicalPlan::Repartition(repartition), new_filter)
824            }
825            LogicalPlan::Distinct(distinct) => {
826                let new_filter =
827                    Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
828                        .map(LogicalPlan::Filter)?;
829                insert_below(LogicalPlan::Distinct(distinct), new_filter)
830            }
831            LogicalPlan::Sort(sort) => {
832                let new_filter =
833                    Filter::try_new(filter.predicate, Arc::clone(&sort.input))
834                        .map(LogicalPlan::Filter)?;
835                insert_below(LogicalPlan::Sort(sort), new_filter)
836            }
837            LogicalPlan::SubqueryAlias(subquery_alias) => {
838                let mut replace_map = HashMap::new();
839                for (i, (qualifier, field)) in
840                    subquery_alias.input.schema().iter().enumerate()
841                {
842                    let (sub_qualifier, sub_field) =
843                        subquery_alias.schema.qualified_field(i);
844                    replace_map.insert(
845                        qualified_name(sub_qualifier, sub_field.name()),
846                        Expr::Column(Column::new(qualifier.cloned(), field.name())),
847                    );
848                }
849                let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
850
851                let new_filter = LogicalPlan::Filter(Filter::try_new(
852                    new_predicate,
853                    Arc::clone(&subquery_alias.input),
854                )?);
855                insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
856            }
857            LogicalPlan::Projection(projection) => {
858                let predicates = split_conjunction_owned(filter.predicate.clone());
859                let (new_projection, keep_predicate) =
860                    rewrite_projection(predicates, projection)?;
861                if new_projection.transformed {
862                    match keep_predicate {
863                        None => Ok(new_projection),
864                        Some(keep_predicate) => new_projection.map_data(|child_plan| {
865                            Filter::try_new(keep_predicate, Arc::new(child_plan))
866                                .map(LogicalPlan::Filter)
867                        }),
868                    }
869                } else {
870                    filter.input = Arc::new(new_projection.data);
871                    Ok(Transformed::no(LogicalPlan::Filter(filter)))
872                }
873            }
874            LogicalPlan::Unnest(mut unnest) => {
875                let predicates = split_conjunction_owned(filter.predicate.clone());
876                let mut non_unnest_predicates = vec![];
877                let mut unnest_predicates = vec![];
878                for predicate in predicates {
879                    // collect all the Expr::Column in predicate recursively
880                    let mut accum: HashSet<Column> = HashSet::new();
881                    expr_to_columns(&predicate, &mut accum)?;
882
883                    if unnest.list_type_columns.iter().any(|(_, unnest_list)| {
884                        accum.contains(&unnest_list.output_column)
885                    }) {
886                        unnest_predicates.push(predicate);
887                    } else {
888                        non_unnest_predicates.push(predicate);
889                    }
890                }
891
892                // Unnest predicates should not be pushed down.
893                // If no non-unnest predicates exist, early return
894                if non_unnest_predicates.is_empty() {
895                    filter.input = Arc::new(LogicalPlan::Unnest(unnest));
896                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
897                }
898
899                // Push down non-unnest filter predicate
900                // Unnest
901                //   Unnest Input (Projection)
902                // -> rewritten to
903                // Unnest
904                //   Filter
905                //     Unnest Input (Projection)
906
907                let unnest_input = std::mem::take(&mut unnest.input);
908
909                let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
910                    conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty.
911                    unnest_input,
912                )?);
913
914                // Directly assign new filter plan as the new unnest's input.
915                // The new filter plan will go through another rewrite pass since the rule itself
916                // is applied recursively to all the child from top to down
917                let unnest_plan =
918                    insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
919
920                match conjunction(unnest_predicates) {
921                    None => Ok(unnest_plan),
922                    Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
923                        Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
924                    ))),
925                }
926            }
927            LogicalPlan::Union(ref union) => {
928                let mut inputs = Vec::with_capacity(union.inputs.len());
929                for input in &union.inputs {
930                    let mut replace_map = HashMap::new();
931                    for (i, (qualifier, field)) in input.schema().iter().enumerate() {
932                        let (union_qualifier, union_field) =
933                            union.schema.qualified_field(i);
934                        replace_map.insert(
935                            qualified_name(union_qualifier, union_field.name()),
936                            Expr::Column(Column::new(qualifier.cloned(), field.name())),
937                        );
938                    }
939
940                    let push_predicate =
941                        replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
942                    inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
943                        push_predicate,
944                        Arc::clone(input),
945                    )?)))
946                }
947                Ok(Transformed::yes(LogicalPlan::Union(Union {
948                    inputs,
949                    schema: Arc::clone(&plan_schema),
950                })))
951            }
952            LogicalPlan::Aggregate(agg) => {
953                // We can push down Predicate which in groupby_expr.
954                let group_expr_columns = agg
955                    .group_expr
956                    .iter()
957                    .map(|e| Ok(Column::from_qualified_name(e.schema_name().to_string())))
958                    .collect::<Result<HashSet<_>>>()?;
959
960                let predicates = split_conjunction_owned(filter.predicate);
961
962                let mut keep_predicates = vec![];
963                let mut push_predicates = vec![];
964                for expr in predicates {
965                    let cols = expr.column_refs();
966                    if cols.iter().all(|c| group_expr_columns.contains(c)) {
967                        push_predicates.push(expr);
968                    } else {
969                        keep_predicates.push(expr);
970                    }
971                }
972
973                // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
974                // After push, we need to replace `a+b` with Column(a)+Column(b)
975                // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
976                let mut replace_map = HashMap::new();
977                for expr in &agg.group_expr {
978                    replace_map.insert(expr.schema_name().to_string(), expr.clone());
979                }
980                let replaced_push_predicates = push_predicates
981                    .into_iter()
982                    .map(|expr| replace_cols_by_name(expr, &replace_map))
983                    .collect::<Result<Vec<_>>>()?;
984
985                let agg_input = Arc::clone(&agg.input);
986                Transformed::yes(LogicalPlan::Aggregate(agg))
987                    .transform_data(|new_plan| {
988                        // If we have a filter to push, we push it down to the input of the aggregate
989                        if let Some(predicate) = conjunction(replaced_push_predicates) {
990                            let new_filter = make_filter(predicate, agg_input)?;
991                            insert_below(new_plan, new_filter)
992                        } else {
993                            Ok(Transformed::no(new_plan))
994                        }
995                    })?
996                    .map_data(|child_plan| {
997                        // if there are any remaining predicates we can't push, add them
998                        // back as a filter
999                        if let Some(predicate) = conjunction(keep_predicates) {
1000                            make_filter(predicate, Arc::new(child_plan))
1001                        } else {
1002                            Ok(child_plan)
1003                        }
1004                    })
1005            }
1006            // Tries to push filters based on the partition key(s) of the window function(s) used.
1007            // Example:
1008            //   Before:
1009            //     Filter: (a > 1) and (b > 1) and (c > 1)
1010            //      Window: func() PARTITION BY [a] ...
1011            //   ---
1012            //   After:
1013            //     Filter: (b > 1) and (c > 1)
1014            //      Window: func() PARTITION BY [a] ...
1015            //        Filter: (a > 1)
1016            LogicalPlan::Window(window) => {
1017                // Retrieve the set of potential partition keys where we can push filters by.
1018                // Unlike aggregations, where there is only one statement per SELECT, there can be
1019                // multiple window functions, each with potentially different partition keys.
1020                // Therefore, we need to ensure that any potential partition key returned is used in
1021                // ALL window functions. Otherwise, filters cannot be pushed by through that column.
1022                let extract_partition_keys = |func: &WindowFunction| {
1023                    func.params
1024                        .partition_by
1025                        .iter()
1026                        .map(|c| Column::from_qualified_name(c.schema_name().to_string()))
1027                        .collect::<HashSet<_>>()
1028                };
1029                let potential_partition_keys = window
1030                    .window_expr
1031                    .iter()
1032                    .map(|e| {
1033                        match e {
1034                            Expr::WindowFunction(window_func) => {
1035                                extract_partition_keys(window_func)
1036                            }
1037                            Expr::Alias(alias) => {
1038                                if let Expr::WindowFunction(window_func) =
1039                                    alias.expr.as_ref()
1040                                {
1041                                    extract_partition_keys(window_func)
1042                                } else {
1043                                    // window functions expressions are only Expr::WindowFunction
1044                                    unreachable!()
1045                                }
1046                            }
1047                            _ => {
1048                                // window functions expressions are only Expr::WindowFunction
1049                                unreachable!()
1050                            }
1051                        }
1052                    })
1053                    // performs the set intersection of the partition keys of all window functions,
1054                    // returning only the common ones
1055                    .reduce(|a, b| &a & &b)
1056                    .unwrap_or_default();
1057
1058                let predicates = split_conjunction_owned(filter.predicate);
1059                let mut keep_predicates = vec![];
1060                let mut push_predicates = vec![];
1061                for expr in predicates {
1062                    let cols = expr.column_refs();
1063                    if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1064                        push_predicates.push(expr);
1065                    } else {
1066                        keep_predicates.push(expr);
1067                    }
1068                }
1069
1070                // Unlike with aggregations, there are no cases where we have to replace, e.g.,
1071                // `a+b` with Column(a)+Column(b). This is because partition expressions are not
1072                // available as standalone columns to the user. For example, while an aggregation on
1073                // `a+b` becomes Column(a + b), in a window partition it becomes
1074                // `func() PARTITION BY [a + b] ...`. Thus, filters on expressions always remain in
1075                // place, so we can use `push_predicates` directly. This is consistent with other
1076                // optimizers, such as the one used by Postgres.
1077
1078                let window_input = Arc::clone(&window.input);
1079                Transformed::yes(LogicalPlan::Window(window))
1080                    .transform_data(|new_plan| {
1081                        // If we have a filter to push, we push it down to the input of the window
1082                        if let Some(predicate) = conjunction(push_predicates) {
1083                            let new_filter = make_filter(predicate, window_input)?;
1084                            insert_below(new_plan, new_filter)
1085                        } else {
1086                            Ok(Transformed::no(new_plan))
1087                        }
1088                    })?
1089                    .map_data(|child_plan| {
1090                        // if there are any remaining predicates we can't push, add them
1091                        // back as a filter
1092                        if let Some(predicate) = conjunction(keep_predicates) {
1093                            make_filter(predicate, Arc::new(child_plan))
1094                        } else {
1095                            Ok(child_plan)
1096                        }
1097                    })
1098            }
1099            LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1100            LogicalPlan::TableScan(scan) => {
1101                let filter_predicates = split_conjunction(&filter.predicate);
1102
1103                let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1104                    filter_predicates
1105                        .into_iter()
1106                        .partition(|pred| pred.is_volatile());
1107
1108                // Check which non-volatile filters are supported by source
1109                let supported_filters = scan
1110                    .source
1111                    .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1112                if non_volatile_filters.len() != supported_filters.len() {
1113                    return internal_err!(
1114                        "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1115                        supported_filters.len(),
1116                        non_volatile_filters.len());
1117                }
1118
1119                // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type
1120                let zip = non_volatile_filters.into_iter().zip(supported_filters);
1121
1122                let new_scan_filters = zip
1123                    .clone()
1124                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1125                    .map(|(pred, _)| pred);
1126
1127                // Add new scan filters
1128                let new_scan_filters: Vec<Expr> = scan
1129                    .filters
1130                    .iter()
1131                    .chain(new_scan_filters)
1132                    .unique()
1133                    .cloned()
1134                    .collect();
1135
1136                // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
1137                let new_predicate: Vec<Expr> = zip
1138                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1139                    .map(|(pred, _)| pred)
1140                    .chain(volatile_filters)
1141                    .cloned()
1142                    .collect();
1143
1144                let new_scan = LogicalPlan::TableScan(TableScan {
1145                    filters: new_scan_filters,
1146                    ..scan
1147                });
1148
1149                Transformed::yes(new_scan).transform_data(|new_scan| {
1150                    if let Some(predicate) = conjunction(new_predicate) {
1151                        make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1152                    } else {
1153                        Ok(Transformed::no(new_scan))
1154                    }
1155                })
1156            }
1157            LogicalPlan::Extension(extension_plan) => {
1158                // This check prevents the Filter from being removed when the extension node has no children,
1159                // so we return the original Filter unchanged.
1160                if extension_plan.node.inputs().is_empty() {
1161                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1162                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1163                }
1164                let prevent_cols =
1165                    extension_plan.node.prevent_predicate_push_down_columns();
1166
1167                // determine if we can push any predicates down past the extension node
1168
1169                // each element is true for push, false to keep
1170                let predicate_push_or_keep = split_conjunction(&filter.predicate)
1171                    .iter()
1172                    .map(|expr| {
1173                        let cols = expr.column_refs();
1174                        if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1175                            Ok(false) // No push (keep)
1176                        } else {
1177                            Ok(true) // push
1178                        }
1179                    })
1180                    .collect::<Result<Vec<_>>>()?;
1181
1182                // all predicates are kept, no changes needed
1183                if predicate_push_or_keep.iter().all(|&x| !x) {
1184                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1185                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1186                }
1187
1188                // going to push some predicates down, so split the predicates
1189                let mut keep_predicates = vec![];
1190                let mut push_predicates = vec![];
1191                for (push, expr) in predicate_push_or_keep
1192                    .into_iter()
1193                    .zip(split_conjunction_owned(filter.predicate).into_iter())
1194                {
1195                    if !push {
1196                        keep_predicates.push(expr);
1197                    } else {
1198                        push_predicates.push(expr);
1199                    }
1200                }
1201
1202                let new_children = match conjunction(push_predicates) {
1203                    Some(predicate) => extension_plan
1204                        .node
1205                        .inputs()
1206                        .into_iter()
1207                        .map(|child| {
1208                            Ok(LogicalPlan::Filter(Filter::try_new(
1209                                predicate.clone(),
1210                                Arc::new(child.clone()),
1211                            )?))
1212                        })
1213                        .collect::<Result<Vec<_>>>()?,
1214                    None => extension_plan.node.inputs().into_iter().cloned().collect(),
1215                };
1216                // extension with new inputs.
1217                let child_plan = LogicalPlan::Extension(extension_plan);
1218                let new_extension =
1219                    child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1220
1221                let new_plan = match conjunction(keep_predicates) {
1222                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1223                        predicate,
1224                        Arc::new(new_extension),
1225                    )?),
1226                    None => new_extension,
1227                };
1228                Ok(Transformed::yes(new_plan))
1229            }
1230            child => {
1231                filter.input = Arc::new(child);
1232                Ok(Transformed::no(LogicalPlan::Filter(filter)))
1233            }
1234        }
1235    }
1236}
1237
1238/// Attempts to push `predicate` into a `FilterExec` below `projection
1239///
1240/// # Returns
1241/// (plan, remaining_predicate)
1242///
1243/// `plan` is a LogicalPlan for `projection` with possibly a new FilterExec below it.
1244/// `remaining_predicate` is any part of the predicate that could not be pushed down
1245///
1246/// # Args
1247/// - predicates: Split predicates like `[foo=5, bar=6]`
1248/// - projection: The target projection plan to push down the predicates
1249///
1250/// # Example
1251///
1252/// Pushing a predicate like `foo=5 AND bar=6` with an input plan like this:
1253///
1254/// ```text
1255/// Projection(foo, c+d as bar)
1256/// ```
1257///
1258/// Might result in returning `remaining_predicate` of `bar=6` and a plan like
1259///
1260/// ```text
1261/// Projection(foo, c+d as bar)
1262///  Filter(foo=5)
1263///   ...
1264/// ```
1265fn rewrite_projection(
1266    predicates: Vec<Expr>,
1267    mut projection: Projection,
1268) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1269    // A projection is filter-commutable if it do not contain volatile predicates or contain volatile
1270    // predicates that are not used in the filter. However, we should re-writes all predicate expressions.
1271    // collect projection.
1272    let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
1273        .schema
1274        .iter()
1275        .zip(projection.expr.iter())
1276        .map(|((qualifier, field), expr)| {
1277            // strip alias, as they should not be part of filters
1278            let expr = expr.clone().unalias();
1279
1280            (qualified_name(qualifier, field.name()), expr)
1281        })
1282        .partition(|(_, value)| value.is_volatile());
1283
1284    let mut push_predicates = vec![];
1285    let mut keep_predicates = vec![];
1286    for expr in predicates {
1287        if contain(&expr, &volatile_map) {
1288            keep_predicates.push(expr);
1289        } else {
1290            push_predicates.push(expr);
1291        }
1292    }
1293
1294    match conjunction(push_predicates) {
1295        Some(expr) => {
1296            // re-write all filters based on this projection
1297            // E.g. in `Filter: b\n  Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
1298            let new_filter = LogicalPlan::Filter(Filter::try_new(
1299                replace_cols_by_name(expr, &non_volatile_map)?,
1300                std::mem::take(&mut projection.input),
1301            )?);
1302
1303            projection.input = Arc::new(new_filter);
1304
1305            Ok((
1306                Transformed::yes(LogicalPlan::Projection(projection)),
1307                conjunction(keep_predicates),
1308            ))
1309        }
1310        None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
1311    }
1312}
1313
1314/// Creates a new LogicalPlan::Filter node.
1315pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1316    Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1317}
1318
1319/// Replace the existing child of the single input node with `new_child`.
1320///
1321/// Starting:
1322/// ```text
1323/// plan
1324///   child
1325/// ```
1326///
1327/// Ending:
1328/// ```text
1329/// plan
1330///   new_child
1331/// ```
1332fn insert_below(
1333    plan: LogicalPlan,
1334    new_child: LogicalPlan,
1335) -> Result<Transformed<LogicalPlan>> {
1336    let mut new_child = Some(new_child);
1337    let transformed_plan = plan.map_children(|_child| {
1338        if let Some(new_child) = new_child.take() {
1339            Ok(Transformed::yes(new_child))
1340        } else {
1341            // already took the new child
1342            internal_err!("node had more than one input")
1343        }
1344    })?;
1345
1346    // make sure we did the actual replacement
1347    if new_child.is_some() {
1348        return internal_err!("node had no  inputs");
1349    }
1350
1351    Ok(transformed_plan)
1352}
1353
1354impl PushDownFilter {
1355    #[allow(missing_docs)]
1356    pub fn new() -> Self {
1357        Self {}
1358    }
1359}
1360
1361/// replaces columns by its name on the projection.
1362pub fn replace_cols_by_name(
1363    e: Expr,
1364    replace_map: &HashMap<String, Expr>,
1365) -> Result<Expr> {
1366    e.transform_up(|expr| {
1367        Ok(if let Expr::Column(c) = &expr {
1368            match replace_map.get(&c.flat_name()) {
1369                Some(new_c) => Transformed::yes(new_c.clone()),
1370                None => Transformed::no(expr),
1371            }
1372        } else {
1373            Transformed::no(expr)
1374        })
1375    })
1376    .data()
1377}
1378
1379/// check whether the expression uses the columns in `check_map`.
1380fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1381    let mut is_contain = false;
1382    e.apply(|expr| {
1383        Ok(if let Expr::Column(c) = &expr {
1384            match check_map.get(&c.flat_name()) {
1385                Some(_) => {
1386                    is_contain = true;
1387                    TreeNodeRecursion::Stop
1388                }
1389                None => TreeNodeRecursion::Continue,
1390            }
1391        } else {
1392            TreeNodeRecursion::Continue
1393        })
1394    })
1395    .unwrap();
1396    is_contain
1397}
1398
1399#[cfg(test)]
1400mod tests {
1401    use std::any::Any;
1402    use std::cmp::Ordering;
1403    use std::fmt::{Debug, Formatter};
1404
1405    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1406    use async_trait::async_trait;
1407
1408    use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1409    use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1410    use datafusion_expr::logical_plan::table_scan;
1411    use datafusion_expr::{
1412        col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension,
1413        LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1414        TableSource, TableType, UserDefinedLogicalNodeCore, Volatility,
1415        WindowFunctionDefinition,
1416    };
1417
1418    use crate::assert_optimized_plan_eq_snapshot;
1419    use crate::optimizer::Optimizer;
1420    use crate::simplify_expressions::SimplifyExpressions;
1421    use crate::test::*;
1422    use crate::OptimizerContext;
1423    use datafusion_expr::test::function_stub::sum;
1424    use insta::assert_snapshot;
1425
1426    use super::*;
1427
1428    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1429
1430    macro_rules! assert_optimized_plan_equal {
1431        (
1432            $plan:expr,
1433            @ $expected:literal $(,)?
1434        ) => {{
1435            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1436            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1437            assert_optimized_plan_eq_snapshot!(
1438                optimizer_ctx,
1439                rules,
1440                $plan,
1441                @ $expected,
1442            )
1443        }};
1444    }
1445
1446    macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1447        (
1448            $plan:expr,
1449            @ $expected:literal $(,)?
1450        ) => {{
1451            let optimizer = Optimizer::with_rules(vec![
1452                Arc::new(SimplifyExpressions::new()),
1453                Arc::new(PushDownFilter::new()),
1454            ]);
1455            let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1456            assert_snapshot!(optimized_plan, @ $expected);
1457            Ok::<(), DataFusionError>(())
1458        }};
1459    }
1460
1461    #[test]
1462    fn filter_before_projection() -> Result<()> {
1463        let table_scan = test_table_scan()?;
1464        let plan = LogicalPlanBuilder::from(table_scan)
1465            .project(vec![col("a"), col("b")])?
1466            .filter(col("a").eq(lit(1i64)))?
1467            .build()?;
1468        // filter is before projection
1469        assert_optimized_plan_equal!(
1470            plan,
1471            @r"
1472        Projection: test.a, test.b
1473          TableScan: test, full_filters=[test.a = Int64(1)]
1474        "
1475        )
1476    }
1477
1478    #[test]
1479    fn filter_after_limit() -> Result<()> {
1480        let table_scan = test_table_scan()?;
1481        let plan = LogicalPlanBuilder::from(table_scan)
1482            .project(vec![col("a"), col("b")])?
1483            .limit(0, Some(10))?
1484            .filter(col("a").eq(lit(1i64)))?
1485            .build()?;
1486        // filter is before single projection
1487        assert_optimized_plan_equal!(
1488            plan,
1489            @r"
1490        Filter: test.a = Int64(1)
1491          Limit: skip=0, fetch=10
1492            Projection: test.a, test.b
1493              TableScan: test
1494        "
1495        )
1496    }
1497
1498    #[test]
1499    fn filter_no_columns() -> Result<()> {
1500        let table_scan = test_table_scan()?;
1501        let plan = LogicalPlanBuilder::from(table_scan)
1502            .filter(lit(0i64).eq(lit(1i64)))?
1503            .build()?;
1504        assert_optimized_plan_equal!(
1505            plan,
1506            @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1507        )
1508    }
1509
1510    #[test]
1511    fn filter_jump_2_plans() -> Result<()> {
1512        let table_scan = test_table_scan()?;
1513        let plan = LogicalPlanBuilder::from(table_scan)
1514            .project(vec![col("a"), col("b"), col("c")])?
1515            .project(vec![col("c"), col("b")])?
1516            .filter(col("a").eq(lit(1i64)))?
1517            .build()?;
1518        // filter is before double projection
1519        assert_optimized_plan_equal!(
1520            plan,
1521            @r"
1522        Projection: test.c, test.b
1523          Projection: test.a, test.b, test.c
1524            TableScan: test, full_filters=[test.a = Int64(1)]
1525        "
1526        )
1527    }
1528
1529    #[test]
1530    fn filter_move_agg() -> Result<()> {
1531        let table_scan = test_table_scan()?;
1532        let plan = LogicalPlanBuilder::from(table_scan)
1533            .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1534            .filter(col("a").gt(lit(10i64)))?
1535            .build()?;
1536        // filter of key aggregation is commutative
1537        assert_optimized_plan_equal!(
1538            plan,
1539            @r"
1540        Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1541          TableScan: test, full_filters=[test.a > Int64(10)]
1542        "
1543        )
1544    }
1545
1546    #[test]
1547    fn filter_complex_group_by() -> Result<()> {
1548        let table_scan = test_table_scan()?;
1549        let plan = LogicalPlanBuilder::from(table_scan)
1550            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1551            .filter(col("b").gt(lit(10i64)))?
1552            .build()?;
1553        assert_optimized_plan_equal!(
1554            plan,
1555            @r"
1556        Filter: test.b > Int64(10)
1557          Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1558            TableScan: test
1559        "
1560        )
1561    }
1562
1563    #[test]
1564    fn push_agg_need_replace_expr() -> Result<()> {
1565        let plan = LogicalPlanBuilder::from(test_table_scan()?)
1566            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1567            .filter(col("test.b + test.a").gt(lit(10i64)))?
1568            .build()?;
1569        assert_optimized_plan_equal!(
1570            plan,
1571            @r"
1572        Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1573          TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1574        "
1575        )
1576    }
1577
1578    #[test]
1579    fn filter_keep_agg() -> Result<()> {
1580        let table_scan = test_table_scan()?;
1581        let plan = LogicalPlanBuilder::from(table_scan)
1582            .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1583            .filter(col("b").gt(lit(10i64)))?
1584            .build()?;
1585        // filter of aggregate is after aggregation since they are non-commutative
1586        assert_optimized_plan_equal!(
1587            plan,
1588            @r"
1589        Filter: b > Int64(10)
1590          Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1591            TableScan: test
1592        "
1593        )
1594    }
1595
1596    /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed
1597    #[test]
1598    fn filter_move_window() -> Result<()> {
1599        let table_scan = test_table_scan()?;
1600
1601        let window = Expr::from(WindowFunction::new(
1602            WindowFunctionDefinition::WindowUDF(
1603                datafusion_functions_window::rank::rank_udwf(),
1604            ),
1605            vec![],
1606        ))
1607        .partition_by(vec![col("a"), col("b")])
1608        .order_by(vec![col("c").sort(true, true)])
1609        .build()
1610        .unwrap();
1611
1612        let plan = LogicalPlanBuilder::from(table_scan)
1613            .window(vec![window])?
1614            .filter(col("b").gt(lit(10i64)))?
1615            .build()?;
1616
1617        assert_optimized_plan_equal!(
1618            plan,
1619            @r"
1620        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1621          TableScan: test, full_filters=[test.b > Int64(10)]
1622        "
1623        )
1624    }
1625
1626    /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and
1627    /// 'b' are pushed
1628    #[test]
1629    fn filter_move_complex_window() -> Result<()> {
1630        let table_scan = test_table_scan()?;
1631
1632        let window = Expr::from(WindowFunction::new(
1633            WindowFunctionDefinition::WindowUDF(
1634                datafusion_functions_window::rank::rank_udwf(),
1635            ),
1636            vec![],
1637        ))
1638        .partition_by(vec![col("a"), col("b")])
1639        .order_by(vec![col("c").sort(true, true)])
1640        .build()
1641        .unwrap();
1642
1643        let plan = LogicalPlanBuilder::from(table_scan)
1644            .window(vec![window])?
1645            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1646            .build()?;
1647
1648        assert_optimized_plan_equal!(
1649            plan,
1650            @r"
1651        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1652          TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1653        "
1654        )
1655    }
1656
1657    /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed
1658    #[test]
1659    fn filter_move_partial_window() -> Result<()> {
1660        let table_scan = test_table_scan()?;
1661
1662        let window = Expr::from(WindowFunction::new(
1663            WindowFunctionDefinition::WindowUDF(
1664                datafusion_functions_window::rank::rank_udwf(),
1665            ),
1666            vec![],
1667        ))
1668        .partition_by(vec![col("a")])
1669        .order_by(vec![col("c").sort(true, true)])
1670        .build()
1671        .unwrap();
1672
1673        let plan = LogicalPlanBuilder::from(table_scan)
1674            .window(vec![window])?
1675            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1676            .build()?;
1677
1678        assert_optimized_plan_equal!(
1679            plan,
1680            @r"
1681        Filter: test.b = Int64(1)
1682          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1683            TableScan: test, full_filters=[test.a > Int64(10)]
1684        "
1685        )
1686    }
1687
1688    /// verifies that filters on partition expressions are not pushed, as the single expression
1689    /// column is not available to the user, unlike with aggregations
1690    #[test]
1691    fn filter_expression_keep_window() -> Result<()> {
1692        let table_scan = test_table_scan()?;
1693
1694        let window = Expr::from(WindowFunction::new(
1695            WindowFunctionDefinition::WindowUDF(
1696                datafusion_functions_window::rank::rank_udwf(),
1697            ),
1698            vec![],
1699        ))
1700        .partition_by(vec![add(col("a"), col("b"))]) // PARTITION BY a + b
1701        .order_by(vec![col("c").sort(true, true)])
1702        .build()
1703        .unwrap();
1704
1705        let plan = LogicalPlanBuilder::from(table_scan)
1706            .window(vec![window])?
1707            // unlike with aggregations, single partition column "test.a + test.b" is not available
1708            // to the plan, so we use multiple columns when filtering
1709            .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1710            .build()?;
1711
1712        assert_optimized_plan_equal!(
1713            plan,
1714            @r"
1715        Filter: test.a + test.b > Int64(10)
1716          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1717            TableScan: test
1718        "
1719        )
1720    }
1721
1722    /// verifies that filters are not pushed on order by columns (that are not used in partitioning)
1723    #[test]
1724    fn filter_order_keep_window() -> Result<()> {
1725        let table_scan = test_table_scan()?;
1726
1727        let window = Expr::from(WindowFunction::new(
1728            WindowFunctionDefinition::WindowUDF(
1729                datafusion_functions_window::rank::rank_udwf(),
1730            ),
1731            vec![],
1732        ))
1733        .partition_by(vec![col("a")])
1734        .order_by(vec![col("c").sort(true, true)])
1735        .build()
1736        .unwrap();
1737
1738        let plan = LogicalPlanBuilder::from(table_scan)
1739            .window(vec![window])?
1740            .filter(col("c").gt(lit(10i64)))?
1741            .build()?;
1742
1743        assert_optimized_plan_equal!(
1744            plan,
1745            @r"
1746        Filter: test.c > Int64(10)
1747          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1748            TableScan: test
1749        "
1750        )
1751    }
1752
1753    /// verifies that when we use multiple window functions with a common partition key, the filter
1754    /// on that key is pushed
1755    #[test]
1756    fn filter_multiple_windows_common_partitions() -> Result<()> {
1757        let table_scan = test_table_scan()?;
1758
1759        let window1 = Expr::from(WindowFunction::new(
1760            WindowFunctionDefinition::WindowUDF(
1761                datafusion_functions_window::rank::rank_udwf(),
1762            ),
1763            vec![],
1764        ))
1765        .partition_by(vec![col("a")])
1766        .order_by(vec![col("c").sort(true, true)])
1767        .build()
1768        .unwrap();
1769
1770        let window2 = Expr::from(WindowFunction::new(
1771            WindowFunctionDefinition::WindowUDF(
1772                datafusion_functions_window::rank::rank_udwf(),
1773            ),
1774            vec![],
1775        ))
1776        .partition_by(vec![col("b"), col("a")])
1777        .order_by(vec![col("c").sort(true, true)])
1778        .build()
1779        .unwrap();
1780
1781        let plan = LogicalPlanBuilder::from(table_scan)
1782            .window(vec![window1, window2])?
1783            .filter(col("a").gt(lit(10i64)))? // a appears in both window functions
1784            .build()?;
1785
1786        assert_optimized_plan_equal!(
1787            plan,
1788            @r"
1789        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1790          TableScan: test, full_filters=[test.a > Int64(10)]
1791        "
1792        )
1793    }
1794
1795    /// verifies that when we use multiple window functions with different partitions keys, the
1796    /// filter cannot be pushed
1797    #[test]
1798    fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1799        let table_scan = test_table_scan()?;
1800
1801        let window1 = Expr::from(WindowFunction::new(
1802            WindowFunctionDefinition::WindowUDF(
1803                datafusion_functions_window::rank::rank_udwf(),
1804            ),
1805            vec![],
1806        ))
1807        .partition_by(vec![col("a")])
1808        .order_by(vec![col("c").sort(true, true)])
1809        .build()
1810        .unwrap();
1811
1812        let window2 = Expr::from(WindowFunction::new(
1813            WindowFunctionDefinition::WindowUDF(
1814                datafusion_functions_window::rank::rank_udwf(),
1815            ),
1816            vec![],
1817        ))
1818        .partition_by(vec![col("b"), col("a")])
1819        .order_by(vec![col("c").sort(true, true)])
1820        .build()
1821        .unwrap();
1822
1823        let plan = LogicalPlanBuilder::from(table_scan)
1824            .window(vec![window1, window2])?
1825            .filter(col("b").gt(lit(10i64)))? // b only appears in one window function
1826            .build()?;
1827
1828        assert_optimized_plan_equal!(
1829            plan,
1830            @r"
1831        Filter: test.b > Int64(10)
1832          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1833            TableScan: test
1834        "
1835        )
1836    }
1837
1838    /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
1839    #[test]
1840    fn alias() -> Result<()> {
1841        let table_scan = test_table_scan()?;
1842        let plan = LogicalPlanBuilder::from(table_scan)
1843            .project(vec![col("a").alias("b"), col("c")])?
1844            .filter(col("b").eq(lit(1i64)))?
1845            .build()?;
1846        // filter is before projection
1847        assert_optimized_plan_equal!(
1848            plan,
1849            @r"
1850        Projection: test.a AS b, test.c
1851          TableScan: test, full_filters=[test.a = Int64(1)]
1852        "
1853        )
1854    }
1855
1856    fn add(left: Expr, right: Expr) -> Expr {
1857        Expr::BinaryExpr(BinaryExpr::new(
1858            Box::new(left),
1859            Operator::Plus,
1860            Box::new(right),
1861        ))
1862    }
1863
1864    fn multiply(left: Expr, right: Expr) -> Expr {
1865        Expr::BinaryExpr(BinaryExpr::new(
1866            Box::new(left),
1867            Operator::Multiply,
1868            Box::new(right),
1869        ))
1870    }
1871
1872    /// verifies that a filter is pushed to before a projection with a complex expression, the filter expression is correctly re-written
1873    #[test]
1874    fn complex_expression() -> Result<()> {
1875        let table_scan = test_table_scan()?;
1876        let plan = LogicalPlanBuilder::from(table_scan)
1877            .project(vec![
1878                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1879                col("c"),
1880            ])?
1881            .filter(col("b").eq(lit(1i64)))?
1882            .build()?;
1883
1884        // not part of the test, just good to know:
1885        assert_snapshot!(plan,
1886        @r"
1887        Filter: b = Int64(1)
1888          Projection: test.a * Int32(2) + test.c AS b, test.c
1889            TableScan: test
1890        ",
1891        );
1892        // filter is before projection
1893        assert_optimized_plan_equal!(
1894            plan,
1895            @r"
1896        Projection: test.a * Int32(2) + test.c AS b, test.c
1897          TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1898        "
1899        )
1900    }
1901
1902    /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written
1903    #[test]
1904    fn complex_plan() -> Result<()> {
1905        let table_scan = test_table_scan()?;
1906        let plan = LogicalPlanBuilder::from(table_scan)
1907            .project(vec![
1908                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1909                col("c"),
1910            ])?
1911            // second projection where we rename columns, just to make it difficult
1912            .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
1913            .filter(col("a").eq(lit(1i64)))?
1914            .build()?;
1915
1916        // not part of the test, just good to know:
1917        assert_snapshot!(plan,
1918        @r"
1919        Filter: a = Int64(1)
1920          Projection: b * Int32(3) AS a, test.c
1921            Projection: test.a * Int32(2) + test.c AS b, test.c
1922              TableScan: test
1923        ",
1924        );
1925        // filter is before the projections
1926        assert_optimized_plan_equal!(
1927            plan,
1928            @r"
1929        Projection: b * Int32(3) AS a, test.c
1930          Projection: test.a * Int32(2) + test.c AS b, test.c
1931            TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
1932        "
1933        )
1934    }
1935
1936    #[derive(Debug, PartialEq, Eq, Hash)]
1937    struct NoopPlan {
1938        input: Vec<LogicalPlan>,
1939        schema: DFSchemaRef,
1940    }
1941
1942    // Manual implementation needed because of `schema` field. Comparison excludes this field.
1943    impl PartialOrd for NoopPlan {
1944        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1945            self.input.partial_cmp(&other.input)
1946        }
1947    }
1948
1949    impl UserDefinedLogicalNodeCore for NoopPlan {
1950        fn name(&self) -> &str {
1951            "NoopPlan"
1952        }
1953
1954        fn inputs(&self) -> Vec<&LogicalPlan> {
1955            self.input.iter().collect()
1956        }
1957
1958        fn schema(&self) -> &DFSchemaRef {
1959            &self.schema
1960        }
1961
1962        fn expressions(&self) -> Vec<Expr> {
1963            self.input
1964                .iter()
1965                .flat_map(|child| child.expressions())
1966                .collect()
1967        }
1968
1969        fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
1970            HashSet::from_iter(vec!["c".to_string()])
1971        }
1972
1973        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1974            write!(f, "NoopPlan")
1975        }
1976
1977        fn with_exprs_and_inputs(
1978            &self,
1979            _exprs: Vec<Expr>,
1980            inputs: Vec<LogicalPlan>,
1981        ) -> Result<Self> {
1982            Ok(Self {
1983                input: inputs,
1984                schema: Arc::clone(&self.schema),
1985            })
1986        }
1987
1988        fn supports_limit_pushdown(&self) -> bool {
1989            false // Disallow limit push-down by default
1990        }
1991    }
1992
1993    #[test]
1994    fn user_defined_plan() -> Result<()> {
1995        let table_scan = test_table_scan()?;
1996
1997        let custom_plan = LogicalPlan::Extension(Extension {
1998            node: Arc::new(NoopPlan {
1999                input: vec![table_scan.clone()],
2000                schema: Arc::clone(table_scan.schema()),
2001            }),
2002        });
2003        let plan = LogicalPlanBuilder::from(custom_plan)
2004            .filter(col("a").eq(lit(1i64)))?
2005            .build()?;
2006
2007        // Push filter below NoopPlan
2008        assert_optimized_plan_equal!(
2009            plan,
2010            @r"
2011        NoopPlan
2012          TableScan: test, full_filters=[test.a = Int64(1)]
2013        "
2014        )?;
2015
2016        let custom_plan = LogicalPlan::Extension(Extension {
2017            node: Arc::new(NoopPlan {
2018                input: vec![table_scan.clone()],
2019                schema: Arc::clone(table_scan.schema()),
2020            }),
2021        });
2022        let plan = LogicalPlanBuilder::from(custom_plan)
2023            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2024            .build()?;
2025
2026        // Push only predicate on `a` below NoopPlan
2027        assert_optimized_plan_equal!(
2028            plan,
2029            @r"
2030        Filter: test.c = Int64(2)
2031          NoopPlan
2032            TableScan: test, full_filters=[test.a = Int64(1)]
2033        "
2034        )?;
2035
2036        let custom_plan = LogicalPlan::Extension(Extension {
2037            node: Arc::new(NoopPlan {
2038                input: vec![table_scan.clone(), table_scan.clone()],
2039                schema: Arc::clone(table_scan.schema()),
2040            }),
2041        });
2042        let plan = LogicalPlanBuilder::from(custom_plan)
2043            .filter(col("a").eq(lit(1i64)))?
2044            .build()?;
2045
2046        // Push filter below NoopPlan for each child branch
2047        assert_optimized_plan_equal!(
2048            plan,
2049            @r"
2050        NoopPlan
2051          TableScan: test, full_filters=[test.a = Int64(1)]
2052          TableScan: test, full_filters=[test.a = Int64(1)]
2053        "
2054        )?;
2055
2056        let custom_plan = LogicalPlan::Extension(Extension {
2057            node: Arc::new(NoopPlan {
2058                input: vec![table_scan.clone(), table_scan.clone()],
2059                schema: Arc::clone(table_scan.schema()),
2060            }),
2061        });
2062        let plan = LogicalPlanBuilder::from(custom_plan)
2063            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2064            .build()?;
2065
2066        // Push only predicate on `a` below NoopPlan
2067        assert_optimized_plan_equal!(
2068            plan,
2069            @r"
2070        Filter: test.c = Int64(2)
2071          NoopPlan
2072            TableScan: test, full_filters=[test.a = Int64(1)]
2073            TableScan: test, full_filters=[test.a = Int64(1)]
2074        "
2075        )
2076    }
2077
2078    /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed
2079    /// and the other not.
2080    #[test]
2081    fn multi_filter() -> Result<()> {
2082        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2083        let table_scan = test_table_scan()?;
2084        let plan = LogicalPlanBuilder::from(table_scan)
2085            .project(vec![col("a").alias("b"), col("c")])?
2086            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2087            .filter(col("b").gt(lit(10i64)))?
2088            .filter(col("sum(test.c)").gt(lit(10i64)))?
2089            .build()?;
2090
2091        // not part of the test, just good to know:
2092        assert_snapshot!(plan,
2093        @r"
2094        Filter: sum(test.c) > Int64(10)
2095          Filter: b > Int64(10)
2096            Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2097              Projection: test.a AS b, test.c
2098                TableScan: test
2099        ",
2100        );
2101        // filter is before the projections
2102        assert_optimized_plan_equal!(
2103            plan,
2104            @r"
2105        Filter: sum(test.c) > Int64(10)
2106          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2107            Projection: test.a AS b, test.c
2108              TableScan: test, full_filters=[test.a > Int64(10)]
2109        "
2110        )
2111    }
2112
2113    /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed
2114    /// and the other not.
2115    #[test]
2116    fn split_filter() -> Result<()> {
2117        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2118        let table_scan = test_table_scan()?;
2119        let plan = LogicalPlanBuilder::from(table_scan)
2120            .project(vec![col("a").alias("b"), col("c")])?
2121            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2122            .filter(and(
2123                col("sum(test.c)").gt(lit(10i64)),
2124                and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2125            ))?
2126            .build()?;
2127
2128        // not part of the test, just good to know:
2129        assert_snapshot!(plan,
2130        @r"
2131        Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2132          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2133            Projection: test.a AS b, test.c
2134              TableScan: test
2135        ",
2136        );
2137        // filter is before the projections
2138        assert_optimized_plan_equal!(
2139            plan,
2140            @r"
2141        Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2142          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2143            Projection: test.a AS b, test.c
2144              TableScan: test, full_filters=[test.a > Int64(10)]
2145        "
2146        )
2147    }
2148
2149    /// verifies that when two limits are in place, we jump neither
2150    #[test]
2151    fn double_limit() -> Result<()> {
2152        let table_scan = test_table_scan()?;
2153        let plan = LogicalPlanBuilder::from(table_scan)
2154            .project(vec![col("a"), col("b")])?
2155            .limit(0, Some(20))?
2156            .limit(0, Some(10))?
2157            .project(vec![col("a"), col("b")])?
2158            .filter(col("a").eq(lit(1i64)))?
2159            .build()?;
2160        // filter does not just any of the limits
2161        assert_optimized_plan_equal!(
2162            plan,
2163            @r"
2164        Projection: test.a, test.b
2165          Filter: test.a = Int64(1)
2166            Limit: skip=0, fetch=10
2167              Limit: skip=0, fetch=20
2168                Projection: test.a, test.b
2169                  TableScan: test
2170        "
2171        )
2172    }
2173
2174    #[test]
2175    fn union_all() -> Result<()> {
2176        let table_scan = test_table_scan()?;
2177        let table_scan2 = test_table_scan_with_name("test2")?;
2178        let plan = LogicalPlanBuilder::from(table_scan)
2179            .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2180            .filter(col("a").eq(lit(1i64)))?
2181            .build()?;
2182        // filter appears below Union
2183        assert_optimized_plan_equal!(
2184            plan,
2185            @r"
2186        Union
2187          TableScan: test, full_filters=[test.a = Int64(1)]
2188          TableScan: test2, full_filters=[test2.a = Int64(1)]
2189        "
2190        )
2191    }
2192
2193    #[test]
2194    fn union_all_on_projection() -> Result<()> {
2195        let table_scan = test_table_scan()?;
2196        let table = LogicalPlanBuilder::from(table_scan)
2197            .project(vec![col("a").alias("b")])?
2198            .alias("test2")?;
2199
2200        let plan = table
2201            .clone()
2202            .union(table.build()?)?
2203            .filter(col("b").eq(lit(1i64)))?
2204            .build()?;
2205
2206        // filter appears below Union
2207        assert_optimized_plan_equal!(
2208            plan,
2209            @r"
2210        Union
2211          SubqueryAlias: test2
2212            Projection: test.a AS b
2213              TableScan: test, full_filters=[test.a = Int64(1)]
2214          SubqueryAlias: test2
2215            Projection: test.a AS b
2216              TableScan: test, full_filters=[test.a = Int64(1)]
2217        "
2218        )
2219    }
2220
2221    #[test]
2222    fn test_union_different_schema() -> Result<()> {
2223        let left = LogicalPlanBuilder::from(test_table_scan()?)
2224            .project(vec![col("a"), col("b"), col("c")])?
2225            .build()?;
2226
2227        let schema = Schema::new(vec![
2228            Field::new("d", DataType::UInt32, false),
2229            Field::new("e", DataType::UInt32, false),
2230            Field::new("f", DataType::UInt32, false),
2231        ]);
2232        let right = table_scan(Some("test1"), &schema, None)?
2233            .project(vec![col("d"), col("e"), col("f")])?
2234            .build()?;
2235        let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2236        let plan = LogicalPlanBuilder::from(left)
2237            .cross_join(right)?
2238            .project(vec![col("test.a"), col("test1.d")])?
2239            .filter(filter)?
2240            .build()?;
2241
2242        assert_optimized_plan_equal!(
2243            plan,
2244            @r"
2245        Projection: test.a, test1.d
2246          Cross Join: 
2247            Projection: test.a, test.b, test.c
2248              TableScan: test, full_filters=[test.a = Int32(1)]
2249            Projection: test1.d, test1.e, test1.f
2250              TableScan: test1, full_filters=[test1.d > Int32(2)]
2251        "
2252        )
2253    }
2254
2255    #[test]
2256    fn test_project_same_name_different_qualifier() -> Result<()> {
2257        let table_scan = test_table_scan()?;
2258        let left = LogicalPlanBuilder::from(table_scan)
2259            .project(vec![col("a"), col("b"), col("c")])?
2260            .build()?;
2261        let right_table_scan = test_table_scan_with_name("test1")?;
2262        let right = LogicalPlanBuilder::from(right_table_scan)
2263            .project(vec![col("a"), col("b"), col("c")])?
2264            .build()?;
2265        let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2266        let plan = LogicalPlanBuilder::from(left)
2267            .cross_join(right)?
2268            .project(vec![col("test.a"), col("test1.a")])?
2269            .filter(filter)?
2270            .build()?;
2271
2272        assert_optimized_plan_equal!(
2273            plan,
2274            @r"
2275        Projection: test.a, test1.a
2276          Cross Join: 
2277            Projection: test.a, test.b, test.c
2278              TableScan: test, full_filters=[test.a = Int32(1)]
2279            Projection: test1.a, test1.b, test1.c
2280              TableScan: test1, full_filters=[test1.a > Int32(2)]
2281        "
2282        )
2283    }
2284
2285    /// verifies that filters with the same columns are correctly placed
2286    #[test]
2287    fn filter_2_breaks_limits() -> Result<()> {
2288        let table_scan = test_table_scan()?;
2289        let plan = LogicalPlanBuilder::from(table_scan)
2290            .project(vec![col("a")])?
2291            .filter(col("a").lt_eq(lit(1i64)))?
2292            .limit(0, Some(1))?
2293            .project(vec![col("a")])?
2294            .filter(col("a").gt_eq(lit(1i64)))?
2295            .build()?;
2296        // Should be able to move both filters below the projections
2297
2298        // not part of the test
2299        assert_snapshot!(plan,
2300        @r"
2301        Filter: test.a >= Int64(1)
2302          Projection: test.a
2303            Limit: skip=0, fetch=1
2304              Filter: test.a <= Int64(1)
2305                Projection: test.a
2306                  TableScan: test
2307        ",
2308        );
2309        assert_optimized_plan_equal!(
2310            plan,
2311            @r"
2312        Projection: test.a
2313          Filter: test.a >= Int64(1)
2314            Limit: skip=0, fetch=1
2315              Projection: test.a
2316                TableScan: test, full_filters=[test.a <= Int64(1)]
2317        "
2318        )
2319    }
2320
2321    /// verifies that filters to be placed on the same depth are ANDed
2322    #[test]
2323    fn two_filters_on_same_depth() -> Result<()> {
2324        let table_scan = test_table_scan()?;
2325        let plan = LogicalPlanBuilder::from(table_scan)
2326            .limit(0, Some(1))?
2327            .filter(col("a").lt_eq(lit(1i64)))?
2328            .filter(col("a").gt_eq(lit(1i64)))?
2329            .project(vec![col("a")])?
2330            .build()?;
2331
2332        // not part of the test
2333        assert_snapshot!(plan,
2334        @r"
2335        Projection: test.a
2336          Filter: test.a >= Int64(1)
2337            Filter: test.a <= Int64(1)
2338              Limit: skip=0, fetch=1
2339                TableScan: test
2340        ",
2341        );
2342        assert_optimized_plan_equal!(
2343            plan,
2344            @r"
2345        Projection: test.a
2346          Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2347            Limit: skip=0, fetch=1
2348              TableScan: test
2349        "
2350        )
2351    }
2352
2353    /// verifies that filters on a plan with user nodes are not lost
2354    /// (ARROW-10547)
2355    #[test]
2356    fn filters_user_defined_node() -> Result<()> {
2357        let table_scan = test_table_scan()?;
2358        let plan = LogicalPlanBuilder::from(table_scan)
2359            .filter(col("a").lt_eq(lit(1i64)))?
2360            .build()?;
2361
2362        let plan = user_defined::new(plan);
2363
2364        // not part of the test
2365        assert_snapshot!(plan,
2366        @r"
2367        TestUserDefined
2368          Filter: test.a <= Int64(1)
2369            TableScan: test
2370        ",
2371        );
2372        assert_optimized_plan_equal!(
2373            plan,
2374            @r"
2375        TestUserDefined
2376          TableScan: test, full_filters=[test.a <= Int64(1)]
2377        "
2378        )
2379    }
2380
2381    /// post-on-join predicates on a column common to both sides is pushed to both sides
2382    #[test]
2383    fn filter_on_join_on_common_independent() -> Result<()> {
2384        let table_scan = test_table_scan()?;
2385        let left = LogicalPlanBuilder::from(table_scan).build()?;
2386        let right_table_scan = test_table_scan_with_name("test2")?;
2387        let right = LogicalPlanBuilder::from(right_table_scan)
2388            .project(vec![col("a")])?
2389            .build()?;
2390        let plan = LogicalPlanBuilder::from(left)
2391            .join(
2392                right,
2393                JoinType::Inner,
2394                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2395                None,
2396            )?
2397            .filter(col("test.a").lt_eq(lit(1i64)))?
2398            .build()?;
2399
2400        // not part of the test, just good to know:
2401        assert_snapshot!(plan,
2402        @r"
2403        Filter: test.a <= Int64(1)
2404          Inner Join: test.a = test2.a
2405            TableScan: test
2406            Projection: test2.a
2407              TableScan: test2
2408        ",
2409        );
2410        // filter sent to side before the join
2411        assert_optimized_plan_equal!(
2412            plan,
2413            @r"
2414        Inner Join: test.a = test2.a
2415          TableScan: test, full_filters=[test.a <= Int64(1)]
2416          Projection: test2.a
2417            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2418        "
2419        )
2420    }
2421
2422    /// post-using-join predicates on a column common to both sides is pushed to both sides
2423    #[test]
2424    fn filter_using_join_on_common_independent() -> Result<()> {
2425        let table_scan = test_table_scan()?;
2426        let left = LogicalPlanBuilder::from(table_scan).build()?;
2427        let right_table_scan = test_table_scan_with_name("test2")?;
2428        let right = LogicalPlanBuilder::from(right_table_scan)
2429            .project(vec![col("a")])?
2430            .build()?;
2431        let plan = LogicalPlanBuilder::from(left)
2432            .join_using(
2433                right,
2434                JoinType::Inner,
2435                vec![Column::from_name("a".to_string())],
2436            )?
2437            .filter(col("a").lt_eq(lit(1i64)))?
2438            .build()?;
2439
2440        // not part of the test, just good to know:
2441        assert_snapshot!(plan,
2442        @r"
2443        Filter: test.a <= Int64(1)
2444          Inner Join: Using test.a = test2.a
2445            TableScan: test
2446            Projection: test2.a
2447              TableScan: test2
2448        ",
2449        );
2450        // filter sent to side before the join
2451        assert_optimized_plan_equal!(
2452            plan,
2453            @r"
2454        Inner Join: Using test.a = test2.a
2455          TableScan: test, full_filters=[test.a <= Int64(1)]
2456          Projection: test2.a
2457            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2458        "
2459        )
2460    }
2461
2462    /// post-join predicates with columns from both sides are converted to join filters
2463    #[test]
2464    fn filter_join_on_common_dependent() -> Result<()> {
2465        let table_scan = test_table_scan()?;
2466        let left = LogicalPlanBuilder::from(table_scan)
2467            .project(vec![col("a"), col("c")])?
2468            .build()?;
2469        let right_table_scan = test_table_scan_with_name("test2")?;
2470        let right = LogicalPlanBuilder::from(right_table_scan)
2471            .project(vec![col("a"), col("b")])?
2472            .build()?;
2473        let plan = LogicalPlanBuilder::from(left)
2474            .join(
2475                right,
2476                JoinType::Inner,
2477                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2478                None,
2479            )?
2480            .filter(col("c").lt_eq(col("b")))?
2481            .build()?;
2482
2483        // not part of the test, just good to know:
2484        assert_snapshot!(plan,
2485        @r"
2486        Filter: test.c <= test2.b
2487          Inner Join: test.a = test2.a
2488            Projection: test.a, test.c
2489              TableScan: test
2490            Projection: test2.a, test2.b
2491              TableScan: test2
2492        ",
2493        );
2494        // Filter is converted to Join Filter
2495        assert_optimized_plan_equal!(
2496            plan,
2497            @r"
2498        Inner Join: test.a = test2.a Filter: test.c <= test2.b
2499          Projection: test.a, test.c
2500            TableScan: test
2501          Projection: test2.a, test2.b
2502            TableScan: test2
2503        "
2504        )
2505    }
2506
2507    /// post-join predicates with columns from one side of a join are pushed only to that side
2508    #[test]
2509    fn filter_join_on_one_side() -> Result<()> {
2510        let table_scan = test_table_scan()?;
2511        let left = LogicalPlanBuilder::from(table_scan)
2512            .project(vec![col("a"), col("b")])?
2513            .build()?;
2514        let table_scan_right = test_table_scan_with_name("test2")?;
2515        let right = LogicalPlanBuilder::from(table_scan_right)
2516            .project(vec![col("a"), col("c")])?
2517            .build()?;
2518
2519        let plan = LogicalPlanBuilder::from(left)
2520            .join(
2521                right,
2522                JoinType::Inner,
2523                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2524                None,
2525            )?
2526            .filter(col("b").lt_eq(lit(1i64)))?
2527            .build()?;
2528
2529        // not part of the test, just good to know:
2530        assert_snapshot!(plan,
2531        @r"
2532        Filter: test.b <= Int64(1)
2533          Inner Join: test.a = test2.a
2534            Projection: test.a, test.b
2535              TableScan: test
2536            Projection: test2.a, test2.c
2537              TableScan: test2
2538        ",
2539        );
2540        assert_optimized_plan_equal!(
2541            plan,
2542            @r"
2543        Inner Join: test.a = test2.a
2544          Projection: test.a, test.b
2545            TableScan: test, full_filters=[test.b <= Int64(1)]
2546          Projection: test2.a, test2.c
2547            TableScan: test2
2548        "
2549        )
2550    }
2551
2552    /// post-join predicates on the right side of a left join are not duplicated
2553    /// TODO: In this case we can sometimes convert the join to an INNER join
2554    #[test]
2555    fn filter_using_left_join() -> Result<()> {
2556        let table_scan = test_table_scan()?;
2557        let left = LogicalPlanBuilder::from(table_scan).build()?;
2558        let right_table_scan = test_table_scan_with_name("test2")?;
2559        let right = LogicalPlanBuilder::from(right_table_scan)
2560            .project(vec![col("a")])?
2561            .build()?;
2562        let plan = LogicalPlanBuilder::from(left)
2563            .join_using(
2564                right,
2565                JoinType::Left,
2566                vec![Column::from_name("a".to_string())],
2567            )?
2568            .filter(col("test2.a").lt_eq(lit(1i64)))?
2569            .build()?;
2570
2571        // not part of the test, just good to know:
2572        assert_snapshot!(plan,
2573        @r"
2574        Filter: test2.a <= Int64(1)
2575          Left Join: Using test.a = test2.a
2576            TableScan: test
2577            Projection: test2.a
2578              TableScan: test2
2579        ",
2580        );
2581        // filter not duplicated nor pushed down - i.e. noop
2582        assert_optimized_plan_equal!(
2583            plan,
2584            @r"
2585        Filter: test2.a <= Int64(1)
2586          Left Join: Using test.a = test2.a
2587            TableScan: test, full_filters=[test.a <= Int64(1)]
2588            Projection: test2.a
2589              TableScan: test2
2590        "
2591        )
2592    }
2593
2594    /// post-join predicates on the left side of a right join are not duplicated
2595    #[test]
2596    fn filter_using_right_join() -> Result<()> {
2597        let table_scan = test_table_scan()?;
2598        let left = LogicalPlanBuilder::from(table_scan).build()?;
2599        let right_table_scan = test_table_scan_with_name("test2")?;
2600        let right = LogicalPlanBuilder::from(right_table_scan)
2601            .project(vec![col("a")])?
2602            .build()?;
2603        let plan = LogicalPlanBuilder::from(left)
2604            .join_using(
2605                right,
2606                JoinType::Right,
2607                vec![Column::from_name("a".to_string())],
2608            )?
2609            .filter(col("test.a").lt_eq(lit(1i64)))?
2610            .build()?;
2611
2612        // not part of the test, just good to know:
2613        assert_snapshot!(plan,
2614        @r"
2615        Filter: test.a <= Int64(1)
2616          Right Join: Using test.a = test2.a
2617            TableScan: test
2618            Projection: test2.a
2619              TableScan: test2
2620        ",
2621        );
2622        // filter not duplicated nor pushed down - i.e. noop
2623        assert_optimized_plan_equal!(
2624            plan,
2625            @r"
2626        Filter: test.a <= Int64(1)
2627          Right Join: Using test.a = test2.a
2628            TableScan: test
2629            Projection: test2.a
2630              TableScan: test2, full_filters=[test2.a <= Int64(1)]
2631        "
2632        )
2633    }
2634
2635    /// post-left-join predicate on a column common to both sides is only pushed to the left side
2636    /// i.e. - not duplicated to the right side
2637    #[test]
2638    fn filter_using_left_join_on_common() -> Result<()> {
2639        let table_scan = test_table_scan()?;
2640        let left = LogicalPlanBuilder::from(table_scan).build()?;
2641        let right_table_scan = test_table_scan_with_name("test2")?;
2642        let right = LogicalPlanBuilder::from(right_table_scan)
2643            .project(vec![col("a")])?
2644            .build()?;
2645        let plan = LogicalPlanBuilder::from(left)
2646            .join_using(
2647                right,
2648                JoinType::Left,
2649                vec![Column::from_name("a".to_string())],
2650            )?
2651            .filter(col("a").lt_eq(lit(1i64)))?
2652            .build()?;
2653
2654        // not part of the test, just good to know:
2655        assert_snapshot!(plan,
2656        @r"
2657        Filter: test.a <= Int64(1)
2658          Left Join: Using test.a = test2.a
2659            TableScan: test
2660            Projection: test2.a
2661              TableScan: test2
2662        ",
2663        );
2664        // filter sent to left side of the join, not the right
2665        assert_optimized_plan_equal!(
2666            plan,
2667            @r"
2668        Left Join: Using test.a = test2.a
2669          TableScan: test, full_filters=[test.a <= Int64(1)]
2670          Projection: test2.a
2671            TableScan: test2
2672        "
2673        )
2674    }
2675
2676    /// post-right-join predicate on a column common to both sides is only pushed to the right side
2677    /// i.e. - not duplicated to the left side.
2678    #[test]
2679    fn filter_using_right_join_on_common() -> Result<()> {
2680        let table_scan = test_table_scan()?;
2681        let left = LogicalPlanBuilder::from(table_scan).build()?;
2682        let right_table_scan = test_table_scan_with_name("test2")?;
2683        let right = LogicalPlanBuilder::from(right_table_scan)
2684            .project(vec![col("a")])?
2685            .build()?;
2686        let plan = LogicalPlanBuilder::from(left)
2687            .join_using(
2688                right,
2689                JoinType::Right,
2690                vec![Column::from_name("a".to_string())],
2691            )?
2692            .filter(col("test2.a").lt_eq(lit(1i64)))?
2693            .build()?;
2694
2695        // not part of the test, just good to know:
2696        assert_snapshot!(plan,
2697        @r"
2698        Filter: test2.a <= Int64(1)
2699          Right Join: Using test.a = test2.a
2700            TableScan: test
2701            Projection: test2.a
2702              TableScan: test2
2703        ",
2704        );
2705        // filter sent to right side of join, not duplicated to the left
2706        assert_optimized_plan_equal!(
2707            plan,
2708            @r"
2709        Right Join: Using test.a = test2.a
2710          TableScan: test
2711          Projection: test2.a
2712            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2713        "
2714        )
2715    }
2716
2717    /// single table predicate parts of ON condition should be pushed to both inputs
2718    #[test]
2719    fn join_on_with_filter() -> Result<()> {
2720        let table_scan = test_table_scan()?;
2721        let left = LogicalPlanBuilder::from(table_scan)
2722            .project(vec![col("a"), col("b"), col("c")])?
2723            .build()?;
2724        let right_table_scan = test_table_scan_with_name("test2")?;
2725        let right = LogicalPlanBuilder::from(right_table_scan)
2726            .project(vec![col("a"), col("b"), col("c")])?
2727            .build()?;
2728        let filter = col("test.c")
2729            .gt(lit(1u32))
2730            .and(col("test.b").lt(col("test2.b")))
2731            .and(col("test2.c").gt(lit(4u32)));
2732        let plan = LogicalPlanBuilder::from(left)
2733            .join(
2734                right,
2735                JoinType::Inner,
2736                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2737                Some(filter),
2738            )?
2739            .build()?;
2740
2741        // not part of the test, just good to know:
2742        assert_snapshot!(plan,
2743        @r"
2744        Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2745          Projection: test.a, test.b, test.c
2746            TableScan: test
2747          Projection: test2.a, test2.b, test2.c
2748            TableScan: test2
2749        ",
2750        );
2751        assert_optimized_plan_equal!(
2752            plan,
2753            @r"
2754        Inner Join: test.a = test2.a Filter: test.b < test2.b
2755          Projection: test.a, test.b, test.c
2756            TableScan: test, full_filters=[test.c > UInt32(1)]
2757          Projection: test2.a, test2.b, test2.c
2758            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2759        "
2760        )
2761    }
2762
2763    /// join filter should be completely removed after pushdown
2764    #[test]
2765    fn join_filter_removed() -> Result<()> {
2766        let table_scan = test_table_scan()?;
2767        let left = LogicalPlanBuilder::from(table_scan)
2768            .project(vec![col("a"), col("b"), col("c")])?
2769            .build()?;
2770        let right_table_scan = test_table_scan_with_name("test2")?;
2771        let right = LogicalPlanBuilder::from(right_table_scan)
2772            .project(vec![col("a"), col("b"), col("c")])?
2773            .build()?;
2774        let filter = col("test.b")
2775            .gt(lit(1u32))
2776            .and(col("test2.c").gt(lit(4u32)));
2777        let plan = LogicalPlanBuilder::from(left)
2778            .join(
2779                right,
2780                JoinType::Inner,
2781                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2782                Some(filter),
2783            )?
2784            .build()?;
2785
2786        // not part of the test, just good to know:
2787        assert_snapshot!(plan,
2788        @r"
2789        Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2790          Projection: test.a, test.b, test.c
2791            TableScan: test
2792          Projection: test2.a, test2.b, test2.c
2793            TableScan: test2
2794        ",
2795        );
2796        assert_optimized_plan_equal!(
2797            plan,
2798            @r"
2799        Inner Join: test.a = test2.a
2800          Projection: test.a, test.b, test.c
2801            TableScan: test, full_filters=[test.b > UInt32(1)]
2802          Projection: test2.a, test2.b, test2.c
2803            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2804        "
2805        )
2806    }
2807
2808    /// predicate on join key in filter expression should be pushed down to both inputs
2809    #[test]
2810    fn join_filter_on_common() -> Result<()> {
2811        let table_scan = test_table_scan()?;
2812        let left = LogicalPlanBuilder::from(table_scan)
2813            .project(vec![col("a")])?
2814            .build()?;
2815        let right_table_scan = test_table_scan_with_name("test2")?;
2816        let right = LogicalPlanBuilder::from(right_table_scan)
2817            .project(vec![col("b")])?
2818            .build()?;
2819        let filter = col("test.a").gt(lit(1u32));
2820        let plan = LogicalPlanBuilder::from(left)
2821            .join(
2822                right,
2823                JoinType::Inner,
2824                (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2825                Some(filter),
2826            )?
2827            .build()?;
2828
2829        // not part of the test, just good to know:
2830        assert_snapshot!(plan,
2831        @r"
2832        Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2833          Projection: test.a
2834            TableScan: test
2835          Projection: test2.b
2836            TableScan: test2
2837        ",
2838        );
2839        assert_optimized_plan_equal!(
2840            plan,
2841            @r"
2842        Inner Join: test.a = test2.b
2843          Projection: test.a
2844            TableScan: test, full_filters=[test.a > UInt32(1)]
2845          Projection: test2.b
2846            TableScan: test2, full_filters=[test2.b > UInt32(1)]
2847        "
2848        )
2849    }
2850
2851    /// single table predicate parts of ON condition should be pushed to right input
2852    #[test]
2853    fn left_join_on_with_filter() -> Result<()> {
2854        let table_scan = test_table_scan()?;
2855        let left = LogicalPlanBuilder::from(table_scan)
2856            .project(vec![col("a"), col("b"), col("c")])?
2857            .build()?;
2858        let right_table_scan = test_table_scan_with_name("test2")?;
2859        let right = LogicalPlanBuilder::from(right_table_scan)
2860            .project(vec![col("a"), col("b"), col("c")])?
2861            .build()?;
2862        let filter = col("test.a")
2863            .gt(lit(1u32))
2864            .and(col("test.b").lt(col("test2.b")))
2865            .and(col("test2.c").gt(lit(4u32)));
2866        let plan = LogicalPlanBuilder::from(left)
2867            .join(
2868                right,
2869                JoinType::Left,
2870                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2871                Some(filter),
2872            )?
2873            .build()?;
2874
2875        // not part of the test, just good to know:
2876        assert_snapshot!(plan,
2877        @r"
2878        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2879          Projection: test.a, test.b, test.c
2880            TableScan: test
2881          Projection: test2.a, test2.b, test2.c
2882            TableScan: test2
2883        ",
2884        );
2885        assert_optimized_plan_equal!(
2886            plan,
2887            @r"
2888        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2889          Projection: test.a, test.b, test.c
2890            TableScan: test
2891          Projection: test2.a, test2.b, test2.c
2892            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2893        "
2894        )
2895    }
2896
2897    /// single table predicate parts of ON condition should be pushed to left input
2898    #[test]
2899    fn right_join_on_with_filter() -> Result<()> {
2900        let table_scan = test_table_scan()?;
2901        let left = LogicalPlanBuilder::from(table_scan)
2902            .project(vec![col("a"), col("b"), col("c")])?
2903            .build()?;
2904        let right_table_scan = test_table_scan_with_name("test2")?;
2905        let right = LogicalPlanBuilder::from(right_table_scan)
2906            .project(vec![col("a"), col("b"), col("c")])?
2907            .build()?;
2908        let filter = col("test.a")
2909            .gt(lit(1u32))
2910            .and(col("test.b").lt(col("test2.b")))
2911            .and(col("test2.c").gt(lit(4u32)));
2912        let plan = LogicalPlanBuilder::from(left)
2913            .join(
2914                right,
2915                JoinType::Right,
2916                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2917                Some(filter),
2918            )?
2919            .build()?;
2920
2921        // not part of the test, just good to know:
2922        assert_snapshot!(plan,
2923        @r"
2924        Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2925          Projection: test.a, test.b, test.c
2926            TableScan: test
2927          Projection: test2.a, test2.b, test2.c
2928            TableScan: test2
2929        ",
2930        );
2931        assert_optimized_plan_equal!(
2932            plan,
2933            @r"
2934        Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
2935          Projection: test.a, test.b, test.c
2936            TableScan: test, full_filters=[test.a > UInt32(1)]
2937          Projection: test2.a, test2.b, test2.c
2938            TableScan: test2
2939        "
2940        )
2941    }
2942
2943    /// single table predicate parts of ON condition should not be pushed
2944    #[test]
2945    fn full_join_on_with_filter() -> Result<()> {
2946        let table_scan = test_table_scan()?;
2947        let left = LogicalPlanBuilder::from(table_scan)
2948            .project(vec![col("a"), col("b"), col("c")])?
2949            .build()?;
2950        let right_table_scan = test_table_scan_with_name("test2")?;
2951        let right = LogicalPlanBuilder::from(right_table_scan)
2952            .project(vec![col("a"), col("b"), col("c")])?
2953            .build()?;
2954        let filter = col("test.a")
2955            .gt(lit(1u32))
2956            .and(col("test.b").lt(col("test2.b")))
2957            .and(col("test2.c").gt(lit(4u32)));
2958        let plan = LogicalPlanBuilder::from(left)
2959            .join(
2960                right,
2961                JoinType::Full,
2962                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2963                Some(filter),
2964            )?
2965            .build()?;
2966
2967        // not part of the test, just good to know:
2968        assert_snapshot!(plan,
2969        @r"
2970        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2971          Projection: test.a, test.b, test.c
2972            TableScan: test
2973          Projection: test2.a, test2.b, test2.c
2974            TableScan: test2
2975        ",
2976        );
2977        assert_optimized_plan_equal!(
2978            plan,
2979            @r"
2980        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2981          Projection: test.a, test.b, test.c
2982            TableScan: test
2983          Projection: test2.a, test2.b, test2.c
2984            TableScan: test2
2985        "
2986        )
2987    }
2988
2989    struct PushDownProvider {
2990        pub filter_support: TableProviderFilterPushDown,
2991    }
2992
2993    #[async_trait]
2994    impl TableSource for PushDownProvider {
2995        fn schema(&self) -> SchemaRef {
2996            Arc::new(Schema::new(vec![
2997                Field::new("a", DataType::Int32, true),
2998                Field::new("b", DataType::Int32, true),
2999            ]))
3000        }
3001
3002        fn table_type(&self) -> TableType {
3003            TableType::Base
3004        }
3005
3006        fn supports_filters_pushdown(
3007            &self,
3008            filters: &[&Expr],
3009        ) -> Result<Vec<TableProviderFilterPushDown>> {
3010            Ok((0..filters.len())
3011                .map(|_| self.filter_support.clone())
3012                .collect())
3013        }
3014
3015        fn as_any(&self) -> &dyn Any {
3016            self
3017        }
3018    }
3019
3020    fn table_scan_with_pushdown_provider_builder(
3021        filter_support: TableProviderFilterPushDown,
3022        filters: Vec<Expr>,
3023        projection: Option<Vec<usize>>,
3024    ) -> Result<LogicalPlanBuilder> {
3025        let test_provider = PushDownProvider { filter_support };
3026
3027        let table_scan = LogicalPlan::TableScan(TableScan {
3028            table_name: "test".into(),
3029            filters,
3030            projected_schema: Arc::new(DFSchema::try_from(
3031                (*test_provider.schema()).clone(),
3032            )?),
3033            projection,
3034            source: Arc::new(test_provider),
3035            fetch: None,
3036        });
3037
3038        Ok(LogicalPlanBuilder::from(table_scan))
3039    }
3040
3041    fn table_scan_with_pushdown_provider(
3042        filter_support: TableProviderFilterPushDown,
3043    ) -> Result<LogicalPlan> {
3044        table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3045            .filter(col("a").eq(lit(1i64)))?
3046            .build()
3047    }
3048
3049    #[test]
3050    fn filter_with_table_provider_exact() -> Result<()> {
3051        let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3052
3053        assert_optimized_plan_equal!(
3054            plan,
3055            @"TableScan: test, full_filters=[a = Int64(1)]"
3056        )
3057    }
3058
3059    #[test]
3060    fn filter_with_table_provider_inexact() -> Result<()> {
3061        let plan =
3062            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3063
3064        assert_optimized_plan_equal!(
3065            plan,
3066            @r"
3067        Filter: a = Int64(1)
3068          TableScan: test, partial_filters=[a = Int64(1)]
3069        "
3070        )
3071    }
3072
3073    #[test]
3074    fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3075        let plan =
3076            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3077
3078        let optimized_plan = PushDownFilter::new()
3079            .rewrite(plan, &OptimizerContext::new())
3080            .expect("failed to optimize plan")
3081            .data;
3082
3083        // Optimizing the same plan multiple times should produce the same plan
3084        // each time.
3085        assert_optimized_plan_equal!(
3086            optimized_plan,
3087            @r"
3088        Filter: a = Int64(1)
3089          TableScan: test, partial_filters=[a = Int64(1)]
3090        "
3091        )
3092    }
3093
3094    #[test]
3095    fn filter_with_table_provider_unsupported() -> Result<()> {
3096        let plan =
3097            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3098
3099        assert_optimized_plan_equal!(
3100            plan,
3101            @r"
3102        Filter: a = Int64(1)
3103          TableScan: test
3104        "
3105        )
3106    }
3107
3108    #[test]
3109    fn multi_combined_filter() -> Result<()> {
3110        let plan = table_scan_with_pushdown_provider_builder(
3111            TableProviderFilterPushDown::Inexact,
3112            vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3113            Some(vec![0]),
3114        )?
3115        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3116        .project(vec![col("a"), col("b")])?
3117        .build()?;
3118
3119        assert_optimized_plan_equal!(
3120            plan,
3121            @r"
3122        Projection: a, b
3123          Filter: a = Int64(10) AND b > Int64(11)
3124            TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3125        "
3126        )
3127    }
3128
3129    #[test]
3130    fn multi_combined_filter_exact() -> Result<()> {
3131        let plan = table_scan_with_pushdown_provider_builder(
3132            TableProviderFilterPushDown::Exact,
3133            vec![],
3134            Some(vec![0]),
3135        )?
3136        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3137        .project(vec![col("a"), col("b")])?
3138        .build()?;
3139
3140        assert_optimized_plan_equal!(
3141            plan,
3142            @r"
3143        Projection: a, b
3144          TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3145        "
3146        )
3147    }
3148
3149    #[test]
3150    fn test_filter_with_alias() -> Result<()> {
3151        // in table scan the true col name is 'test.a',
3152        // but we rename it as 'b', and use col 'b' in filter
3153        // we need rewrite filter col before push down.
3154        let table_scan = test_table_scan()?;
3155        let plan = LogicalPlanBuilder::from(table_scan)
3156            .project(vec![col("a").alias("b"), col("c")])?
3157            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3158            .build()?;
3159
3160        // filter on col b
3161        assert_snapshot!(plan,
3162        @r"
3163        Filter: b > Int64(10) AND test.c > Int64(10)
3164          Projection: test.a AS b, test.c
3165            TableScan: test
3166        ",
3167        );
3168        // rewrite filter col b to test.a
3169        assert_optimized_plan_equal!(
3170            plan,
3171            @r"
3172        Projection: test.a AS b, test.c
3173          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3174        "
3175        )
3176    }
3177
3178    #[test]
3179    fn test_filter_with_alias_2() -> Result<()> {
3180        // in table scan the true col name is 'test.a',
3181        // but we rename it as 'b', and use col 'b' in filter
3182        // we need rewrite filter col before push down.
3183        let table_scan = test_table_scan()?;
3184        let plan = LogicalPlanBuilder::from(table_scan)
3185            .project(vec![col("a").alias("b"), col("c")])?
3186            .project(vec![col("b"), col("c")])?
3187            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3188            .build()?;
3189
3190        // filter on col b
3191        assert_snapshot!(plan,
3192        @r"
3193        Filter: b > Int64(10) AND test.c > Int64(10)
3194          Projection: b, test.c
3195            Projection: test.a AS b, test.c
3196              TableScan: test
3197        ",
3198        );
3199        // rewrite filter col b to test.a
3200        assert_optimized_plan_equal!(
3201            plan,
3202            @r"
3203        Projection: b, test.c
3204          Projection: test.a AS b, test.c
3205            TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3206        "
3207        )
3208    }
3209
3210    #[test]
3211    fn test_filter_with_multi_alias() -> Result<()> {
3212        let table_scan = test_table_scan()?;
3213        let plan = LogicalPlanBuilder::from(table_scan)
3214            .project(vec![col("a").alias("b"), col("c").alias("d")])?
3215            .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3216            .build()?;
3217
3218        // filter on col b and d
3219        assert_snapshot!(plan,
3220        @r"
3221        Filter: b > Int64(10) AND d > Int64(10)
3222          Projection: test.a AS b, test.c AS d
3223            TableScan: test
3224        ",
3225        );
3226        // rewrite filter col b to test.a, col d to test.c
3227        assert_optimized_plan_equal!(
3228            plan,
3229            @r"
3230        Projection: test.a AS b, test.c AS d
3231          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3232        "
3233        )
3234    }
3235
3236    /// predicate on join key in filter expression should be pushed down to both inputs
3237    #[test]
3238    fn join_filter_with_alias() -> Result<()> {
3239        let table_scan = test_table_scan()?;
3240        let left = LogicalPlanBuilder::from(table_scan)
3241            .project(vec![col("a").alias("c")])?
3242            .build()?;
3243        let right_table_scan = test_table_scan_with_name("test2")?;
3244        let right = LogicalPlanBuilder::from(right_table_scan)
3245            .project(vec![col("b").alias("d")])?
3246            .build()?;
3247        let filter = col("c").gt(lit(1u32));
3248        let plan = LogicalPlanBuilder::from(left)
3249            .join(
3250                right,
3251                JoinType::Inner,
3252                (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3253                Some(filter),
3254            )?
3255            .build()?;
3256
3257        assert_snapshot!(plan,
3258        @r"
3259        Inner Join: c = d Filter: c > UInt32(1)
3260          Projection: test.a AS c
3261            TableScan: test
3262          Projection: test2.b AS d
3263            TableScan: test2
3264        ",
3265        );
3266        // Change filter on col `c`, 'd' to `test.a`, 'test.b'
3267        assert_optimized_plan_equal!(
3268            plan,
3269            @r"
3270        Inner Join: c = d
3271          Projection: test.a AS c
3272            TableScan: test, full_filters=[test.a > UInt32(1)]
3273          Projection: test2.b AS d
3274            TableScan: test2, full_filters=[test2.b > UInt32(1)]
3275        "
3276        )
3277    }
3278
3279    #[test]
3280    fn test_in_filter_with_alias() -> Result<()> {
3281        // in table scan the true col name is 'test.a',
3282        // but we rename it as 'b', and use col 'b' in filter
3283        // we need rewrite filter col before push down.
3284        let table_scan = test_table_scan()?;
3285        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3286        let plan = LogicalPlanBuilder::from(table_scan)
3287            .project(vec![col("a").alias("b"), col("c")])?
3288            .filter(in_list(col("b"), filter_value, false))?
3289            .build()?;
3290
3291        // filter on col b
3292        assert_snapshot!(plan,
3293        @r"
3294        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3295          Projection: test.a AS b, test.c
3296            TableScan: test
3297        ",
3298        );
3299        // rewrite filter col b to test.a
3300        assert_optimized_plan_equal!(
3301            plan,
3302            @r"
3303        Projection: test.a AS b, test.c
3304          TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3305        "
3306        )
3307    }
3308
3309    #[test]
3310    fn test_in_filter_with_alias_2() -> Result<()> {
3311        // in table scan the true col name is 'test.a',
3312        // but we rename it as 'b', and use col 'b' in filter
3313        // we need rewrite filter col before push down.
3314        let table_scan = test_table_scan()?;
3315        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3316        let plan = LogicalPlanBuilder::from(table_scan)
3317            .project(vec![col("a").alias("b"), col("c")])?
3318            .project(vec![col("b"), col("c")])?
3319            .filter(in_list(col("b"), filter_value, false))?
3320            .build()?;
3321
3322        // filter on col b
3323        assert_snapshot!(plan,
3324        @r"
3325        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3326          Projection: b, test.c
3327            Projection: test.a AS b, test.c
3328              TableScan: test
3329        ",
3330        );
3331        // rewrite filter col b to test.a
3332        assert_optimized_plan_equal!(
3333            plan,
3334            @r"
3335        Projection: b, test.c
3336          Projection: test.a AS b, test.c
3337            TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3338        "
3339        )
3340    }
3341
3342    #[test]
3343    fn test_in_subquery_with_alias() -> Result<()> {
3344        // in table scan the true col name is 'test.a',
3345        // but we rename it as 'b', and use col 'b' in subquery filter
3346        let table_scan = test_table_scan()?;
3347        let table_scan_sq = test_table_scan_with_name("sq")?;
3348        let subplan = Arc::new(
3349            LogicalPlanBuilder::from(table_scan_sq)
3350                .project(vec![col("c")])?
3351                .build()?,
3352        );
3353        let plan = LogicalPlanBuilder::from(table_scan)
3354            .project(vec![col("a").alias("b"), col("c")])?
3355            .filter(in_subquery(col("b"), subplan))?
3356            .build()?;
3357
3358        // filter on col b in subquery
3359        assert_snapshot!(plan,
3360        @r"
3361        Filter: b IN (<subquery>)
3362          Subquery:
3363            Projection: sq.c
3364              TableScan: sq
3365          Projection: test.a AS b, test.c
3366            TableScan: test
3367        ",
3368        );
3369        // rewrite filter col b to test.a
3370        assert_optimized_plan_equal!(
3371            plan,
3372            @r"
3373        Projection: test.a AS b, test.c
3374          TableScan: test, full_filters=[test.a IN (<subquery>)]
3375            Subquery:
3376              Projection: sq.c
3377                TableScan: sq
3378        "
3379        )
3380    }
3381
3382    #[test]
3383    fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3384        // SELECT a FROM (SELECT 1 AS a) b WHERE b.a = 1
3385        let plan = LogicalPlanBuilder::empty(true)
3386            .project(vec![lit(0i64).alias("a")])?
3387            .alias("b")?
3388            .project(vec![col("b.a")])?
3389            .alias("b")?
3390            .filter(col("b.a").eq(lit(1i64)))?
3391            .project(vec![col("b.a")])?
3392            .build()?;
3393
3394        assert_snapshot!(plan,
3395        @r"
3396        Projection: b.a
3397          Filter: b.a = Int64(1)
3398            SubqueryAlias: b
3399              Projection: b.a
3400                SubqueryAlias: b
3401                  Projection: Int64(0) AS a
3402                    EmptyRelation
3403        ",
3404        );
3405        // Ensure that the predicate without any columns (0 = 1) is
3406        // still there.
3407        assert_optimized_plan_equal!(
3408            plan,
3409            @r"
3410        Projection: b.a
3411          SubqueryAlias: b
3412            Projection: b.a
3413              SubqueryAlias: b
3414                Projection: Int64(0) AS a
3415                  Filter: Int64(0) = Int64(1)
3416                    EmptyRelation
3417        "
3418        )
3419    }
3420
3421    #[test]
3422    fn test_crossjoin_with_or_clause() -> Result<()> {
3423        // select * from test,test1 where (test.a = test1.a and test.b > 1) or (test.b = test1.b and test.c < 10);
3424        let table_scan = test_table_scan()?;
3425        let left = LogicalPlanBuilder::from(table_scan)
3426            .project(vec![col("a"), col("b"), col("c")])?
3427            .build()?;
3428        let right_table_scan = test_table_scan_with_name("test1")?;
3429        let right = LogicalPlanBuilder::from(right_table_scan)
3430            .project(vec![col("a").alias("d"), col("a").alias("e")])?
3431            .build()?;
3432        let filter = or(
3433            and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3434            and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3435        );
3436        let plan = LogicalPlanBuilder::from(left)
3437            .cross_join(right)?
3438            .filter(filter)?
3439            .build()?;
3440
3441        assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3442        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3443          Projection: test.a, test.b, test.c
3444            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3445          Projection: test1.a AS d, test1.a AS e
3446            TableScan: test1
3447        ")?;
3448
3449        // Originally global state which can help to avoid duplicate Filters been generated and pushed down.
3450        // Now the global state is removed. Need to double confirm that avoid duplicate Filters.
3451        let optimized_plan = PushDownFilter::new()
3452            .rewrite(plan, &OptimizerContext::new())
3453            .expect("failed to optimize plan")
3454            .data;
3455        assert_optimized_plan_equal!(
3456            optimized_plan,
3457            @r"
3458        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3459          Projection: test.a, test.b, test.c
3460            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3461          Projection: test1.a AS d, test1.a AS e
3462            TableScan: test1
3463        "
3464        )
3465    }
3466
3467    #[test]
3468    fn left_semi_join() -> Result<()> {
3469        let left = test_table_scan_with_name("test1")?;
3470        let right_table_scan = test_table_scan_with_name("test2")?;
3471        let right = LogicalPlanBuilder::from(right_table_scan)
3472            .project(vec![col("a"), col("b")])?
3473            .build()?;
3474        let plan = LogicalPlanBuilder::from(left)
3475            .join(
3476                right,
3477                JoinType::LeftSemi,
3478                (
3479                    vec![Column::from_qualified_name("test1.a")],
3480                    vec![Column::from_qualified_name("test2.a")],
3481                ),
3482                None,
3483            )?
3484            .filter(col("test2.a").lt_eq(lit(1i64)))?
3485            .build()?;
3486
3487        // not part of the test, just good to know:
3488        assert_snapshot!(plan,
3489        @r"
3490        Filter: test2.a <= Int64(1)
3491          LeftSemi Join: test1.a = test2.a
3492            TableScan: test1
3493            Projection: test2.a, test2.b
3494              TableScan: test2
3495        ",
3496        );
3497        // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side.
3498        assert_optimized_plan_equal!(
3499            plan,
3500            @r"
3501        Filter: test2.a <= Int64(1)
3502          LeftSemi Join: test1.a = test2.a
3503            TableScan: test1, full_filters=[test1.a <= Int64(1)]
3504            Projection: test2.a, test2.b
3505              TableScan: test2
3506        "
3507        )
3508    }
3509
3510    #[test]
3511    fn left_semi_join_with_filters() -> Result<()> {
3512        let left = test_table_scan_with_name("test1")?;
3513        let right_table_scan = test_table_scan_with_name("test2")?;
3514        let right = LogicalPlanBuilder::from(right_table_scan)
3515            .project(vec![col("a"), col("b")])?
3516            .build()?;
3517        let plan = LogicalPlanBuilder::from(left)
3518            .join(
3519                right,
3520                JoinType::LeftSemi,
3521                (
3522                    vec![Column::from_qualified_name("test1.a")],
3523                    vec![Column::from_qualified_name("test2.a")],
3524                ),
3525                Some(
3526                    col("test1.b")
3527                        .gt(lit(1u32))
3528                        .and(col("test2.b").gt(lit(2u32))),
3529                ),
3530            )?
3531            .build()?;
3532
3533        // not part of the test, just good to know:
3534        assert_snapshot!(plan,
3535        @r"
3536        LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3537          TableScan: test1
3538          Projection: test2.a, test2.b
3539            TableScan: test2
3540        ",
3541        );
3542        // Both side will be pushed down.
3543        assert_optimized_plan_equal!(
3544            plan,
3545            @r"
3546        LeftSemi Join: test1.a = test2.a
3547          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3548          Projection: test2.a, test2.b
3549            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3550        "
3551        )
3552    }
3553
3554    #[test]
3555    fn right_semi_join() -> Result<()> {
3556        let left = test_table_scan_with_name("test1")?;
3557        let right_table_scan = test_table_scan_with_name("test2")?;
3558        let right = LogicalPlanBuilder::from(right_table_scan)
3559            .project(vec![col("a"), col("b")])?
3560            .build()?;
3561        let plan = LogicalPlanBuilder::from(left)
3562            .join(
3563                right,
3564                JoinType::RightSemi,
3565                (
3566                    vec![Column::from_qualified_name("test1.a")],
3567                    vec![Column::from_qualified_name("test2.a")],
3568                ),
3569                None,
3570            )?
3571            .filter(col("test1.a").lt_eq(lit(1i64)))?
3572            .build()?;
3573
3574        // not part of the test, just good to know:
3575        assert_snapshot!(plan,
3576        @r"
3577        Filter: test1.a <= Int64(1)
3578          RightSemi Join: test1.a = test2.a
3579            TableScan: test1
3580            Projection: test2.a, test2.b
3581              TableScan: test2
3582        ",
3583        );
3584        // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side.
3585        assert_optimized_plan_equal!(
3586            plan,
3587            @r"
3588        Filter: test1.a <= Int64(1)
3589          RightSemi Join: test1.a = test2.a
3590            TableScan: test1
3591            Projection: test2.a, test2.b
3592              TableScan: test2, full_filters=[test2.a <= Int64(1)]
3593        "
3594        )
3595    }
3596
3597    #[test]
3598    fn right_semi_join_with_filters() -> Result<()> {
3599        let left = test_table_scan_with_name("test1")?;
3600        let right_table_scan = test_table_scan_with_name("test2")?;
3601        let right = LogicalPlanBuilder::from(right_table_scan)
3602            .project(vec![col("a"), col("b")])?
3603            .build()?;
3604        let plan = LogicalPlanBuilder::from(left)
3605            .join(
3606                right,
3607                JoinType::RightSemi,
3608                (
3609                    vec![Column::from_qualified_name("test1.a")],
3610                    vec![Column::from_qualified_name("test2.a")],
3611                ),
3612                Some(
3613                    col("test1.b")
3614                        .gt(lit(1u32))
3615                        .and(col("test2.b").gt(lit(2u32))),
3616                ),
3617            )?
3618            .build()?;
3619
3620        // not part of the test, just good to know:
3621        assert_snapshot!(plan,
3622        @r"
3623        RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3624          TableScan: test1
3625          Projection: test2.a, test2.b
3626            TableScan: test2
3627        ",
3628        );
3629        // Both side will be pushed down.
3630        assert_optimized_plan_equal!(
3631            plan,
3632            @r"
3633        RightSemi Join: test1.a = test2.a
3634          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3635          Projection: test2.a, test2.b
3636            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3637        "
3638        )
3639    }
3640
3641    #[test]
3642    fn left_anti_join() -> Result<()> {
3643        let table_scan = test_table_scan_with_name("test1")?;
3644        let left = LogicalPlanBuilder::from(table_scan)
3645            .project(vec![col("a"), col("b")])?
3646            .build()?;
3647        let right_table_scan = test_table_scan_with_name("test2")?;
3648        let right = LogicalPlanBuilder::from(right_table_scan)
3649            .project(vec![col("a"), col("b")])?
3650            .build()?;
3651        let plan = LogicalPlanBuilder::from(left)
3652            .join(
3653                right,
3654                JoinType::LeftAnti,
3655                (
3656                    vec![Column::from_qualified_name("test1.a")],
3657                    vec![Column::from_qualified_name("test2.a")],
3658                ),
3659                None,
3660            )?
3661            .filter(col("test2.a").gt(lit(2u32)))?
3662            .build()?;
3663
3664        // not part of the test, just good to know:
3665        assert_snapshot!(plan,
3666        @r"
3667        Filter: test2.a > UInt32(2)
3668          LeftAnti Join: test1.a = test2.a
3669            Projection: test1.a, test1.b
3670              TableScan: test1
3671            Projection: test2.a, test2.b
3672              TableScan: test2
3673        ",
3674        );
3675        // For left anti, filter of the right side filter can be pushed down.
3676        assert_optimized_plan_equal!(
3677            plan,
3678            @r"
3679        Filter: test2.a > UInt32(2)
3680          LeftAnti Join: test1.a = test2.a
3681            Projection: test1.a, test1.b
3682              TableScan: test1, full_filters=[test1.a > UInt32(2)]
3683            Projection: test2.a, test2.b
3684              TableScan: test2
3685        "
3686        )
3687    }
3688
3689    #[test]
3690    fn left_anti_join_with_filters() -> Result<()> {
3691        let table_scan = test_table_scan_with_name("test1")?;
3692        let left = LogicalPlanBuilder::from(table_scan)
3693            .project(vec![col("a"), col("b")])?
3694            .build()?;
3695        let right_table_scan = test_table_scan_with_name("test2")?;
3696        let right = LogicalPlanBuilder::from(right_table_scan)
3697            .project(vec![col("a"), col("b")])?
3698            .build()?;
3699        let plan = LogicalPlanBuilder::from(left)
3700            .join(
3701                right,
3702                JoinType::LeftAnti,
3703                (
3704                    vec![Column::from_qualified_name("test1.a")],
3705                    vec![Column::from_qualified_name("test2.a")],
3706                ),
3707                Some(
3708                    col("test1.b")
3709                        .gt(lit(1u32))
3710                        .and(col("test2.b").gt(lit(2u32))),
3711                ),
3712            )?
3713            .build()?;
3714
3715        // not part of the test, just good to know:
3716        assert_snapshot!(plan,
3717        @r"
3718        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3719          Projection: test1.a, test1.b
3720            TableScan: test1
3721          Projection: test2.a, test2.b
3722            TableScan: test2
3723        ",
3724        );
3725        // For left anti, filter of the right side filter can be pushed down.
3726        assert_optimized_plan_equal!(
3727            plan,
3728            @r"
3729        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3730          Projection: test1.a, test1.b
3731            TableScan: test1
3732          Projection: test2.a, test2.b
3733            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3734        "
3735        )
3736    }
3737
3738    #[test]
3739    fn right_anti_join() -> Result<()> {
3740        let table_scan = test_table_scan_with_name("test1")?;
3741        let left = LogicalPlanBuilder::from(table_scan)
3742            .project(vec![col("a"), col("b")])?
3743            .build()?;
3744        let right_table_scan = test_table_scan_with_name("test2")?;
3745        let right = LogicalPlanBuilder::from(right_table_scan)
3746            .project(vec![col("a"), col("b")])?
3747            .build()?;
3748        let plan = LogicalPlanBuilder::from(left)
3749            .join(
3750                right,
3751                JoinType::RightAnti,
3752                (
3753                    vec![Column::from_qualified_name("test1.a")],
3754                    vec![Column::from_qualified_name("test2.a")],
3755                ),
3756                None,
3757            )?
3758            .filter(col("test1.a").gt(lit(2u32)))?
3759            .build()?;
3760
3761        // not part of the test, just good to know:
3762        assert_snapshot!(plan,
3763        @r"
3764        Filter: test1.a > UInt32(2)
3765          RightAnti Join: test1.a = test2.a
3766            Projection: test1.a, test1.b
3767              TableScan: test1
3768            Projection: test2.a, test2.b
3769              TableScan: test2
3770        ",
3771        );
3772        // For right anti, filter of the left side can be pushed down.
3773        assert_optimized_plan_equal!(
3774            plan,
3775            @r"
3776        Filter: test1.a > UInt32(2)
3777          RightAnti Join: test1.a = test2.a
3778            Projection: test1.a, test1.b
3779              TableScan: test1
3780            Projection: test2.a, test2.b
3781              TableScan: test2, full_filters=[test2.a > UInt32(2)]
3782        "
3783        )
3784    }
3785
3786    #[test]
3787    fn right_anti_join_with_filters() -> Result<()> {
3788        let table_scan = test_table_scan_with_name("test1")?;
3789        let left = LogicalPlanBuilder::from(table_scan)
3790            .project(vec![col("a"), col("b")])?
3791            .build()?;
3792        let right_table_scan = test_table_scan_with_name("test2")?;
3793        let right = LogicalPlanBuilder::from(right_table_scan)
3794            .project(vec![col("a"), col("b")])?
3795            .build()?;
3796        let plan = LogicalPlanBuilder::from(left)
3797            .join(
3798                right,
3799                JoinType::RightAnti,
3800                (
3801                    vec![Column::from_qualified_name("test1.a")],
3802                    vec![Column::from_qualified_name("test2.a")],
3803                ),
3804                Some(
3805                    col("test1.b")
3806                        .gt(lit(1u32))
3807                        .and(col("test2.b").gt(lit(2u32))),
3808                ),
3809            )?
3810            .build()?;
3811
3812        // not part of the test, just good to know:
3813        assert_snapshot!(plan,
3814        @r"
3815        RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3816          Projection: test1.a, test1.b
3817            TableScan: test1
3818          Projection: test2.a, test2.b
3819            TableScan: test2
3820        ",
3821        );
3822        // For right anti, filter of the left side can be pushed down.
3823        assert_optimized_plan_equal!(
3824            plan,
3825            @r"
3826        RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3827          Projection: test1.a, test1.b
3828            TableScan: test1, full_filters=[test1.b > UInt32(1)]
3829          Projection: test2.a, test2.b
3830            TableScan: test2
3831        "
3832        )
3833    }
3834
3835    #[derive(Debug)]
3836    struct TestScalarUDF {
3837        signature: Signature,
3838    }
3839
3840    impl ScalarUDFImpl for TestScalarUDF {
3841        fn as_any(&self) -> &dyn Any {
3842            self
3843        }
3844        fn name(&self) -> &str {
3845            "TestScalarUDF"
3846        }
3847
3848        fn signature(&self) -> &Signature {
3849            &self.signature
3850        }
3851
3852        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3853            Ok(DataType::Int32)
3854        }
3855
3856        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3857            Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3858        }
3859    }
3860
3861    #[test]
3862    fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3863        // SELECT t.a, t.r FROM (SELECT a, sum(b),  TestScalarUDF()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5;
3864        let table_scan = test_table_scan_with_name("test1")?;
3865        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3866            signature: Signature::exact(vec![], Volatility::Volatile),
3867        });
3868        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3869
3870        let plan = LogicalPlanBuilder::from(table_scan)
3871            .aggregate(vec![col("a")], vec![sum(col("b"))])?
3872            .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3873            .alias("t")?
3874            .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3875            .project(vec![col("t.a"), col("t.r")])?
3876            .build()?;
3877
3878        assert_snapshot!(plan,
3879        @r"
3880        Projection: t.a, t.r
3881          Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3882            SubqueryAlias: t
3883              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3884                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3885                  TableScan: test1
3886        ",
3887        );
3888        assert_optimized_plan_equal!(
3889            plan,
3890            @r"
3891        Projection: t.a, t.r
3892          SubqueryAlias: t
3893            Filter: r > Float64(0.5)
3894              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3895                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3896                  TableScan: test1, full_filters=[test1.a > Int32(5)]
3897        "
3898        )
3899    }
3900
3901    #[test]
3902    fn test_push_down_volatile_function_in_join() -> Result<()> {
3903        // SELECT t.a, t.r FROM (SELECT test1.a AS a, TestScalarUDF() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5;
3904        let table_scan = test_table_scan_with_name("test1")?;
3905        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3906            signature: Signature::exact(vec![], Volatility::Volatile),
3907        });
3908        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3909        let left = LogicalPlanBuilder::from(table_scan).build()?;
3910        let right_table_scan = test_table_scan_with_name("test2")?;
3911        let right = LogicalPlanBuilder::from(right_table_scan).build()?;
3912        let plan = LogicalPlanBuilder::from(left)
3913            .join(
3914                right,
3915                JoinType::Inner,
3916                (
3917                    vec![Column::from_qualified_name("test1.a")],
3918                    vec![Column::from_qualified_name("test2.a")],
3919                ),
3920                None,
3921            )?
3922            .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
3923            .alias("t")?
3924            .filter(col("t.r").gt(lit(0.8)))?
3925            .project(vec![col("t.a"), col("t.r")])?
3926            .build()?;
3927
3928        assert_snapshot!(plan,
3929        @r"
3930        Projection: t.a, t.r
3931          Filter: t.r > Float64(0.8)
3932            SubqueryAlias: t
3933              Projection: test1.a AS a, TestScalarUDF() AS r
3934                Inner Join: test1.a = test2.a
3935                  TableScan: test1
3936                  TableScan: test2
3937        ",
3938        );
3939        assert_optimized_plan_equal!(
3940            plan,
3941            @r"
3942        Projection: t.a, t.r
3943          SubqueryAlias: t
3944            Filter: r > Float64(0.8)
3945              Projection: test1.a AS a, TestScalarUDF() AS r
3946                Inner Join: test1.a = test2.a
3947                  TableScan: test1
3948                  TableScan: test2
3949        "
3950        )
3951    }
3952
3953    #[test]
3954    fn test_push_down_volatile_table_scan() -> Result<()> {
3955        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1;
3956        let table_scan = test_table_scan()?;
3957        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3958            signature: Signature::exact(vec![], Volatility::Volatile),
3959        });
3960        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3961        let plan = LogicalPlanBuilder::from(table_scan)
3962            .project(vec![col("a"), col("b")])?
3963            .filter(expr.gt(lit(0.1)))?
3964            .build()?;
3965
3966        assert_snapshot!(plan,
3967        @r"
3968        Filter: TestScalarUDF() > Float64(0.1)
3969          Projection: test.a, test.b
3970            TableScan: test
3971        ",
3972        );
3973        assert_optimized_plan_equal!(
3974            plan,
3975            @r"
3976        Projection: test.a, test.b
3977          Filter: TestScalarUDF() > Float64(0.1)
3978            TableScan: test
3979        "
3980        )
3981    }
3982
3983    #[test]
3984    fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
3985        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
3986        let table_scan = test_table_scan()?;
3987        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3988            signature: Signature::exact(vec![], Volatility::Volatile),
3989        });
3990        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3991        let plan = LogicalPlanBuilder::from(table_scan)
3992            .project(vec![col("a"), col("b")])?
3993            .filter(
3994                expr.gt(lit(0.1))
3995                    .and(col("t.a").gt(lit(5)))
3996                    .and(col("t.b").gt(lit(10))),
3997            )?
3998            .build()?;
3999
4000        assert_snapshot!(plan,
4001        @r"
4002        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4003          Projection: test.a, test.b
4004            TableScan: test
4005        ",
4006        );
4007        assert_optimized_plan_equal!(
4008            plan,
4009            @r"
4010        Projection: test.a, test.b
4011          Filter: TestScalarUDF() > Float64(0.1)
4012            TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4013        "
4014        )
4015    }
4016
4017    #[test]
4018    fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4019        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4020        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4021            signature: Signature::exact(vec![], Volatility::Volatile),
4022        });
4023        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4024        let plan = table_scan_with_pushdown_provider_builder(
4025            TableProviderFilterPushDown::Unsupported,
4026            vec![],
4027            None,
4028        )?
4029        .project(vec![col("a"), col("b")])?
4030        .filter(
4031            expr.gt(lit(0.1))
4032                .and(col("t.a").gt(lit(5)))
4033                .and(col("t.b").gt(lit(10))),
4034        )?
4035        .build()?;
4036
4037        assert_snapshot!(plan,
4038        @r"
4039        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4040          Projection: a, b
4041            TableScan: test
4042        ",
4043        );
4044        assert_optimized_plan_equal!(
4045            plan,
4046            @r"
4047        Projection: a, b
4048          Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4049            TableScan: test
4050        "
4051        )
4052    }
4053
4054    #[test]
4055    fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4056        // Define a custom user-defined logical node
4057        #[derive(Debug, Hash, Eq, PartialEq)]
4058        struct TestUserNode {
4059            schema: DFSchemaRef,
4060        }
4061
4062        impl PartialOrd for TestUserNode {
4063            fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4064                None
4065            }
4066        }
4067
4068        impl TestUserNode {
4069            fn new() -> Self {
4070                let schema = Arc::new(
4071                    DFSchema::new_with_metadata(
4072                        vec![(None, Field::new("a", DataType::Int64, false).into())],
4073                        Default::default(),
4074                    )
4075                    .unwrap(),
4076                );
4077
4078                Self { schema }
4079            }
4080        }
4081
4082        impl UserDefinedLogicalNodeCore for TestUserNode {
4083            fn name(&self) -> &str {
4084                "test_node"
4085            }
4086
4087            fn inputs(&self) -> Vec<&LogicalPlan> {
4088                vec![]
4089            }
4090
4091            fn schema(&self) -> &DFSchemaRef {
4092                &self.schema
4093            }
4094
4095            fn expressions(&self) -> Vec<Expr> {
4096                vec![]
4097            }
4098
4099            fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4100                write!(f, "TestUserNode")
4101            }
4102
4103            fn with_exprs_and_inputs(
4104                &self,
4105                exprs: Vec<Expr>,
4106                inputs: Vec<LogicalPlan>,
4107            ) -> Result<Self> {
4108                assert!(exprs.is_empty());
4109                assert!(inputs.is_empty());
4110                Ok(Self {
4111                    schema: Arc::clone(&self.schema),
4112                })
4113            }
4114        }
4115
4116        // Create a node and build a plan with a filter
4117        let node = LogicalPlan::Extension(Extension {
4118            node: Arc::new(TestUserNode::new()),
4119        });
4120
4121        let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4122
4123        // Check the original plan format (not part of the test assertions)
4124        assert_snapshot!(plan,
4125        @r"
4126        Filter: Boolean(false)
4127          TestUserNode
4128        ",
4129        );
4130        // Check that the filter is pushed down to the user-defined node
4131        assert_optimized_plan_equal!(
4132            plan,
4133            @r"
4134        Filter: Boolean(false)
4135          TestUserNode
4136        "
4137        )
4138    }
4139}