datafusion_optimizer/
push_down_filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`PushDownFilter`] applies filters as early as possible
19
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26
27use datafusion_common::tree_node::{
28    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29};
30use datafusion_common::{
31    internal_err, plan_err, qualified_name, Column, DFSchema, Result,
32};
33use datafusion_expr::expr::WindowFunction;
34use datafusion_expr::expr_rewriter::replace_col;
35use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
36use datafusion_expr::utils::{
37    conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
38};
39use datafusion_expr::{
40    and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown,
41};
42
43use crate::optimizer::ApplyOrder;
44use crate::simplify_expressions::simplify_predicates;
45use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
46use crate::{OptimizerConfig, OptimizerRule};
47
48/// Optimizer rule for pushing (moving) filter expressions down in a plan so
49/// they are applied as early as possible.
50///
51/// # Introduction
52///
53/// The goal of this rule is to improve query performance by eliminating
54/// redundant work.
55///
56/// For example, given a plan that sorts all values where `a > 10`:
57///
58/// ```text
59///  Filter (a > 10)
60///    Sort (a, b)
61/// ```
62///
63/// A better plan is to  filter the data *before* the Sort, which sorts fewer
64/// rows and therefore does less work overall:
65///
66/// ```text
67///  Sort (a, b)
68///    Filter (a > 10)  <-- Filter is moved before the sort
69/// ```
70///
71/// However it is not always possible to push filters down. For example, given a
72/// plan that finds the top 3 values and then keeps only those that are greater
73/// than 10, if the filter is pushed below the limit it would produce a
74/// different result.
75///
76/// ```text
77///  Filter (a > 10)   <-- can not move this Filter before the limit
78///    Limit (fetch=3)
79///      Sort (a, b)
80/// ```
81///
82///
83/// More formally, a filter-commutative operation is an operation `op` that
84/// satisfies `filter(op(data)) = op(filter(data))`.
85///
86/// The filter-commutative property is plan and column-specific. A filter on `a`
87/// can be pushed through a `Aggregate(group_by = [a], agg=[sum(b))`. However, a
88/// filter on  `sum(b)` can not be pushed through the same aggregate.
89///
90/// # Handling Conjunctions
91///
92/// It is possible to only push down **part** of a filter expression if is
93/// connected with `AND`s (more formally if it is a "conjunction").
94///
95/// For example, given the following plan:
96///
97/// ```text
98/// Filter(a > 10 AND sum(b) < 5)
99///   Aggregate(group_by = [a], agg = [sum(b))
100/// ```
101///
102/// The `a > 10` is commutative with the `Aggregate` but  `sum(b) < 5` is not.
103/// Therefore it is possible to only push part of the expression, resulting in:
104///
105/// ```text
106/// Filter(sum(b) < 5)
107///   Aggregate(group_by = [a], agg = [sum(b))
108///     Filter(a > 10)
109/// ```
110///
111/// # Handling Column Aliases
112///
113/// This optimizer must sometimes handle re-writing filter expressions when they
114/// pushed, for example if there is a projection that aliases `a+1` to `"b"`:
115///
116/// ```text
117/// Filter (b > 10)
118///     Projection: [a+1 AS "b"]  <-- changes the name of `a+1` to `b`
119/// ```
120///
121/// To apply the filter prior to the `Projection`, all references to `b` must be
122/// rewritten to `a+1`:
123///
124/// ```text
125/// Projection: a AS "b"
126///     Filter: (a + 1 > 10)  <--- changed from b to a + 1
127/// ```
128/// # Implementation Notes
129///
130/// This implementation performs a single pass through the plan, "pushing" down
131/// filters. When it passes through a filter, it stores that filter, and when it
132/// reaches a plan node that does not commute with that filter, it adds the
133/// filter to that place. When it passes through a projection, it re-writes the
134/// filter's expression taking into account that projection.
135#[derive(Default, Debug)]
136pub struct PushDownFilter {}
137
138/// For a given JOIN type, determine whether each input of the join is preserved
139/// for post-join (`WHERE` clause) filters.
140///
141/// It is only correct to push filters below a join for preserved inputs.
142///
143/// # Return Value
144/// A tuple of booleans - (left_preserved, right_preserved).
145///
146/// # "Preserved" input definition
147///
148/// We say a join side is preserved if the join returns all or a subset of the rows from
149/// the relevant side, such that each row of the output table directly maps to a row of
150/// the preserved input table. If a table is not preserved, it can provide extra null rows.
151/// That is, there may be rows in the output table that don't directly map to a row in the
152/// input table.
153///
154/// For example:
155///   - In an inner join, both sides are preserved, because each row of the output
156///     maps directly to a row from each side.
157///
158///   - In a left join, the left side is preserved (we can push predicates) but
159///     the right is not, because there may be rows in the output that don't
160///     directly map to a row in the right input (due to nulls filling where there
161///     is no match on the right).
162pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
163    match join_type {
164        JoinType::Inner => (true, true),
165        JoinType::Left => (true, false),
166        JoinType::Right => (false, true),
167        JoinType::Full => (false, false),
168        // No columns from the right side of the join can be referenced in output
169        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
170        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
171        // No columns from the left side of the join can be referenced in output
172        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
173        JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
174    }
175}
176
177/// For a given JOIN type, determine whether each input of the join is preserved
178/// for the join condition (`ON` clause filters).
179///
180/// It is only correct to push filters below a join for preserved inputs.
181///
182/// # Return Value
183/// A tuple of booleans - (left_preserved, right_preserved).
184///
185/// See [`lr_is_preserved`] for a definition of "preserved".
186pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
187    match join_type {
188        JoinType::Inner => (true, true),
189        JoinType::Left => (false, true),
190        JoinType::Right => (true, false),
191        JoinType::Full => (false, false),
192        JoinType::LeftSemi | JoinType::RightSemi => (true, true),
193        JoinType::LeftAnti => (false, true),
194        JoinType::RightAnti => (true, false),
195        JoinType::LeftMark => (false, true),
196        JoinType::RightMark => (true, false),
197    }
198}
199
200/// Evaluates the columns referenced in the given expression to see if they refer
201/// only to the left or right columns
202#[derive(Debug)]
203struct ColumnChecker<'a> {
204    /// schema of left join input
205    left_schema: &'a DFSchema,
206    /// columns in left_schema, computed on demand
207    left_columns: Option<HashSet<Column>>,
208    /// schema of right join input
209    right_schema: &'a DFSchema,
210    /// columns in left_schema, computed on demand
211    right_columns: Option<HashSet<Column>>,
212}
213
214impl<'a> ColumnChecker<'a> {
215    fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
216        Self {
217            left_schema,
218            left_columns: None,
219            right_schema,
220            right_columns: None,
221        }
222    }
223
224    /// Return true if the expression references only columns from the left side of the join
225    fn is_left_only(&mut self, predicate: &Expr) -> bool {
226        if self.left_columns.is_none() {
227            self.left_columns = Some(schema_columns(self.left_schema));
228        }
229        has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
230    }
231
232    /// Return true if the expression references only columns from the right side of the join
233    fn is_right_only(&mut self, predicate: &Expr) -> bool {
234        if self.right_columns.is_none() {
235            self.right_columns = Some(schema_columns(self.right_schema));
236        }
237        has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
238    }
239}
240
241/// Returns all columns in the schema
242fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
243    schema
244        .iter()
245        .flat_map(|(qualifier, field)| {
246            [
247                Column::new(qualifier.cloned(), field.name()),
248                // we need to push down filter using unqualified column as well
249                Column::new_unqualified(field.name()),
250            ]
251        })
252        .collect::<HashSet<_>>()
253}
254
255/// Determine whether the predicate can evaluate as the join conditions
256fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
257    let mut is_evaluate = true;
258    predicate.apply(|expr| match expr {
259        Expr::Column(_)
260        | Expr::Literal(_, _)
261        | Expr::Placeholder(_)
262        | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
263        Expr::Exists { .. }
264        | Expr::InSubquery(_)
265        | Expr::ScalarSubquery(_)
266        | Expr::OuterReferenceColumn(_, _)
267        | Expr::Unnest(_) => {
268            is_evaluate = false;
269            Ok(TreeNodeRecursion::Stop)
270        }
271        Expr::Alias(_)
272        | Expr::BinaryExpr(_)
273        | Expr::Like(_)
274        | Expr::SimilarTo(_)
275        | Expr::Not(_)
276        | Expr::IsNotNull(_)
277        | Expr::IsNull(_)
278        | Expr::IsTrue(_)
279        | Expr::IsFalse(_)
280        | Expr::IsUnknown(_)
281        | Expr::IsNotTrue(_)
282        | Expr::IsNotFalse(_)
283        | Expr::IsNotUnknown(_)
284        | Expr::Negative(_)
285        | Expr::Between(_)
286        | Expr::Case(_)
287        | Expr::Cast(_)
288        | Expr::TryCast(_)
289        | Expr::InList { .. }
290        | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
291        // TODO: remove the next line after `Expr::Wildcard` is removed
292        #[expect(deprecated)]
293        Expr::AggregateFunction(_)
294        | Expr::WindowFunction(_)
295        | Expr::Wildcard { .. }
296        | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
297    })?;
298    Ok(is_evaluate)
299}
300
301/// examine OR clause to see if any useful clauses can be extracted and push down.
302/// extract at least one qual from each sub clauses of OR clause, then form the quals
303/// to new OR clause as predicate.
304///
305/// # Example
306/// ```text
307/// Filter: (a = c and a < 20) or (b = d and b > 10)
308///     join/crossjoin:
309///          TableScan: projection=[a, b]
310///          TableScan: projection=[c, d]
311/// ```
312///
313/// is optimized to
314///
315/// ```text
316/// Filter: (a = c and a < 20) or (b = d and b > 10)
317///     join/crossjoin:
318///          Filter: (a < 20) or (b > 10)
319///              TableScan: projection=[a, b]
320///          TableScan: projection=[c, d]
321/// ```
322///
323/// In general, predicates of this form:
324///
325/// ```sql
326/// (A AND B) OR (C AND D)
327/// ```
328///
329/// will be transformed to one of:
330///
331/// * `((A AND B) OR (C AND D)) AND (A OR C)`
332/// * `((A AND B) OR (C AND D)) AND ((A AND B) OR C)`
333/// * do nothing.
334fn extract_or_clauses_for_join<'a>(
335    filters: &'a [Expr],
336    schema: &'a DFSchema,
337) -> impl Iterator<Item = Expr> + 'a {
338    let schema_columns = schema_columns(schema);
339
340    // new formed OR clauses and their column references
341    filters.iter().filter_map(move |expr| {
342        if let Expr::BinaryExpr(BinaryExpr {
343            left,
344            op: Operator::Or,
345            right,
346        }) = expr
347        {
348            let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
349            let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
350
351            // If nothing can be extracted from any sub clauses, do nothing for this OR clause.
352            if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
353                return Some(or(left_expr, right_expr));
354            }
355        }
356        None
357    })
358}
359
360/// extract qual from OR sub-clause.
361///
362/// A qual is extracted if it only contains set of column references in schema_columns.
363///
364/// For AND clause, we extract from both sub-clauses, then make new AND clause by extracted
365/// clauses if both extracted; Otherwise, use the extracted clause from any sub-clauses or None.
366///
367/// For OR clause, we extract from both sub-clauses, then make new OR clause by extracted clauses if both extracted;
368/// Otherwise, return None.
369///
370/// For other clause, apply the rule above to extract clause.
371fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
372    let mut predicate = None;
373
374    match expr {
375        Expr::BinaryExpr(BinaryExpr {
376            left: l_expr,
377            op: Operator::Or,
378            right: r_expr,
379        }) => {
380            let l_expr = extract_or_clause(l_expr, schema_columns);
381            let r_expr = extract_or_clause(r_expr, schema_columns);
382
383            if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
384                predicate = Some(or(l_expr, r_expr));
385            }
386        }
387        Expr::BinaryExpr(BinaryExpr {
388            left: l_expr,
389            op: Operator::And,
390            right: r_expr,
391        }) => {
392            let l_expr = extract_or_clause(l_expr, schema_columns);
393            let r_expr = extract_or_clause(r_expr, schema_columns);
394
395            match (l_expr, r_expr) {
396                (Some(l_expr), Some(r_expr)) => {
397                    predicate = Some(and(l_expr, r_expr));
398                }
399                (Some(l_expr), None) => {
400                    predicate = Some(l_expr);
401                }
402                (None, Some(r_expr)) => {
403                    predicate = Some(r_expr);
404                }
405                (None, None) => {
406                    predicate = None;
407                }
408            }
409        }
410        _ => {
411            if has_all_column_refs(expr, schema_columns) {
412                predicate = Some(expr.clone());
413            }
414        }
415    }
416
417    predicate
418}
419
420/// push down join/cross-join
421fn push_down_all_join(
422    predicates: Vec<Expr>,
423    inferred_join_predicates: Vec<Expr>,
424    mut join: Join,
425    on_filter: Vec<Expr>,
426) -> Result<Transformed<LogicalPlan>> {
427    let is_inner_join = join.join_type == JoinType::Inner;
428    // Get pushable predicates from current optimizer state
429    let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
430
431    // The predicates can be divided to three categories:
432    // 1) can push through join to its children(left or right)
433    // 2) can be converted to join conditions if the join type is Inner
434    // 3) should be kept as filter conditions
435    let left_schema = join.left.schema();
436    let right_schema = join.right.schema();
437    let mut left_push = vec![];
438    let mut right_push = vec![];
439    let mut keep_predicates = vec![];
440    let mut join_conditions = vec![];
441    let mut checker = ColumnChecker::new(left_schema, right_schema);
442    for predicate in predicates {
443        if left_preserved && checker.is_left_only(&predicate) {
444            left_push.push(predicate);
445        } else if right_preserved && checker.is_right_only(&predicate) {
446            right_push.push(predicate);
447        } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
448            // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
449            // and convert to the join on condition
450            join_conditions.push(predicate);
451        } else {
452            keep_predicates.push(predicate);
453        }
454    }
455
456    // For infer predicates, if they can not push through join, just drop them
457    for predicate in inferred_join_predicates {
458        if left_preserved && checker.is_left_only(&predicate) {
459            left_push.push(predicate);
460        } else if right_preserved && checker.is_right_only(&predicate) {
461            right_push.push(predicate);
462        }
463    }
464
465    let mut on_filter_join_conditions = vec![];
466    let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
467
468    if !on_filter.is_empty() {
469        for on in on_filter {
470            if on_left_preserved && checker.is_left_only(&on) {
471                left_push.push(on)
472            } else if on_right_preserved && checker.is_right_only(&on) {
473                right_push.push(on)
474            } else {
475                on_filter_join_conditions.push(on)
476            }
477        }
478    }
479
480    // Extract from OR clause, generate new predicates for both side of join if possible.
481    // We only track the unpushable predicates above.
482    if left_preserved {
483        left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
484        left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
485    }
486    if right_preserved {
487        right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
488        right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
489    }
490
491    // For predicates from join filter, we should check with if a join side is preserved
492    // in term of join filtering.
493    if on_left_preserved {
494        left_push.extend(extract_or_clauses_for_join(
495            &on_filter_join_conditions,
496            left_schema,
497        ));
498    }
499    if on_right_preserved {
500        right_push.extend(extract_or_clauses_for_join(
501            &on_filter_join_conditions,
502            right_schema,
503        ));
504    }
505
506    if let Some(predicate) = conjunction(left_push) {
507        join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
508    }
509    if let Some(predicate) = conjunction(right_push) {
510        join.right =
511            Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
512    }
513
514    // Add any new join conditions as the non join predicates
515    join_conditions.extend(on_filter_join_conditions);
516    join.filter = conjunction(join_conditions);
517
518    // wrap the join on the filter whose predicates must be kept, if any
519    let plan = LogicalPlan::Join(join);
520    let plan = if let Some(predicate) = conjunction(keep_predicates) {
521        LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
522    } else {
523        plan
524    };
525    Ok(Transformed::yes(plan))
526}
527
528fn push_down_join(
529    join: Join,
530    parent_predicate: Option<&Expr>,
531) -> Result<Transformed<LogicalPlan>> {
532    // Split the parent predicate into individual conjunctive parts.
533    let predicates = parent_predicate
534        .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
535
536    // Extract conjunctions from the JOIN's ON filter, if present.
537    let on_filters = join
538        .filter
539        .as_ref()
540        .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
541
542    // Are there any new join predicates that can be inferred from the filter expressions?
543    let inferred_join_predicates =
544        infer_join_predicates(&join, &predicates, &on_filters)?;
545
546    if on_filters.is_empty()
547        && predicates.is_empty()
548        && inferred_join_predicates.is_empty()
549    {
550        return Ok(Transformed::no(LogicalPlan::Join(join)));
551    }
552
553    push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
554}
555
556/// Extracts any equi-join join predicates from the given filter expressions.
557///
558/// Parameters
559/// * `join` the join in question
560///
561/// * `predicates` the pushed down filter expression
562///
563/// * `on_filters` filters from the join ON clause that have not already been
564///   identified as join predicates
565fn infer_join_predicates(
566    join: &Join,
567    predicates: &[Expr],
568    on_filters: &[Expr],
569) -> Result<Vec<Expr>> {
570    // Only allow both side key is column.
571    let join_col_keys = join
572        .on
573        .iter()
574        .filter_map(|(l, r)| {
575            let left_col = l.try_as_col()?;
576            let right_col = r.try_as_col()?;
577            Some((left_col, right_col))
578        })
579        .collect::<Vec<_>>();
580
581    let join_type = join.join_type;
582
583    let mut inferred_predicates = InferredPredicates::new(join_type);
584
585    infer_join_predicates_from_predicates(
586        &join_col_keys,
587        predicates,
588        &mut inferred_predicates,
589    )?;
590
591    infer_join_predicates_from_on_filters(
592        &join_col_keys,
593        join_type,
594        on_filters,
595        &mut inferred_predicates,
596    )?;
597
598    Ok(inferred_predicates.predicates)
599}
600
601/// Inferred predicates collector.
602/// When the JoinType is not Inner, we need to detect whether the inferred predicate can strictly
603/// filter out NULL, otherwise ignore it. e.g.
604/// ```text
605/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL;
606/// ```
607/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate will be pushed down to
608/// the left side, resulting in the wrong result.
609struct InferredPredicates {
610    predicates: Vec<Expr>,
611    is_inner_join: bool,
612}
613
614impl InferredPredicates {
615    fn new(join_type: JoinType) -> Self {
616        Self {
617            predicates: vec![],
618            is_inner_join: matches!(join_type, JoinType::Inner),
619        }
620    }
621
622    fn try_build_predicate(
623        &mut self,
624        predicate: Expr,
625        replace_map: &HashMap<&Column, &Column>,
626    ) -> Result<()> {
627        if self.is_inner_join
628            || matches!(
629                is_restrict_null_predicate(
630                    predicate.clone(),
631                    replace_map.keys().cloned()
632                ),
633                Ok(true)
634            )
635        {
636            self.predicates.push(replace_col(predicate, replace_map)?);
637        }
638
639        Ok(())
640    }
641}
642
643/// Infer predicates from the pushed down predicates.
644///
645/// Parameters
646/// * `join_col_keys` column pairs from the join ON clause
647///
648/// * `predicates` the pushed down predicates
649///
650/// * `inferred_predicates` the inferred results
651fn infer_join_predicates_from_predicates(
652    join_col_keys: &[(&Column, &Column)],
653    predicates: &[Expr],
654    inferred_predicates: &mut InferredPredicates,
655) -> Result<()> {
656    infer_join_predicates_impl::<true, true>(
657        join_col_keys,
658        predicates,
659        inferred_predicates,
660    )
661}
662
663/// Infer predicates from the join filter.
664///
665/// Parameters
666/// * `join_col_keys` column pairs from the join ON clause
667///
668/// * `join_type` the JoinType of Join
669///
670/// * `on_filters` filters from the join ON clause that have not already been
671///   identified as join predicates
672///
673/// * `inferred_predicates` the inferred results
674fn infer_join_predicates_from_on_filters(
675    join_col_keys: &[(&Column, &Column)],
676    join_type: JoinType,
677    on_filters: &[Expr],
678    inferred_predicates: &mut InferredPredicates,
679) -> Result<()> {
680    match join_type {
681        JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
682        JoinType::Inner => infer_join_predicates_impl::<true, true>(
683            join_col_keys,
684            on_filters,
685            inferred_predicates,
686        ),
687        JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
688            infer_join_predicates_impl::<true, false>(
689                join_col_keys,
690                on_filters,
691                inferred_predicates,
692            )
693        }
694        JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
695            infer_join_predicates_impl::<false, true>(
696                join_col_keys,
697                on_filters,
698                inferred_predicates,
699            )
700        }
701    }
702}
703
704/// Infer predicates from the given predicates.
705///
706/// Parameters
707/// * `join_col_keys` column pairs from the join ON clause
708///
709/// * `input_predicates` the given predicates. It can be the pushed down predicates,
710///   or it can be the filters of the Join
711///
712/// * `inferred_predicates` the inferred results
713///
714/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate can
715///   be inferred from the left table related predicate
716///
717/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate can
718///   be inferred from the right table related predicate
719fn infer_join_predicates_impl<
720    const ENABLE_LEFT_TO_RIGHT: bool,
721    const ENABLE_RIGHT_TO_LEFT: bool,
722>(
723    join_col_keys: &[(&Column, &Column)],
724    input_predicates: &[Expr],
725    inferred_predicates: &mut InferredPredicates,
726) -> Result<()> {
727    for predicate in input_predicates {
728        let mut join_cols_to_replace = HashMap::new();
729
730        for &col in &predicate.column_refs() {
731            for (l, r) in join_col_keys.iter() {
732                if ENABLE_LEFT_TO_RIGHT && col == *l {
733                    join_cols_to_replace.insert(col, *r);
734                    break;
735                }
736                if ENABLE_RIGHT_TO_LEFT && col == *r {
737                    join_cols_to_replace.insert(col, *l);
738                    break;
739                }
740            }
741        }
742        if join_cols_to_replace.is_empty() {
743            continue;
744        }
745
746        inferred_predicates
747            .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
748    }
749    Ok(())
750}
751
752impl OptimizerRule for PushDownFilter {
753    fn name(&self) -> &str {
754        "push_down_filter"
755    }
756
757    fn apply_order(&self) -> Option<ApplyOrder> {
758        Some(ApplyOrder::TopDown)
759    }
760
761    fn supports_rewrite(&self) -> bool {
762        true
763    }
764
765    fn rewrite(
766        &self,
767        plan: LogicalPlan,
768        _config: &dyn OptimizerConfig,
769    ) -> Result<Transformed<LogicalPlan>> {
770        if let LogicalPlan::Join(join) = plan {
771            return push_down_join(join, None);
772        };
773
774        let plan_schema = Arc::clone(plan.schema());
775
776        let LogicalPlan::Filter(mut filter) = plan else {
777            return Ok(Transformed::no(plan));
778        };
779
780        let predicate = split_conjunction_owned(filter.predicate.clone());
781        let old_predicate_len = predicate.len();
782        let new_predicates = simplify_predicates(predicate)?;
783        if old_predicate_len != new_predicates.len() {
784            let Some(new_predicate) = conjunction(new_predicates) else {
785                // new_predicates is empty - remove the filter entirely
786                // Return the child plan without the filter
787                return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
788            };
789            filter.predicate = new_predicate;
790        }
791
792        match Arc::unwrap_or_clone(filter.input) {
793            LogicalPlan::Filter(child_filter) => {
794                let parents_predicates = split_conjunction_owned(filter.predicate);
795
796                // remove duplicated filters
797                let child_predicates = split_conjunction_owned(child_filter.predicate);
798                let new_predicates = parents_predicates
799                    .into_iter()
800                    .chain(child_predicates)
801                    // use IndexSet to remove dupes while preserving predicate order
802                    .collect::<IndexSet<_>>()
803                    .into_iter()
804                    .collect::<Vec<_>>();
805
806                let Some(new_predicate) = conjunction(new_predicates) else {
807                    return plan_err!("at least one expression exists");
808                };
809                let new_filter = LogicalPlan::Filter(Filter::try_new(
810                    new_predicate,
811                    child_filter.input,
812                )?);
813                #[allow(clippy::used_underscore_binding)]
814                self.rewrite(new_filter, _config)
815            }
816            LogicalPlan::Repartition(repartition) => {
817                let new_filter =
818                    Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
819                        .map(LogicalPlan::Filter)?;
820                insert_below(LogicalPlan::Repartition(repartition), new_filter)
821            }
822            LogicalPlan::Distinct(distinct) => {
823                let new_filter =
824                    Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
825                        .map(LogicalPlan::Filter)?;
826                insert_below(LogicalPlan::Distinct(distinct), new_filter)
827            }
828            LogicalPlan::Sort(sort) => {
829                let new_filter =
830                    Filter::try_new(filter.predicate, Arc::clone(&sort.input))
831                        .map(LogicalPlan::Filter)?;
832                insert_below(LogicalPlan::Sort(sort), new_filter)
833            }
834            LogicalPlan::SubqueryAlias(subquery_alias) => {
835                let mut replace_map = HashMap::new();
836                for (i, (qualifier, field)) in
837                    subquery_alias.input.schema().iter().enumerate()
838                {
839                    let (sub_qualifier, sub_field) =
840                        subquery_alias.schema.qualified_field(i);
841                    replace_map.insert(
842                        qualified_name(sub_qualifier, sub_field.name()),
843                        Expr::Column(Column::new(qualifier.cloned(), field.name())),
844                    );
845                }
846                let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
847
848                let new_filter = LogicalPlan::Filter(Filter::try_new(
849                    new_predicate,
850                    Arc::clone(&subquery_alias.input),
851                )?);
852                insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
853            }
854            LogicalPlan::Projection(projection) => {
855                let predicates = split_conjunction_owned(filter.predicate.clone());
856                let (new_projection, keep_predicate) =
857                    rewrite_projection(predicates, projection)?;
858                if new_projection.transformed {
859                    match keep_predicate {
860                        None => Ok(new_projection),
861                        Some(keep_predicate) => new_projection.map_data(|child_plan| {
862                            Filter::try_new(keep_predicate, Arc::new(child_plan))
863                                .map(LogicalPlan::Filter)
864                        }),
865                    }
866                } else {
867                    filter.input = Arc::new(new_projection.data);
868                    Ok(Transformed::no(LogicalPlan::Filter(filter)))
869                }
870            }
871            LogicalPlan::Unnest(mut unnest) => {
872                let predicates = split_conjunction_owned(filter.predicate.clone());
873                let mut non_unnest_predicates = vec![];
874                let mut unnest_predicates = vec![];
875                let mut unnest_struct_columns = vec![];
876
877                for idx in &unnest.struct_type_columns {
878                    let (sub_qualifier, field) =
879                        unnest.input.schema().qualified_field(*idx);
880                    let field_name = field.name().clone();
881
882                    if let DataType::Struct(children) = field.data_type() {
883                        for child in children {
884                            let child_name = child.name().clone();
885                            unnest_struct_columns.push(Column::new(
886                                sub_qualifier.cloned(),
887                                format!("{field_name}.{child_name}"),
888                            ));
889                        }
890                    }
891                }
892
893                for predicate in predicates {
894                    // collect all the Expr::Column in predicate recursively
895                    let mut accum: HashSet<Column> = HashSet::new();
896                    expr_to_columns(&predicate, &mut accum)?;
897
898                    let contains_list_columns =
899                        unnest.list_type_columns.iter().any(|(_, unnest_list)| {
900                            accum.contains(&unnest_list.output_column)
901                        });
902                    let contains_struct_columns =
903                        unnest_struct_columns.iter().any(|c| accum.contains(c));
904
905                    if contains_list_columns || contains_struct_columns {
906                        unnest_predicates.push(predicate);
907                    } else {
908                        non_unnest_predicates.push(predicate);
909                    }
910                }
911
912                // Unnest predicates should not be pushed down.
913                // If no non-unnest predicates exist, early return
914                if non_unnest_predicates.is_empty() {
915                    filter.input = Arc::new(LogicalPlan::Unnest(unnest));
916                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
917                }
918
919                // Push down non-unnest filter predicate
920                // Unnest
921                //   Unnest Input (Projection)
922                // -> rewritten to
923                // Unnest
924                //   Filter
925                //     Unnest Input (Projection)
926
927                let unnest_input = std::mem::take(&mut unnest.input);
928
929                let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
930                    conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty.
931                    unnest_input,
932                )?);
933
934                // Directly assign new filter plan as the new unnest's input.
935                // The new filter plan will go through another rewrite pass since the rule itself
936                // is applied recursively to all the child from top to down
937                let unnest_plan =
938                    insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
939
940                match conjunction(unnest_predicates) {
941                    None => Ok(unnest_plan),
942                    Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
943                        Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
944                    ))),
945                }
946            }
947            LogicalPlan::Union(ref union) => {
948                let mut inputs = Vec::with_capacity(union.inputs.len());
949                for input in &union.inputs {
950                    let mut replace_map = HashMap::new();
951                    for (i, (qualifier, field)) in input.schema().iter().enumerate() {
952                        let (union_qualifier, union_field) =
953                            union.schema.qualified_field(i);
954                        replace_map.insert(
955                            qualified_name(union_qualifier, union_field.name()),
956                            Expr::Column(Column::new(qualifier.cloned(), field.name())),
957                        );
958                    }
959
960                    let push_predicate =
961                        replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
962                    inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
963                        push_predicate,
964                        Arc::clone(input),
965                    )?)))
966                }
967                Ok(Transformed::yes(LogicalPlan::Union(Union {
968                    inputs,
969                    schema: Arc::clone(&plan_schema),
970                })))
971            }
972            LogicalPlan::Aggregate(agg) => {
973                // We can push down Predicate which in groupby_expr.
974                let group_expr_columns = agg
975                    .group_expr
976                    .iter()
977                    .map(|e| {
978                        let (relation, name) = e.qualified_name();
979                        Column::new(relation, name)
980                    })
981                    .collect::<HashSet<_>>();
982
983                let predicates = split_conjunction_owned(filter.predicate);
984
985                let mut keep_predicates = vec![];
986                let mut push_predicates = vec![];
987                for expr in predicates {
988                    let cols = expr.column_refs();
989                    if cols.iter().all(|c| group_expr_columns.contains(c)) {
990                        push_predicates.push(expr);
991                    } else {
992                        keep_predicates.push(expr);
993                    }
994                }
995
996                // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
997                // After push, we need to replace `a+b` with Column(a)+Column(b)
998                // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
999                let mut replace_map = HashMap::new();
1000                for expr in &agg.group_expr {
1001                    replace_map.insert(expr.schema_name().to_string(), expr.clone());
1002                }
1003                let replaced_push_predicates = push_predicates
1004                    .into_iter()
1005                    .map(|expr| replace_cols_by_name(expr, &replace_map))
1006                    .collect::<Result<Vec<_>>>()?;
1007
1008                let agg_input = Arc::clone(&agg.input);
1009                Transformed::yes(LogicalPlan::Aggregate(agg))
1010                    .transform_data(|new_plan| {
1011                        // If we have a filter to push, we push it down to the input of the aggregate
1012                        if let Some(predicate) = conjunction(replaced_push_predicates) {
1013                            let new_filter = make_filter(predicate, agg_input)?;
1014                            insert_below(new_plan, new_filter)
1015                        } else {
1016                            Ok(Transformed::no(new_plan))
1017                        }
1018                    })?
1019                    .map_data(|child_plan| {
1020                        // if there are any remaining predicates we can't push, add them
1021                        // back as a filter
1022                        if let Some(predicate) = conjunction(keep_predicates) {
1023                            make_filter(predicate, Arc::new(child_plan))
1024                        } else {
1025                            Ok(child_plan)
1026                        }
1027                    })
1028            }
1029            // Tries to push filters based on the partition key(s) of the window function(s) used.
1030            // Example:
1031            //   Before:
1032            //     Filter: (a > 1) and (b > 1) and (c > 1)
1033            //      Window: func() PARTITION BY [a] ...
1034            //   ---
1035            //   After:
1036            //     Filter: (b > 1) and (c > 1)
1037            //      Window: func() PARTITION BY [a] ...
1038            //        Filter: (a > 1)
1039            LogicalPlan::Window(window) => {
1040                // Retrieve the set of potential partition keys where we can push filters by.
1041                // Unlike aggregations, where there is only one statement per SELECT, there can be
1042                // multiple window functions, each with potentially different partition keys.
1043                // Therefore, we need to ensure that any potential partition key returned is used in
1044                // ALL window functions. Otherwise, filters cannot be pushed by through that column.
1045                let extract_partition_keys = |func: &WindowFunction| {
1046                    func.params
1047                        .partition_by
1048                        .iter()
1049                        .map(|c| {
1050                            let (relation, name) = c.qualified_name();
1051                            Column::new(relation, name)
1052                        })
1053                        .collect::<HashSet<_>>()
1054                };
1055                let potential_partition_keys = window
1056                    .window_expr
1057                    .iter()
1058                    .map(|e| {
1059                        match e {
1060                            Expr::WindowFunction(window_func) => {
1061                                extract_partition_keys(window_func)
1062                            }
1063                            Expr::Alias(alias) => {
1064                                if let Expr::WindowFunction(window_func) =
1065                                    alias.expr.as_ref()
1066                                {
1067                                    extract_partition_keys(window_func)
1068                                } else {
1069                                    // window functions expressions are only Expr::WindowFunction
1070                                    unreachable!()
1071                                }
1072                            }
1073                            _ => {
1074                                // window functions expressions are only Expr::WindowFunction
1075                                unreachable!()
1076                            }
1077                        }
1078                    })
1079                    // performs the set intersection of the partition keys of all window functions,
1080                    // returning only the common ones
1081                    .reduce(|a, b| &a & &b)
1082                    .unwrap_or_default();
1083
1084                let predicates = split_conjunction_owned(filter.predicate);
1085                let mut keep_predicates = vec![];
1086                let mut push_predicates = vec![];
1087                for expr in predicates {
1088                    let cols = expr.column_refs();
1089                    if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1090                        push_predicates.push(expr);
1091                    } else {
1092                        keep_predicates.push(expr);
1093                    }
1094                }
1095
1096                // Unlike with aggregations, there are no cases where we have to replace, e.g.,
1097                // `a+b` with Column(a)+Column(b). This is because partition expressions are not
1098                // available as standalone columns to the user. For example, while an aggregation on
1099                // `a+b` becomes Column(a + b), in a window partition it becomes
1100                // `func() PARTITION BY [a + b] ...`. Thus, filters on expressions always remain in
1101                // place, so we can use `push_predicates` directly. This is consistent with other
1102                // optimizers, such as the one used by Postgres.
1103
1104                let window_input = Arc::clone(&window.input);
1105                Transformed::yes(LogicalPlan::Window(window))
1106                    .transform_data(|new_plan| {
1107                        // If we have a filter to push, we push it down to the input of the window
1108                        if let Some(predicate) = conjunction(push_predicates) {
1109                            let new_filter = make_filter(predicate, window_input)?;
1110                            insert_below(new_plan, new_filter)
1111                        } else {
1112                            Ok(Transformed::no(new_plan))
1113                        }
1114                    })?
1115                    .map_data(|child_plan| {
1116                        // if there are any remaining predicates we can't push, add them
1117                        // back as a filter
1118                        if let Some(predicate) = conjunction(keep_predicates) {
1119                            make_filter(predicate, Arc::new(child_plan))
1120                        } else {
1121                            Ok(child_plan)
1122                        }
1123                    })
1124            }
1125            LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1126            LogicalPlan::TableScan(scan) => {
1127                let filter_predicates = split_conjunction(&filter.predicate);
1128
1129                let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1130                    filter_predicates
1131                        .into_iter()
1132                        .partition(|pred| pred.is_volatile());
1133
1134                // Check which non-volatile filters are supported by source
1135                let supported_filters = scan
1136                    .source
1137                    .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1138                if non_volatile_filters.len() != supported_filters.len() {
1139                    return internal_err!(
1140                        "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1141                        supported_filters.len(),
1142                        non_volatile_filters.len());
1143                }
1144
1145                // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type
1146                let zip = non_volatile_filters.into_iter().zip(supported_filters);
1147
1148                let new_scan_filters = zip
1149                    .clone()
1150                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1151                    .map(|(pred, _)| pred);
1152
1153                // Add new scan filters
1154                let new_scan_filters: Vec<Expr> = scan
1155                    .filters
1156                    .iter()
1157                    .chain(new_scan_filters)
1158                    .unique()
1159                    .cloned()
1160                    .collect();
1161
1162                // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
1163                let new_predicate: Vec<Expr> = zip
1164                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1165                    .map(|(pred, _)| pred)
1166                    .chain(volatile_filters)
1167                    .cloned()
1168                    .collect();
1169
1170                let new_scan = LogicalPlan::TableScan(TableScan {
1171                    filters: new_scan_filters,
1172                    ..scan
1173                });
1174
1175                Transformed::yes(new_scan).transform_data(|new_scan| {
1176                    if let Some(predicate) = conjunction(new_predicate) {
1177                        make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1178                    } else {
1179                        Ok(Transformed::no(new_scan))
1180                    }
1181                })
1182            }
1183            LogicalPlan::Extension(extension_plan) => {
1184                // This check prevents the Filter from being removed when the extension node has no children,
1185                // so we return the original Filter unchanged.
1186                if extension_plan.node.inputs().is_empty() {
1187                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1188                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1189                }
1190                let prevent_cols =
1191                    extension_plan.node.prevent_predicate_push_down_columns();
1192
1193                // determine if we can push any predicates down past the extension node
1194
1195                // each element is true for push, false to keep
1196                let predicate_push_or_keep = split_conjunction(&filter.predicate)
1197                    .iter()
1198                    .map(|expr| {
1199                        let cols = expr.column_refs();
1200                        if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1201                            Ok(false) // No push (keep)
1202                        } else {
1203                            Ok(true) // push
1204                        }
1205                    })
1206                    .collect::<Result<Vec<_>>>()?;
1207
1208                // all predicates are kept, no changes needed
1209                if predicate_push_or_keep.iter().all(|&x| !x) {
1210                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1211                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1212                }
1213
1214                // going to push some predicates down, so split the predicates
1215                let mut keep_predicates = vec![];
1216                let mut push_predicates = vec![];
1217                for (push, expr) in predicate_push_or_keep
1218                    .into_iter()
1219                    .zip(split_conjunction_owned(filter.predicate).into_iter())
1220                {
1221                    if !push {
1222                        keep_predicates.push(expr);
1223                    } else {
1224                        push_predicates.push(expr);
1225                    }
1226                }
1227
1228                let new_children = match conjunction(push_predicates) {
1229                    Some(predicate) => extension_plan
1230                        .node
1231                        .inputs()
1232                        .into_iter()
1233                        .map(|child| {
1234                            Ok(LogicalPlan::Filter(Filter::try_new(
1235                                predicate.clone(),
1236                                Arc::new(child.clone()),
1237                            )?))
1238                        })
1239                        .collect::<Result<Vec<_>>>()?,
1240                    None => extension_plan.node.inputs().into_iter().cloned().collect(),
1241                };
1242                // extension with new inputs.
1243                let child_plan = LogicalPlan::Extension(extension_plan);
1244                let new_extension =
1245                    child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1246
1247                let new_plan = match conjunction(keep_predicates) {
1248                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1249                        predicate,
1250                        Arc::new(new_extension),
1251                    )?),
1252                    None => new_extension,
1253                };
1254                Ok(Transformed::yes(new_plan))
1255            }
1256            child => {
1257                filter.input = Arc::new(child);
1258                Ok(Transformed::no(LogicalPlan::Filter(filter)))
1259            }
1260        }
1261    }
1262}
1263
1264/// Attempts to push `predicate` into a `FilterExec` below `projection
1265///
1266/// # Returns
1267/// (plan, remaining_predicate)
1268///
1269/// `plan` is a LogicalPlan for `projection` with possibly a new FilterExec below it.
1270/// `remaining_predicate` is any part of the predicate that could not be pushed down
1271///
1272/// # Args
1273/// - predicates: Split predicates like `[foo=5, bar=6]`
1274/// - projection: The target projection plan to push down the predicates
1275///
1276/// # Example
1277///
1278/// Pushing a predicate like `foo=5 AND bar=6` with an input plan like this:
1279///
1280/// ```text
1281/// Projection(foo, c+d as bar)
1282/// ```
1283///
1284/// Might result in returning `remaining_predicate` of `bar=6` and a plan like
1285///
1286/// ```text
1287/// Projection(foo, c+d as bar)
1288///  Filter(foo=5)
1289///   ...
1290/// ```
1291fn rewrite_projection(
1292    predicates: Vec<Expr>,
1293    mut projection: Projection,
1294) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1295    // A projection is filter-commutable if it do not contain volatile predicates or contain volatile
1296    // predicates that are not used in the filter. However, we should re-writes all predicate expressions.
1297    // collect projection.
1298    let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
1299        .schema
1300        .iter()
1301        .zip(projection.expr.iter())
1302        .map(|((qualifier, field), expr)| {
1303            // strip alias, as they should not be part of filters
1304            let expr = expr.clone().unalias();
1305
1306            (qualified_name(qualifier, field.name()), expr)
1307        })
1308        .partition(|(_, value)| value.is_volatile());
1309
1310    let mut push_predicates = vec![];
1311    let mut keep_predicates = vec![];
1312    for expr in predicates {
1313        if contain(&expr, &volatile_map) {
1314            keep_predicates.push(expr);
1315        } else {
1316            push_predicates.push(expr);
1317        }
1318    }
1319
1320    match conjunction(push_predicates) {
1321        Some(expr) => {
1322            // re-write all filters based on this projection
1323            // E.g. in `Filter: b\n  Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
1324            let new_filter = LogicalPlan::Filter(Filter::try_new(
1325                replace_cols_by_name(expr, &non_volatile_map)?,
1326                std::mem::take(&mut projection.input),
1327            )?);
1328
1329            projection.input = Arc::new(new_filter);
1330
1331            Ok((
1332                Transformed::yes(LogicalPlan::Projection(projection)),
1333                conjunction(keep_predicates),
1334            ))
1335        }
1336        None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
1337    }
1338}
1339
1340/// Creates a new LogicalPlan::Filter node.
1341pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1342    Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1343}
1344
1345/// Replace the existing child of the single input node with `new_child`.
1346///
1347/// Starting:
1348/// ```text
1349/// plan
1350///   child
1351/// ```
1352///
1353/// Ending:
1354/// ```text
1355/// plan
1356///   new_child
1357/// ```
1358fn insert_below(
1359    plan: LogicalPlan,
1360    new_child: LogicalPlan,
1361) -> Result<Transformed<LogicalPlan>> {
1362    let mut new_child = Some(new_child);
1363    let transformed_plan = plan.map_children(|_child| {
1364        if let Some(new_child) = new_child.take() {
1365            Ok(Transformed::yes(new_child))
1366        } else {
1367            // already took the new child
1368            internal_err!("node had more than one input")
1369        }
1370    })?;
1371
1372    // make sure we did the actual replacement
1373    if new_child.is_some() {
1374        return internal_err!("node had no  inputs");
1375    }
1376
1377    Ok(transformed_plan)
1378}
1379
1380impl PushDownFilter {
1381    #[allow(missing_docs)]
1382    pub fn new() -> Self {
1383        Self {}
1384    }
1385}
1386
1387/// replaces columns by its name on the projection.
1388pub fn replace_cols_by_name(
1389    e: Expr,
1390    replace_map: &HashMap<String, Expr>,
1391) -> Result<Expr> {
1392    e.transform_up(|expr| {
1393        Ok(if let Expr::Column(c) = &expr {
1394            match replace_map.get(&c.flat_name()) {
1395                Some(new_c) => Transformed::yes(new_c.clone()),
1396                None => Transformed::no(expr),
1397            }
1398        } else {
1399            Transformed::no(expr)
1400        })
1401    })
1402    .data()
1403}
1404
1405/// check whether the expression uses the columns in `check_map`.
1406fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1407    let mut is_contain = false;
1408    e.apply(|expr| {
1409        Ok(if let Expr::Column(c) = &expr {
1410            match check_map.get(&c.flat_name()) {
1411                Some(_) => {
1412                    is_contain = true;
1413                    TreeNodeRecursion::Stop
1414                }
1415                None => TreeNodeRecursion::Continue,
1416            }
1417        } else {
1418            TreeNodeRecursion::Continue
1419        })
1420    })
1421    .unwrap();
1422    is_contain
1423}
1424
1425#[cfg(test)]
1426mod tests {
1427    use std::any::Any;
1428    use std::cmp::Ordering;
1429    use std::fmt::{Debug, Formatter};
1430
1431    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1432    use async_trait::async_trait;
1433
1434    use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1435    use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1436    use datafusion_expr::logical_plan::table_scan;
1437    use datafusion_expr::{
1438        col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension,
1439        LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1440        TableSource, TableType, UserDefinedLogicalNodeCore, Volatility,
1441        WindowFunctionDefinition,
1442    };
1443
1444    use crate::assert_optimized_plan_eq_snapshot;
1445    use crate::optimizer::Optimizer;
1446    use crate::simplify_expressions::SimplifyExpressions;
1447    use crate::test::*;
1448    use crate::OptimizerContext;
1449    use datafusion_expr::test::function_stub::sum;
1450    use insta::assert_snapshot;
1451
1452    use super::*;
1453
1454    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1455
1456    macro_rules! assert_optimized_plan_equal {
1457        (
1458            $plan:expr,
1459            @ $expected:literal $(,)?
1460        ) => {{
1461            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1462            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1463            assert_optimized_plan_eq_snapshot!(
1464                optimizer_ctx,
1465                rules,
1466                $plan,
1467                @ $expected,
1468            )
1469        }};
1470    }
1471
1472    macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1473        (
1474            $plan:expr,
1475            @ $expected:literal $(,)?
1476        ) => {{
1477            let optimizer = Optimizer::with_rules(vec![
1478                Arc::new(SimplifyExpressions::new()),
1479                Arc::new(PushDownFilter::new()),
1480            ]);
1481            let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1482            assert_snapshot!(optimized_plan, @ $expected);
1483            Ok::<(), DataFusionError>(())
1484        }};
1485    }
1486
1487    #[test]
1488    fn filter_before_projection() -> Result<()> {
1489        let table_scan = test_table_scan()?;
1490        let plan = LogicalPlanBuilder::from(table_scan)
1491            .project(vec![col("a"), col("b")])?
1492            .filter(col("a").eq(lit(1i64)))?
1493            .build()?;
1494        // filter is before projection
1495        assert_optimized_plan_equal!(
1496            plan,
1497            @r"
1498        Projection: test.a, test.b
1499          TableScan: test, full_filters=[test.a = Int64(1)]
1500        "
1501        )
1502    }
1503
1504    #[test]
1505    fn filter_after_limit() -> Result<()> {
1506        let table_scan = test_table_scan()?;
1507        let plan = LogicalPlanBuilder::from(table_scan)
1508            .project(vec![col("a"), col("b")])?
1509            .limit(0, Some(10))?
1510            .filter(col("a").eq(lit(1i64)))?
1511            .build()?;
1512        // filter is before single projection
1513        assert_optimized_plan_equal!(
1514            plan,
1515            @r"
1516        Filter: test.a = Int64(1)
1517          Limit: skip=0, fetch=10
1518            Projection: test.a, test.b
1519              TableScan: test
1520        "
1521        )
1522    }
1523
1524    #[test]
1525    fn filter_no_columns() -> Result<()> {
1526        let table_scan = test_table_scan()?;
1527        let plan = LogicalPlanBuilder::from(table_scan)
1528            .filter(lit(0i64).eq(lit(1i64)))?
1529            .build()?;
1530        assert_optimized_plan_equal!(
1531            plan,
1532            @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1533        )
1534    }
1535
1536    #[test]
1537    fn filter_jump_2_plans() -> Result<()> {
1538        let table_scan = test_table_scan()?;
1539        let plan = LogicalPlanBuilder::from(table_scan)
1540            .project(vec![col("a"), col("b"), col("c")])?
1541            .project(vec![col("c"), col("b")])?
1542            .filter(col("a").eq(lit(1i64)))?
1543            .build()?;
1544        // filter is before double projection
1545        assert_optimized_plan_equal!(
1546            plan,
1547            @r"
1548        Projection: test.c, test.b
1549          Projection: test.a, test.b, test.c
1550            TableScan: test, full_filters=[test.a = Int64(1)]
1551        "
1552        )
1553    }
1554
1555    #[test]
1556    fn filter_move_agg() -> Result<()> {
1557        let table_scan = test_table_scan()?;
1558        let plan = LogicalPlanBuilder::from(table_scan)
1559            .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1560            .filter(col("a").gt(lit(10i64)))?
1561            .build()?;
1562        // filter of key aggregation is commutative
1563        assert_optimized_plan_equal!(
1564            plan,
1565            @r"
1566        Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1567          TableScan: test, full_filters=[test.a > Int64(10)]
1568        "
1569        )
1570    }
1571
1572    /// verifies that filters with unusual column names are pushed down through aggregate operators
1573    #[test]
1574    fn filter_move_agg_special() -> Result<()> {
1575        let schema = Schema::new(vec![
1576            Field::new("$a", DataType::UInt32, false),
1577            Field::new("$b", DataType::UInt32, false),
1578            Field::new("$c", DataType::UInt32, false),
1579        ]);
1580        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1581
1582        let plan = LogicalPlanBuilder::from(table_scan)
1583            .aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])?
1584            .filter(col("$a").gt(lit(10i64)))?
1585            .build()?;
1586        // filter of key aggregation is commutative
1587        assert_optimized_plan_equal!(
1588            plan,
1589            @r"
1590        Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]]
1591          TableScan: test, full_filters=[test.$a > Int64(10)]
1592        "
1593        )
1594    }
1595
1596    #[test]
1597    fn filter_complex_group_by() -> Result<()> {
1598        let table_scan = test_table_scan()?;
1599        let plan = LogicalPlanBuilder::from(table_scan)
1600            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1601            .filter(col("b").gt(lit(10i64)))?
1602            .build()?;
1603        assert_optimized_plan_equal!(
1604            plan,
1605            @r"
1606        Filter: test.b > Int64(10)
1607          Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1608            TableScan: test
1609        "
1610        )
1611    }
1612
1613    #[test]
1614    fn push_agg_need_replace_expr() -> Result<()> {
1615        let plan = LogicalPlanBuilder::from(test_table_scan()?)
1616            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1617            .filter(col("test.b + test.a").gt(lit(10i64)))?
1618            .build()?;
1619        assert_optimized_plan_equal!(
1620            plan,
1621            @r"
1622        Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1623          TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1624        "
1625        )
1626    }
1627
1628    #[test]
1629    fn filter_keep_agg() -> Result<()> {
1630        let table_scan = test_table_scan()?;
1631        let plan = LogicalPlanBuilder::from(table_scan)
1632            .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1633            .filter(col("b").gt(lit(10i64)))?
1634            .build()?;
1635        // filter of aggregate is after aggregation since they are non-commutative
1636        assert_optimized_plan_equal!(
1637            plan,
1638            @r"
1639        Filter: b > Int64(10)
1640          Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1641            TableScan: test
1642        "
1643        )
1644    }
1645
1646    /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed
1647    #[test]
1648    fn filter_move_window() -> Result<()> {
1649        let table_scan = test_table_scan()?;
1650
1651        let window = Expr::from(WindowFunction::new(
1652            WindowFunctionDefinition::WindowUDF(
1653                datafusion_functions_window::rank::rank_udwf(),
1654            ),
1655            vec![],
1656        ))
1657        .partition_by(vec![col("a"), col("b")])
1658        .order_by(vec![col("c").sort(true, true)])
1659        .build()
1660        .unwrap();
1661
1662        let plan = LogicalPlanBuilder::from(table_scan)
1663            .window(vec![window])?
1664            .filter(col("b").gt(lit(10i64)))?
1665            .build()?;
1666
1667        assert_optimized_plan_equal!(
1668            plan,
1669            @r"
1670        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1671          TableScan: test, full_filters=[test.b > Int64(10)]
1672        "
1673        )
1674    }
1675
1676    /// verifies that filters with unusual identifier names are pushed down through window functions
1677    #[test]
1678    fn filter_window_special_identifier() -> Result<()> {
1679        let schema = Schema::new(vec![
1680            Field::new("$a", DataType::UInt32, false),
1681            Field::new("$b", DataType::UInt32, false),
1682            Field::new("$c", DataType::UInt32, false),
1683        ]);
1684        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1685
1686        let window = Expr::from(WindowFunction::new(
1687            WindowFunctionDefinition::WindowUDF(
1688                datafusion_functions_window::rank::rank_udwf(),
1689            ),
1690            vec![],
1691        ))
1692        .partition_by(vec![col("$a"), col("$b")])
1693        .order_by(vec![col("$c").sort(true, true)])
1694        .build()
1695        .unwrap();
1696
1697        let plan = LogicalPlanBuilder::from(table_scan)
1698            .window(vec![window])?
1699            .filter(col("$b").gt(lit(10i64)))?
1700            .build()?;
1701
1702        assert_optimized_plan_equal!(
1703            plan,
1704            @r"
1705        WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1706          TableScan: test, full_filters=[test.$b > Int64(10)]
1707        "
1708        )
1709    }
1710
1711    /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and
1712    /// 'b' are pushed
1713    #[test]
1714    fn filter_move_complex_window() -> Result<()> {
1715        let table_scan = test_table_scan()?;
1716
1717        let window = Expr::from(WindowFunction::new(
1718            WindowFunctionDefinition::WindowUDF(
1719                datafusion_functions_window::rank::rank_udwf(),
1720            ),
1721            vec![],
1722        ))
1723        .partition_by(vec![col("a"), col("b")])
1724        .order_by(vec![col("c").sort(true, true)])
1725        .build()
1726        .unwrap();
1727
1728        let plan = LogicalPlanBuilder::from(table_scan)
1729            .window(vec![window])?
1730            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1731            .build()?;
1732
1733        assert_optimized_plan_equal!(
1734            plan,
1735            @r"
1736        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1737          TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1738        "
1739        )
1740    }
1741
1742    /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed
1743    #[test]
1744    fn filter_move_partial_window() -> Result<()> {
1745        let table_scan = test_table_scan()?;
1746
1747        let window = Expr::from(WindowFunction::new(
1748            WindowFunctionDefinition::WindowUDF(
1749                datafusion_functions_window::rank::rank_udwf(),
1750            ),
1751            vec![],
1752        ))
1753        .partition_by(vec![col("a")])
1754        .order_by(vec![col("c").sort(true, true)])
1755        .build()
1756        .unwrap();
1757
1758        let plan = LogicalPlanBuilder::from(table_scan)
1759            .window(vec![window])?
1760            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1761            .build()?;
1762
1763        assert_optimized_plan_equal!(
1764            plan,
1765            @r"
1766        Filter: test.b = Int64(1)
1767          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1768            TableScan: test, full_filters=[test.a > Int64(10)]
1769        "
1770        )
1771    }
1772
1773    /// verifies that filters on partition expressions are not pushed, as the single expression
1774    /// column is not available to the user, unlike with aggregations
1775    #[test]
1776    fn filter_expression_keep_window() -> Result<()> {
1777        let table_scan = test_table_scan()?;
1778
1779        let window = Expr::from(WindowFunction::new(
1780            WindowFunctionDefinition::WindowUDF(
1781                datafusion_functions_window::rank::rank_udwf(),
1782            ),
1783            vec![],
1784        ))
1785        .partition_by(vec![add(col("a"), col("b"))]) // PARTITION BY a + b
1786        .order_by(vec![col("c").sort(true, true)])
1787        .build()
1788        .unwrap();
1789
1790        let plan = LogicalPlanBuilder::from(table_scan)
1791            .window(vec![window])?
1792            // unlike with aggregations, single partition column "test.a + test.b" is not available
1793            // to the plan, so we use multiple columns when filtering
1794            .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1795            .build()?;
1796
1797        assert_optimized_plan_equal!(
1798            plan,
1799            @r"
1800        Filter: test.a + test.b > Int64(10)
1801          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1802            TableScan: test
1803        "
1804        )
1805    }
1806
1807    /// verifies that filters are not pushed on order by columns (that are not used in partitioning)
1808    #[test]
1809    fn filter_order_keep_window() -> Result<()> {
1810        let table_scan = test_table_scan()?;
1811
1812        let window = Expr::from(WindowFunction::new(
1813            WindowFunctionDefinition::WindowUDF(
1814                datafusion_functions_window::rank::rank_udwf(),
1815            ),
1816            vec![],
1817        ))
1818        .partition_by(vec![col("a")])
1819        .order_by(vec![col("c").sort(true, true)])
1820        .build()
1821        .unwrap();
1822
1823        let plan = LogicalPlanBuilder::from(table_scan)
1824            .window(vec![window])?
1825            .filter(col("c").gt(lit(10i64)))?
1826            .build()?;
1827
1828        assert_optimized_plan_equal!(
1829            plan,
1830            @r"
1831        Filter: test.c > Int64(10)
1832          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1833            TableScan: test
1834        "
1835        )
1836    }
1837
1838    /// verifies that when we use multiple window functions with a common partition key, the filter
1839    /// on that key is pushed
1840    #[test]
1841    fn filter_multiple_windows_common_partitions() -> Result<()> {
1842        let table_scan = test_table_scan()?;
1843
1844        let window1 = Expr::from(WindowFunction::new(
1845            WindowFunctionDefinition::WindowUDF(
1846                datafusion_functions_window::rank::rank_udwf(),
1847            ),
1848            vec![],
1849        ))
1850        .partition_by(vec![col("a")])
1851        .order_by(vec![col("c").sort(true, true)])
1852        .build()
1853        .unwrap();
1854
1855        let window2 = Expr::from(WindowFunction::new(
1856            WindowFunctionDefinition::WindowUDF(
1857                datafusion_functions_window::rank::rank_udwf(),
1858            ),
1859            vec![],
1860        ))
1861        .partition_by(vec![col("b"), col("a")])
1862        .order_by(vec![col("c").sort(true, true)])
1863        .build()
1864        .unwrap();
1865
1866        let plan = LogicalPlanBuilder::from(table_scan)
1867            .window(vec![window1, window2])?
1868            .filter(col("a").gt(lit(10i64)))? // a appears in both window functions
1869            .build()?;
1870
1871        assert_optimized_plan_equal!(
1872            plan,
1873            @r"
1874        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1875          TableScan: test, full_filters=[test.a > Int64(10)]
1876        "
1877        )
1878    }
1879
1880    /// verifies that when we use multiple window functions with different partitions keys, the
1881    /// filter cannot be pushed
1882    #[test]
1883    fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1884        let table_scan = test_table_scan()?;
1885
1886        let window1 = Expr::from(WindowFunction::new(
1887            WindowFunctionDefinition::WindowUDF(
1888                datafusion_functions_window::rank::rank_udwf(),
1889            ),
1890            vec![],
1891        ))
1892        .partition_by(vec![col("a")])
1893        .order_by(vec![col("c").sort(true, true)])
1894        .build()
1895        .unwrap();
1896
1897        let window2 = Expr::from(WindowFunction::new(
1898            WindowFunctionDefinition::WindowUDF(
1899                datafusion_functions_window::rank::rank_udwf(),
1900            ),
1901            vec![],
1902        ))
1903        .partition_by(vec![col("b"), col("a")])
1904        .order_by(vec![col("c").sort(true, true)])
1905        .build()
1906        .unwrap();
1907
1908        let plan = LogicalPlanBuilder::from(table_scan)
1909            .window(vec![window1, window2])?
1910            .filter(col("b").gt(lit(10i64)))? // b only appears in one window function
1911            .build()?;
1912
1913        assert_optimized_plan_equal!(
1914            plan,
1915            @r"
1916        Filter: test.b > Int64(10)
1917          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1918            TableScan: test
1919        "
1920        )
1921    }
1922
1923    /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
1924    #[test]
1925    fn alias() -> Result<()> {
1926        let table_scan = test_table_scan()?;
1927        let plan = LogicalPlanBuilder::from(table_scan)
1928            .project(vec![col("a").alias("b"), col("c")])?
1929            .filter(col("b").eq(lit(1i64)))?
1930            .build()?;
1931        // filter is before projection
1932        assert_optimized_plan_equal!(
1933            plan,
1934            @r"
1935        Projection: test.a AS b, test.c
1936          TableScan: test, full_filters=[test.a = Int64(1)]
1937        "
1938        )
1939    }
1940
1941    fn add(left: Expr, right: Expr) -> Expr {
1942        Expr::BinaryExpr(BinaryExpr::new(
1943            Box::new(left),
1944            Operator::Plus,
1945            Box::new(right),
1946        ))
1947    }
1948
1949    fn multiply(left: Expr, right: Expr) -> Expr {
1950        Expr::BinaryExpr(BinaryExpr::new(
1951            Box::new(left),
1952            Operator::Multiply,
1953            Box::new(right),
1954        ))
1955    }
1956
1957    /// verifies that a filter is pushed to before a projection with a complex expression, the filter expression is correctly re-written
1958    #[test]
1959    fn complex_expression() -> Result<()> {
1960        let table_scan = test_table_scan()?;
1961        let plan = LogicalPlanBuilder::from(table_scan)
1962            .project(vec![
1963                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1964                col("c"),
1965            ])?
1966            .filter(col("b").eq(lit(1i64)))?
1967            .build()?;
1968
1969        // not part of the test, just good to know:
1970        assert_snapshot!(plan,
1971        @r"
1972        Filter: b = Int64(1)
1973          Projection: test.a * Int32(2) + test.c AS b, test.c
1974            TableScan: test
1975        ",
1976        );
1977        // filter is before projection
1978        assert_optimized_plan_equal!(
1979            plan,
1980            @r"
1981        Projection: test.a * Int32(2) + test.c AS b, test.c
1982          TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1983        "
1984        )
1985    }
1986
1987    /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written
1988    #[test]
1989    fn complex_plan() -> Result<()> {
1990        let table_scan = test_table_scan()?;
1991        let plan = LogicalPlanBuilder::from(table_scan)
1992            .project(vec![
1993                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1994                col("c"),
1995            ])?
1996            // second projection where we rename columns, just to make it difficult
1997            .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
1998            .filter(col("a").eq(lit(1i64)))?
1999            .build()?;
2000
2001        // not part of the test, just good to know:
2002        assert_snapshot!(plan,
2003        @r"
2004        Filter: a = Int64(1)
2005          Projection: b * Int32(3) AS a, test.c
2006            Projection: test.a * Int32(2) + test.c AS b, test.c
2007              TableScan: test
2008        ",
2009        );
2010        // filter is before the projections
2011        assert_optimized_plan_equal!(
2012            plan,
2013            @r"
2014        Projection: b * Int32(3) AS a, test.c
2015          Projection: test.a * Int32(2) + test.c AS b, test.c
2016            TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
2017        "
2018        )
2019    }
2020
2021    #[derive(Debug, PartialEq, Eq, Hash)]
2022    struct NoopPlan {
2023        input: Vec<LogicalPlan>,
2024        schema: DFSchemaRef,
2025    }
2026
2027    // Manual implementation needed because of `schema` field. Comparison excludes this field.
2028    impl PartialOrd for NoopPlan {
2029        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2030            self.input
2031                .partial_cmp(&other.input)
2032                // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
2033                .filter(|cmp| *cmp != Ordering::Equal || self == other)
2034        }
2035    }
2036
2037    impl UserDefinedLogicalNodeCore for NoopPlan {
2038        fn name(&self) -> &str {
2039            "NoopPlan"
2040        }
2041
2042        fn inputs(&self) -> Vec<&LogicalPlan> {
2043            self.input.iter().collect()
2044        }
2045
2046        fn schema(&self) -> &DFSchemaRef {
2047            &self.schema
2048        }
2049
2050        fn expressions(&self) -> Vec<Expr> {
2051            self.input
2052                .iter()
2053                .flat_map(|child| child.expressions())
2054                .collect()
2055        }
2056
2057        fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
2058            HashSet::from_iter(vec!["c".to_string()])
2059        }
2060
2061        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
2062            write!(f, "NoopPlan")
2063        }
2064
2065        fn with_exprs_and_inputs(
2066            &self,
2067            _exprs: Vec<Expr>,
2068            inputs: Vec<LogicalPlan>,
2069        ) -> Result<Self> {
2070            Ok(Self {
2071                input: inputs,
2072                schema: Arc::clone(&self.schema),
2073            })
2074        }
2075
2076        fn supports_limit_pushdown(&self) -> bool {
2077            false // Disallow limit push-down by default
2078        }
2079    }
2080
2081    #[test]
2082    fn user_defined_plan() -> Result<()> {
2083        let table_scan = test_table_scan()?;
2084
2085        let custom_plan = LogicalPlan::Extension(Extension {
2086            node: Arc::new(NoopPlan {
2087                input: vec![table_scan.clone()],
2088                schema: Arc::clone(table_scan.schema()),
2089            }),
2090        });
2091        let plan = LogicalPlanBuilder::from(custom_plan)
2092            .filter(col("a").eq(lit(1i64)))?
2093            .build()?;
2094
2095        // Push filter below NoopPlan
2096        assert_optimized_plan_equal!(
2097            plan,
2098            @r"
2099        NoopPlan
2100          TableScan: test, full_filters=[test.a = Int64(1)]
2101        "
2102        )?;
2103
2104        let custom_plan = LogicalPlan::Extension(Extension {
2105            node: Arc::new(NoopPlan {
2106                input: vec![table_scan.clone()],
2107                schema: Arc::clone(table_scan.schema()),
2108            }),
2109        });
2110        let plan = LogicalPlanBuilder::from(custom_plan)
2111            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2112            .build()?;
2113
2114        // Push only predicate on `a` below NoopPlan
2115        assert_optimized_plan_equal!(
2116            plan,
2117            @r"
2118        Filter: test.c = Int64(2)
2119          NoopPlan
2120            TableScan: test, full_filters=[test.a = Int64(1)]
2121        "
2122        )?;
2123
2124        let custom_plan = LogicalPlan::Extension(Extension {
2125            node: Arc::new(NoopPlan {
2126                input: vec![table_scan.clone(), table_scan.clone()],
2127                schema: Arc::clone(table_scan.schema()),
2128            }),
2129        });
2130        let plan = LogicalPlanBuilder::from(custom_plan)
2131            .filter(col("a").eq(lit(1i64)))?
2132            .build()?;
2133
2134        // Push filter below NoopPlan for each child branch
2135        assert_optimized_plan_equal!(
2136            plan,
2137            @r"
2138        NoopPlan
2139          TableScan: test, full_filters=[test.a = Int64(1)]
2140          TableScan: test, full_filters=[test.a = Int64(1)]
2141        "
2142        )?;
2143
2144        let custom_plan = LogicalPlan::Extension(Extension {
2145            node: Arc::new(NoopPlan {
2146                input: vec![table_scan.clone(), table_scan.clone()],
2147                schema: Arc::clone(table_scan.schema()),
2148            }),
2149        });
2150        let plan = LogicalPlanBuilder::from(custom_plan)
2151            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2152            .build()?;
2153
2154        // Push only predicate on `a` below NoopPlan
2155        assert_optimized_plan_equal!(
2156            plan,
2157            @r"
2158        Filter: test.c = Int64(2)
2159          NoopPlan
2160            TableScan: test, full_filters=[test.a = Int64(1)]
2161            TableScan: test, full_filters=[test.a = Int64(1)]
2162        "
2163        )
2164    }
2165
2166    /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed
2167    /// and the other not.
2168    #[test]
2169    fn multi_filter() -> Result<()> {
2170        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2171        let table_scan = test_table_scan()?;
2172        let plan = LogicalPlanBuilder::from(table_scan)
2173            .project(vec![col("a").alias("b"), col("c")])?
2174            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2175            .filter(col("b").gt(lit(10i64)))?
2176            .filter(col("sum(test.c)").gt(lit(10i64)))?
2177            .build()?;
2178
2179        // not part of the test, just good to know:
2180        assert_snapshot!(plan,
2181        @r"
2182        Filter: sum(test.c) > Int64(10)
2183          Filter: b > Int64(10)
2184            Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2185              Projection: test.a AS b, test.c
2186                TableScan: test
2187        ",
2188        );
2189        // filter is before the projections
2190        assert_optimized_plan_equal!(
2191            plan,
2192            @r"
2193        Filter: sum(test.c) > Int64(10)
2194          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2195            Projection: test.a AS b, test.c
2196              TableScan: test, full_filters=[test.a > Int64(10)]
2197        "
2198        )
2199    }
2200
2201    /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed
2202    /// and the other not.
2203    #[test]
2204    fn split_filter() -> Result<()> {
2205        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2206        let table_scan = test_table_scan()?;
2207        let plan = LogicalPlanBuilder::from(table_scan)
2208            .project(vec![col("a").alias("b"), col("c")])?
2209            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2210            .filter(and(
2211                col("sum(test.c)").gt(lit(10i64)),
2212                and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2213            ))?
2214            .build()?;
2215
2216        // not part of the test, just good to know:
2217        assert_snapshot!(plan,
2218        @r"
2219        Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2220          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2221            Projection: test.a AS b, test.c
2222              TableScan: test
2223        ",
2224        );
2225        // filter is before the projections
2226        assert_optimized_plan_equal!(
2227            plan,
2228            @r"
2229        Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2230          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2231            Projection: test.a AS b, test.c
2232              TableScan: test, full_filters=[test.a > Int64(10)]
2233        "
2234        )
2235    }
2236
2237    /// verifies that when two limits are in place, we jump neither
2238    #[test]
2239    fn double_limit() -> Result<()> {
2240        let table_scan = test_table_scan()?;
2241        let plan = LogicalPlanBuilder::from(table_scan)
2242            .project(vec![col("a"), col("b")])?
2243            .limit(0, Some(20))?
2244            .limit(0, Some(10))?
2245            .project(vec![col("a"), col("b")])?
2246            .filter(col("a").eq(lit(1i64)))?
2247            .build()?;
2248        // filter does not just any of the limits
2249        assert_optimized_plan_equal!(
2250            plan,
2251            @r"
2252        Projection: test.a, test.b
2253          Filter: test.a = Int64(1)
2254            Limit: skip=0, fetch=10
2255              Limit: skip=0, fetch=20
2256                Projection: test.a, test.b
2257                  TableScan: test
2258        "
2259        )
2260    }
2261
2262    #[test]
2263    fn union_all() -> Result<()> {
2264        let table_scan = test_table_scan()?;
2265        let table_scan2 = test_table_scan_with_name("test2")?;
2266        let plan = LogicalPlanBuilder::from(table_scan)
2267            .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2268            .filter(col("a").eq(lit(1i64)))?
2269            .build()?;
2270        // filter appears below Union
2271        assert_optimized_plan_equal!(
2272            plan,
2273            @r"
2274        Union
2275          TableScan: test, full_filters=[test.a = Int64(1)]
2276          TableScan: test2, full_filters=[test2.a = Int64(1)]
2277        "
2278        )
2279    }
2280
2281    #[test]
2282    fn union_all_on_projection() -> Result<()> {
2283        let table_scan = test_table_scan()?;
2284        let table = LogicalPlanBuilder::from(table_scan)
2285            .project(vec![col("a").alias("b")])?
2286            .alias("test2")?;
2287
2288        let plan = table
2289            .clone()
2290            .union(table.build()?)?
2291            .filter(col("b").eq(lit(1i64)))?
2292            .build()?;
2293
2294        // filter appears below Union
2295        assert_optimized_plan_equal!(
2296            plan,
2297            @r"
2298        Union
2299          SubqueryAlias: test2
2300            Projection: test.a AS b
2301              TableScan: test, full_filters=[test.a = Int64(1)]
2302          SubqueryAlias: test2
2303            Projection: test.a AS b
2304              TableScan: test, full_filters=[test.a = Int64(1)]
2305        "
2306        )
2307    }
2308
2309    #[test]
2310    fn test_union_different_schema() -> Result<()> {
2311        let left = LogicalPlanBuilder::from(test_table_scan()?)
2312            .project(vec![col("a"), col("b"), col("c")])?
2313            .build()?;
2314
2315        let schema = Schema::new(vec![
2316            Field::new("d", DataType::UInt32, false),
2317            Field::new("e", DataType::UInt32, false),
2318            Field::new("f", DataType::UInt32, false),
2319        ]);
2320        let right = table_scan(Some("test1"), &schema, None)?
2321            .project(vec![col("d"), col("e"), col("f")])?
2322            .build()?;
2323        let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2324        let plan = LogicalPlanBuilder::from(left)
2325            .cross_join(right)?
2326            .project(vec![col("test.a"), col("test1.d")])?
2327            .filter(filter)?
2328            .build()?;
2329
2330        assert_optimized_plan_equal!(
2331            plan,
2332            @r"
2333        Projection: test.a, test1.d
2334          Cross Join: 
2335            Projection: test.a, test.b, test.c
2336              TableScan: test, full_filters=[test.a = Int32(1)]
2337            Projection: test1.d, test1.e, test1.f
2338              TableScan: test1, full_filters=[test1.d > Int32(2)]
2339        "
2340        )
2341    }
2342
2343    #[test]
2344    fn test_project_same_name_different_qualifier() -> Result<()> {
2345        let table_scan = test_table_scan()?;
2346        let left = LogicalPlanBuilder::from(table_scan)
2347            .project(vec![col("a"), col("b"), col("c")])?
2348            .build()?;
2349        let right_table_scan = test_table_scan_with_name("test1")?;
2350        let right = LogicalPlanBuilder::from(right_table_scan)
2351            .project(vec![col("a"), col("b"), col("c")])?
2352            .build()?;
2353        let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2354        let plan = LogicalPlanBuilder::from(left)
2355            .cross_join(right)?
2356            .project(vec![col("test.a"), col("test1.a")])?
2357            .filter(filter)?
2358            .build()?;
2359
2360        assert_optimized_plan_equal!(
2361            plan,
2362            @r"
2363        Projection: test.a, test1.a
2364          Cross Join: 
2365            Projection: test.a, test.b, test.c
2366              TableScan: test, full_filters=[test.a = Int32(1)]
2367            Projection: test1.a, test1.b, test1.c
2368              TableScan: test1, full_filters=[test1.a > Int32(2)]
2369        "
2370        )
2371    }
2372
2373    /// verifies that filters with the same columns are correctly placed
2374    #[test]
2375    fn filter_2_breaks_limits() -> Result<()> {
2376        let table_scan = test_table_scan()?;
2377        let plan = LogicalPlanBuilder::from(table_scan)
2378            .project(vec![col("a")])?
2379            .filter(col("a").lt_eq(lit(1i64)))?
2380            .limit(0, Some(1))?
2381            .project(vec![col("a")])?
2382            .filter(col("a").gt_eq(lit(1i64)))?
2383            .build()?;
2384        // Should be able to move both filters below the projections
2385
2386        // not part of the test
2387        assert_snapshot!(plan,
2388        @r"
2389        Filter: test.a >= Int64(1)
2390          Projection: test.a
2391            Limit: skip=0, fetch=1
2392              Filter: test.a <= Int64(1)
2393                Projection: test.a
2394                  TableScan: test
2395        ",
2396        );
2397        assert_optimized_plan_equal!(
2398            plan,
2399            @r"
2400        Projection: test.a
2401          Filter: test.a >= Int64(1)
2402            Limit: skip=0, fetch=1
2403              Projection: test.a
2404                TableScan: test, full_filters=[test.a <= Int64(1)]
2405        "
2406        )
2407    }
2408
2409    /// verifies that filters to be placed on the same depth are ANDed
2410    #[test]
2411    fn two_filters_on_same_depth() -> Result<()> {
2412        let table_scan = test_table_scan()?;
2413        let plan = LogicalPlanBuilder::from(table_scan)
2414            .limit(0, Some(1))?
2415            .filter(col("a").lt_eq(lit(1i64)))?
2416            .filter(col("a").gt_eq(lit(1i64)))?
2417            .project(vec![col("a")])?
2418            .build()?;
2419
2420        // not part of the test
2421        assert_snapshot!(plan,
2422        @r"
2423        Projection: test.a
2424          Filter: test.a >= Int64(1)
2425            Filter: test.a <= Int64(1)
2426              Limit: skip=0, fetch=1
2427                TableScan: test
2428        ",
2429        );
2430        assert_optimized_plan_equal!(
2431            plan,
2432            @r"
2433        Projection: test.a
2434          Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2435            Limit: skip=0, fetch=1
2436              TableScan: test
2437        "
2438        )
2439    }
2440
2441    /// verifies that filters on a plan with user nodes are not lost
2442    /// (ARROW-10547)
2443    #[test]
2444    fn filters_user_defined_node() -> Result<()> {
2445        let table_scan = test_table_scan()?;
2446        let plan = LogicalPlanBuilder::from(table_scan)
2447            .filter(col("a").lt_eq(lit(1i64)))?
2448            .build()?;
2449
2450        let plan = user_defined::new(plan);
2451
2452        // not part of the test
2453        assert_snapshot!(plan,
2454        @r"
2455        TestUserDefined
2456          Filter: test.a <= Int64(1)
2457            TableScan: test
2458        ",
2459        );
2460        assert_optimized_plan_equal!(
2461            plan,
2462            @r"
2463        TestUserDefined
2464          TableScan: test, full_filters=[test.a <= Int64(1)]
2465        "
2466        )
2467    }
2468
2469    /// post-on-join predicates on a column common to both sides is pushed to both sides
2470    #[test]
2471    fn filter_on_join_on_common_independent() -> Result<()> {
2472        let table_scan = test_table_scan()?;
2473        let left = LogicalPlanBuilder::from(table_scan).build()?;
2474        let right_table_scan = test_table_scan_with_name("test2")?;
2475        let right = LogicalPlanBuilder::from(right_table_scan)
2476            .project(vec![col("a")])?
2477            .build()?;
2478        let plan = LogicalPlanBuilder::from(left)
2479            .join(
2480                right,
2481                JoinType::Inner,
2482                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2483                None,
2484            )?
2485            .filter(col("test.a").lt_eq(lit(1i64)))?
2486            .build()?;
2487
2488        // not part of the test, just good to know:
2489        assert_snapshot!(plan,
2490        @r"
2491        Filter: test.a <= Int64(1)
2492          Inner Join: test.a = test2.a
2493            TableScan: test
2494            Projection: test2.a
2495              TableScan: test2
2496        ",
2497        );
2498        // filter sent to side before the join
2499        assert_optimized_plan_equal!(
2500            plan,
2501            @r"
2502        Inner Join: test.a = test2.a
2503          TableScan: test, full_filters=[test.a <= Int64(1)]
2504          Projection: test2.a
2505            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2506        "
2507        )
2508    }
2509
2510    /// post-using-join predicates on a column common to both sides is pushed to both sides
2511    #[test]
2512    fn filter_using_join_on_common_independent() -> Result<()> {
2513        let table_scan = test_table_scan()?;
2514        let left = LogicalPlanBuilder::from(table_scan).build()?;
2515        let right_table_scan = test_table_scan_with_name("test2")?;
2516        let right = LogicalPlanBuilder::from(right_table_scan)
2517            .project(vec![col("a")])?
2518            .build()?;
2519        let plan = LogicalPlanBuilder::from(left)
2520            .join_using(
2521                right,
2522                JoinType::Inner,
2523                vec![Column::from_name("a".to_string())],
2524            )?
2525            .filter(col("a").lt_eq(lit(1i64)))?
2526            .build()?;
2527
2528        // not part of the test, just good to know:
2529        assert_snapshot!(plan,
2530        @r"
2531        Filter: test.a <= Int64(1)
2532          Inner Join: Using test.a = test2.a
2533            TableScan: test
2534            Projection: test2.a
2535              TableScan: test2
2536        ",
2537        );
2538        // filter sent to side before the join
2539        assert_optimized_plan_equal!(
2540            plan,
2541            @r"
2542        Inner Join: Using test.a = test2.a
2543          TableScan: test, full_filters=[test.a <= Int64(1)]
2544          Projection: test2.a
2545            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2546        "
2547        )
2548    }
2549
2550    /// post-join predicates with columns from both sides are converted to join filters
2551    #[test]
2552    fn filter_join_on_common_dependent() -> Result<()> {
2553        let table_scan = test_table_scan()?;
2554        let left = LogicalPlanBuilder::from(table_scan)
2555            .project(vec![col("a"), col("c")])?
2556            .build()?;
2557        let right_table_scan = test_table_scan_with_name("test2")?;
2558        let right = LogicalPlanBuilder::from(right_table_scan)
2559            .project(vec![col("a"), col("b")])?
2560            .build()?;
2561        let plan = LogicalPlanBuilder::from(left)
2562            .join(
2563                right,
2564                JoinType::Inner,
2565                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2566                None,
2567            )?
2568            .filter(col("c").lt_eq(col("b")))?
2569            .build()?;
2570
2571        // not part of the test, just good to know:
2572        assert_snapshot!(plan,
2573        @r"
2574        Filter: test.c <= test2.b
2575          Inner Join: test.a = test2.a
2576            Projection: test.a, test.c
2577              TableScan: test
2578            Projection: test2.a, test2.b
2579              TableScan: test2
2580        ",
2581        );
2582        // Filter is converted to Join Filter
2583        assert_optimized_plan_equal!(
2584            plan,
2585            @r"
2586        Inner Join: test.a = test2.a Filter: test.c <= test2.b
2587          Projection: test.a, test.c
2588            TableScan: test
2589          Projection: test2.a, test2.b
2590            TableScan: test2
2591        "
2592        )
2593    }
2594
2595    /// post-join predicates with columns from one side of a join are pushed only to that side
2596    #[test]
2597    fn filter_join_on_one_side() -> Result<()> {
2598        let table_scan = test_table_scan()?;
2599        let left = LogicalPlanBuilder::from(table_scan)
2600            .project(vec![col("a"), col("b")])?
2601            .build()?;
2602        let table_scan_right = test_table_scan_with_name("test2")?;
2603        let right = LogicalPlanBuilder::from(table_scan_right)
2604            .project(vec![col("a"), col("c")])?
2605            .build()?;
2606
2607        let plan = LogicalPlanBuilder::from(left)
2608            .join(
2609                right,
2610                JoinType::Inner,
2611                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2612                None,
2613            )?
2614            .filter(col("b").lt_eq(lit(1i64)))?
2615            .build()?;
2616
2617        // not part of the test, just good to know:
2618        assert_snapshot!(plan,
2619        @r"
2620        Filter: test.b <= Int64(1)
2621          Inner Join: test.a = test2.a
2622            Projection: test.a, test.b
2623              TableScan: test
2624            Projection: test2.a, test2.c
2625              TableScan: test2
2626        ",
2627        );
2628        assert_optimized_plan_equal!(
2629            plan,
2630            @r"
2631        Inner Join: test.a = test2.a
2632          Projection: test.a, test.b
2633            TableScan: test, full_filters=[test.b <= Int64(1)]
2634          Projection: test2.a, test2.c
2635            TableScan: test2
2636        "
2637        )
2638    }
2639
2640    /// post-join predicates on the right side of a left join are not duplicated
2641    /// TODO: In this case we can sometimes convert the join to an INNER join
2642    #[test]
2643    fn filter_using_left_join() -> Result<()> {
2644        let table_scan = test_table_scan()?;
2645        let left = LogicalPlanBuilder::from(table_scan).build()?;
2646        let right_table_scan = test_table_scan_with_name("test2")?;
2647        let right = LogicalPlanBuilder::from(right_table_scan)
2648            .project(vec![col("a")])?
2649            .build()?;
2650        let plan = LogicalPlanBuilder::from(left)
2651            .join_using(
2652                right,
2653                JoinType::Left,
2654                vec![Column::from_name("a".to_string())],
2655            )?
2656            .filter(col("test2.a").lt_eq(lit(1i64)))?
2657            .build()?;
2658
2659        // not part of the test, just good to know:
2660        assert_snapshot!(plan,
2661        @r"
2662        Filter: test2.a <= Int64(1)
2663          Left Join: Using test.a = test2.a
2664            TableScan: test
2665            Projection: test2.a
2666              TableScan: test2
2667        ",
2668        );
2669        // filter not duplicated nor pushed down - i.e. noop
2670        assert_optimized_plan_equal!(
2671            plan,
2672            @r"
2673        Filter: test2.a <= Int64(1)
2674          Left Join: Using test.a = test2.a
2675            TableScan: test, full_filters=[test.a <= Int64(1)]
2676            Projection: test2.a
2677              TableScan: test2
2678        "
2679        )
2680    }
2681
2682    /// post-join predicates on the left side of a right join are not duplicated
2683    #[test]
2684    fn filter_using_right_join() -> Result<()> {
2685        let table_scan = test_table_scan()?;
2686        let left = LogicalPlanBuilder::from(table_scan).build()?;
2687        let right_table_scan = test_table_scan_with_name("test2")?;
2688        let right = LogicalPlanBuilder::from(right_table_scan)
2689            .project(vec![col("a")])?
2690            .build()?;
2691        let plan = LogicalPlanBuilder::from(left)
2692            .join_using(
2693                right,
2694                JoinType::Right,
2695                vec![Column::from_name("a".to_string())],
2696            )?
2697            .filter(col("test.a").lt_eq(lit(1i64)))?
2698            .build()?;
2699
2700        // not part of the test, just good to know:
2701        assert_snapshot!(plan,
2702        @r"
2703        Filter: test.a <= Int64(1)
2704          Right Join: Using test.a = test2.a
2705            TableScan: test
2706            Projection: test2.a
2707              TableScan: test2
2708        ",
2709        );
2710        // filter not duplicated nor pushed down - i.e. noop
2711        assert_optimized_plan_equal!(
2712            plan,
2713            @r"
2714        Filter: test.a <= Int64(1)
2715          Right Join: Using test.a = test2.a
2716            TableScan: test
2717            Projection: test2.a
2718              TableScan: test2, full_filters=[test2.a <= Int64(1)]
2719        "
2720        )
2721    }
2722
2723    /// post-left-join predicate on a column common to both sides is only pushed to the left side
2724    /// i.e. - not duplicated to the right side
2725    #[test]
2726    fn filter_using_left_join_on_common() -> Result<()> {
2727        let table_scan = test_table_scan()?;
2728        let left = LogicalPlanBuilder::from(table_scan).build()?;
2729        let right_table_scan = test_table_scan_with_name("test2")?;
2730        let right = LogicalPlanBuilder::from(right_table_scan)
2731            .project(vec![col("a")])?
2732            .build()?;
2733        let plan = LogicalPlanBuilder::from(left)
2734            .join_using(
2735                right,
2736                JoinType::Left,
2737                vec![Column::from_name("a".to_string())],
2738            )?
2739            .filter(col("a").lt_eq(lit(1i64)))?
2740            .build()?;
2741
2742        // not part of the test, just good to know:
2743        assert_snapshot!(plan,
2744        @r"
2745        Filter: test.a <= Int64(1)
2746          Left Join: Using test.a = test2.a
2747            TableScan: test
2748            Projection: test2.a
2749              TableScan: test2
2750        ",
2751        );
2752        // filter sent to left side of the join, not the right
2753        assert_optimized_plan_equal!(
2754            plan,
2755            @r"
2756        Left Join: Using test.a = test2.a
2757          TableScan: test, full_filters=[test.a <= Int64(1)]
2758          Projection: test2.a
2759            TableScan: test2
2760        "
2761        )
2762    }
2763
2764    /// post-right-join predicate on a column common to both sides is only pushed to the right side
2765    /// i.e. - not duplicated to the left side.
2766    #[test]
2767    fn filter_using_right_join_on_common() -> Result<()> {
2768        let table_scan = test_table_scan()?;
2769        let left = LogicalPlanBuilder::from(table_scan).build()?;
2770        let right_table_scan = test_table_scan_with_name("test2")?;
2771        let right = LogicalPlanBuilder::from(right_table_scan)
2772            .project(vec![col("a")])?
2773            .build()?;
2774        let plan = LogicalPlanBuilder::from(left)
2775            .join_using(
2776                right,
2777                JoinType::Right,
2778                vec![Column::from_name("a".to_string())],
2779            )?
2780            .filter(col("test2.a").lt_eq(lit(1i64)))?
2781            .build()?;
2782
2783        // not part of the test, just good to know:
2784        assert_snapshot!(plan,
2785        @r"
2786        Filter: test2.a <= Int64(1)
2787          Right Join: Using test.a = test2.a
2788            TableScan: test
2789            Projection: test2.a
2790              TableScan: test2
2791        ",
2792        );
2793        // filter sent to right side of join, not duplicated to the left
2794        assert_optimized_plan_equal!(
2795            plan,
2796            @r"
2797        Right Join: Using test.a = test2.a
2798          TableScan: test
2799          Projection: test2.a
2800            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2801        "
2802        )
2803    }
2804
2805    /// single table predicate parts of ON condition should be pushed to both inputs
2806    #[test]
2807    fn join_on_with_filter() -> Result<()> {
2808        let table_scan = test_table_scan()?;
2809        let left = LogicalPlanBuilder::from(table_scan)
2810            .project(vec![col("a"), col("b"), col("c")])?
2811            .build()?;
2812        let right_table_scan = test_table_scan_with_name("test2")?;
2813        let right = LogicalPlanBuilder::from(right_table_scan)
2814            .project(vec![col("a"), col("b"), col("c")])?
2815            .build()?;
2816        let filter = col("test.c")
2817            .gt(lit(1u32))
2818            .and(col("test.b").lt(col("test2.b")))
2819            .and(col("test2.c").gt(lit(4u32)));
2820        let plan = LogicalPlanBuilder::from(left)
2821            .join(
2822                right,
2823                JoinType::Inner,
2824                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2825                Some(filter),
2826            )?
2827            .build()?;
2828
2829        // not part of the test, just good to know:
2830        assert_snapshot!(plan,
2831        @r"
2832        Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2833          Projection: test.a, test.b, test.c
2834            TableScan: test
2835          Projection: test2.a, test2.b, test2.c
2836            TableScan: test2
2837        ",
2838        );
2839        assert_optimized_plan_equal!(
2840            plan,
2841            @r"
2842        Inner Join: test.a = test2.a Filter: test.b < test2.b
2843          Projection: test.a, test.b, test.c
2844            TableScan: test, full_filters=[test.c > UInt32(1)]
2845          Projection: test2.a, test2.b, test2.c
2846            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2847        "
2848        )
2849    }
2850
2851    /// join filter should be completely removed after pushdown
2852    #[test]
2853    fn join_filter_removed() -> Result<()> {
2854        let table_scan = test_table_scan()?;
2855        let left = LogicalPlanBuilder::from(table_scan)
2856            .project(vec![col("a"), col("b"), col("c")])?
2857            .build()?;
2858        let right_table_scan = test_table_scan_with_name("test2")?;
2859        let right = LogicalPlanBuilder::from(right_table_scan)
2860            .project(vec![col("a"), col("b"), col("c")])?
2861            .build()?;
2862        let filter = col("test.b")
2863            .gt(lit(1u32))
2864            .and(col("test2.c").gt(lit(4u32)));
2865        let plan = LogicalPlanBuilder::from(left)
2866            .join(
2867                right,
2868                JoinType::Inner,
2869                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2870                Some(filter),
2871            )?
2872            .build()?;
2873
2874        // not part of the test, just good to know:
2875        assert_snapshot!(plan,
2876        @r"
2877        Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2878          Projection: test.a, test.b, test.c
2879            TableScan: test
2880          Projection: test2.a, test2.b, test2.c
2881            TableScan: test2
2882        ",
2883        );
2884        assert_optimized_plan_equal!(
2885            plan,
2886            @r"
2887        Inner Join: test.a = test2.a
2888          Projection: test.a, test.b, test.c
2889            TableScan: test, full_filters=[test.b > UInt32(1)]
2890          Projection: test2.a, test2.b, test2.c
2891            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2892        "
2893        )
2894    }
2895
2896    /// predicate on join key in filter expression should be pushed down to both inputs
2897    #[test]
2898    fn join_filter_on_common() -> Result<()> {
2899        let table_scan = test_table_scan()?;
2900        let left = LogicalPlanBuilder::from(table_scan)
2901            .project(vec![col("a")])?
2902            .build()?;
2903        let right_table_scan = test_table_scan_with_name("test2")?;
2904        let right = LogicalPlanBuilder::from(right_table_scan)
2905            .project(vec![col("b")])?
2906            .build()?;
2907        let filter = col("test.a").gt(lit(1u32));
2908        let plan = LogicalPlanBuilder::from(left)
2909            .join(
2910                right,
2911                JoinType::Inner,
2912                (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2913                Some(filter),
2914            )?
2915            .build()?;
2916
2917        // not part of the test, just good to know:
2918        assert_snapshot!(plan,
2919        @r"
2920        Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2921          Projection: test.a
2922            TableScan: test
2923          Projection: test2.b
2924            TableScan: test2
2925        ",
2926        );
2927        assert_optimized_plan_equal!(
2928            plan,
2929            @r"
2930        Inner Join: test.a = test2.b
2931          Projection: test.a
2932            TableScan: test, full_filters=[test.a > UInt32(1)]
2933          Projection: test2.b
2934            TableScan: test2, full_filters=[test2.b > UInt32(1)]
2935        "
2936        )
2937    }
2938
2939    /// single table predicate parts of ON condition should be pushed to right input
2940    #[test]
2941    fn left_join_on_with_filter() -> Result<()> {
2942        let table_scan = test_table_scan()?;
2943        let left = LogicalPlanBuilder::from(table_scan)
2944            .project(vec![col("a"), col("b"), col("c")])?
2945            .build()?;
2946        let right_table_scan = test_table_scan_with_name("test2")?;
2947        let right = LogicalPlanBuilder::from(right_table_scan)
2948            .project(vec![col("a"), col("b"), col("c")])?
2949            .build()?;
2950        let filter = col("test.a")
2951            .gt(lit(1u32))
2952            .and(col("test.b").lt(col("test2.b")))
2953            .and(col("test2.c").gt(lit(4u32)));
2954        let plan = LogicalPlanBuilder::from(left)
2955            .join(
2956                right,
2957                JoinType::Left,
2958                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2959                Some(filter),
2960            )?
2961            .build()?;
2962
2963        // not part of the test, just good to know:
2964        assert_snapshot!(plan,
2965        @r"
2966        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2967          Projection: test.a, test.b, test.c
2968            TableScan: test
2969          Projection: test2.a, test2.b, test2.c
2970            TableScan: test2
2971        ",
2972        );
2973        assert_optimized_plan_equal!(
2974            plan,
2975            @r"
2976        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2977          Projection: test.a, test.b, test.c
2978            TableScan: test
2979          Projection: test2.a, test2.b, test2.c
2980            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2981        "
2982        )
2983    }
2984
2985    /// single table predicate parts of ON condition should be pushed to left input
2986    #[test]
2987    fn right_join_on_with_filter() -> Result<()> {
2988        let table_scan = test_table_scan()?;
2989        let left = LogicalPlanBuilder::from(table_scan)
2990            .project(vec![col("a"), col("b"), col("c")])?
2991            .build()?;
2992        let right_table_scan = test_table_scan_with_name("test2")?;
2993        let right = LogicalPlanBuilder::from(right_table_scan)
2994            .project(vec![col("a"), col("b"), col("c")])?
2995            .build()?;
2996        let filter = col("test.a")
2997            .gt(lit(1u32))
2998            .and(col("test.b").lt(col("test2.b")))
2999            .and(col("test2.c").gt(lit(4u32)));
3000        let plan = LogicalPlanBuilder::from(left)
3001            .join(
3002                right,
3003                JoinType::Right,
3004                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3005                Some(filter),
3006            )?
3007            .build()?;
3008
3009        // not part of the test, just good to know:
3010        assert_snapshot!(plan,
3011        @r"
3012        Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3013          Projection: test.a, test.b, test.c
3014            TableScan: test
3015          Projection: test2.a, test2.b, test2.c
3016            TableScan: test2
3017        ",
3018        );
3019        assert_optimized_plan_equal!(
3020            plan,
3021            @r"
3022        Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
3023          Projection: test.a, test.b, test.c
3024            TableScan: test, full_filters=[test.a > UInt32(1)]
3025          Projection: test2.a, test2.b, test2.c
3026            TableScan: test2
3027        "
3028        )
3029    }
3030
3031    /// single table predicate parts of ON condition should not be pushed
3032    #[test]
3033    fn full_join_on_with_filter() -> Result<()> {
3034        let table_scan = test_table_scan()?;
3035        let left = LogicalPlanBuilder::from(table_scan)
3036            .project(vec![col("a"), col("b"), col("c")])?
3037            .build()?;
3038        let right_table_scan = test_table_scan_with_name("test2")?;
3039        let right = LogicalPlanBuilder::from(right_table_scan)
3040            .project(vec![col("a"), col("b"), col("c")])?
3041            .build()?;
3042        let filter = col("test.a")
3043            .gt(lit(1u32))
3044            .and(col("test.b").lt(col("test2.b")))
3045            .and(col("test2.c").gt(lit(4u32)));
3046        let plan = LogicalPlanBuilder::from(left)
3047            .join(
3048                right,
3049                JoinType::Full,
3050                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3051                Some(filter),
3052            )?
3053            .build()?;
3054
3055        // not part of the test, just good to know:
3056        assert_snapshot!(plan,
3057        @r"
3058        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3059          Projection: test.a, test.b, test.c
3060            TableScan: test
3061          Projection: test2.a, test2.b, test2.c
3062            TableScan: test2
3063        ",
3064        );
3065        assert_optimized_plan_equal!(
3066            plan,
3067            @r"
3068        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3069          Projection: test.a, test.b, test.c
3070            TableScan: test
3071          Projection: test2.a, test2.b, test2.c
3072            TableScan: test2
3073        "
3074        )
3075    }
3076
3077    struct PushDownProvider {
3078        pub filter_support: TableProviderFilterPushDown,
3079    }
3080
3081    #[async_trait]
3082    impl TableSource for PushDownProvider {
3083        fn schema(&self) -> SchemaRef {
3084            Arc::new(Schema::new(vec![
3085                Field::new("a", DataType::Int32, true),
3086                Field::new("b", DataType::Int32, true),
3087            ]))
3088        }
3089
3090        fn table_type(&self) -> TableType {
3091            TableType::Base
3092        }
3093
3094        fn supports_filters_pushdown(
3095            &self,
3096            filters: &[&Expr],
3097        ) -> Result<Vec<TableProviderFilterPushDown>> {
3098            Ok((0..filters.len())
3099                .map(|_| self.filter_support.clone())
3100                .collect())
3101        }
3102
3103        fn as_any(&self) -> &dyn Any {
3104            self
3105        }
3106    }
3107
3108    fn table_scan_with_pushdown_provider_builder(
3109        filter_support: TableProviderFilterPushDown,
3110        filters: Vec<Expr>,
3111        projection: Option<Vec<usize>>,
3112    ) -> Result<LogicalPlanBuilder> {
3113        let test_provider = PushDownProvider { filter_support };
3114
3115        let table_scan = LogicalPlan::TableScan(TableScan {
3116            table_name: "test".into(),
3117            filters,
3118            projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3119            projection,
3120            source: Arc::new(test_provider),
3121            fetch: None,
3122        });
3123
3124        Ok(LogicalPlanBuilder::from(table_scan))
3125    }
3126
3127    fn table_scan_with_pushdown_provider(
3128        filter_support: TableProviderFilterPushDown,
3129    ) -> Result<LogicalPlan> {
3130        table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3131            .filter(col("a").eq(lit(1i64)))?
3132            .build()
3133    }
3134
3135    #[test]
3136    fn filter_with_table_provider_exact() -> Result<()> {
3137        let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3138
3139        assert_optimized_plan_equal!(
3140            plan,
3141            @"TableScan: test, full_filters=[a = Int64(1)]"
3142        )
3143    }
3144
3145    #[test]
3146    fn filter_with_table_provider_inexact() -> Result<()> {
3147        let plan =
3148            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3149
3150        assert_optimized_plan_equal!(
3151            plan,
3152            @r"
3153        Filter: a = Int64(1)
3154          TableScan: test, partial_filters=[a = Int64(1)]
3155        "
3156        )
3157    }
3158
3159    #[test]
3160    fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3161        let plan =
3162            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3163
3164        let optimized_plan = PushDownFilter::new()
3165            .rewrite(plan, &OptimizerContext::new())
3166            .expect("failed to optimize plan")
3167            .data;
3168
3169        // Optimizing the same plan multiple times should produce the same plan
3170        // each time.
3171        assert_optimized_plan_equal!(
3172            optimized_plan,
3173            @r"
3174        Filter: a = Int64(1)
3175          TableScan: test, partial_filters=[a = Int64(1)]
3176        "
3177        )
3178    }
3179
3180    #[test]
3181    fn filter_with_table_provider_unsupported() -> Result<()> {
3182        let plan =
3183            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3184
3185        assert_optimized_plan_equal!(
3186            plan,
3187            @r"
3188        Filter: a = Int64(1)
3189          TableScan: test
3190        "
3191        )
3192    }
3193
3194    #[test]
3195    fn multi_combined_filter() -> Result<()> {
3196        let plan = table_scan_with_pushdown_provider_builder(
3197            TableProviderFilterPushDown::Inexact,
3198            vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3199            Some(vec![0]),
3200        )?
3201        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3202        .project(vec![col("a"), col("b")])?
3203        .build()?;
3204
3205        assert_optimized_plan_equal!(
3206            plan,
3207            @r"
3208        Projection: a, b
3209          Filter: a = Int64(10) AND b > Int64(11)
3210            TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3211        "
3212        )
3213    }
3214
3215    #[test]
3216    fn multi_combined_filter_exact() -> Result<()> {
3217        let plan = table_scan_with_pushdown_provider_builder(
3218            TableProviderFilterPushDown::Exact,
3219            vec![],
3220            Some(vec![0]),
3221        )?
3222        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3223        .project(vec![col("a"), col("b")])?
3224        .build()?;
3225
3226        assert_optimized_plan_equal!(
3227            plan,
3228            @r"
3229        Projection: a, b
3230          TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3231        "
3232        )
3233    }
3234
3235    #[test]
3236    fn test_filter_with_alias() -> Result<()> {
3237        // in table scan the true col name is 'test.a',
3238        // but we rename it as 'b', and use col 'b' in filter
3239        // we need rewrite filter col before push down.
3240        let table_scan = test_table_scan()?;
3241        let plan = LogicalPlanBuilder::from(table_scan)
3242            .project(vec![col("a").alias("b"), col("c")])?
3243            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3244            .build()?;
3245
3246        // filter on col b
3247        assert_snapshot!(plan,
3248        @r"
3249        Filter: b > Int64(10) AND test.c > Int64(10)
3250          Projection: test.a AS b, test.c
3251            TableScan: test
3252        ",
3253        );
3254        // rewrite filter col b to test.a
3255        assert_optimized_plan_equal!(
3256            plan,
3257            @r"
3258        Projection: test.a AS b, test.c
3259          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3260        "
3261        )
3262    }
3263
3264    #[test]
3265    fn test_filter_with_alias_2() -> Result<()> {
3266        // in table scan the true col name is 'test.a',
3267        // but we rename it as 'b', and use col 'b' in filter
3268        // we need rewrite filter col before push down.
3269        let table_scan = test_table_scan()?;
3270        let plan = LogicalPlanBuilder::from(table_scan)
3271            .project(vec![col("a").alias("b"), col("c")])?
3272            .project(vec![col("b"), col("c")])?
3273            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3274            .build()?;
3275
3276        // filter on col b
3277        assert_snapshot!(plan,
3278        @r"
3279        Filter: b > Int64(10) AND test.c > Int64(10)
3280          Projection: b, test.c
3281            Projection: test.a AS b, test.c
3282              TableScan: test
3283        ",
3284        );
3285        // rewrite filter col b to test.a
3286        assert_optimized_plan_equal!(
3287            plan,
3288            @r"
3289        Projection: b, test.c
3290          Projection: test.a AS b, test.c
3291            TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3292        "
3293        )
3294    }
3295
3296    #[test]
3297    fn test_filter_with_multi_alias() -> Result<()> {
3298        let table_scan = test_table_scan()?;
3299        let plan = LogicalPlanBuilder::from(table_scan)
3300            .project(vec![col("a").alias("b"), col("c").alias("d")])?
3301            .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3302            .build()?;
3303
3304        // filter on col b and d
3305        assert_snapshot!(plan,
3306        @r"
3307        Filter: b > Int64(10) AND d > Int64(10)
3308          Projection: test.a AS b, test.c AS d
3309            TableScan: test
3310        ",
3311        );
3312        // rewrite filter col b to test.a, col d to test.c
3313        assert_optimized_plan_equal!(
3314            plan,
3315            @r"
3316        Projection: test.a AS b, test.c AS d
3317          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3318        "
3319        )
3320    }
3321
3322    /// predicate on join key in filter expression should be pushed down to both inputs
3323    #[test]
3324    fn join_filter_with_alias() -> Result<()> {
3325        let table_scan = test_table_scan()?;
3326        let left = LogicalPlanBuilder::from(table_scan)
3327            .project(vec![col("a").alias("c")])?
3328            .build()?;
3329        let right_table_scan = test_table_scan_with_name("test2")?;
3330        let right = LogicalPlanBuilder::from(right_table_scan)
3331            .project(vec![col("b").alias("d")])?
3332            .build()?;
3333        let filter = col("c").gt(lit(1u32));
3334        let plan = LogicalPlanBuilder::from(left)
3335            .join(
3336                right,
3337                JoinType::Inner,
3338                (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3339                Some(filter),
3340            )?
3341            .build()?;
3342
3343        assert_snapshot!(plan,
3344        @r"
3345        Inner Join: c = d Filter: c > UInt32(1)
3346          Projection: test.a AS c
3347            TableScan: test
3348          Projection: test2.b AS d
3349            TableScan: test2
3350        ",
3351        );
3352        // Change filter on col `c`, 'd' to `test.a`, 'test.b'
3353        assert_optimized_plan_equal!(
3354            plan,
3355            @r"
3356        Inner Join: c = d
3357          Projection: test.a AS c
3358            TableScan: test, full_filters=[test.a > UInt32(1)]
3359          Projection: test2.b AS d
3360            TableScan: test2, full_filters=[test2.b > UInt32(1)]
3361        "
3362        )
3363    }
3364
3365    #[test]
3366    fn test_in_filter_with_alias() -> Result<()> {
3367        // in table scan the true col name is 'test.a',
3368        // but we rename it as 'b', and use col 'b' in filter
3369        // we need rewrite filter col before push down.
3370        let table_scan = test_table_scan()?;
3371        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3372        let plan = LogicalPlanBuilder::from(table_scan)
3373            .project(vec![col("a").alias("b"), col("c")])?
3374            .filter(in_list(col("b"), filter_value, false))?
3375            .build()?;
3376
3377        // filter on col b
3378        assert_snapshot!(plan,
3379        @r"
3380        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3381          Projection: test.a AS b, test.c
3382            TableScan: test
3383        ",
3384        );
3385        // rewrite filter col b to test.a
3386        assert_optimized_plan_equal!(
3387            plan,
3388            @r"
3389        Projection: test.a AS b, test.c
3390          TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3391        "
3392        )
3393    }
3394
3395    #[test]
3396    fn test_in_filter_with_alias_2() -> Result<()> {
3397        // in table scan the true col name is 'test.a',
3398        // but we rename it as 'b', and use col 'b' in filter
3399        // we need rewrite filter col before push down.
3400        let table_scan = test_table_scan()?;
3401        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3402        let plan = LogicalPlanBuilder::from(table_scan)
3403            .project(vec![col("a").alias("b"), col("c")])?
3404            .project(vec![col("b"), col("c")])?
3405            .filter(in_list(col("b"), filter_value, false))?
3406            .build()?;
3407
3408        // filter on col b
3409        assert_snapshot!(plan,
3410        @r"
3411        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3412          Projection: b, test.c
3413            Projection: test.a AS b, test.c
3414              TableScan: test
3415        ",
3416        );
3417        // rewrite filter col b to test.a
3418        assert_optimized_plan_equal!(
3419            plan,
3420            @r"
3421        Projection: b, test.c
3422          Projection: test.a AS b, test.c
3423            TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3424        "
3425        )
3426    }
3427
3428    #[test]
3429    fn test_in_subquery_with_alias() -> Result<()> {
3430        // in table scan the true col name is 'test.a',
3431        // but we rename it as 'b', and use col 'b' in subquery filter
3432        let table_scan = test_table_scan()?;
3433        let table_scan_sq = test_table_scan_with_name("sq")?;
3434        let subplan = Arc::new(
3435            LogicalPlanBuilder::from(table_scan_sq)
3436                .project(vec![col("c")])?
3437                .build()?,
3438        );
3439        let plan = LogicalPlanBuilder::from(table_scan)
3440            .project(vec![col("a").alias("b"), col("c")])?
3441            .filter(in_subquery(col("b"), subplan))?
3442            .build()?;
3443
3444        // filter on col b in subquery
3445        assert_snapshot!(plan,
3446        @r"
3447        Filter: b IN (<subquery>)
3448          Subquery:
3449            Projection: sq.c
3450              TableScan: sq
3451          Projection: test.a AS b, test.c
3452            TableScan: test
3453        ",
3454        );
3455        // rewrite filter col b to test.a
3456        assert_optimized_plan_equal!(
3457            plan,
3458            @r"
3459        Projection: test.a AS b, test.c
3460          TableScan: test, full_filters=[test.a IN (<subquery>)]
3461            Subquery:
3462              Projection: sq.c
3463                TableScan: sq
3464        "
3465        )
3466    }
3467
3468    #[test]
3469    fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3470        // SELECT a FROM (SELECT 1 AS a) b WHERE b.a = 1
3471        let plan = LogicalPlanBuilder::empty(true)
3472            .project(vec![lit(0i64).alias("a")])?
3473            .alias("b")?
3474            .project(vec![col("b.a")])?
3475            .alias("b")?
3476            .filter(col("b.a").eq(lit(1i64)))?
3477            .project(vec![col("b.a")])?
3478            .build()?;
3479
3480        assert_snapshot!(plan,
3481        @r"
3482        Projection: b.a
3483          Filter: b.a = Int64(1)
3484            SubqueryAlias: b
3485              Projection: b.a
3486                SubqueryAlias: b
3487                  Projection: Int64(0) AS a
3488                    EmptyRelation: rows=1
3489        ",
3490        );
3491        // Ensure that the predicate without any columns (0 = 1) is
3492        // still there.
3493        assert_optimized_plan_equal!(
3494            plan,
3495            @r"
3496        Projection: b.a
3497          SubqueryAlias: b
3498            Projection: b.a
3499              SubqueryAlias: b
3500                Projection: Int64(0) AS a
3501                  Filter: Int64(0) = Int64(1)
3502                    EmptyRelation: rows=1
3503        "
3504        )
3505    }
3506
3507    #[test]
3508    fn test_crossjoin_with_or_clause() -> Result<()> {
3509        // select * from test,test1 where (test.a = test1.a and test.b > 1) or (test.b = test1.b and test.c < 10);
3510        let table_scan = test_table_scan()?;
3511        let left = LogicalPlanBuilder::from(table_scan)
3512            .project(vec![col("a"), col("b"), col("c")])?
3513            .build()?;
3514        let right_table_scan = test_table_scan_with_name("test1")?;
3515        let right = LogicalPlanBuilder::from(right_table_scan)
3516            .project(vec![col("a").alias("d"), col("a").alias("e")])?
3517            .build()?;
3518        let filter = or(
3519            and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3520            and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3521        );
3522        let plan = LogicalPlanBuilder::from(left)
3523            .cross_join(right)?
3524            .filter(filter)?
3525            .build()?;
3526
3527        assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3528        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3529          Projection: test.a, test.b, test.c
3530            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3531          Projection: test1.a AS d, test1.a AS e
3532            TableScan: test1
3533        ")?;
3534
3535        // Originally global state which can help to avoid duplicate Filters been generated and pushed down.
3536        // Now the global state is removed. Need to double confirm that avoid duplicate Filters.
3537        let optimized_plan = PushDownFilter::new()
3538            .rewrite(plan, &OptimizerContext::new())
3539            .expect("failed to optimize plan")
3540            .data;
3541        assert_optimized_plan_equal!(
3542            optimized_plan,
3543            @r"
3544        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3545          Projection: test.a, test.b, test.c
3546            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3547          Projection: test1.a AS d, test1.a AS e
3548            TableScan: test1
3549        "
3550        )
3551    }
3552
3553    #[test]
3554    fn left_semi_join() -> Result<()> {
3555        let left = test_table_scan_with_name("test1")?;
3556        let right_table_scan = test_table_scan_with_name("test2")?;
3557        let right = LogicalPlanBuilder::from(right_table_scan)
3558            .project(vec![col("a"), col("b")])?
3559            .build()?;
3560        let plan = LogicalPlanBuilder::from(left)
3561            .join(
3562                right,
3563                JoinType::LeftSemi,
3564                (
3565                    vec![Column::from_qualified_name("test1.a")],
3566                    vec![Column::from_qualified_name("test2.a")],
3567                ),
3568                None,
3569            )?
3570            .filter(col("test2.a").lt_eq(lit(1i64)))?
3571            .build()?;
3572
3573        // not part of the test, just good to know:
3574        assert_snapshot!(plan,
3575        @r"
3576        Filter: test2.a <= Int64(1)
3577          LeftSemi Join: test1.a = test2.a
3578            TableScan: test1
3579            Projection: test2.a, test2.b
3580              TableScan: test2
3581        ",
3582        );
3583        // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side.
3584        assert_optimized_plan_equal!(
3585            plan,
3586            @r"
3587        Filter: test2.a <= Int64(1)
3588          LeftSemi Join: test1.a = test2.a
3589            TableScan: test1, full_filters=[test1.a <= Int64(1)]
3590            Projection: test2.a, test2.b
3591              TableScan: test2
3592        "
3593        )
3594    }
3595
3596    #[test]
3597    fn left_semi_join_with_filters() -> Result<()> {
3598        let left = test_table_scan_with_name("test1")?;
3599        let right_table_scan = test_table_scan_with_name("test2")?;
3600        let right = LogicalPlanBuilder::from(right_table_scan)
3601            .project(vec![col("a"), col("b")])?
3602            .build()?;
3603        let plan = LogicalPlanBuilder::from(left)
3604            .join(
3605                right,
3606                JoinType::LeftSemi,
3607                (
3608                    vec![Column::from_qualified_name("test1.a")],
3609                    vec![Column::from_qualified_name("test2.a")],
3610                ),
3611                Some(
3612                    col("test1.b")
3613                        .gt(lit(1u32))
3614                        .and(col("test2.b").gt(lit(2u32))),
3615                ),
3616            )?
3617            .build()?;
3618
3619        // not part of the test, just good to know:
3620        assert_snapshot!(plan,
3621        @r"
3622        LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3623          TableScan: test1
3624          Projection: test2.a, test2.b
3625            TableScan: test2
3626        ",
3627        );
3628        // Both side will be pushed down.
3629        assert_optimized_plan_equal!(
3630            plan,
3631            @r"
3632        LeftSemi Join: test1.a = test2.a
3633          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3634          Projection: test2.a, test2.b
3635            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3636        "
3637        )
3638    }
3639
3640    #[test]
3641    fn right_semi_join() -> Result<()> {
3642        let left = test_table_scan_with_name("test1")?;
3643        let right_table_scan = test_table_scan_with_name("test2")?;
3644        let right = LogicalPlanBuilder::from(right_table_scan)
3645            .project(vec![col("a"), col("b")])?
3646            .build()?;
3647        let plan = LogicalPlanBuilder::from(left)
3648            .join(
3649                right,
3650                JoinType::RightSemi,
3651                (
3652                    vec![Column::from_qualified_name("test1.a")],
3653                    vec![Column::from_qualified_name("test2.a")],
3654                ),
3655                None,
3656            )?
3657            .filter(col("test1.a").lt_eq(lit(1i64)))?
3658            .build()?;
3659
3660        // not part of the test, just good to know:
3661        assert_snapshot!(plan,
3662        @r"
3663        Filter: test1.a <= Int64(1)
3664          RightSemi Join: test1.a = test2.a
3665            TableScan: test1
3666            Projection: test2.a, test2.b
3667              TableScan: test2
3668        ",
3669        );
3670        // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side.
3671        assert_optimized_plan_equal!(
3672            plan,
3673            @r"
3674        Filter: test1.a <= Int64(1)
3675          RightSemi Join: test1.a = test2.a
3676            TableScan: test1
3677            Projection: test2.a, test2.b
3678              TableScan: test2, full_filters=[test2.a <= Int64(1)]
3679        "
3680        )
3681    }
3682
3683    #[test]
3684    fn right_semi_join_with_filters() -> Result<()> {
3685        let left = test_table_scan_with_name("test1")?;
3686        let right_table_scan = test_table_scan_with_name("test2")?;
3687        let right = LogicalPlanBuilder::from(right_table_scan)
3688            .project(vec![col("a"), col("b")])?
3689            .build()?;
3690        let plan = LogicalPlanBuilder::from(left)
3691            .join(
3692                right,
3693                JoinType::RightSemi,
3694                (
3695                    vec![Column::from_qualified_name("test1.a")],
3696                    vec![Column::from_qualified_name("test2.a")],
3697                ),
3698                Some(
3699                    col("test1.b")
3700                        .gt(lit(1u32))
3701                        .and(col("test2.b").gt(lit(2u32))),
3702                ),
3703            )?
3704            .build()?;
3705
3706        // not part of the test, just good to know:
3707        assert_snapshot!(plan,
3708        @r"
3709        RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3710          TableScan: test1
3711          Projection: test2.a, test2.b
3712            TableScan: test2
3713        ",
3714        );
3715        // Both side will be pushed down.
3716        assert_optimized_plan_equal!(
3717            plan,
3718            @r"
3719        RightSemi Join: test1.a = test2.a
3720          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3721          Projection: test2.a, test2.b
3722            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3723        "
3724        )
3725    }
3726
3727    #[test]
3728    fn left_anti_join() -> Result<()> {
3729        let table_scan = test_table_scan_with_name("test1")?;
3730        let left = LogicalPlanBuilder::from(table_scan)
3731            .project(vec![col("a"), col("b")])?
3732            .build()?;
3733        let right_table_scan = test_table_scan_with_name("test2")?;
3734        let right = LogicalPlanBuilder::from(right_table_scan)
3735            .project(vec![col("a"), col("b")])?
3736            .build()?;
3737        let plan = LogicalPlanBuilder::from(left)
3738            .join(
3739                right,
3740                JoinType::LeftAnti,
3741                (
3742                    vec![Column::from_qualified_name("test1.a")],
3743                    vec![Column::from_qualified_name("test2.a")],
3744                ),
3745                None,
3746            )?
3747            .filter(col("test2.a").gt(lit(2u32)))?
3748            .build()?;
3749
3750        // not part of the test, just good to know:
3751        assert_snapshot!(plan,
3752        @r"
3753        Filter: test2.a > UInt32(2)
3754          LeftAnti Join: test1.a = test2.a
3755            Projection: test1.a, test1.b
3756              TableScan: test1
3757            Projection: test2.a, test2.b
3758              TableScan: test2
3759        ",
3760        );
3761        // For left anti, filter of the right side filter can be pushed down.
3762        assert_optimized_plan_equal!(
3763            plan,
3764            @r"
3765        Filter: test2.a > UInt32(2)
3766          LeftAnti Join: test1.a = test2.a
3767            Projection: test1.a, test1.b
3768              TableScan: test1, full_filters=[test1.a > UInt32(2)]
3769            Projection: test2.a, test2.b
3770              TableScan: test2
3771        "
3772        )
3773    }
3774
3775    #[test]
3776    fn left_anti_join_with_filters() -> Result<()> {
3777        let table_scan = test_table_scan_with_name("test1")?;
3778        let left = LogicalPlanBuilder::from(table_scan)
3779            .project(vec![col("a"), col("b")])?
3780            .build()?;
3781        let right_table_scan = test_table_scan_with_name("test2")?;
3782        let right = LogicalPlanBuilder::from(right_table_scan)
3783            .project(vec![col("a"), col("b")])?
3784            .build()?;
3785        let plan = LogicalPlanBuilder::from(left)
3786            .join(
3787                right,
3788                JoinType::LeftAnti,
3789                (
3790                    vec![Column::from_qualified_name("test1.a")],
3791                    vec![Column::from_qualified_name("test2.a")],
3792                ),
3793                Some(
3794                    col("test1.b")
3795                        .gt(lit(1u32))
3796                        .and(col("test2.b").gt(lit(2u32))),
3797                ),
3798            )?
3799            .build()?;
3800
3801        // not part of the test, just good to know:
3802        assert_snapshot!(plan,
3803        @r"
3804        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3805          Projection: test1.a, test1.b
3806            TableScan: test1
3807          Projection: test2.a, test2.b
3808            TableScan: test2
3809        ",
3810        );
3811        // For left anti, filter of the right side filter can be pushed down.
3812        assert_optimized_plan_equal!(
3813            plan,
3814            @r"
3815        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3816          Projection: test1.a, test1.b
3817            TableScan: test1
3818          Projection: test2.a, test2.b
3819            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3820        "
3821        )
3822    }
3823
3824    #[test]
3825    fn right_anti_join() -> Result<()> {
3826        let table_scan = test_table_scan_with_name("test1")?;
3827        let left = LogicalPlanBuilder::from(table_scan)
3828            .project(vec![col("a"), col("b")])?
3829            .build()?;
3830        let right_table_scan = test_table_scan_with_name("test2")?;
3831        let right = LogicalPlanBuilder::from(right_table_scan)
3832            .project(vec![col("a"), col("b")])?
3833            .build()?;
3834        let plan = LogicalPlanBuilder::from(left)
3835            .join(
3836                right,
3837                JoinType::RightAnti,
3838                (
3839                    vec![Column::from_qualified_name("test1.a")],
3840                    vec![Column::from_qualified_name("test2.a")],
3841                ),
3842                None,
3843            )?
3844            .filter(col("test1.a").gt(lit(2u32)))?
3845            .build()?;
3846
3847        // not part of the test, just good to know:
3848        assert_snapshot!(plan,
3849        @r"
3850        Filter: test1.a > UInt32(2)
3851          RightAnti Join: test1.a = test2.a
3852            Projection: test1.a, test1.b
3853              TableScan: test1
3854            Projection: test2.a, test2.b
3855              TableScan: test2
3856        ",
3857        );
3858        // For right anti, filter of the left side can be pushed down.
3859        assert_optimized_plan_equal!(
3860            plan,
3861            @r"
3862        Filter: test1.a > UInt32(2)
3863          RightAnti Join: test1.a = test2.a
3864            Projection: test1.a, test1.b
3865              TableScan: test1
3866            Projection: test2.a, test2.b
3867              TableScan: test2, full_filters=[test2.a > UInt32(2)]
3868        "
3869        )
3870    }
3871
3872    #[test]
3873    fn right_anti_join_with_filters() -> Result<()> {
3874        let table_scan = test_table_scan_with_name("test1")?;
3875        let left = LogicalPlanBuilder::from(table_scan)
3876            .project(vec![col("a"), col("b")])?
3877            .build()?;
3878        let right_table_scan = test_table_scan_with_name("test2")?;
3879        let right = LogicalPlanBuilder::from(right_table_scan)
3880            .project(vec![col("a"), col("b")])?
3881            .build()?;
3882        let plan = LogicalPlanBuilder::from(left)
3883            .join(
3884                right,
3885                JoinType::RightAnti,
3886                (
3887                    vec![Column::from_qualified_name("test1.a")],
3888                    vec![Column::from_qualified_name("test2.a")],
3889                ),
3890                Some(
3891                    col("test1.b")
3892                        .gt(lit(1u32))
3893                        .and(col("test2.b").gt(lit(2u32))),
3894                ),
3895            )?
3896            .build()?;
3897
3898        // not part of the test, just good to know:
3899        assert_snapshot!(plan,
3900        @r"
3901        RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3902          Projection: test1.a, test1.b
3903            TableScan: test1
3904          Projection: test2.a, test2.b
3905            TableScan: test2
3906        ",
3907        );
3908        // For right anti, filter of the left side can be pushed down.
3909        assert_optimized_plan_equal!(
3910            plan,
3911            @r"
3912        RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3913          Projection: test1.a, test1.b
3914            TableScan: test1, full_filters=[test1.b > UInt32(1)]
3915          Projection: test2.a, test2.b
3916            TableScan: test2
3917        "
3918        )
3919    }
3920
3921    #[derive(Debug, PartialEq, Eq, Hash)]
3922    struct TestScalarUDF {
3923        signature: Signature,
3924    }
3925
3926    impl ScalarUDFImpl for TestScalarUDF {
3927        fn as_any(&self) -> &dyn Any {
3928            self
3929        }
3930        fn name(&self) -> &str {
3931            "TestScalarUDF"
3932        }
3933
3934        fn signature(&self) -> &Signature {
3935            &self.signature
3936        }
3937
3938        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3939            Ok(DataType::Int32)
3940        }
3941
3942        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3943            Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3944        }
3945    }
3946
3947    #[test]
3948    fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3949        // SELECT t.a, t.r FROM (SELECT a, sum(b),  TestScalarUDF()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5;
3950        let table_scan = test_table_scan_with_name("test1")?;
3951        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3952            signature: Signature::exact(vec![], Volatility::Volatile),
3953        });
3954        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3955
3956        let plan = LogicalPlanBuilder::from(table_scan)
3957            .aggregate(vec![col("a")], vec![sum(col("b"))])?
3958            .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3959            .alias("t")?
3960            .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3961            .project(vec![col("t.a"), col("t.r")])?
3962            .build()?;
3963
3964        assert_snapshot!(plan,
3965        @r"
3966        Projection: t.a, t.r
3967          Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3968            SubqueryAlias: t
3969              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3970                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3971                  TableScan: test1
3972        ",
3973        );
3974        assert_optimized_plan_equal!(
3975            plan,
3976            @r"
3977        Projection: t.a, t.r
3978          SubqueryAlias: t
3979            Filter: r > Float64(0.5)
3980              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3981                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3982                  TableScan: test1, full_filters=[test1.a > Int32(5)]
3983        "
3984        )
3985    }
3986
3987    #[test]
3988    fn test_push_down_volatile_function_in_join() -> Result<()> {
3989        // SELECT t.a, t.r FROM (SELECT test1.a AS a, TestScalarUDF() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5;
3990        let table_scan = test_table_scan_with_name("test1")?;
3991        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3992            signature: Signature::exact(vec![], Volatility::Volatile),
3993        });
3994        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3995        let left = LogicalPlanBuilder::from(table_scan).build()?;
3996        let right_table_scan = test_table_scan_with_name("test2")?;
3997        let right = LogicalPlanBuilder::from(right_table_scan).build()?;
3998        let plan = LogicalPlanBuilder::from(left)
3999            .join(
4000                right,
4001                JoinType::Inner,
4002                (
4003                    vec![Column::from_qualified_name("test1.a")],
4004                    vec![Column::from_qualified_name("test2.a")],
4005                ),
4006                None,
4007            )?
4008            .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
4009            .alias("t")?
4010            .filter(col("t.r").gt(lit(0.8)))?
4011            .project(vec![col("t.a"), col("t.r")])?
4012            .build()?;
4013
4014        assert_snapshot!(plan,
4015        @r"
4016        Projection: t.a, t.r
4017          Filter: t.r > Float64(0.8)
4018            SubqueryAlias: t
4019              Projection: test1.a AS a, TestScalarUDF() AS r
4020                Inner Join: test1.a = test2.a
4021                  TableScan: test1
4022                  TableScan: test2
4023        ",
4024        );
4025        assert_optimized_plan_equal!(
4026            plan,
4027            @r"
4028        Projection: t.a, t.r
4029          SubqueryAlias: t
4030            Filter: r > Float64(0.8)
4031              Projection: test1.a AS a, TestScalarUDF() AS r
4032                Inner Join: test1.a = test2.a
4033                  TableScan: test1
4034                  TableScan: test2
4035        "
4036        )
4037    }
4038
4039    #[test]
4040    fn test_push_down_volatile_table_scan() -> Result<()> {
4041        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1;
4042        let table_scan = test_table_scan()?;
4043        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4044            signature: Signature::exact(vec![], Volatility::Volatile),
4045        });
4046        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4047        let plan = LogicalPlanBuilder::from(table_scan)
4048            .project(vec![col("a"), col("b")])?
4049            .filter(expr.gt(lit(0.1)))?
4050            .build()?;
4051
4052        assert_snapshot!(plan,
4053        @r"
4054        Filter: TestScalarUDF() > Float64(0.1)
4055          Projection: test.a, test.b
4056            TableScan: test
4057        ",
4058        );
4059        assert_optimized_plan_equal!(
4060            plan,
4061            @r"
4062        Projection: test.a, test.b
4063          Filter: TestScalarUDF() > Float64(0.1)
4064            TableScan: test
4065        "
4066        )
4067    }
4068
4069    #[test]
4070    fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4071        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4072        let table_scan = test_table_scan()?;
4073        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4074            signature: Signature::exact(vec![], Volatility::Volatile),
4075        });
4076        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4077        let plan = LogicalPlanBuilder::from(table_scan)
4078            .project(vec![col("a"), col("b")])?
4079            .filter(
4080                expr.gt(lit(0.1))
4081                    .and(col("t.a").gt(lit(5)))
4082                    .and(col("t.b").gt(lit(10))),
4083            )?
4084            .build()?;
4085
4086        assert_snapshot!(plan,
4087        @r"
4088        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4089          Projection: test.a, test.b
4090            TableScan: test
4091        ",
4092        );
4093        assert_optimized_plan_equal!(
4094            plan,
4095            @r"
4096        Projection: test.a, test.b
4097          Filter: TestScalarUDF() > Float64(0.1)
4098            TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4099        "
4100        )
4101    }
4102
4103    #[test]
4104    fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4105        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4106        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4107            signature: Signature::exact(vec![], Volatility::Volatile),
4108        });
4109        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4110        let plan = table_scan_with_pushdown_provider_builder(
4111            TableProviderFilterPushDown::Unsupported,
4112            vec![],
4113            None,
4114        )?
4115        .project(vec![col("a"), col("b")])?
4116        .filter(
4117            expr.gt(lit(0.1))
4118                .and(col("t.a").gt(lit(5)))
4119                .and(col("t.b").gt(lit(10))),
4120        )?
4121        .build()?;
4122
4123        assert_snapshot!(plan,
4124        @r"
4125        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4126          Projection: a, b
4127            TableScan: test
4128        ",
4129        );
4130        assert_optimized_plan_equal!(
4131            plan,
4132            @r"
4133        Projection: a, b
4134          Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4135            TableScan: test
4136        "
4137        )
4138    }
4139
4140    #[test]
4141    fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4142        // Define a custom user-defined logical node
4143        #[derive(Debug, Hash, Eq, PartialEq)]
4144        struct TestUserNode {
4145            schema: DFSchemaRef,
4146        }
4147
4148        impl PartialOrd for TestUserNode {
4149            fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4150                None
4151            }
4152        }
4153
4154        impl TestUserNode {
4155            fn new() -> Self {
4156                let schema = Arc::new(
4157                    DFSchema::new_with_metadata(
4158                        vec![(None, Field::new("a", DataType::Int64, false).into())],
4159                        Default::default(),
4160                    )
4161                    .unwrap(),
4162                );
4163
4164                Self { schema }
4165            }
4166        }
4167
4168        impl UserDefinedLogicalNodeCore for TestUserNode {
4169            fn name(&self) -> &str {
4170                "test_node"
4171            }
4172
4173            fn inputs(&self) -> Vec<&LogicalPlan> {
4174                vec![]
4175            }
4176
4177            fn schema(&self) -> &DFSchemaRef {
4178                &self.schema
4179            }
4180
4181            fn expressions(&self) -> Vec<Expr> {
4182                vec![]
4183            }
4184
4185            fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4186                write!(f, "TestUserNode")
4187            }
4188
4189            fn with_exprs_and_inputs(
4190                &self,
4191                exprs: Vec<Expr>,
4192                inputs: Vec<LogicalPlan>,
4193            ) -> Result<Self> {
4194                assert!(exprs.is_empty());
4195                assert!(inputs.is_empty());
4196                Ok(Self {
4197                    schema: Arc::clone(&self.schema),
4198                })
4199            }
4200        }
4201
4202        // Create a node and build a plan with a filter
4203        let node = LogicalPlan::Extension(Extension {
4204            node: Arc::new(TestUserNode::new()),
4205        });
4206
4207        let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4208
4209        // Check the original plan format (not part of the test assertions)
4210        assert_snapshot!(plan,
4211        @r"
4212        Filter: Boolean(false)
4213          TestUserNode
4214        ",
4215        );
4216        // Check that the filter is pushed down to the user-defined node
4217        assert_optimized_plan_equal!(
4218            plan,
4219            @r"
4220        Filter: Boolean(false)
4221          TestUserNode
4222        "
4223        )
4224    }
4225}