datafusion_optimizer/
push_down_filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`PushDownFilter`] applies filters as early as possible
19
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26
27use datafusion_common::tree_node::{
28    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29};
30use datafusion_common::{
31    internal_err, plan_err, qualified_name, Column, DFSchema, Result,
32};
33use datafusion_expr::expr::WindowFunction;
34use datafusion_expr::expr_rewriter::replace_col;
35use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
36use datafusion_expr::utils::{
37    conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
38};
39use datafusion_expr::{
40    and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown,
41};
42
43use crate::optimizer::ApplyOrder;
44use crate::simplify_expressions::simplify_predicates;
45use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
46use crate::{OptimizerConfig, OptimizerRule};
47
48/// Optimizer rule for pushing (moving) filter expressions down in a plan so
49/// they are applied as early as possible.
50///
51/// # Introduction
52///
53/// The goal of this rule is to improve query performance by eliminating
54/// redundant work.
55///
56/// For example, given a plan that sorts all values where `a > 10`:
57///
58/// ```text
59///  Filter (a > 10)
60///    Sort (a, b)
61/// ```
62///
63/// A better plan is to  filter the data *before* the Sort, which sorts fewer
64/// rows and therefore does less work overall:
65///
66/// ```text
67///  Sort (a, b)
68///    Filter (a > 10)  <-- Filter is moved before the sort
69/// ```
70///
71/// However it is not always possible to push filters down. For example, given a
72/// plan that finds the top 3 values and then keeps only those that are greater
73/// than 10, if the filter is pushed below the limit it would produce a
74/// different result.
75///
76/// ```text
77///  Filter (a > 10)   <-- can not move this Filter before the limit
78///    Limit (fetch=3)
79///      Sort (a, b)
80/// ```
81///
82///
83/// More formally, a filter-commutative operation is an operation `op` that
84/// satisfies `filter(op(data)) = op(filter(data))`.
85///
86/// The filter-commutative property is plan and column-specific. A filter on `a`
87/// can be pushed through a `Aggregate(group_by = [a], agg=[sum(b))`. However, a
88/// filter on  `sum(b)` can not be pushed through the same aggregate.
89///
90/// # Handling Conjunctions
91///
92/// It is possible to only push down **part** of a filter expression if is
93/// connected with `AND`s (more formally if it is a "conjunction").
94///
95/// For example, given the following plan:
96///
97/// ```text
98/// Filter(a > 10 AND sum(b) < 5)
99///   Aggregate(group_by = [a], agg = [sum(b))
100/// ```
101///
102/// The `a > 10` is commutative with the `Aggregate` but  `sum(b) < 5` is not.
103/// Therefore it is possible to only push part of the expression, resulting in:
104///
105/// ```text
106/// Filter(sum(b) < 5)
107///   Aggregate(group_by = [a], agg = [sum(b))
108///     Filter(a > 10)
109/// ```
110///
111/// # Handling Column Aliases
112///
113/// This optimizer must sometimes handle re-writing filter expressions when they
114/// pushed, for example if there is a projection that aliases `a+1` to `"b"`:
115///
116/// ```text
117/// Filter (b > 10)
118///     Projection: [a+1 AS "b"]  <-- changes the name of `a+1` to `b`
119/// ```
120///
121/// To apply the filter prior to the `Projection`, all references to `b` must be
122/// rewritten to `a+1`:
123///
124/// ```text
125/// Projection: a AS "b"
126///     Filter: (a + 1 > 10)  <--- changed from b to a + 1
127/// ```
128/// # Implementation Notes
129///
130/// This implementation performs a single pass through the plan, "pushing" down
131/// filters. When it passes through a filter, it stores that filter, and when it
132/// reaches a plan node that does not commute with that filter, it adds the
133/// filter to that place. When it passes through a projection, it re-writes the
134/// filter's expression taking into account that projection.
135#[derive(Default, Debug)]
136pub struct PushDownFilter {}
137
138/// For a given JOIN type, determine whether each input of the join is preserved
139/// for post-join (`WHERE` clause) filters.
140///
141/// It is only correct to push filters below a join for preserved inputs.
142///
143/// # Return Value
144/// A tuple of booleans - (left_preserved, right_preserved).
145///
146/// # "Preserved" input definition
147///
148/// We say a join side is preserved if the join returns all or a subset of the rows from
149/// the relevant side, such that each row of the output table directly maps to a row of
150/// the preserved input table. If a table is not preserved, it can provide extra null rows.
151/// That is, there may be rows in the output table that don't directly map to a row in the
152/// input table.
153///
154/// For example:
155///   - In an inner join, both sides are preserved, because each row of the output
156///     maps directly to a row from each side.
157///
158///   - In a left join, the left side is preserved (we can push predicates) but
159///     the right is not, because there may be rows in the output that don't
160///     directly map to a row in the right input (due to nulls filling where there
161///     is no match on the right).
162pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
163    match join_type {
164        JoinType::Inner => (true, true),
165        JoinType::Left => (true, false),
166        JoinType::Right => (false, true),
167        JoinType::Full => (false, false),
168        // No columns from the right side of the join can be referenced in output
169        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
170        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
171        // No columns from the left side of the join can be referenced in output
172        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
173        JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
174    }
175}
176
177/// For a given JOIN type, determine whether each input of the join is preserved
178/// for the join condition (`ON` clause filters).
179///
180/// It is only correct to push filters below a join for preserved inputs.
181///
182/// # Return Value
183/// A tuple of booleans - (left_preserved, right_preserved).
184///
185/// See [`lr_is_preserved`] for a definition of "preserved".
186pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
187    match join_type {
188        JoinType::Inner => (true, true),
189        JoinType::Left => (false, true),
190        JoinType::Right => (true, false),
191        JoinType::Full => (false, false),
192        JoinType::LeftSemi | JoinType::RightSemi => (true, true),
193        JoinType::LeftAnti => (false, true),
194        JoinType::RightAnti => (true, false),
195        JoinType::LeftMark => (false, true),
196        JoinType::RightMark => (true, false),
197    }
198}
199
200/// Evaluates the columns referenced in the given expression to see if they refer
201/// only to the left or right columns
202#[derive(Debug)]
203struct ColumnChecker<'a> {
204    /// schema of left join input
205    left_schema: &'a DFSchema,
206    /// columns in left_schema, computed on demand
207    left_columns: Option<HashSet<Column>>,
208    /// schema of right join input
209    right_schema: &'a DFSchema,
210    /// columns in left_schema, computed on demand
211    right_columns: Option<HashSet<Column>>,
212}
213
214impl<'a> ColumnChecker<'a> {
215    fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
216        Self {
217            left_schema,
218            left_columns: None,
219            right_schema,
220            right_columns: None,
221        }
222    }
223
224    /// Return true if the expression references only columns from the left side of the join
225    fn is_left_only(&mut self, predicate: &Expr) -> bool {
226        if self.left_columns.is_none() {
227            self.left_columns = Some(schema_columns(self.left_schema));
228        }
229        has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
230    }
231
232    /// Return true if the expression references only columns from the right side of the join
233    fn is_right_only(&mut self, predicate: &Expr) -> bool {
234        if self.right_columns.is_none() {
235            self.right_columns = Some(schema_columns(self.right_schema));
236        }
237        has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
238    }
239}
240
241/// Returns all columns in the schema
242fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
243    schema
244        .iter()
245        .flat_map(|(qualifier, field)| {
246            [
247                Column::new(qualifier.cloned(), field.name()),
248                // we need to push down filter using unqualified column as well
249                Column::new_unqualified(field.name()),
250            ]
251        })
252        .collect::<HashSet<_>>()
253}
254
255/// Determine whether the predicate can evaluate as the join conditions
256fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
257    let mut is_evaluate = true;
258    predicate.apply(|expr| match expr {
259        Expr::Column(_)
260        | Expr::Literal(_, _)
261        | Expr::Placeholder(_)
262        | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
263        Expr::Exists { .. }
264        | Expr::InSubquery(_)
265        | Expr::ScalarSubquery(_)
266        | Expr::OuterReferenceColumn(_, _)
267        | Expr::Unnest(_) => {
268            is_evaluate = false;
269            Ok(TreeNodeRecursion::Stop)
270        }
271        Expr::Alias(_)
272        | Expr::BinaryExpr(_)
273        | Expr::Like(_)
274        | Expr::SimilarTo(_)
275        | Expr::Not(_)
276        | Expr::IsNotNull(_)
277        | Expr::IsNull(_)
278        | Expr::IsTrue(_)
279        | Expr::IsFalse(_)
280        | Expr::IsUnknown(_)
281        | Expr::IsNotTrue(_)
282        | Expr::IsNotFalse(_)
283        | Expr::IsNotUnknown(_)
284        | Expr::Negative(_)
285        | Expr::Between(_)
286        | Expr::Case(_)
287        | Expr::Cast(_)
288        | Expr::TryCast(_)
289        | Expr::InList { .. }
290        | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
291        // TODO: remove the next line after `Expr::Wildcard` is removed
292        #[expect(deprecated)]
293        Expr::AggregateFunction(_)
294        | Expr::WindowFunction(_)
295        | Expr::Wildcard { .. }
296        | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
297    })?;
298    Ok(is_evaluate)
299}
300
301/// examine OR clause to see if any useful clauses can be extracted and push down.
302/// extract at least one qual from each sub clauses of OR clause, then form the quals
303/// to new OR clause as predicate.
304///
305/// # Example
306/// ```text
307/// Filter: (a = c and a < 20) or (b = d and b > 10)
308///     join/crossjoin:
309///          TableScan: projection=[a, b]
310///          TableScan: projection=[c, d]
311/// ```
312///
313/// is optimized to
314///
315/// ```text
316/// Filter: (a = c and a < 20) or (b = d and b > 10)
317///     join/crossjoin:
318///          Filter: (a < 20) or (b > 10)
319///              TableScan: projection=[a, b]
320///          TableScan: projection=[c, d]
321/// ```
322///
323/// In general, predicates of this form:
324///
325/// ```sql
326/// (A AND B) OR (C AND D)
327/// ```
328///
329/// will be transformed to one of:
330///
331/// * `((A AND B) OR (C AND D)) AND (A OR C)`
332/// * `((A AND B) OR (C AND D)) AND ((A AND B) OR C)`
333/// * do nothing.
334fn extract_or_clauses_for_join<'a>(
335    filters: &'a [Expr],
336    schema: &'a DFSchema,
337) -> impl Iterator<Item = Expr> + 'a {
338    let schema_columns = schema_columns(schema);
339
340    // new formed OR clauses and their column references
341    filters.iter().filter_map(move |expr| {
342        if let Expr::BinaryExpr(BinaryExpr {
343            left,
344            op: Operator::Or,
345            right,
346        }) = expr
347        {
348            let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
349            let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
350
351            // If nothing can be extracted from any sub clauses, do nothing for this OR clause.
352            if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
353                return Some(or(left_expr, right_expr));
354            }
355        }
356        None
357    })
358}
359
360/// extract qual from OR sub-clause.
361///
362/// A qual is extracted if it only contains set of column references in schema_columns.
363///
364/// For AND clause, we extract from both sub-clauses, then make new AND clause by extracted
365/// clauses if both extracted; Otherwise, use the extracted clause from any sub-clauses or None.
366///
367/// For OR clause, we extract from both sub-clauses, then make new OR clause by extracted clauses if both extracted;
368/// Otherwise, return None.
369///
370/// For other clause, apply the rule above to extract clause.
371fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
372    let mut predicate = None;
373
374    match expr {
375        Expr::BinaryExpr(BinaryExpr {
376            left: l_expr,
377            op: Operator::Or,
378            right: r_expr,
379        }) => {
380            let l_expr = extract_or_clause(l_expr, schema_columns);
381            let r_expr = extract_or_clause(r_expr, schema_columns);
382
383            if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
384                predicate = Some(or(l_expr, r_expr));
385            }
386        }
387        Expr::BinaryExpr(BinaryExpr {
388            left: l_expr,
389            op: Operator::And,
390            right: r_expr,
391        }) => {
392            let l_expr = extract_or_clause(l_expr, schema_columns);
393            let r_expr = extract_or_clause(r_expr, schema_columns);
394
395            match (l_expr, r_expr) {
396                (Some(l_expr), Some(r_expr)) => {
397                    predicate = Some(and(l_expr, r_expr));
398                }
399                (Some(l_expr), None) => {
400                    predicate = Some(l_expr);
401                }
402                (None, Some(r_expr)) => {
403                    predicate = Some(r_expr);
404                }
405                (None, None) => {
406                    predicate = None;
407                }
408            }
409        }
410        _ => {
411            if has_all_column_refs(expr, schema_columns) {
412                predicate = Some(expr.clone());
413            }
414        }
415    }
416
417    predicate
418}
419
420/// push down join/cross-join
421fn push_down_all_join(
422    predicates: Vec<Expr>,
423    inferred_join_predicates: Vec<Expr>,
424    mut join: Join,
425    on_filter: Vec<Expr>,
426) -> Result<Transformed<LogicalPlan>> {
427    let is_inner_join = join.join_type == JoinType::Inner;
428    // Get pushable predicates from current optimizer state
429    let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
430
431    // The predicates can be divided to three categories:
432    // 1) can push through join to its children(left or right)
433    // 2) can be converted to join conditions if the join type is Inner
434    // 3) should be kept as filter conditions
435    let left_schema = join.left.schema();
436    let right_schema = join.right.schema();
437    let mut left_push = vec![];
438    let mut right_push = vec![];
439    let mut keep_predicates = vec![];
440    let mut join_conditions = vec![];
441    let mut checker = ColumnChecker::new(left_schema, right_schema);
442    for predicate in predicates {
443        if left_preserved && checker.is_left_only(&predicate) {
444            left_push.push(predicate);
445        } else if right_preserved && checker.is_right_only(&predicate) {
446            right_push.push(predicate);
447        } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
448            // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
449            // and convert to the join on condition
450            join_conditions.push(predicate);
451        } else {
452            keep_predicates.push(predicate);
453        }
454    }
455
456    // For infer predicates, if they can not push through join, just drop them
457    for predicate in inferred_join_predicates {
458        if left_preserved && checker.is_left_only(&predicate) {
459            left_push.push(predicate);
460        } else if right_preserved && checker.is_right_only(&predicate) {
461            right_push.push(predicate);
462        }
463    }
464
465    let mut on_filter_join_conditions = vec![];
466    let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
467
468    if !on_filter.is_empty() {
469        for on in on_filter {
470            if on_left_preserved && checker.is_left_only(&on) {
471                left_push.push(on)
472            } else if on_right_preserved && checker.is_right_only(&on) {
473                right_push.push(on)
474            } else {
475                on_filter_join_conditions.push(on)
476            }
477        }
478    }
479
480    // Extract from OR clause, generate new predicates for both side of join if possible.
481    // We only track the unpushable predicates above.
482    if left_preserved {
483        left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
484        left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
485    }
486    if right_preserved {
487        right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
488        right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
489    }
490
491    // For predicates from join filter, we should check with if a join side is preserved
492    // in term of join filtering.
493    if on_left_preserved {
494        left_push.extend(extract_or_clauses_for_join(
495            &on_filter_join_conditions,
496            left_schema,
497        ));
498    }
499    if on_right_preserved {
500        right_push.extend(extract_or_clauses_for_join(
501            &on_filter_join_conditions,
502            right_schema,
503        ));
504    }
505
506    if let Some(predicate) = conjunction(left_push) {
507        join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
508    }
509    if let Some(predicate) = conjunction(right_push) {
510        join.right =
511            Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
512    }
513
514    // Add any new join conditions as the non join predicates
515    join_conditions.extend(on_filter_join_conditions);
516    join.filter = conjunction(join_conditions);
517
518    // wrap the join on the filter whose predicates must be kept, if any
519    let plan = LogicalPlan::Join(join);
520    let plan = if let Some(predicate) = conjunction(keep_predicates) {
521        LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
522    } else {
523        plan
524    };
525    Ok(Transformed::yes(plan))
526}
527
528fn push_down_join(
529    join: Join,
530    parent_predicate: Option<&Expr>,
531) -> Result<Transformed<LogicalPlan>> {
532    // Split the parent predicate into individual conjunctive parts.
533    let predicates = parent_predicate
534        .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
535
536    // Extract conjunctions from the JOIN's ON filter, if present.
537    let on_filters = join
538        .filter
539        .as_ref()
540        .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
541
542    // Are there any new join predicates that can be inferred from the filter expressions?
543    let inferred_join_predicates =
544        infer_join_predicates(&join, &predicates, &on_filters)?;
545
546    if on_filters.is_empty()
547        && predicates.is_empty()
548        && inferred_join_predicates.is_empty()
549    {
550        return Ok(Transformed::no(LogicalPlan::Join(join)));
551    }
552
553    push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
554}
555
556/// Extracts any equi-join join predicates from the given filter expressions.
557///
558/// Parameters
559/// * `join` the join in question
560///
561/// * `predicates` the pushed down filter expression
562///
563/// * `on_filters` filters from the join ON clause that have not already been
564///   identified as join predicates
565///
566fn infer_join_predicates(
567    join: &Join,
568    predicates: &[Expr],
569    on_filters: &[Expr],
570) -> Result<Vec<Expr>> {
571    // Only allow both side key is column.
572    let join_col_keys = join
573        .on
574        .iter()
575        .filter_map(|(l, r)| {
576            let left_col = l.try_as_col()?;
577            let right_col = r.try_as_col()?;
578            Some((left_col, right_col))
579        })
580        .collect::<Vec<_>>();
581
582    let join_type = join.join_type;
583
584    let mut inferred_predicates = InferredPredicates::new(join_type);
585
586    infer_join_predicates_from_predicates(
587        &join_col_keys,
588        predicates,
589        &mut inferred_predicates,
590    )?;
591
592    infer_join_predicates_from_on_filters(
593        &join_col_keys,
594        join_type,
595        on_filters,
596        &mut inferred_predicates,
597    )?;
598
599    Ok(inferred_predicates.predicates)
600}
601
602/// Inferred predicates collector.
603/// When the JoinType is not Inner, we need to detect whether the inferred predicate can strictly
604/// filter out NULL, otherwise ignore it. e.g.
605/// ```text
606/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL;
607/// ```
608/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate will be pushed down to
609/// the left side, resulting in the wrong result.
610struct InferredPredicates {
611    predicates: Vec<Expr>,
612    is_inner_join: bool,
613}
614
615impl InferredPredicates {
616    fn new(join_type: JoinType) -> Self {
617        Self {
618            predicates: vec![],
619            is_inner_join: matches!(join_type, JoinType::Inner),
620        }
621    }
622
623    fn try_build_predicate(
624        &mut self,
625        predicate: Expr,
626        replace_map: &HashMap<&Column, &Column>,
627    ) -> Result<()> {
628        if self.is_inner_join
629            || matches!(
630                is_restrict_null_predicate(
631                    predicate.clone(),
632                    replace_map.keys().cloned()
633                ),
634                Ok(true)
635            )
636        {
637            self.predicates.push(replace_col(predicate, replace_map)?);
638        }
639
640        Ok(())
641    }
642}
643
644/// Infer predicates from the pushed down predicates.
645///
646/// Parameters
647/// * `join_col_keys` column pairs from the join ON clause
648///
649/// * `predicates` the pushed down predicates
650///
651/// * `inferred_predicates` the inferred results
652///
653fn infer_join_predicates_from_predicates(
654    join_col_keys: &[(&Column, &Column)],
655    predicates: &[Expr],
656    inferred_predicates: &mut InferredPredicates,
657) -> Result<()> {
658    infer_join_predicates_impl::<true, true>(
659        join_col_keys,
660        predicates,
661        inferred_predicates,
662    )
663}
664
665/// Infer predicates from the join filter.
666///
667/// Parameters
668/// * `join_col_keys` column pairs from the join ON clause
669///
670/// * `join_type` the JoinType of Join
671///
672/// * `on_filters` filters from the join ON clause that have not already been
673///   identified as join predicates
674///
675/// * `inferred_predicates` the inferred results
676///
677fn infer_join_predicates_from_on_filters(
678    join_col_keys: &[(&Column, &Column)],
679    join_type: JoinType,
680    on_filters: &[Expr],
681    inferred_predicates: &mut InferredPredicates,
682) -> Result<()> {
683    match join_type {
684        JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
685        JoinType::Inner => infer_join_predicates_impl::<true, true>(
686            join_col_keys,
687            on_filters,
688            inferred_predicates,
689        ),
690        JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
691            infer_join_predicates_impl::<true, false>(
692                join_col_keys,
693                on_filters,
694                inferred_predicates,
695            )
696        }
697        JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
698            infer_join_predicates_impl::<false, true>(
699                join_col_keys,
700                on_filters,
701                inferred_predicates,
702            )
703        }
704    }
705}
706
707/// Infer predicates from the given predicates.
708///
709/// Parameters
710/// * `join_col_keys` column pairs from the join ON clause
711///
712/// * `input_predicates` the given predicates. It can be the pushed down predicates,
713///   or it can be the filters of the Join
714///
715/// * `inferred_predicates` the inferred results
716///
717/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate can
718///   be inferred from the left table related predicate
719///
720/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate can
721///   be inferred from the right table related predicate
722///
723fn infer_join_predicates_impl<
724    const ENABLE_LEFT_TO_RIGHT: bool,
725    const ENABLE_RIGHT_TO_LEFT: bool,
726>(
727    join_col_keys: &[(&Column, &Column)],
728    input_predicates: &[Expr],
729    inferred_predicates: &mut InferredPredicates,
730) -> Result<()> {
731    for predicate in input_predicates {
732        let mut join_cols_to_replace = HashMap::new();
733
734        for &col in &predicate.column_refs() {
735            for (l, r) in join_col_keys.iter() {
736                if ENABLE_LEFT_TO_RIGHT && col == *l {
737                    join_cols_to_replace.insert(col, *r);
738                    break;
739                }
740                if ENABLE_RIGHT_TO_LEFT && col == *r {
741                    join_cols_to_replace.insert(col, *l);
742                    break;
743                }
744            }
745        }
746        if join_cols_to_replace.is_empty() {
747            continue;
748        }
749
750        inferred_predicates
751            .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
752    }
753    Ok(())
754}
755
756impl OptimizerRule for PushDownFilter {
757    fn name(&self) -> &str {
758        "push_down_filter"
759    }
760
761    fn apply_order(&self) -> Option<ApplyOrder> {
762        Some(ApplyOrder::TopDown)
763    }
764
765    fn supports_rewrite(&self) -> bool {
766        true
767    }
768
769    fn rewrite(
770        &self,
771        plan: LogicalPlan,
772        _config: &dyn OptimizerConfig,
773    ) -> Result<Transformed<LogicalPlan>> {
774        if let LogicalPlan::Join(join) = plan {
775            return push_down_join(join, None);
776        };
777
778        let plan_schema = Arc::clone(plan.schema());
779
780        let LogicalPlan::Filter(mut filter) = plan else {
781            return Ok(Transformed::no(plan));
782        };
783
784        let predicate = split_conjunction_owned(filter.predicate.clone());
785        let old_predicate_len = predicate.len();
786        let new_predicates = simplify_predicates(predicate)?;
787        if old_predicate_len != new_predicates.len() {
788            let Some(new_predicate) = conjunction(new_predicates) else {
789                // new_predicates is empty - remove the filter entirely
790                // Return the child plan without the filter
791                return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
792            };
793            filter.predicate = new_predicate;
794        }
795
796        match Arc::unwrap_or_clone(filter.input) {
797            LogicalPlan::Filter(child_filter) => {
798                let parents_predicates = split_conjunction_owned(filter.predicate);
799
800                // remove duplicated filters
801                let child_predicates = split_conjunction_owned(child_filter.predicate);
802                let new_predicates = parents_predicates
803                    .into_iter()
804                    .chain(child_predicates)
805                    // use IndexSet to remove dupes while preserving predicate order
806                    .collect::<IndexSet<_>>()
807                    .into_iter()
808                    .collect::<Vec<_>>();
809
810                let Some(new_predicate) = conjunction(new_predicates) else {
811                    return plan_err!("at least one expression exists");
812                };
813                let new_filter = LogicalPlan::Filter(Filter::try_new(
814                    new_predicate,
815                    child_filter.input,
816                )?);
817                #[allow(clippy::used_underscore_binding)]
818                self.rewrite(new_filter, _config)
819            }
820            LogicalPlan::Repartition(repartition) => {
821                let new_filter =
822                    Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
823                        .map(LogicalPlan::Filter)?;
824                insert_below(LogicalPlan::Repartition(repartition), new_filter)
825            }
826            LogicalPlan::Distinct(distinct) => {
827                let new_filter =
828                    Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
829                        .map(LogicalPlan::Filter)?;
830                insert_below(LogicalPlan::Distinct(distinct), new_filter)
831            }
832            LogicalPlan::Sort(sort) => {
833                let new_filter =
834                    Filter::try_new(filter.predicate, Arc::clone(&sort.input))
835                        .map(LogicalPlan::Filter)?;
836                insert_below(LogicalPlan::Sort(sort), new_filter)
837            }
838            LogicalPlan::SubqueryAlias(subquery_alias) => {
839                let mut replace_map = HashMap::new();
840                for (i, (qualifier, field)) in
841                    subquery_alias.input.schema().iter().enumerate()
842                {
843                    let (sub_qualifier, sub_field) =
844                        subquery_alias.schema.qualified_field(i);
845                    replace_map.insert(
846                        qualified_name(sub_qualifier, sub_field.name()),
847                        Expr::Column(Column::new(qualifier.cloned(), field.name())),
848                    );
849                }
850                let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
851
852                let new_filter = LogicalPlan::Filter(Filter::try_new(
853                    new_predicate,
854                    Arc::clone(&subquery_alias.input),
855                )?);
856                insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
857            }
858            LogicalPlan::Projection(projection) => {
859                let predicates = split_conjunction_owned(filter.predicate.clone());
860                let (new_projection, keep_predicate) =
861                    rewrite_projection(predicates, projection)?;
862                if new_projection.transformed {
863                    match keep_predicate {
864                        None => Ok(new_projection),
865                        Some(keep_predicate) => new_projection.map_data(|child_plan| {
866                            Filter::try_new(keep_predicate, Arc::new(child_plan))
867                                .map(LogicalPlan::Filter)
868                        }),
869                    }
870                } else {
871                    filter.input = Arc::new(new_projection.data);
872                    Ok(Transformed::no(LogicalPlan::Filter(filter)))
873                }
874            }
875            LogicalPlan::Unnest(mut unnest) => {
876                let predicates = split_conjunction_owned(filter.predicate.clone());
877                let mut non_unnest_predicates = vec![];
878                let mut unnest_predicates = vec![];
879                let mut unnest_struct_columns = vec![];
880
881                for idx in &unnest.struct_type_columns {
882                    let (sub_qualifier, field) =
883                        unnest.input.schema().qualified_field(*idx);
884                    let field_name = field.name().clone();
885
886                    if let DataType::Struct(children) = field.data_type() {
887                        for child in children {
888                            let child_name = child.name().clone();
889                            unnest_struct_columns.push(Column::new(
890                                sub_qualifier.cloned(),
891                                format!("{field_name}.{child_name}"),
892                            ));
893                        }
894                    }
895                }
896
897                for predicate in predicates {
898                    // collect all the Expr::Column in predicate recursively
899                    let mut accum: HashSet<Column> = HashSet::new();
900                    expr_to_columns(&predicate, &mut accum)?;
901
902                    let contains_list_columns =
903                        unnest.list_type_columns.iter().any(|(_, unnest_list)| {
904                            accum.contains(&unnest_list.output_column)
905                        });
906                    let contains_struct_columns =
907                        unnest_struct_columns.iter().any(|c| accum.contains(c));
908
909                    if contains_list_columns || contains_struct_columns {
910                        unnest_predicates.push(predicate);
911                    } else {
912                        non_unnest_predicates.push(predicate);
913                    }
914                }
915
916                // Unnest predicates should not be pushed down.
917                // If no non-unnest predicates exist, early return
918                if non_unnest_predicates.is_empty() {
919                    filter.input = Arc::new(LogicalPlan::Unnest(unnest));
920                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
921                }
922
923                // Push down non-unnest filter predicate
924                // Unnest
925                //   Unnest Input (Projection)
926                // -> rewritten to
927                // Unnest
928                //   Filter
929                //     Unnest Input (Projection)
930
931                let unnest_input = std::mem::take(&mut unnest.input);
932
933                let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
934                    conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty.
935                    unnest_input,
936                )?);
937
938                // Directly assign new filter plan as the new unnest's input.
939                // The new filter plan will go through another rewrite pass since the rule itself
940                // is applied recursively to all the child from top to down
941                let unnest_plan =
942                    insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
943
944                match conjunction(unnest_predicates) {
945                    None => Ok(unnest_plan),
946                    Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
947                        Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
948                    ))),
949                }
950            }
951            LogicalPlan::Union(ref union) => {
952                let mut inputs = Vec::with_capacity(union.inputs.len());
953                for input in &union.inputs {
954                    let mut replace_map = HashMap::new();
955                    for (i, (qualifier, field)) in input.schema().iter().enumerate() {
956                        let (union_qualifier, union_field) =
957                            union.schema.qualified_field(i);
958                        replace_map.insert(
959                            qualified_name(union_qualifier, union_field.name()),
960                            Expr::Column(Column::new(qualifier.cloned(), field.name())),
961                        );
962                    }
963
964                    let push_predicate =
965                        replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
966                    inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
967                        push_predicate,
968                        Arc::clone(input),
969                    )?)))
970                }
971                Ok(Transformed::yes(LogicalPlan::Union(Union {
972                    inputs,
973                    schema: Arc::clone(&plan_schema),
974                })))
975            }
976            LogicalPlan::Aggregate(agg) => {
977                // We can push down Predicate which in groupby_expr.
978                let group_expr_columns = agg
979                    .group_expr
980                    .iter()
981                    .map(|e| Ok(Column::from_qualified_name(e.schema_name().to_string())))
982                    .collect::<Result<HashSet<_>>>()?;
983
984                let predicates = split_conjunction_owned(filter.predicate);
985
986                let mut keep_predicates = vec![];
987                let mut push_predicates = vec![];
988                for expr in predicates {
989                    let cols = expr.column_refs();
990                    if cols.iter().all(|c| group_expr_columns.contains(c)) {
991                        push_predicates.push(expr);
992                    } else {
993                        keep_predicates.push(expr);
994                    }
995                }
996
997                // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
998                // After push, we need to replace `a+b` with Column(a)+Column(b)
999                // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
1000                let mut replace_map = HashMap::new();
1001                for expr in &agg.group_expr {
1002                    replace_map.insert(expr.schema_name().to_string(), expr.clone());
1003                }
1004                let replaced_push_predicates = push_predicates
1005                    .into_iter()
1006                    .map(|expr| replace_cols_by_name(expr, &replace_map))
1007                    .collect::<Result<Vec<_>>>()?;
1008
1009                let agg_input = Arc::clone(&agg.input);
1010                Transformed::yes(LogicalPlan::Aggregate(agg))
1011                    .transform_data(|new_plan| {
1012                        // If we have a filter to push, we push it down to the input of the aggregate
1013                        if let Some(predicate) = conjunction(replaced_push_predicates) {
1014                            let new_filter = make_filter(predicate, agg_input)?;
1015                            insert_below(new_plan, new_filter)
1016                        } else {
1017                            Ok(Transformed::no(new_plan))
1018                        }
1019                    })?
1020                    .map_data(|child_plan| {
1021                        // if there are any remaining predicates we can't push, add them
1022                        // back as a filter
1023                        if let Some(predicate) = conjunction(keep_predicates) {
1024                            make_filter(predicate, Arc::new(child_plan))
1025                        } else {
1026                            Ok(child_plan)
1027                        }
1028                    })
1029            }
1030            // Tries to push filters based on the partition key(s) of the window function(s) used.
1031            // Example:
1032            //   Before:
1033            //     Filter: (a > 1) and (b > 1) and (c > 1)
1034            //      Window: func() PARTITION BY [a] ...
1035            //   ---
1036            //   After:
1037            //     Filter: (b > 1) and (c > 1)
1038            //      Window: func() PARTITION BY [a] ...
1039            //        Filter: (a > 1)
1040            LogicalPlan::Window(window) => {
1041                // Retrieve the set of potential partition keys where we can push filters by.
1042                // Unlike aggregations, where there is only one statement per SELECT, there can be
1043                // multiple window functions, each with potentially different partition keys.
1044                // Therefore, we need to ensure that any potential partition key returned is used in
1045                // ALL window functions. Otherwise, filters cannot be pushed by through that column.
1046                let extract_partition_keys = |func: &WindowFunction| {
1047                    func.params
1048                        .partition_by
1049                        .iter()
1050                        .map(|c| Column::from_qualified_name(c.schema_name().to_string()))
1051                        .collect::<HashSet<_>>()
1052                };
1053                let potential_partition_keys = window
1054                    .window_expr
1055                    .iter()
1056                    .map(|e| {
1057                        match e {
1058                            Expr::WindowFunction(window_func) => {
1059                                extract_partition_keys(window_func)
1060                            }
1061                            Expr::Alias(alias) => {
1062                                if let Expr::WindowFunction(window_func) =
1063                                    alias.expr.as_ref()
1064                                {
1065                                    extract_partition_keys(window_func)
1066                                } else {
1067                                    // window functions expressions are only Expr::WindowFunction
1068                                    unreachable!()
1069                                }
1070                            }
1071                            _ => {
1072                                // window functions expressions are only Expr::WindowFunction
1073                                unreachable!()
1074                            }
1075                        }
1076                    })
1077                    // performs the set intersection of the partition keys of all window functions,
1078                    // returning only the common ones
1079                    .reduce(|a, b| &a & &b)
1080                    .unwrap_or_default();
1081
1082                let predicates = split_conjunction_owned(filter.predicate);
1083                let mut keep_predicates = vec![];
1084                let mut push_predicates = vec![];
1085                for expr in predicates {
1086                    let cols = expr.column_refs();
1087                    if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1088                        push_predicates.push(expr);
1089                    } else {
1090                        keep_predicates.push(expr);
1091                    }
1092                }
1093
1094                // Unlike with aggregations, there are no cases where we have to replace, e.g.,
1095                // `a+b` with Column(a)+Column(b). This is because partition expressions are not
1096                // available as standalone columns to the user. For example, while an aggregation on
1097                // `a+b` becomes Column(a + b), in a window partition it becomes
1098                // `func() PARTITION BY [a + b] ...`. Thus, filters on expressions always remain in
1099                // place, so we can use `push_predicates` directly. This is consistent with other
1100                // optimizers, such as the one used by Postgres.
1101
1102                let window_input = Arc::clone(&window.input);
1103                Transformed::yes(LogicalPlan::Window(window))
1104                    .transform_data(|new_plan| {
1105                        // If we have a filter to push, we push it down to the input of the window
1106                        if let Some(predicate) = conjunction(push_predicates) {
1107                            let new_filter = make_filter(predicate, window_input)?;
1108                            insert_below(new_plan, new_filter)
1109                        } else {
1110                            Ok(Transformed::no(new_plan))
1111                        }
1112                    })?
1113                    .map_data(|child_plan| {
1114                        // if there are any remaining predicates we can't push, add them
1115                        // back as a filter
1116                        if let Some(predicate) = conjunction(keep_predicates) {
1117                            make_filter(predicate, Arc::new(child_plan))
1118                        } else {
1119                            Ok(child_plan)
1120                        }
1121                    })
1122            }
1123            LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1124            LogicalPlan::TableScan(scan) => {
1125                let filter_predicates = split_conjunction(&filter.predicate);
1126
1127                let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1128                    filter_predicates
1129                        .into_iter()
1130                        .partition(|pred| pred.is_volatile());
1131
1132                // Check which non-volatile filters are supported by source
1133                let supported_filters = scan
1134                    .source
1135                    .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1136                if non_volatile_filters.len() != supported_filters.len() {
1137                    return internal_err!(
1138                        "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1139                        supported_filters.len(),
1140                        non_volatile_filters.len());
1141                }
1142
1143                // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type
1144                let zip = non_volatile_filters.into_iter().zip(supported_filters);
1145
1146                let new_scan_filters = zip
1147                    .clone()
1148                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1149                    .map(|(pred, _)| pred);
1150
1151                // Add new scan filters
1152                let new_scan_filters: Vec<Expr> = scan
1153                    .filters
1154                    .iter()
1155                    .chain(new_scan_filters)
1156                    .unique()
1157                    .cloned()
1158                    .collect();
1159
1160                // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
1161                let new_predicate: Vec<Expr> = zip
1162                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1163                    .map(|(pred, _)| pred)
1164                    .chain(volatile_filters)
1165                    .cloned()
1166                    .collect();
1167
1168                let new_scan = LogicalPlan::TableScan(TableScan {
1169                    filters: new_scan_filters,
1170                    ..scan
1171                });
1172
1173                Transformed::yes(new_scan).transform_data(|new_scan| {
1174                    if let Some(predicate) = conjunction(new_predicate) {
1175                        make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1176                    } else {
1177                        Ok(Transformed::no(new_scan))
1178                    }
1179                })
1180            }
1181            LogicalPlan::Extension(extension_plan) => {
1182                // This check prevents the Filter from being removed when the extension node has no children,
1183                // so we return the original Filter unchanged.
1184                if extension_plan.node.inputs().is_empty() {
1185                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1186                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1187                }
1188                let prevent_cols =
1189                    extension_plan.node.prevent_predicate_push_down_columns();
1190
1191                // determine if we can push any predicates down past the extension node
1192
1193                // each element is true for push, false to keep
1194                let predicate_push_or_keep = split_conjunction(&filter.predicate)
1195                    .iter()
1196                    .map(|expr| {
1197                        let cols = expr.column_refs();
1198                        if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1199                            Ok(false) // No push (keep)
1200                        } else {
1201                            Ok(true) // push
1202                        }
1203                    })
1204                    .collect::<Result<Vec<_>>>()?;
1205
1206                // all predicates are kept, no changes needed
1207                if predicate_push_or_keep.iter().all(|&x| !x) {
1208                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1209                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1210                }
1211
1212                // going to push some predicates down, so split the predicates
1213                let mut keep_predicates = vec![];
1214                let mut push_predicates = vec![];
1215                for (push, expr) in predicate_push_or_keep
1216                    .into_iter()
1217                    .zip(split_conjunction_owned(filter.predicate).into_iter())
1218                {
1219                    if !push {
1220                        keep_predicates.push(expr);
1221                    } else {
1222                        push_predicates.push(expr);
1223                    }
1224                }
1225
1226                let new_children = match conjunction(push_predicates) {
1227                    Some(predicate) => extension_plan
1228                        .node
1229                        .inputs()
1230                        .into_iter()
1231                        .map(|child| {
1232                            Ok(LogicalPlan::Filter(Filter::try_new(
1233                                predicate.clone(),
1234                                Arc::new(child.clone()),
1235                            )?))
1236                        })
1237                        .collect::<Result<Vec<_>>>()?,
1238                    None => extension_plan.node.inputs().into_iter().cloned().collect(),
1239                };
1240                // extension with new inputs.
1241                let child_plan = LogicalPlan::Extension(extension_plan);
1242                let new_extension =
1243                    child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1244
1245                let new_plan = match conjunction(keep_predicates) {
1246                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1247                        predicate,
1248                        Arc::new(new_extension),
1249                    )?),
1250                    None => new_extension,
1251                };
1252                Ok(Transformed::yes(new_plan))
1253            }
1254            child => {
1255                filter.input = Arc::new(child);
1256                Ok(Transformed::no(LogicalPlan::Filter(filter)))
1257            }
1258        }
1259    }
1260}
1261
1262/// Attempts to push `predicate` into a `FilterExec` below `projection
1263///
1264/// # Returns
1265/// (plan, remaining_predicate)
1266///
1267/// `plan` is a LogicalPlan for `projection` with possibly a new FilterExec below it.
1268/// `remaining_predicate` is any part of the predicate that could not be pushed down
1269///
1270/// # Args
1271/// - predicates: Split predicates like `[foo=5, bar=6]`
1272/// - projection: The target projection plan to push down the predicates
1273///
1274/// # Example
1275///
1276/// Pushing a predicate like `foo=5 AND bar=6` with an input plan like this:
1277///
1278/// ```text
1279/// Projection(foo, c+d as bar)
1280/// ```
1281///
1282/// Might result in returning `remaining_predicate` of `bar=6` and a plan like
1283///
1284/// ```text
1285/// Projection(foo, c+d as bar)
1286///  Filter(foo=5)
1287///   ...
1288/// ```
1289fn rewrite_projection(
1290    predicates: Vec<Expr>,
1291    mut projection: Projection,
1292) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1293    // A projection is filter-commutable if it do not contain volatile predicates or contain volatile
1294    // predicates that are not used in the filter. However, we should re-writes all predicate expressions.
1295    // collect projection.
1296    let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
1297        .schema
1298        .iter()
1299        .zip(projection.expr.iter())
1300        .map(|((qualifier, field), expr)| {
1301            // strip alias, as they should not be part of filters
1302            let expr = expr.clone().unalias();
1303
1304            (qualified_name(qualifier, field.name()), expr)
1305        })
1306        .partition(|(_, value)| value.is_volatile());
1307
1308    let mut push_predicates = vec![];
1309    let mut keep_predicates = vec![];
1310    for expr in predicates {
1311        if contain(&expr, &volatile_map) {
1312            keep_predicates.push(expr);
1313        } else {
1314            push_predicates.push(expr);
1315        }
1316    }
1317
1318    match conjunction(push_predicates) {
1319        Some(expr) => {
1320            // re-write all filters based on this projection
1321            // E.g. in `Filter: b\n  Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
1322            let new_filter = LogicalPlan::Filter(Filter::try_new(
1323                replace_cols_by_name(expr, &non_volatile_map)?,
1324                std::mem::take(&mut projection.input),
1325            )?);
1326
1327            projection.input = Arc::new(new_filter);
1328
1329            Ok((
1330                Transformed::yes(LogicalPlan::Projection(projection)),
1331                conjunction(keep_predicates),
1332            ))
1333        }
1334        None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
1335    }
1336}
1337
1338/// Creates a new LogicalPlan::Filter node.
1339pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1340    Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1341}
1342
1343/// Replace the existing child of the single input node with `new_child`.
1344///
1345/// Starting:
1346/// ```text
1347/// plan
1348///   child
1349/// ```
1350///
1351/// Ending:
1352/// ```text
1353/// plan
1354///   new_child
1355/// ```
1356fn insert_below(
1357    plan: LogicalPlan,
1358    new_child: LogicalPlan,
1359) -> Result<Transformed<LogicalPlan>> {
1360    let mut new_child = Some(new_child);
1361    let transformed_plan = plan.map_children(|_child| {
1362        if let Some(new_child) = new_child.take() {
1363            Ok(Transformed::yes(new_child))
1364        } else {
1365            // already took the new child
1366            internal_err!("node had more than one input")
1367        }
1368    })?;
1369
1370    // make sure we did the actual replacement
1371    if new_child.is_some() {
1372        return internal_err!("node had no  inputs");
1373    }
1374
1375    Ok(transformed_plan)
1376}
1377
1378impl PushDownFilter {
1379    #[allow(missing_docs)]
1380    pub fn new() -> Self {
1381        Self {}
1382    }
1383}
1384
1385/// replaces columns by its name on the projection.
1386pub fn replace_cols_by_name(
1387    e: Expr,
1388    replace_map: &HashMap<String, Expr>,
1389) -> Result<Expr> {
1390    e.transform_up(|expr| {
1391        Ok(if let Expr::Column(c) = &expr {
1392            match replace_map.get(&c.flat_name()) {
1393                Some(new_c) => Transformed::yes(new_c.clone()),
1394                None => Transformed::no(expr),
1395            }
1396        } else {
1397            Transformed::no(expr)
1398        })
1399    })
1400    .data()
1401}
1402
1403/// check whether the expression uses the columns in `check_map`.
1404fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1405    let mut is_contain = false;
1406    e.apply(|expr| {
1407        Ok(if let Expr::Column(c) = &expr {
1408            match check_map.get(&c.flat_name()) {
1409                Some(_) => {
1410                    is_contain = true;
1411                    TreeNodeRecursion::Stop
1412                }
1413                None => TreeNodeRecursion::Continue,
1414            }
1415        } else {
1416            TreeNodeRecursion::Continue
1417        })
1418    })
1419    .unwrap();
1420    is_contain
1421}
1422
1423#[cfg(test)]
1424mod tests {
1425    use std::any::Any;
1426    use std::cmp::Ordering;
1427    use std::fmt::{Debug, Formatter};
1428
1429    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1430    use async_trait::async_trait;
1431
1432    use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1433    use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1434    use datafusion_expr::logical_plan::table_scan;
1435    use datafusion_expr::{
1436        col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension,
1437        LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1438        TableSource, TableType, UserDefinedLogicalNodeCore, Volatility,
1439        WindowFunctionDefinition,
1440    };
1441
1442    use crate::assert_optimized_plan_eq_snapshot;
1443    use crate::optimizer::Optimizer;
1444    use crate::simplify_expressions::SimplifyExpressions;
1445    use crate::test::*;
1446    use crate::OptimizerContext;
1447    use datafusion_expr::test::function_stub::sum;
1448    use insta::assert_snapshot;
1449
1450    use super::*;
1451
1452    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1453
1454    macro_rules! assert_optimized_plan_equal {
1455        (
1456            $plan:expr,
1457            @ $expected:literal $(,)?
1458        ) => {{
1459            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1460            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1461            assert_optimized_plan_eq_snapshot!(
1462                optimizer_ctx,
1463                rules,
1464                $plan,
1465                @ $expected,
1466            )
1467        }};
1468    }
1469
1470    macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1471        (
1472            $plan:expr,
1473            @ $expected:literal $(,)?
1474        ) => {{
1475            let optimizer = Optimizer::with_rules(vec![
1476                Arc::new(SimplifyExpressions::new()),
1477                Arc::new(PushDownFilter::new()),
1478            ]);
1479            let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1480            assert_snapshot!(optimized_plan, @ $expected);
1481            Ok::<(), DataFusionError>(())
1482        }};
1483    }
1484
1485    #[test]
1486    fn filter_before_projection() -> Result<()> {
1487        let table_scan = test_table_scan()?;
1488        let plan = LogicalPlanBuilder::from(table_scan)
1489            .project(vec![col("a"), col("b")])?
1490            .filter(col("a").eq(lit(1i64)))?
1491            .build()?;
1492        // filter is before projection
1493        assert_optimized_plan_equal!(
1494            plan,
1495            @r"
1496        Projection: test.a, test.b
1497          TableScan: test, full_filters=[test.a = Int64(1)]
1498        "
1499        )
1500    }
1501
1502    #[test]
1503    fn filter_after_limit() -> Result<()> {
1504        let table_scan = test_table_scan()?;
1505        let plan = LogicalPlanBuilder::from(table_scan)
1506            .project(vec![col("a"), col("b")])?
1507            .limit(0, Some(10))?
1508            .filter(col("a").eq(lit(1i64)))?
1509            .build()?;
1510        // filter is before single projection
1511        assert_optimized_plan_equal!(
1512            plan,
1513            @r"
1514        Filter: test.a = Int64(1)
1515          Limit: skip=0, fetch=10
1516            Projection: test.a, test.b
1517              TableScan: test
1518        "
1519        )
1520    }
1521
1522    #[test]
1523    fn filter_no_columns() -> Result<()> {
1524        let table_scan = test_table_scan()?;
1525        let plan = LogicalPlanBuilder::from(table_scan)
1526            .filter(lit(0i64).eq(lit(1i64)))?
1527            .build()?;
1528        assert_optimized_plan_equal!(
1529            plan,
1530            @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1531        )
1532    }
1533
1534    #[test]
1535    fn filter_jump_2_plans() -> Result<()> {
1536        let table_scan = test_table_scan()?;
1537        let plan = LogicalPlanBuilder::from(table_scan)
1538            .project(vec![col("a"), col("b"), col("c")])?
1539            .project(vec![col("c"), col("b")])?
1540            .filter(col("a").eq(lit(1i64)))?
1541            .build()?;
1542        // filter is before double projection
1543        assert_optimized_plan_equal!(
1544            plan,
1545            @r"
1546        Projection: test.c, test.b
1547          Projection: test.a, test.b, test.c
1548            TableScan: test, full_filters=[test.a = Int64(1)]
1549        "
1550        )
1551    }
1552
1553    #[test]
1554    fn filter_move_agg() -> Result<()> {
1555        let table_scan = test_table_scan()?;
1556        let plan = LogicalPlanBuilder::from(table_scan)
1557            .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1558            .filter(col("a").gt(lit(10i64)))?
1559            .build()?;
1560        // filter of key aggregation is commutative
1561        assert_optimized_plan_equal!(
1562            plan,
1563            @r"
1564        Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1565          TableScan: test, full_filters=[test.a > Int64(10)]
1566        "
1567        )
1568    }
1569
1570    #[test]
1571    fn filter_complex_group_by() -> Result<()> {
1572        let table_scan = test_table_scan()?;
1573        let plan = LogicalPlanBuilder::from(table_scan)
1574            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1575            .filter(col("b").gt(lit(10i64)))?
1576            .build()?;
1577        assert_optimized_plan_equal!(
1578            plan,
1579            @r"
1580        Filter: test.b > Int64(10)
1581          Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1582            TableScan: test
1583        "
1584        )
1585    }
1586
1587    #[test]
1588    fn push_agg_need_replace_expr() -> Result<()> {
1589        let plan = LogicalPlanBuilder::from(test_table_scan()?)
1590            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1591            .filter(col("test.b + test.a").gt(lit(10i64)))?
1592            .build()?;
1593        assert_optimized_plan_equal!(
1594            plan,
1595            @r"
1596        Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1597          TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1598        "
1599        )
1600    }
1601
1602    #[test]
1603    fn filter_keep_agg() -> Result<()> {
1604        let table_scan = test_table_scan()?;
1605        let plan = LogicalPlanBuilder::from(table_scan)
1606            .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1607            .filter(col("b").gt(lit(10i64)))?
1608            .build()?;
1609        // filter of aggregate is after aggregation since they are non-commutative
1610        assert_optimized_plan_equal!(
1611            plan,
1612            @r"
1613        Filter: b > Int64(10)
1614          Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1615            TableScan: test
1616        "
1617        )
1618    }
1619
1620    /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed
1621    #[test]
1622    fn filter_move_window() -> Result<()> {
1623        let table_scan = test_table_scan()?;
1624
1625        let window = Expr::from(WindowFunction::new(
1626            WindowFunctionDefinition::WindowUDF(
1627                datafusion_functions_window::rank::rank_udwf(),
1628            ),
1629            vec![],
1630        ))
1631        .partition_by(vec![col("a"), col("b")])
1632        .order_by(vec![col("c").sort(true, true)])
1633        .build()
1634        .unwrap();
1635
1636        let plan = LogicalPlanBuilder::from(table_scan)
1637            .window(vec![window])?
1638            .filter(col("b").gt(lit(10i64)))?
1639            .build()?;
1640
1641        assert_optimized_plan_equal!(
1642            plan,
1643            @r"
1644        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1645          TableScan: test, full_filters=[test.b > Int64(10)]
1646        "
1647        )
1648    }
1649
1650    /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and
1651    /// 'b' are pushed
1652    #[test]
1653    fn filter_move_complex_window() -> Result<()> {
1654        let table_scan = test_table_scan()?;
1655
1656        let window = Expr::from(WindowFunction::new(
1657            WindowFunctionDefinition::WindowUDF(
1658                datafusion_functions_window::rank::rank_udwf(),
1659            ),
1660            vec![],
1661        ))
1662        .partition_by(vec![col("a"), col("b")])
1663        .order_by(vec![col("c").sort(true, true)])
1664        .build()
1665        .unwrap();
1666
1667        let plan = LogicalPlanBuilder::from(table_scan)
1668            .window(vec![window])?
1669            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1670            .build()?;
1671
1672        assert_optimized_plan_equal!(
1673            plan,
1674            @r"
1675        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1676          TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1677        "
1678        )
1679    }
1680
1681    /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed
1682    #[test]
1683    fn filter_move_partial_window() -> Result<()> {
1684        let table_scan = test_table_scan()?;
1685
1686        let window = Expr::from(WindowFunction::new(
1687            WindowFunctionDefinition::WindowUDF(
1688                datafusion_functions_window::rank::rank_udwf(),
1689            ),
1690            vec![],
1691        ))
1692        .partition_by(vec![col("a")])
1693        .order_by(vec![col("c").sort(true, true)])
1694        .build()
1695        .unwrap();
1696
1697        let plan = LogicalPlanBuilder::from(table_scan)
1698            .window(vec![window])?
1699            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1700            .build()?;
1701
1702        assert_optimized_plan_equal!(
1703            plan,
1704            @r"
1705        Filter: test.b = Int64(1)
1706          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1707            TableScan: test, full_filters=[test.a > Int64(10)]
1708        "
1709        )
1710    }
1711
1712    /// verifies that filters on partition expressions are not pushed, as the single expression
1713    /// column is not available to the user, unlike with aggregations
1714    #[test]
1715    fn filter_expression_keep_window() -> Result<()> {
1716        let table_scan = test_table_scan()?;
1717
1718        let window = Expr::from(WindowFunction::new(
1719            WindowFunctionDefinition::WindowUDF(
1720                datafusion_functions_window::rank::rank_udwf(),
1721            ),
1722            vec![],
1723        ))
1724        .partition_by(vec![add(col("a"), col("b"))]) // PARTITION BY a + b
1725        .order_by(vec![col("c").sort(true, true)])
1726        .build()
1727        .unwrap();
1728
1729        let plan = LogicalPlanBuilder::from(table_scan)
1730            .window(vec![window])?
1731            // unlike with aggregations, single partition column "test.a + test.b" is not available
1732            // to the plan, so we use multiple columns when filtering
1733            .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1734            .build()?;
1735
1736        assert_optimized_plan_equal!(
1737            plan,
1738            @r"
1739        Filter: test.a + test.b > Int64(10)
1740          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1741            TableScan: test
1742        "
1743        )
1744    }
1745
1746    /// verifies that filters are not pushed on order by columns (that are not used in partitioning)
1747    #[test]
1748    fn filter_order_keep_window() -> Result<()> {
1749        let table_scan = test_table_scan()?;
1750
1751        let window = Expr::from(WindowFunction::new(
1752            WindowFunctionDefinition::WindowUDF(
1753                datafusion_functions_window::rank::rank_udwf(),
1754            ),
1755            vec![],
1756        ))
1757        .partition_by(vec![col("a")])
1758        .order_by(vec![col("c").sort(true, true)])
1759        .build()
1760        .unwrap();
1761
1762        let plan = LogicalPlanBuilder::from(table_scan)
1763            .window(vec![window])?
1764            .filter(col("c").gt(lit(10i64)))?
1765            .build()?;
1766
1767        assert_optimized_plan_equal!(
1768            plan,
1769            @r"
1770        Filter: test.c > Int64(10)
1771          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1772            TableScan: test
1773        "
1774        )
1775    }
1776
1777    /// verifies that when we use multiple window functions with a common partition key, the filter
1778    /// on that key is pushed
1779    #[test]
1780    fn filter_multiple_windows_common_partitions() -> Result<()> {
1781        let table_scan = test_table_scan()?;
1782
1783        let window1 = Expr::from(WindowFunction::new(
1784            WindowFunctionDefinition::WindowUDF(
1785                datafusion_functions_window::rank::rank_udwf(),
1786            ),
1787            vec![],
1788        ))
1789        .partition_by(vec![col("a")])
1790        .order_by(vec![col("c").sort(true, true)])
1791        .build()
1792        .unwrap();
1793
1794        let window2 = Expr::from(WindowFunction::new(
1795            WindowFunctionDefinition::WindowUDF(
1796                datafusion_functions_window::rank::rank_udwf(),
1797            ),
1798            vec![],
1799        ))
1800        .partition_by(vec![col("b"), col("a")])
1801        .order_by(vec![col("c").sort(true, true)])
1802        .build()
1803        .unwrap();
1804
1805        let plan = LogicalPlanBuilder::from(table_scan)
1806            .window(vec![window1, window2])?
1807            .filter(col("a").gt(lit(10i64)))? // a appears in both window functions
1808            .build()?;
1809
1810        assert_optimized_plan_equal!(
1811            plan,
1812            @r"
1813        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1814          TableScan: test, full_filters=[test.a > Int64(10)]
1815        "
1816        )
1817    }
1818
1819    /// verifies that when we use multiple window functions with different partitions keys, the
1820    /// filter cannot be pushed
1821    #[test]
1822    fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1823        let table_scan = test_table_scan()?;
1824
1825        let window1 = Expr::from(WindowFunction::new(
1826            WindowFunctionDefinition::WindowUDF(
1827                datafusion_functions_window::rank::rank_udwf(),
1828            ),
1829            vec![],
1830        ))
1831        .partition_by(vec![col("a")])
1832        .order_by(vec![col("c").sort(true, true)])
1833        .build()
1834        .unwrap();
1835
1836        let window2 = Expr::from(WindowFunction::new(
1837            WindowFunctionDefinition::WindowUDF(
1838                datafusion_functions_window::rank::rank_udwf(),
1839            ),
1840            vec![],
1841        ))
1842        .partition_by(vec![col("b"), col("a")])
1843        .order_by(vec![col("c").sort(true, true)])
1844        .build()
1845        .unwrap();
1846
1847        let plan = LogicalPlanBuilder::from(table_scan)
1848            .window(vec![window1, window2])?
1849            .filter(col("b").gt(lit(10i64)))? // b only appears in one window function
1850            .build()?;
1851
1852        assert_optimized_plan_equal!(
1853            plan,
1854            @r"
1855        Filter: test.b > Int64(10)
1856          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1857            TableScan: test
1858        "
1859        )
1860    }
1861
1862    /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
1863    #[test]
1864    fn alias() -> Result<()> {
1865        let table_scan = test_table_scan()?;
1866        let plan = LogicalPlanBuilder::from(table_scan)
1867            .project(vec![col("a").alias("b"), col("c")])?
1868            .filter(col("b").eq(lit(1i64)))?
1869            .build()?;
1870        // filter is before projection
1871        assert_optimized_plan_equal!(
1872            plan,
1873            @r"
1874        Projection: test.a AS b, test.c
1875          TableScan: test, full_filters=[test.a = Int64(1)]
1876        "
1877        )
1878    }
1879
1880    fn add(left: Expr, right: Expr) -> Expr {
1881        Expr::BinaryExpr(BinaryExpr::new(
1882            Box::new(left),
1883            Operator::Plus,
1884            Box::new(right),
1885        ))
1886    }
1887
1888    fn multiply(left: Expr, right: Expr) -> Expr {
1889        Expr::BinaryExpr(BinaryExpr::new(
1890            Box::new(left),
1891            Operator::Multiply,
1892            Box::new(right),
1893        ))
1894    }
1895
1896    /// verifies that a filter is pushed to before a projection with a complex expression, the filter expression is correctly re-written
1897    #[test]
1898    fn complex_expression() -> Result<()> {
1899        let table_scan = test_table_scan()?;
1900        let plan = LogicalPlanBuilder::from(table_scan)
1901            .project(vec![
1902                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1903                col("c"),
1904            ])?
1905            .filter(col("b").eq(lit(1i64)))?
1906            .build()?;
1907
1908        // not part of the test, just good to know:
1909        assert_snapshot!(plan,
1910        @r"
1911        Filter: b = Int64(1)
1912          Projection: test.a * Int32(2) + test.c AS b, test.c
1913            TableScan: test
1914        ",
1915        );
1916        // filter is before projection
1917        assert_optimized_plan_equal!(
1918            plan,
1919            @r"
1920        Projection: test.a * Int32(2) + test.c AS b, test.c
1921          TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1922        "
1923        )
1924    }
1925
1926    /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written
1927    #[test]
1928    fn complex_plan() -> Result<()> {
1929        let table_scan = test_table_scan()?;
1930        let plan = LogicalPlanBuilder::from(table_scan)
1931            .project(vec![
1932                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1933                col("c"),
1934            ])?
1935            // second projection where we rename columns, just to make it difficult
1936            .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
1937            .filter(col("a").eq(lit(1i64)))?
1938            .build()?;
1939
1940        // not part of the test, just good to know:
1941        assert_snapshot!(plan,
1942        @r"
1943        Filter: a = Int64(1)
1944          Projection: b * Int32(3) AS a, test.c
1945            Projection: test.a * Int32(2) + test.c AS b, test.c
1946              TableScan: test
1947        ",
1948        );
1949        // filter is before the projections
1950        assert_optimized_plan_equal!(
1951            plan,
1952            @r"
1953        Projection: b * Int32(3) AS a, test.c
1954          Projection: test.a * Int32(2) + test.c AS b, test.c
1955            TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
1956        "
1957        )
1958    }
1959
1960    #[derive(Debug, PartialEq, Eq, Hash)]
1961    struct NoopPlan {
1962        input: Vec<LogicalPlan>,
1963        schema: DFSchemaRef,
1964    }
1965
1966    // Manual implementation needed because of `schema` field. Comparison excludes this field.
1967    impl PartialOrd for NoopPlan {
1968        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1969            self.input.partial_cmp(&other.input)
1970        }
1971    }
1972
1973    impl UserDefinedLogicalNodeCore for NoopPlan {
1974        fn name(&self) -> &str {
1975            "NoopPlan"
1976        }
1977
1978        fn inputs(&self) -> Vec<&LogicalPlan> {
1979            self.input.iter().collect()
1980        }
1981
1982        fn schema(&self) -> &DFSchemaRef {
1983            &self.schema
1984        }
1985
1986        fn expressions(&self) -> Vec<Expr> {
1987            self.input
1988                .iter()
1989                .flat_map(|child| child.expressions())
1990                .collect()
1991        }
1992
1993        fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
1994            HashSet::from_iter(vec!["c".to_string()])
1995        }
1996
1997        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1998            write!(f, "NoopPlan")
1999        }
2000
2001        fn with_exprs_and_inputs(
2002            &self,
2003            _exprs: Vec<Expr>,
2004            inputs: Vec<LogicalPlan>,
2005        ) -> Result<Self> {
2006            Ok(Self {
2007                input: inputs,
2008                schema: Arc::clone(&self.schema),
2009            })
2010        }
2011
2012        fn supports_limit_pushdown(&self) -> bool {
2013            false // Disallow limit push-down by default
2014        }
2015    }
2016
2017    #[test]
2018    fn user_defined_plan() -> Result<()> {
2019        let table_scan = test_table_scan()?;
2020
2021        let custom_plan = LogicalPlan::Extension(Extension {
2022            node: Arc::new(NoopPlan {
2023                input: vec![table_scan.clone()],
2024                schema: Arc::clone(table_scan.schema()),
2025            }),
2026        });
2027        let plan = LogicalPlanBuilder::from(custom_plan)
2028            .filter(col("a").eq(lit(1i64)))?
2029            .build()?;
2030
2031        // Push filter below NoopPlan
2032        assert_optimized_plan_equal!(
2033            plan,
2034            @r"
2035        NoopPlan
2036          TableScan: test, full_filters=[test.a = Int64(1)]
2037        "
2038        )?;
2039
2040        let custom_plan = LogicalPlan::Extension(Extension {
2041            node: Arc::new(NoopPlan {
2042                input: vec![table_scan.clone()],
2043                schema: Arc::clone(table_scan.schema()),
2044            }),
2045        });
2046        let plan = LogicalPlanBuilder::from(custom_plan)
2047            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2048            .build()?;
2049
2050        // Push only predicate on `a` below NoopPlan
2051        assert_optimized_plan_equal!(
2052            plan,
2053            @r"
2054        Filter: test.c = Int64(2)
2055          NoopPlan
2056            TableScan: test, full_filters=[test.a = Int64(1)]
2057        "
2058        )?;
2059
2060        let custom_plan = LogicalPlan::Extension(Extension {
2061            node: Arc::new(NoopPlan {
2062                input: vec![table_scan.clone(), table_scan.clone()],
2063                schema: Arc::clone(table_scan.schema()),
2064            }),
2065        });
2066        let plan = LogicalPlanBuilder::from(custom_plan)
2067            .filter(col("a").eq(lit(1i64)))?
2068            .build()?;
2069
2070        // Push filter below NoopPlan for each child branch
2071        assert_optimized_plan_equal!(
2072            plan,
2073            @r"
2074        NoopPlan
2075          TableScan: test, full_filters=[test.a = Int64(1)]
2076          TableScan: test, full_filters=[test.a = Int64(1)]
2077        "
2078        )?;
2079
2080        let custom_plan = LogicalPlan::Extension(Extension {
2081            node: Arc::new(NoopPlan {
2082                input: vec![table_scan.clone(), table_scan.clone()],
2083                schema: Arc::clone(table_scan.schema()),
2084            }),
2085        });
2086        let plan = LogicalPlanBuilder::from(custom_plan)
2087            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2088            .build()?;
2089
2090        // Push only predicate on `a` below NoopPlan
2091        assert_optimized_plan_equal!(
2092            plan,
2093            @r"
2094        Filter: test.c = Int64(2)
2095          NoopPlan
2096            TableScan: test, full_filters=[test.a = Int64(1)]
2097            TableScan: test, full_filters=[test.a = Int64(1)]
2098        "
2099        )
2100    }
2101
2102    /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed
2103    /// and the other not.
2104    #[test]
2105    fn multi_filter() -> Result<()> {
2106        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2107        let table_scan = test_table_scan()?;
2108        let plan = LogicalPlanBuilder::from(table_scan)
2109            .project(vec![col("a").alias("b"), col("c")])?
2110            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2111            .filter(col("b").gt(lit(10i64)))?
2112            .filter(col("sum(test.c)").gt(lit(10i64)))?
2113            .build()?;
2114
2115        // not part of the test, just good to know:
2116        assert_snapshot!(plan,
2117        @r"
2118        Filter: sum(test.c) > Int64(10)
2119          Filter: b > Int64(10)
2120            Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2121              Projection: test.a AS b, test.c
2122                TableScan: test
2123        ",
2124        );
2125        // filter is before the projections
2126        assert_optimized_plan_equal!(
2127            plan,
2128            @r"
2129        Filter: sum(test.c) > Int64(10)
2130          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2131            Projection: test.a AS b, test.c
2132              TableScan: test, full_filters=[test.a > Int64(10)]
2133        "
2134        )
2135    }
2136
2137    /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed
2138    /// and the other not.
2139    #[test]
2140    fn split_filter() -> Result<()> {
2141        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2142        let table_scan = test_table_scan()?;
2143        let plan = LogicalPlanBuilder::from(table_scan)
2144            .project(vec![col("a").alias("b"), col("c")])?
2145            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2146            .filter(and(
2147                col("sum(test.c)").gt(lit(10i64)),
2148                and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2149            ))?
2150            .build()?;
2151
2152        // not part of the test, just good to know:
2153        assert_snapshot!(plan,
2154        @r"
2155        Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2156          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2157            Projection: test.a AS b, test.c
2158              TableScan: test
2159        ",
2160        );
2161        // filter is before the projections
2162        assert_optimized_plan_equal!(
2163            plan,
2164            @r"
2165        Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2166          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2167            Projection: test.a AS b, test.c
2168              TableScan: test, full_filters=[test.a > Int64(10)]
2169        "
2170        )
2171    }
2172
2173    /// verifies that when two limits are in place, we jump neither
2174    #[test]
2175    fn double_limit() -> Result<()> {
2176        let table_scan = test_table_scan()?;
2177        let plan = LogicalPlanBuilder::from(table_scan)
2178            .project(vec![col("a"), col("b")])?
2179            .limit(0, Some(20))?
2180            .limit(0, Some(10))?
2181            .project(vec![col("a"), col("b")])?
2182            .filter(col("a").eq(lit(1i64)))?
2183            .build()?;
2184        // filter does not just any of the limits
2185        assert_optimized_plan_equal!(
2186            plan,
2187            @r"
2188        Projection: test.a, test.b
2189          Filter: test.a = Int64(1)
2190            Limit: skip=0, fetch=10
2191              Limit: skip=0, fetch=20
2192                Projection: test.a, test.b
2193                  TableScan: test
2194        "
2195        )
2196    }
2197
2198    #[test]
2199    fn union_all() -> Result<()> {
2200        let table_scan = test_table_scan()?;
2201        let table_scan2 = test_table_scan_with_name("test2")?;
2202        let plan = LogicalPlanBuilder::from(table_scan)
2203            .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2204            .filter(col("a").eq(lit(1i64)))?
2205            .build()?;
2206        // filter appears below Union
2207        assert_optimized_plan_equal!(
2208            plan,
2209            @r"
2210        Union
2211          TableScan: test, full_filters=[test.a = Int64(1)]
2212          TableScan: test2, full_filters=[test2.a = Int64(1)]
2213        "
2214        )
2215    }
2216
2217    #[test]
2218    fn union_all_on_projection() -> Result<()> {
2219        let table_scan = test_table_scan()?;
2220        let table = LogicalPlanBuilder::from(table_scan)
2221            .project(vec![col("a").alias("b")])?
2222            .alias("test2")?;
2223
2224        let plan = table
2225            .clone()
2226            .union(table.build()?)?
2227            .filter(col("b").eq(lit(1i64)))?
2228            .build()?;
2229
2230        // filter appears below Union
2231        assert_optimized_plan_equal!(
2232            plan,
2233            @r"
2234        Union
2235          SubqueryAlias: test2
2236            Projection: test.a AS b
2237              TableScan: test, full_filters=[test.a = Int64(1)]
2238          SubqueryAlias: test2
2239            Projection: test.a AS b
2240              TableScan: test, full_filters=[test.a = Int64(1)]
2241        "
2242        )
2243    }
2244
2245    #[test]
2246    fn test_union_different_schema() -> Result<()> {
2247        let left = LogicalPlanBuilder::from(test_table_scan()?)
2248            .project(vec![col("a"), col("b"), col("c")])?
2249            .build()?;
2250
2251        let schema = Schema::new(vec![
2252            Field::new("d", DataType::UInt32, false),
2253            Field::new("e", DataType::UInt32, false),
2254            Field::new("f", DataType::UInt32, false),
2255        ]);
2256        let right = table_scan(Some("test1"), &schema, None)?
2257            .project(vec![col("d"), col("e"), col("f")])?
2258            .build()?;
2259        let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2260        let plan = LogicalPlanBuilder::from(left)
2261            .cross_join(right)?
2262            .project(vec![col("test.a"), col("test1.d")])?
2263            .filter(filter)?
2264            .build()?;
2265
2266        assert_optimized_plan_equal!(
2267            plan,
2268            @r"
2269        Projection: test.a, test1.d
2270          Cross Join: 
2271            Projection: test.a, test.b, test.c
2272              TableScan: test, full_filters=[test.a = Int32(1)]
2273            Projection: test1.d, test1.e, test1.f
2274              TableScan: test1, full_filters=[test1.d > Int32(2)]
2275        "
2276        )
2277    }
2278
2279    #[test]
2280    fn test_project_same_name_different_qualifier() -> Result<()> {
2281        let table_scan = test_table_scan()?;
2282        let left = LogicalPlanBuilder::from(table_scan)
2283            .project(vec![col("a"), col("b"), col("c")])?
2284            .build()?;
2285        let right_table_scan = test_table_scan_with_name("test1")?;
2286        let right = LogicalPlanBuilder::from(right_table_scan)
2287            .project(vec![col("a"), col("b"), col("c")])?
2288            .build()?;
2289        let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2290        let plan = LogicalPlanBuilder::from(left)
2291            .cross_join(right)?
2292            .project(vec![col("test.a"), col("test1.a")])?
2293            .filter(filter)?
2294            .build()?;
2295
2296        assert_optimized_plan_equal!(
2297            plan,
2298            @r"
2299        Projection: test.a, test1.a
2300          Cross Join: 
2301            Projection: test.a, test.b, test.c
2302              TableScan: test, full_filters=[test.a = Int32(1)]
2303            Projection: test1.a, test1.b, test1.c
2304              TableScan: test1, full_filters=[test1.a > Int32(2)]
2305        "
2306        )
2307    }
2308
2309    /// verifies that filters with the same columns are correctly placed
2310    #[test]
2311    fn filter_2_breaks_limits() -> Result<()> {
2312        let table_scan = test_table_scan()?;
2313        let plan = LogicalPlanBuilder::from(table_scan)
2314            .project(vec![col("a")])?
2315            .filter(col("a").lt_eq(lit(1i64)))?
2316            .limit(0, Some(1))?
2317            .project(vec![col("a")])?
2318            .filter(col("a").gt_eq(lit(1i64)))?
2319            .build()?;
2320        // Should be able to move both filters below the projections
2321
2322        // not part of the test
2323        assert_snapshot!(plan,
2324        @r"
2325        Filter: test.a >= Int64(1)
2326          Projection: test.a
2327            Limit: skip=0, fetch=1
2328              Filter: test.a <= Int64(1)
2329                Projection: test.a
2330                  TableScan: test
2331        ",
2332        );
2333        assert_optimized_plan_equal!(
2334            plan,
2335            @r"
2336        Projection: test.a
2337          Filter: test.a >= Int64(1)
2338            Limit: skip=0, fetch=1
2339              Projection: test.a
2340                TableScan: test, full_filters=[test.a <= Int64(1)]
2341        "
2342        )
2343    }
2344
2345    /// verifies that filters to be placed on the same depth are ANDed
2346    #[test]
2347    fn two_filters_on_same_depth() -> Result<()> {
2348        let table_scan = test_table_scan()?;
2349        let plan = LogicalPlanBuilder::from(table_scan)
2350            .limit(0, Some(1))?
2351            .filter(col("a").lt_eq(lit(1i64)))?
2352            .filter(col("a").gt_eq(lit(1i64)))?
2353            .project(vec![col("a")])?
2354            .build()?;
2355
2356        // not part of the test
2357        assert_snapshot!(plan,
2358        @r"
2359        Projection: test.a
2360          Filter: test.a >= Int64(1)
2361            Filter: test.a <= Int64(1)
2362              Limit: skip=0, fetch=1
2363                TableScan: test
2364        ",
2365        );
2366        assert_optimized_plan_equal!(
2367            plan,
2368            @r"
2369        Projection: test.a
2370          Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2371            Limit: skip=0, fetch=1
2372              TableScan: test
2373        "
2374        )
2375    }
2376
2377    /// verifies that filters on a plan with user nodes are not lost
2378    /// (ARROW-10547)
2379    #[test]
2380    fn filters_user_defined_node() -> Result<()> {
2381        let table_scan = test_table_scan()?;
2382        let plan = LogicalPlanBuilder::from(table_scan)
2383            .filter(col("a").lt_eq(lit(1i64)))?
2384            .build()?;
2385
2386        let plan = user_defined::new(plan);
2387
2388        // not part of the test
2389        assert_snapshot!(plan,
2390        @r"
2391        TestUserDefined
2392          Filter: test.a <= Int64(1)
2393            TableScan: test
2394        ",
2395        );
2396        assert_optimized_plan_equal!(
2397            plan,
2398            @r"
2399        TestUserDefined
2400          TableScan: test, full_filters=[test.a <= Int64(1)]
2401        "
2402        )
2403    }
2404
2405    /// post-on-join predicates on a column common to both sides is pushed to both sides
2406    #[test]
2407    fn filter_on_join_on_common_independent() -> Result<()> {
2408        let table_scan = test_table_scan()?;
2409        let left = LogicalPlanBuilder::from(table_scan).build()?;
2410        let right_table_scan = test_table_scan_with_name("test2")?;
2411        let right = LogicalPlanBuilder::from(right_table_scan)
2412            .project(vec![col("a")])?
2413            .build()?;
2414        let plan = LogicalPlanBuilder::from(left)
2415            .join(
2416                right,
2417                JoinType::Inner,
2418                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2419                None,
2420            )?
2421            .filter(col("test.a").lt_eq(lit(1i64)))?
2422            .build()?;
2423
2424        // not part of the test, just good to know:
2425        assert_snapshot!(plan,
2426        @r"
2427        Filter: test.a <= Int64(1)
2428          Inner Join: test.a = test2.a
2429            TableScan: test
2430            Projection: test2.a
2431              TableScan: test2
2432        ",
2433        );
2434        // filter sent to side before the join
2435        assert_optimized_plan_equal!(
2436            plan,
2437            @r"
2438        Inner Join: test.a = test2.a
2439          TableScan: test, full_filters=[test.a <= Int64(1)]
2440          Projection: test2.a
2441            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2442        "
2443        )
2444    }
2445
2446    /// post-using-join predicates on a column common to both sides is pushed to both sides
2447    #[test]
2448    fn filter_using_join_on_common_independent() -> Result<()> {
2449        let table_scan = test_table_scan()?;
2450        let left = LogicalPlanBuilder::from(table_scan).build()?;
2451        let right_table_scan = test_table_scan_with_name("test2")?;
2452        let right = LogicalPlanBuilder::from(right_table_scan)
2453            .project(vec![col("a")])?
2454            .build()?;
2455        let plan = LogicalPlanBuilder::from(left)
2456            .join_using(
2457                right,
2458                JoinType::Inner,
2459                vec![Column::from_name("a".to_string())],
2460            )?
2461            .filter(col("a").lt_eq(lit(1i64)))?
2462            .build()?;
2463
2464        // not part of the test, just good to know:
2465        assert_snapshot!(plan,
2466        @r"
2467        Filter: test.a <= Int64(1)
2468          Inner Join: Using test.a = test2.a
2469            TableScan: test
2470            Projection: test2.a
2471              TableScan: test2
2472        ",
2473        );
2474        // filter sent to side before the join
2475        assert_optimized_plan_equal!(
2476            plan,
2477            @r"
2478        Inner Join: Using test.a = test2.a
2479          TableScan: test, full_filters=[test.a <= Int64(1)]
2480          Projection: test2.a
2481            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2482        "
2483        )
2484    }
2485
2486    /// post-join predicates with columns from both sides are converted to join filters
2487    #[test]
2488    fn filter_join_on_common_dependent() -> Result<()> {
2489        let table_scan = test_table_scan()?;
2490        let left = LogicalPlanBuilder::from(table_scan)
2491            .project(vec![col("a"), col("c")])?
2492            .build()?;
2493        let right_table_scan = test_table_scan_with_name("test2")?;
2494        let right = LogicalPlanBuilder::from(right_table_scan)
2495            .project(vec![col("a"), col("b")])?
2496            .build()?;
2497        let plan = LogicalPlanBuilder::from(left)
2498            .join(
2499                right,
2500                JoinType::Inner,
2501                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2502                None,
2503            )?
2504            .filter(col("c").lt_eq(col("b")))?
2505            .build()?;
2506
2507        // not part of the test, just good to know:
2508        assert_snapshot!(plan,
2509        @r"
2510        Filter: test.c <= test2.b
2511          Inner Join: test.a = test2.a
2512            Projection: test.a, test.c
2513              TableScan: test
2514            Projection: test2.a, test2.b
2515              TableScan: test2
2516        ",
2517        );
2518        // Filter is converted to Join Filter
2519        assert_optimized_plan_equal!(
2520            plan,
2521            @r"
2522        Inner Join: test.a = test2.a Filter: test.c <= test2.b
2523          Projection: test.a, test.c
2524            TableScan: test
2525          Projection: test2.a, test2.b
2526            TableScan: test2
2527        "
2528        )
2529    }
2530
2531    /// post-join predicates with columns from one side of a join are pushed only to that side
2532    #[test]
2533    fn filter_join_on_one_side() -> Result<()> {
2534        let table_scan = test_table_scan()?;
2535        let left = LogicalPlanBuilder::from(table_scan)
2536            .project(vec![col("a"), col("b")])?
2537            .build()?;
2538        let table_scan_right = test_table_scan_with_name("test2")?;
2539        let right = LogicalPlanBuilder::from(table_scan_right)
2540            .project(vec![col("a"), col("c")])?
2541            .build()?;
2542
2543        let plan = LogicalPlanBuilder::from(left)
2544            .join(
2545                right,
2546                JoinType::Inner,
2547                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2548                None,
2549            )?
2550            .filter(col("b").lt_eq(lit(1i64)))?
2551            .build()?;
2552
2553        // not part of the test, just good to know:
2554        assert_snapshot!(plan,
2555        @r"
2556        Filter: test.b <= Int64(1)
2557          Inner Join: test.a = test2.a
2558            Projection: test.a, test.b
2559              TableScan: test
2560            Projection: test2.a, test2.c
2561              TableScan: test2
2562        ",
2563        );
2564        assert_optimized_plan_equal!(
2565            plan,
2566            @r"
2567        Inner Join: test.a = test2.a
2568          Projection: test.a, test.b
2569            TableScan: test, full_filters=[test.b <= Int64(1)]
2570          Projection: test2.a, test2.c
2571            TableScan: test2
2572        "
2573        )
2574    }
2575
2576    /// post-join predicates on the right side of a left join are not duplicated
2577    /// TODO: In this case we can sometimes convert the join to an INNER join
2578    #[test]
2579    fn filter_using_left_join() -> Result<()> {
2580        let table_scan = test_table_scan()?;
2581        let left = LogicalPlanBuilder::from(table_scan).build()?;
2582        let right_table_scan = test_table_scan_with_name("test2")?;
2583        let right = LogicalPlanBuilder::from(right_table_scan)
2584            .project(vec![col("a")])?
2585            .build()?;
2586        let plan = LogicalPlanBuilder::from(left)
2587            .join_using(
2588                right,
2589                JoinType::Left,
2590                vec![Column::from_name("a".to_string())],
2591            )?
2592            .filter(col("test2.a").lt_eq(lit(1i64)))?
2593            .build()?;
2594
2595        // not part of the test, just good to know:
2596        assert_snapshot!(plan,
2597        @r"
2598        Filter: test2.a <= Int64(1)
2599          Left Join: Using test.a = test2.a
2600            TableScan: test
2601            Projection: test2.a
2602              TableScan: test2
2603        ",
2604        );
2605        // filter not duplicated nor pushed down - i.e. noop
2606        assert_optimized_plan_equal!(
2607            plan,
2608            @r"
2609        Filter: test2.a <= Int64(1)
2610          Left Join: Using test.a = test2.a
2611            TableScan: test, full_filters=[test.a <= Int64(1)]
2612            Projection: test2.a
2613              TableScan: test2
2614        "
2615        )
2616    }
2617
2618    /// post-join predicates on the left side of a right join are not duplicated
2619    #[test]
2620    fn filter_using_right_join() -> Result<()> {
2621        let table_scan = test_table_scan()?;
2622        let left = LogicalPlanBuilder::from(table_scan).build()?;
2623        let right_table_scan = test_table_scan_with_name("test2")?;
2624        let right = LogicalPlanBuilder::from(right_table_scan)
2625            .project(vec![col("a")])?
2626            .build()?;
2627        let plan = LogicalPlanBuilder::from(left)
2628            .join_using(
2629                right,
2630                JoinType::Right,
2631                vec![Column::from_name("a".to_string())],
2632            )?
2633            .filter(col("test.a").lt_eq(lit(1i64)))?
2634            .build()?;
2635
2636        // not part of the test, just good to know:
2637        assert_snapshot!(plan,
2638        @r"
2639        Filter: test.a <= Int64(1)
2640          Right Join: Using test.a = test2.a
2641            TableScan: test
2642            Projection: test2.a
2643              TableScan: test2
2644        ",
2645        );
2646        // filter not duplicated nor pushed down - i.e. noop
2647        assert_optimized_plan_equal!(
2648            plan,
2649            @r"
2650        Filter: test.a <= Int64(1)
2651          Right Join: Using test.a = test2.a
2652            TableScan: test
2653            Projection: test2.a
2654              TableScan: test2, full_filters=[test2.a <= Int64(1)]
2655        "
2656        )
2657    }
2658
2659    /// post-left-join predicate on a column common to both sides is only pushed to the left side
2660    /// i.e. - not duplicated to the right side
2661    #[test]
2662    fn filter_using_left_join_on_common() -> Result<()> {
2663        let table_scan = test_table_scan()?;
2664        let left = LogicalPlanBuilder::from(table_scan).build()?;
2665        let right_table_scan = test_table_scan_with_name("test2")?;
2666        let right = LogicalPlanBuilder::from(right_table_scan)
2667            .project(vec![col("a")])?
2668            .build()?;
2669        let plan = LogicalPlanBuilder::from(left)
2670            .join_using(
2671                right,
2672                JoinType::Left,
2673                vec![Column::from_name("a".to_string())],
2674            )?
2675            .filter(col("a").lt_eq(lit(1i64)))?
2676            .build()?;
2677
2678        // not part of the test, just good to know:
2679        assert_snapshot!(plan,
2680        @r"
2681        Filter: test.a <= Int64(1)
2682          Left Join: Using test.a = test2.a
2683            TableScan: test
2684            Projection: test2.a
2685              TableScan: test2
2686        ",
2687        );
2688        // filter sent to left side of the join, not the right
2689        assert_optimized_plan_equal!(
2690            plan,
2691            @r"
2692        Left Join: Using test.a = test2.a
2693          TableScan: test, full_filters=[test.a <= Int64(1)]
2694          Projection: test2.a
2695            TableScan: test2
2696        "
2697        )
2698    }
2699
2700    /// post-right-join predicate on a column common to both sides is only pushed to the right side
2701    /// i.e. - not duplicated to the left side.
2702    #[test]
2703    fn filter_using_right_join_on_common() -> Result<()> {
2704        let table_scan = test_table_scan()?;
2705        let left = LogicalPlanBuilder::from(table_scan).build()?;
2706        let right_table_scan = test_table_scan_with_name("test2")?;
2707        let right = LogicalPlanBuilder::from(right_table_scan)
2708            .project(vec![col("a")])?
2709            .build()?;
2710        let plan = LogicalPlanBuilder::from(left)
2711            .join_using(
2712                right,
2713                JoinType::Right,
2714                vec![Column::from_name("a".to_string())],
2715            )?
2716            .filter(col("test2.a").lt_eq(lit(1i64)))?
2717            .build()?;
2718
2719        // not part of the test, just good to know:
2720        assert_snapshot!(plan,
2721        @r"
2722        Filter: test2.a <= Int64(1)
2723          Right Join: Using test.a = test2.a
2724            TableScan: test
2725            Projection: test2.a
2726              TableScan: test2
2727        ",
2728        );
2729        // filter sent to right side of join, not duplicated to the left
2730        assert_optimized_plan_equal!(
2731            plan,
2732            @r"
2733        Right Join: Using test.a = test2.a
2734          TableScan: test
2735          Projection: test2.a
2736            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2737        "
2738        )
2739    }
2740
2741    /// single table predicate parts of ON condition should be pushed to both inputs
2742    #[test]
2743    fn join_on_with_filter() -> Result<()> {
2744        let table_scan = test_table_scan()?;
2745        let left = LogicalPlanBuilder::from(table_scan)
2746            .project(vec![col("a"), col("b"), col("c")])?
2747            .build()?;
2748        let right_table_scan = test_table_scan_with_name("test2")?;
2749        let right = LogicalPlanBuilder::from(right_table_scan)
2750            .project(vec![col("a"), col("b"), col("c")])?
2751            .build()?;
2752        let filter = col("test.c")
2753            .gt(lit(1u32))
2754            .and(col("test.b").lt(col("test2.b")))
2755            .and(col("test2.c").gt(lit(4u32)));
2756        let plan = LogicalPlanBuilder::from(left)
2757            .join(
2758                right,
2759                JoinType::Inner,
2760                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2761                Some(filter),
2762            )?
2763            .build()?;
2764
2765        // not part of the test, just good to know:
2766        assert_snapshot!(plan,
2767        @r"
2768        Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2769          Projection: test.a, test.b, test.c
2770            TableScan: test
2771          Projection: test2.a, test2.b, test2.c
2772            TableScan: test2
2773        ",
2774        );
2775        assert_optimized_plan_equal!(
2776            plan,
2777            @r"
2778        Inner Join: test.a = test2.a Filter: test.b < test2.b
2779          Projection: test.a, test.b, test.c
2780            TableScan: test, full_filters=[test.c > UInt32(1)]
2781          Projection: test2.a, test2.b, test2.c
2782            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2783        "
2784        )
2785    }
2786
2787    /// join filter should be completely removed after pushdown
2788    #[test]
2789    fn join_filter_removed() -> Result<()> {
2790        let table_scan = test_table_scan()?;
2791        let left = LogicalPlanBuilder::from(table_scan)
2792            .project(vec![col("a"), col("b"), col("c")])?
2793            .build()?;
2794        let right_table_scan = test_table_scan_with_name("test2")?;
2795        let right = LogicalPlanBuilder::from(right_table_scan)
2796            .project(vec![col("a"), col("b"), col("c")])?
2797            .build()?;
2798        let filter = col("test.b")
2799            .gt(lit(1u32))
2800            .and(col("test2.c").gt(lit(4u32)));
2801        let plan = LogicalPlanBuilder::from(left)
2802            .join(
2803                right,
2804                JoinType::Inner,
2805                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2806                Some(filter),
2807            )?
2808            .build()?;
2809
2810        // not part of the test, just good to know:
2811        assert_snapshot!(plan,
2812        @r"
2813        Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2814          Projection: test.a, test.b, test.c
2815            TableScan: test
2816          Projection: test2.a, test2.b, test2.c
2817            TableScan: test2
2818        ",
2819        );
2820        assert_optimized_plan_equal!(
2821            plan,
2822            @r"
2823        Inner Join: test.a = test2.a
2824          Projection: test.a, test.b, test.c
2825            TableScan: test, full_filters=[test.b > UInt32(1)]
2826          Projection: test2.a, test2.b, test2.c
2827            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2828        "
2829        )
2830    }
2831
2832    /// predicate on join key in filter expression should be pushed down to both inputs
2833    #[test]
2834    fn join_filter_on_common() -> Result<()> {
2835        let table_scan = test_table_scan()?;
2836        let left = LogicalPlanBuilder::from(table_scan)
2837            .project(vec![col("a")])?
2838            .build()?;
2839        let right_table_scan = test_table_scan_with_name("test2")?;
2840        let right = LogicalPlanBuilder::from(right_table_scan)
2841            .project(vec![col("b")])?
2842            .build()?;
2843        let filter = col("test.a").gt(lit(1u32));
2844        let plan = LogicalPlanBuilder::from(left)
2845            .join(
2846                right,
2847                JoinType::Inner,
2848                (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2849                Some(filter),
2850            )?
2851            .build()?;
2852
2853        // not part of the test, just good to know:
2854        assert_snapshot!(plan,
2855        @r"
2856        Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2857          Projection: test.a
2858            TableScan: test
2859          Projection: test2.b
2860            TableScan: test2
2861        ",
2862        );
2863        assert_optimized_plan_equal!(
2864            plan,
2865            @r"
2866        Inner Join: test.a = test2.b
2867          Projection: test.a
2868            TableScan: test, full_filters=[test.a > UInt32(1)]
2869          Projection: test2.b
2870            TableScan: test2, full_filters=[test2.b > UInt32(1)]
2871        "
2872        )
2873    }
2874
2875    /// single table predicate parts of ON condition should be pushed to right input
2876    #[test]
2877    fn left_join_on_with_filter() -> Result<()> {
2878        let table_scan = test_table_scan()?;
2879        let left = LogicalPlanBuilder::from(table_scan)
2880            .project(vec![col("a"), col("b"), col("c")])?
2881            .build()?;
2882        let right_table_scan = test_table_scan_with_name("test2")?;
2883        let right = LogicalPlanBuilder::from(right_table_scan)
2884            .project(vec![col("a"), col("b"), col("c")])?
2885            .build()?;
2886        let filter = col("test.a")
2887            .gt(lit(1u32))
2888            .and(col("test.b").lt(col("test2.b")))
2889            .and(col("test2.c").gt(lit(4u32)));
2890        let plan = LogicalPlanBuilder::from(left)
2891            .join(
2892                right,
2893                JoinType::Left,
2894                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2895                Some(filter),
2896            )?
2897            .build()?;
2898
2899        // not part of the test, just good to know:
2900        assert_snapshot!(plan,
2901        @r"
2902        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2903          Projection: test.a, test.b, test.c
2904            TableScan: test
2905          Projection: test2.a, test2.b, test2.c
2906            TableScan: test2
2907        ",
2908        );
2909        assert_optimized_plan_equal!(
2910            plan,
2911            @r"
2912        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2913          Projection: test.a, test.b, test.c
2914            TableScan: test
2915          Projection: test2.a, test2.b, test2.c
2916            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2917        "
2918        )
2919    }
2920
2921    /// single table predicate parts of ON condition should be pushed to left input
2922    #[test]
2923    fn right_join_on_with_filter() -> Result<()> {
2924        let table_scan = test_table_scan()?;
2925        let left = LogicalPlanBuilder::from(table_scan)
2926            .project(vec![col("a"), col("b"), col("c")])?
2927            .build()?;
2928        let right_table_scan = test_table_scan_with_name("test2")?;
2929        let right = LogicalPlanBuilder::from(right_table_scan)
2930            .project(vec![col("a"), col("b"), col("c")])?
2931            .build()?;
2932        let filter = col("test.a")
2933            .gt(lit(1u32))
2934            .and(col("test.b").lt(col("test2.b")))
2935            .and(col("test2.c").gt(lit(4u32)));
2936        let plan = LogicalPlanBuilder::from(left)
2937            .join(
2938                right,
2939                JoinType::Right,
2940                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2941                Some(filter),
2942            )?
2943            .build()?;
2944
2945        // not part of the test, just good to know:
2946        assert_snapshot!(plan,
2947        @r"
2948        Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2949          Projection: test.a, test.b, test.c
2950            TableScan: test
2951          Projection: test2.a, test2.b, test2.c
2952            TableScan: test2
2953        ",
2954        );
2955        assert_optimized_plan_equal!(
2956            plan,
2957            @r"
2958        Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
2959          Projection: test.a, test.b, test.c
2960            TableScan: test, full_filters=[test.a > UInt32(1)]
2961          Projection: test2.a, test2.b, test2.c
2962            TableScan: test2
2963        "
2964        )
2965    }
2966
2967    /// single table predicate parts of ON condition should not be pushed
2968    #[test]
2969    fn full_join_on_with_filter() -> Result<()> {
2970        let table_scan = test_table_scan()?;
2971        let left = LogicalPlanBuilder::from(table_scan)
2972            .project(vec![col("a"), col("b"), col("c")])?
2973            .build()?;
2974        let right_table_scan = test_table_scan_with_name("test2")?;
2975        let right = LogicalPlanBuilder::from(right_table_scan)
2976            .project(vec![col("a"), col("b"), col("c")])?
2977            .build()?;
2978        let filter = col("test.a")
2979            .gt(lit(1u32))
2980            .and(col("test.b").lt(col("test2.b")))
2981            .and(col("test2.c").gt(lit(4u32)));
2982        let plan = LogicalPlanBuilder::from(left)
2983            .join(
2984                right,
2985                JoinType::Full,
2986                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2987                Some(filter),
2988            )?
2989            .build()?;
2990
2991        // not part of the test, just good to know:
2992        assert_snapshot!(plan,
2993        @r"
2994        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2995          Projection: test.a, test.b, test.c
2996            TableScan: test
2997          Projection: test2.a, test2.b, test2.c
2998            TableScan: test2
2999        ",
3000        );
3001        assert_optimized_plan_equal!(
3002            plan,
3003            @r"
3004        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3005          Projection: test.a, test.b, test.c
3006            TableScan: test
3007          Projection: test2.a, test2.b, test2.c
3008            TableScan: test2
3009        "
3010        )
3011    }
3012
3013    struct PushDownProvider {
3014        pub filter_support: TableProviderFilterPushDown,
3015    }
3016
3017    #[async_trait]
3018    impl TableSource for PushDownProvider {
3019        fn schema(&self) -> SchemaRef {
3020            Arc::new(Schema::new(vec![
3021                Field::new("a", DataType::Int32, true),
3022                Field::new("b", DataType::Int32, true),
3023            ]))
3024        }
3025
3026        fn table_type(&self) -> TableType {
3027            TableType::Base
3028        }
3029
3030        fn supports_filters_pushdown(
3031            &self,
3032            filters: &[&Expr],
3033        ) -> Result<Vec<TableProviderFilterPushDown>> {
3034            Ok((0..filters.len())
3035                .map(|_| self.filter_support.clone())
3036                .collect())
3037        }
3038
3039        fn as_any(&self) -> &dyn Any {
3040            self
3041        }
3042    }
3043
3044    fn table_scan_with_pushdown_provider_builder(
3045        filter_support: TableProviderFilterPushDown,
3046        filters: Vec<Expr>,
3047        projection: Option<Vec<usize>>,
3048    ) -> Result<LogicalPlanBuilder> {
3049        let test_provider = PushDownProvider { filter_support };
3050
3051        let table_scan = LogicalPlan::TableScan(TableScan {
3052            table_name: "test".into(),
3053            filters,
3054            projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3055            projection,
3056            source: Arc::new(test_provider),
3057            fetch: None,
3058        });
3059
3060        Ok(LogicalPlanBuilder::from(table_scan))
3061    }
3062
3063    fn table_scan_with_pushdown_provider(
3064        filter_support: TableProviderFilterPushDown,
3065    ) -> Result<LogicalPlan> {
3066        table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3067            .filter(col("a").eq(lit(1i64)))?
3068            .build()
3069    }
3070
3071    #[test]
3072    fn filter_with_table_provider_exact() -> Result<()> {
3073        let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3074
3075        assert_optimized_plan_equal!(
3076            plan,
3077            @"TableScan: test, full_filters=[a = Int64(1)]"
3078        )
3079    }
3080
3081    #[test]
3082    fn filter_with_table_provider_inexact() -> Result<()> {
3083        let plan =
3084            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3085
3086        assert_optimized_plan_equal!(
3087            plan,
3088            @r"
3089        Filter: a = Int64(1)
3090          TableScan: test, partial_filters=[a = Int64(1)]
3091        "
3092        )
3093    }
3094
3095    #[test]
3096    fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3097        let plan =
3098            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3099
3100        let optimized_plan = PushDownFilter::new()
3101            .rewrite(plan, &OptimizerContext::new())
3102            .expect("failed to optimize plan")
3103            .data;
3104
3105        // Optimizing the same plan multiple times should produce the same plan
3106        // each time.
3107        assert_optimized_plan_equal!(
3108            optimized_plan,
3109            @r"
3110        Filter: a = Int64(1)
3111          TableScan: test, partial_filters=[a = Int64(1)]
3112        "
3113        )
3114    }
3115
3116    #[test]
3117    fn filter_with_table_provider_unsupported() -> Result<()> {
3118        let plan =
3119            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3120
3121        assert_optimized_plan_equal!(
3122            plan,
3123            @r"
3124        Filter: a = Int64(1)
3125          TableScan: test
3126        "
3127        )
3128    }
3129
3130    #[test]
3131    fn multi_combined_filter() -> Result<()> {
3132        let plan = table_scan_with_pushdown_provider_builder(
3133            TableProviderFilterPushDown::Inexact,
3134            vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3135            Some(vec![0]),
3136        )?
3137        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3138        .project(vec![col("a"), col("b")])?
3139        .build()?;
3140
3141        assert_optimized_plan_equal!(
3142            plan,
3143            @r"
3144        Projection: a, b
3145          Filter: a = Int64(10) AND b > Int64(11)
3146            TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3147        "
3148        )
3149    }
3150
3151    #[test]
3152    fn multi_combined_filter_exact() -> Result<()> {
3153        let plan = table_scan_with_pushdown_provider_builder(
3154            TableProviderFilterPushDown::Exact,
3155            vec![],
3156            Some(vec![0]),
3157        )?
3158        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3159        .project(vec![col("a"), col("b")])?
3160        .build()?;
3161
3162        assert_optimized_plan_equal!(
3163            plan,
3164            @r"
3165        Projection: a, b
3166          TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3167        "
3168        )
3169    }
3170
3171    #[test]
3172    fn test_filter_with_alias() -> Result<()> {
3173        // in table scan the true col name is 'test.a',
3174        // but we rename it as 'b', and use col 'b' in filter
3175        // we need rewrite filter col before push down.
3176        let table_scan = test_table_scan()?;
3177        let plan = LogicalPlanBuilder::from(table_scan)
3178            .project(vec![col("a").alias("b"), col("c")])?
3179            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3180            .build()?;
3181
3182        // filter on col b
3183        assert_snapshot!(plan,
3184        @r"
3185        Filter: b > Int64(10) AND test.c > Int64(10)
3186          Projection: test.a AS b, test.c
3187            TableScan: test
3188        ",
3189        );
3190        // rewrite filter col b to test.a
3191        assert_optimized_plan_equal!(
3192            plan,
3193            @r"
3194        Projection: test.a AS b, test.c
3195          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3196        "
3197        )
3198    }
3199
3200    #[test]
3201    fn test_filter_with_alias_2() -> Result<()> {
3202        // in table scan the true col name is 'test.a',
3203        // but we rename it as 'b', and use col 'b' in filter
3204        // we need rewrite filter col before push down.
3205        let table_scan = test_table_scan()?;
3206        let plan = LogicalPlanBuilder::from(table_scan)
3207            .project(vec![col("a").alias("b"), col("c")])?
3208            .project(vec![col("b"), col("c")])?
3209            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3210            .build()?;
3211
3212        // filter on col b
3213        assert_snapshot!(plan,
3214        @r"
3215        Filter: b > Int64(10) AND test.c > Int64(10)
3216          Projection: b, test.c
3217            Projection: test.a AS b, test.c
3218              TableScan: test
3219        ",
3220        );
3221        // rewrite filter col b to test.a
3222        assert_optimized_plan_equal!(
3223            plan,
3224            @r"
3225        Projection: b, test.c
3226          Projection: test.a AS b, test.c
3227            TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3228        "
3229        )
3230    }
3231
3232    #[test]
3233    fn test_filter_with_multi_alias() -> Result<()> {
3234        let table_scan = test_table_scan()?;
3235        let plan = LogicalPlanBuilder::from(table_scan)
3236            .project(vec![col("a").alias("b"), col("c").alias("d")])?
3237            .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3238            .build()?;
3239
3240        // filter on col b and d
3241        assert_snapshot!(plan,
3242        @r"
3243        Filter: b > Int64(10) AND d > Int64(10)
3244          Projection: test.a AS b, test.c AS d
3245            TableScan: test
3246        ",
3247        );
3248        // rewrite filter col b to test.a, col d to test.c
3249        assert_optimized_plan_equal!(
3250            plan,
3251            @r"
3252        Projection: test.a AS b, test.c AS d
3253          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3254        "
3255        )
3256    }
3257
3258    /// predicate on join key in filter expression should be pushed down to both inputs
3259    #[test]
3260    fn join_filter_with_alias() -> Result<()> {
3261        let table_scan = test_table_scan()?;
3262        let left = LogicalPlanBuilder::from(table_scan)
3263            .project(vec![col("a").alias("c")])?
3264            .build()?;
3265        let right_table_scan = test_table_scan_with_name("test2")?;
3266        let right = LogicalPlanBuilder::from(right_table_scan)
3267            .project(vec![col("b").alias("d")])?
3268            .build()?;
3269        let filter = col("c").gt(lit(1u32));
3270        let plan = LogicalPlanBuilder::from(left)
3271            .join(
3272                right,
3273                JoinType::Inner,
3274                (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3275                Some(filter),
3276            )?
3277            .build()?;
3278
3279        assert_snapshot!(plan,
3280        @r"
3281        Inner Join: c = d Filter: c > UInt32(1)
3282          Projection: test.a AS c
3283            TableScan: test
3284          Projection: test2.b AS d
3285            TableScan: test2
3286        ",
3287        );
3288        // Change filter on col `c`, 'd' to `test.a`, 'test.b'
3289        assert_optimized_plan_equal!(
3290            plan,
3291            @r"
3292        Inner Join: c = d
3293          Projection: test.a AS c
3294            TableScan: test, full_filters=[test.a > UInt32(1)]
3295          Projection: test2.b AS d
3296            TableScan: test2, full_filters=[test2.b > UInt32(1)]
3297        "
3298        )
3299    }
3300
3301    #[test]
3302    fn test_in_filter_with_alias() -> Result<()> {
3303        // in table scan the true col name is 'test.a',
3304        // but we rename it as 'b', and use col 'b' in filter
3305        // we need rewrite filter col before push down.
3306        let table_scan = test_table_scan()?;
3307        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3308        let plan = LogicalPlanBuilder::from(table_scan)
3309            .project(vec![col("a").alias("b"), col("c")])?
3310            .filter(in_list(col("b"), filter_value, false))?
3311            .build()?;
3312
3313        // filter on col b
3314        assert_snapshot!(plan,
3315        @r"
3316        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3317          Projection: test.a AS b, test.c
3318            TableScan: test
3319        ",
3320        );
3321        // rewrite filter col b to test.a
3322        assert_optimized_plan_equal!(
3323            plan,
3324            @r"
3325        Projection: test.a AS b, test.c
3326          TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3327        "
3328        )
3329    }
3330
3331    #[test]
3332    fn test_in_filter_with_alias_2() -> Result<()> {
3333        // in table scan the true col name is 'test.a',
3334        // but we rename it as 'b', and use col 'b' in filter
3335        // we need rewrite filter col before push down.
3336        let table_scan = test_table_scan()?;
3337        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3338        let plan = LogicalPlanBuilder::from(table_scan)
3339            .project(vec![col("a").alias("b"), col("c")])?
3340            .project(vec![col("b"), col("c")])?
3341            .filter(in_list(col("b"), filter_value, false))?
3342            .build()?;
3343
3344        // filter on col b
3345        assert_snapshot!(plan,
3346        @r"
3347        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3348          Projection: b, test.c
3349            Projection: test.a AS b, test.c
3350              TableScan: test
3351        ",
3352        );
3353        // rewrite filter col b to test.a
3354        assert_optimized_plan_equal!(
3355            plan,
3356            @r"
3357        Projection: b, test.c
3358          Projection: test.a AS b, test.c
3359            TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3360        "
3361        )
3362    }
3363
3364    #[test]
3365    fn test_in_subquery_with_alias() -> Result<()> {
3366        // in table scan the true col name is 'test.a',
3367        // but we rename it as 'b', and use col 'b' in subquery filter
3368        let table_scan = test_table_scan()?;
3369        let table_scan_sq = test_table_scan_with_name("sq")?;
3370        let subplan = Arc::new(
3371            LogicalPlanBuilder::from(table_scan_sq)
3372                .project(vec![col("c")])?
3373                .build()?,
3374        );
3375        let plan = LogicalPlanBuilder::from(table_scan)
3376            .project(vec![col("a").alias("b"), col("c")])?
3377            .filter(in_subquery(col("b"), subplan))?
3378            .build()?;
3379
3380        // filter on col b in subquery
3381        assert_snapshot!(plan,
3382        @r"
3383        Filter: b IN (<subquery>)
3384          Subquery:
3385            Projection: sq.c
3386              TableScan: sq
3387          Projection: test.a AS b, test.c
3388            TableScan: test
3389        ",
3390        );
3391        // rewrite filter col b to test.a
3392        assert_optimized_plan_equal!(
3393            plan,
3394            @r"
3395        Projection: test.a AS b, test.c
3396          TableScan: test, full_filters=[test.a IN (<subquery>)]
3397            Subquery:
3398              Projection: sq.c
3399                TableScan: sq
3400        "
3401        )
3402    }
3403
3404    #[test]
3405    fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3406        // SELECT a FROM (SELECT 1 AS a) b WHERE b.a = 1
3407        let plan = LogicalPlanBuilder::empty(true)
3408            .project(vec![lit(0i64).alias("a")])?
3409            .alias("b")?
3410            .project(vec![col("b.a")])?
3411            .alias("b")?
3412            .filter(col("b.a").eq(lit(1i64)))?
3413            .project(vec![col("b.a")])?
3414            .build()?;
3415
3416        assert_snapshot!(plan,
3417        @r"
3418        Projection: b.a
3419          Filter: b.a = Int64(1)
3420            SubqueryAlias: b
3421              Projection: b.a
3422                SubqueryAlias: b
3423                  Projection: Int64(0) AS a
3424                    EmptyRelation: rows=1
3425        ",
3426        );
3427        // Ensure that the predicate without any columns (0 = 1) is
3428        // still there.
3429        assert_optimized_plan_equal!(
3430            plan,
3431            @r"
3432        Projection: b.a
3433          SubqueryAlias: b
3434            Projection: b.a
3435              SubqueryAlias: b
3436                Projection: Int64(0) AS a
3437                  Filter: Int64(0) = Int64(1)
3438                    EmptyRelation: rows=1
3439        "
3440        )
3441    }
3442
3443    #[test]
3444    fn test_crossjoin_with_or_clause() -> Result<()> {
3445        // select * from test,test1 where (test.a = test1.a and test.b > 1) or (test.b = test1.b and test.c < 10);
3446        let table_scan = test_table_scan()?;
3447        let left = LogicalPlanBuilder::from(table_scan)
3448            .project(vec![col("a"), col("b"), col("c")])?
3449            .build()?;
3450        let right_table_scan = test_table_scan_with_name("test1")?;
3451        let right = LogicalPlanBuilder::from(right_table_scan)
3452            .project(vec![col("a").alias("d"), col("a").alias("e")])?
3453            .build()?;
3454        let filter = or(
3455            and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3456            and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3457        );
3458        let plan = LogicalPlanBuilder::from(left)
3459            .cross_join(right)?
3460            .filter(filter)?
3461            .build()?;
3462
3463        assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3464        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3465          Projection: test.a, test.b, test.c
3466            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3467          Projection: test1.a AS d, test1.a AS e
3468            TableScan: test1
3469        ")?;
3470
3471        // Originally global state which can help to avoid duplicate Filters been generated and pushed down.
3472        // Now the global state is removed. Need to double confirm that avoid duplicate Filters.
3473        let optimized_plan = PushDownFilter::new()
3474            .rewrite(plan, &OptimizerContext::new())
3475            .expect("failed to optimize plan")
3476            .data;
3477        assert_optimized_plan_equal!(
3478            optimized_plan,
3479            @r"
3480        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3481          Projection: test.a, test.b, test.c
3482            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3483          Projection: test1.a AS d, test1.a AS e
3484            TableScan: test1
3485        "
3486        )
3487    }
3488
3489    #[test]
3490    fn left_semi_join() -> Result<()> {
3491        let left = test_table_scan_with_name("test1")?;
3492        let right_table_scan = test_table_scan_with_name("test2")?;
3493        let right = LogicalPlanBuilder::from(right_table_scan)
3494            .project(vec![col("a"), col("b")])?
3495            .build()?;
3496        let plan = LogicalPlanBuilder::from(left)
3497            .join(
3498                right,
3499                JoinType::LeftSemi,
3500                (
3501                    vec![Column::from_qualified_name("test1.a")],
3502                    vec![Column::from_qualified_name("test2.a")],
3503                ),
3504                None,
3505            )?
3506            .filter(col("test2.a").lt_eq(lit(1i64)))?
3507            .build()?;
3508
3509        // not part of the test, just good to know:
3510        assert_snapshot!(plan,
3511        @r"
3512        Filter: test2.a <= Int64(1)
3513          LeftSemi Join: test1.a = test2.a
3514            TableScan: test1
3515            Projection: test2.a, test2.b
3516              TableScan: test2
3517        ",
3518        );
3519        // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side.
3520        assert_optimized_plan_equal!(
3521            plan,
3522            @r"
3523        Filter: test2.a <= Int64(1)
3524          LeftSemi Join: test1.a = test2.a
3525            TableScan: test1, full_filters=[test1.a <= Int64(1)]
3526            Projection: test2.a, test2.b
3527              TableScan: test2
3528        "
3529        )
3530    }
3531
3532    #[test]
3533    fn left_semi_join_with_filters() -> Result<()> {
3534        let left = test_table_scan_with_name("test1")?;
3535        let right_table_scan = test_table_scan_with_name("test2")?;
3536        let right = LogicalPlanBuilder::from(right_table_scan)
3537            .project(vec![col("a"), col("b")])?
3538            .build()?;
3539        let plan = LogicalPlanBuilder::from(left)
3540            .join(
3541                right,
3542                JoinType::LeftSemi,
3543                (
3544                    vec![Column::from_qualified_name("test1.a")],
3545                    vec![Column::from_qualified_name("test2.a")],
3546                ),
3547                Some(
3548                    col("test1.b")
3549                        .gt(lit(1u32))
3550                        .and(col("test2.b").gt(lit(2u32))),
3551                ),
3552            )?
3553            .build()?;
3554
3555        // not part of the test, just good to know:
3556        assert_snapshot!(plan,
3557        @r"
3558        LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3559          TableScan: test1
3560          Projection: test2.a, test2.b
3561            TableScan: test2
3562        ",
3563        );
3564        // Both side will be pushed down.
3565        assert_optimized_plan_equal!(
3566            plan,
3567            @r"
3568        LeftSemi Join: test1.a = test2.a
3569          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3570          Projection: test2.a, test2.b
3571            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3572        "
3573        )
3574    }
3575
3576    #[test]
3577    fn right_semi_join() -> Result<()> {
3578        let left = test_table_scan_with_name("test1")?;
3579        let right_table_scan = test_table_scan_with_name("test2")?;
3580        let right = LogicalPlanBuilder::from(right_table_scan)
3581            .project(vec![col("a"), col("b")])?
3582            .build()?;
3583        let plan = LogicalPlanBuilder::from(left)
3584            .join(
3585                right,
3586                JoinType::RightSemi,
3587                (
3588                    vec![Column::from_qualified_name("test1.a")],
3589                    vec![Column::from_qualified_name("test2.a")],
3590                ),
3591                None,
3592            )?
3593            .filter(col("test1.a").lt_eq(lit(1i64)))?
3594            .build()?;
3595
3596        // not part of the test, just good to know:
3597        assert_snapshot!(plan,
3598        @r"
3599        Filter: test1.a <= Int64(1)
3600          RightSemi Join: test1.a = test2.a
3601            TableScan: test1
3602            Projection: test2.a, test2.b
3603              TableScan: test2
3604        ",
3605        );
3606        // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side.
3607        assert_optimized_plan_equal!(
3608            plan,
3609            @r"
3610        Filter: test1.a <= Int64(1)
3611          RightSemi Join: test1.a = test2.a
3612            TableScan: test1
3613            Projection: test2.a, test2.b
3614              TableScan: test2, full_filters=[test2.a <= Int64(1)]
3615        "
3616        )
3617    }
3618
3619    #[test]
3620    fn right_semi_join_with_filters() -> Result<()> {
3621        let left = test_table_scan_with_name("test1")?;
3622        let right_table_scan = test_table_scan_with_name("test2")?;
3623        let right = LogicalPlanBuilder::from(right_table_scan)
3624            .project(vec![col("a"), col("b")])?
3625            .build()?;
3626        let plan = LogicalPlanBuilder::from(left)
3627            .join(
3628                right,
3629                JoinType::RightSemi,
3630                (
3631                    vec![Column::from_qualified_name("test1.a")],
3632                    vec![Column::from_qualified_name("test2.a")],
3633                ),
3634                Some(
3635                    col("test1.b")
3636                        .gt(lit(1u32))
3637                        .and(col("test2.b").gt(lit(2u32))),
3638                ),
3639            )?
3640            .build()?;
3641
3642        // not part of the test, just good to know:
3643        assert_snapshot!(plan,
3644        @r"
3645        RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3646          TableScan: test1
3647          Projection: test2.a, test2.b
3648            TableScan: test2
3649        ",
3650        );
3651        // Both side will be pushed down.
3652        assert_optimized_plan_equal!(
3653            plan,
3654            @r"
3655        RightSemi Join: test1.a = test2.a
3656          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3657          Projection: test2.a, test2.b
3658            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3659        "
3660        )
3661    }
3662
3663    #[test]
3664    fn left_anti_join() -> Result<()> {
3665        let table_scan = test_table_scan_with_name("test1")?;
3666        let left = LogicalPlanBuilder::from(table_scan)
3667            .project(vec![col("a"), col("b")])?
3668            .build()?;
3669        let right_table_scan = test_table_scan_with_name("test2")?;
3670        let right = LogicalPlanBuilder::from(right_table_scan)
3671            .project(vec![col("a"), col("b")])?
3672            .build()?;
3673        let plan = LogicalPlanBuilder::from(left)
3674            .join(
3675                right,
3676                JoinType::LeftAnti,
3677                (
3678                    vec![Column::from_qualified_name("test1.a")],
3679                    vec![Column::from_qualified_name("test2.a")],
3680                ),
3681                None,
3682            )?
3683            .filter(col("test2.a").gt(lit(2u32)))?
3684            .build()?;
3685
3686        // not part of the test, just good to know:
3687        assert_snapshot!(plan,
3688        @r"
3689        Filter: test2.a > UInt32(2)
3690          LeftAnti Join: test1.a = test2.a
3691            Projection: test1.a, test1.b
3692              TableScan: test1
3693            Projection: test2.a, test2.b
3694              TableScan: test2
3695        ",
3696        );
3697        // For left anti, filter of the right side filter can be pushed down.
3698        assert_optimized_plan_equal!(
3699            plan,
3700            @r"
3701        Filter: test2.a > UInt32(2)
3702          LeftAnti Join: test1.a = test2.a
3703            Projection: test1.a, test1.b
3704              TableScan: test1, full_filters=[test1.a > UInt32(2)]
3705            Projection: test2.a, test2.b
3706              TableScan: test2
3707        "
3708        )
3709    }
3710
3711    #[test]
3712    fn left_anti_join_with_filters() -> Result<()> {
3713        let table_scan = test_table_scan_with_name("test1")?;
3714        let left = LogicalPlanBuilder::from(table_scan)
3715            .project(vec![col("a"), col("b")])?
3716            .build()?;
3717        let right_table_scan = test_table_scan_with_name("test2")?;
3718        let right = LogicalPlanBuilder::from(right_table_scan)
3719            .project(vec![col("a"), col("b")])?
3720            .build()?;
3721        let plan = LogicalPlanBuilder::from(left)
3722            .join(
3723                right,
3724                JoinType::LeftAnti,
3725                (
3726                    vec![Column::from_qualified_name("test1.a")],
3727                    vec![Column::from_qualified_name("test2.a")],
3728                ),
3729                Some(
3730                    col("test1.b")
3731                        .gt(lit(1u32))
3732                        .and(col("test2.b").gt(lit(2u32))),
3733                ),
3734            )?
3735            .build()?;
3736
3737        // not part of the test, just good to know:
3738        assert_snapshot!(plan,
3739        @r"
3740        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3741          Projection: test1.a, test1.b
3742            TableScan: test1
3743          Projection: test2.a, test2.b
3744            TableScan: test2
3745        ",
3746        );
3747        // For left anti, filter of the right side filter can be pushed down.
3748        assert_optimized_plan_equal!(
3749            plan,
3750            @r"
3751        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3752          Projection: test1.a, test1.b
3753            TableScan: test1
3754          Projection: test2.a, test2.b
3755            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3756        "
3757        )
3758    }
3759
3760    #[test]
3761    fn right_anti_join() -> Result<()> {
3762        let table_scan = test_table_scan_with_name("test1")?;
3763        let left = LogicalPlanBuilder::from(table_scan)
3764            .project(vec![col("a"), col("b")])?
3765            .build()?;
3766        let right_table_scan = test_table_scan_with_name("test2")?;
3767        let right = LogicalPlanBuilder::from(right_table_scan)
3768            .project(vec![col("a"), col("b")])?
3769            .build()?;
3770        let plan = LogicalPlanBuilder::from(left)
3771            .join(
3772                right,
3773                JoinType::RightAnti,
3774                (
3775                    vec![Column::from_qualified_name("test1.a")],
3776                    vec![Column::from_qualified_name("test2.a")],
3777                ),
3778                None,
3779            )?
3780            .filter(col("test1.a").gt(lit(2u32)))?
3781            .build()?;
3782
3783        // not part of the test, just good to know:
3784        assert_snapshot!(plan,
3785        @r"
3786        Filter: test1.a > UInt32(2)
3787          RightAnti Join: test1.a = test2.a
3788            Projection: test1.a, test1.b
3789              TableScan: test1
3790            Projection: test2.a, test2.b
3791              TableScan: test2
3792        ",
3793        );
3794        // For right anti, filter of the left side can be pushed down.
3795        assert_optimized_plan_equal!(
3796            plan,
3797            @r"
3798        Filter: test1.a > UInt32(2)
3799          RightAnti Join: test1.a = test2.a
3800            Projection: test1.a, test1.b
3801              TableScan: test1
3802            Projection: test2.a, test2.b
3803              TableScan: test2, full_filters=[test2.a > UInt32(2)]
3804        "
3805        )
3806    }
3807
3808    #[test]
3809    fn right_anti_join_with_filters() -> Result<()> {
3810        let table_scan = test_table_scan_with_name("test1")?;
3811        let left = LogicalPlanBuilder::from(table_scan)
3812            .project(vec![col("a"), col("b")])?
3813            .build()?;
3814        let right_table_scan = test_table_scan_with_name("test2")?;
3815        let right = LogicalPlanBuilder::from(right_table_scan)
3816            .project(vec![col("a"), col("b")])?
3817            .build()?;
3818        let plan = LogicalPlanBuilder::from(left)
3819            .join(
3820                right,
3821                JoinType::RightAnti,
3822                (
3823                    vec![Column::from_qualified_name("test1.a")],
3824                    vec![Column::from_qualified_name("test2.a")],
3825                ),
3826                Some(
3827                    col("test1.b")
3828                        .gt(lit(1u32))
3829                        .and(col("test2.b").gt(lit(2u32))),
3830                ),
3831            )?
3832            .build()?;
3833
3834        // not part of the test, just good to know:
3835        assert_snapshot!(plan,
3836        @r"
3837        RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3838          Projection: test1.a, test1.b
3839            TableScan: test1
3840          Projection: test2.a, test2.b
3841            TableScan: test2
3842        ",
3843        );
3844        // For right anti, filter of the left side can be pushed down.
3845        assert_optimized_plan_equal!(
3846            plan,
3847            @r"
3848        RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3849          Projection: test1.a, test1.b
3850            TableScan: test1, full_filters=[test1.b > UInt32(1)]
3851          Projection: test2.a, test2.b
3852            TableScan: test2
3853        "
3854        )
3855    }
3856
3857    #[derive(Debug, PartialEq, Eq, Hash)]
3858    struct TestScalarUDF {
3859        signature: Signature,
3860    }
3861
3862    impl ScalarUDFImpl for TestScalarUDF {
3863        fn as_any(&self) -> &dyn Any {
3864            self
3865        }
3866        fn name(&self) -> &str {
3867            "TestScalarUDF"
3868        }
3869
3870        fn signature(&self) -> &Signature {
3871            &self.signature
3872        }
3873
3874        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3875            Ok(DataType::Int32)
3876        }
3877
3878        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3879            Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3880        }
3881    }
3882
3883    #[test]
3884    fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3885        // SELECT t.a, t.r FROM (SELECT a, sum(b),  TestScalarUDF()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5;
3886        let table_scan = test_table_scan_with_name("test1")?;
3887        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3888            signature: Signature::exact(vec![], Volatility::Volatile),
3889        });
3890        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3891
3892        let plan = LogicalPlanBuilder::from(table_scan)
3893            .aggregate(vec![col("a")], vec![sum(col("b"))])?
3894            .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3895            .alias("t")?
3896            .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3897            .project(vec![col("t.a"), col("t.r")])?
3898            .build()?;
3899
3900        assert_snapshot!(plan,
3901        @r"
3902        Projection: t.a, t.r
3903          Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3904            SubqueryAlias: t
3905              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3906                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3907                  TableScan: test1
3908        ",
3909        );
3910        assert_optimized_plan_equal!(
3911            plan,
3912            @r"
3913        Projection: t.a, t.r
3914          SubqueryAlias: t
3915            Filter: r > Float64(0.5)
3916              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3917                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3918                  TableScan: test1, full_filters=[test1.a > Int32(5)]
3919        "
3920        )
3921    }
3922
3923    #[test]
3924    fn test_push_down_volatile_function_in_join() -> Result<()> {
3925        // SELECT t.a, t.r FROM (SELECT test1.a AS a, TestScalarUDF() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5;
3926        let table_scan = test_table_scan_with_name("test1")?;
3927        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3928            signature: Signature::exact(vec![], Volatility::Volatile),
3929        });
3930        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3931        let left = LogicalPlanBuilder::from(table_scan).build()?;
3932        let right_table_scan = test_table_scan_with_name("test2")?;
3933        let right = LogicalPlanBuilder::from(right_table_scan).build()?;
3934        let plan = LogicalPlanBuilder::from(left)
3935            .join(
3936                right,
3937                JoinType::Inner,
3938                (
3939                    vec![Column::from_qualified_name("test1.a")],
3940                    vec![Column::from_qualified_name("test2.a")],
3941                ),
3942                None,
3943            )?
3944            .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
3945            .alias("t")?
3946            .filter(col("t.r").gt(lit(0.8)))?
3947            .project(vec![col("t.a"), col("t.r")])?
3948            .build()?;
3949
3950        assert_snapshot!(plan,
3951        @r"
3952        Projection: t.a, t.r
3953          Filter: t.r > Float64(0.8)
3954            SubqueryAlias: t
3955              Projection: test1.a AS a, TestScalarUDF() AS r
3956                Inner Join: test1.a = test2.a
3957                  TableScan: test1
3958                  TableScan: test2
3959        ",
3960        );
3961        assert_optimized_plan_equal!(
3962            plan,
3963            @r"
3964        Projection: t.a, t.r
3965          SubqueryAlias: t
3966            Filter: r > Float64(0.8)
3967              Projection: test1.a AS a, TestScalarUDF() AS r
3968                Inner Join: test1.a = test2.a
3969                  TableScan: test1
3970                  TableScan: test2
3971        "
3972        )
3973    }
3974
3975    #[test]
3976    fn test_push_down_volatile_table_scan() -> Result<()> {
3977        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1;
3978        let table_scan = test_table_scan()?;
3979        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3980            signature: Signature::exact(vec![], Volatility::Volatile),
3981        });
3982        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3983        let plan = LogicalPlanBuilder::from(table_scan)
3984            .project(vec![col("a"), col("b")])?
3985            .filter(expr.gt(lit(0.1)))?
3986            .build()?;
3987
3988        assert_snapshot!(plan,
3989        @r"
3990        Filter: TestScalarUDF() > Float64(0.1)
3991          Projection: test.a, test.b
3992            TableScan: test
3993        ",
3994        );
3995        assert_optimized_plan_equal!(
3996            plan,
3997            @r"
3998        Projection: test.a, test.b
3999          Filter: TestScalarUDF() > Float64(0.1)
4000            TableScan: test
4001        "
4002        )
4003    }
4004
4005    #[test]
4006    fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4007        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4008        let table_scan = test_table_scan()?;
4009        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4010            signature: Signature::exact(vec![], Volatility::Volatile),
4011        });
4012        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4013        let plan = LogicalPlanBuilder::from(table_scan)
4014            .project(vec![col("a"), col("b")])?
4015            .filter(
4016                expr.gt(lit(0.1))
4017                    .and(col("t.a").gt(lit(5)))
4018                    .and(col("t.b").gt(lit(10))),
4019            )?
4020            .build()?;
4021
4022        assert_snapshot!(plan,
4023        @r"
4024        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4025          Projection: test.a, test.b
4026            TableScan: test
4027        ",
4028        );
4029        assert_optimized_plan_equal!(
4030            plan,
4031            @r"
4032        Projection: test.a, test.b
4033          Filter: TestScalarUDF() > Float64(0.1)
4034            TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4035        "
4036        )
4037    }
4038
4039    #[test]
4040    fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4041        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4042        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4043            signature: Signature::exact(vec![], Volatility::Volatile),
4044        });
4045        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4046        let plan = table_scan_with_pushdown_provider_builder(
4047            TableProviderFilterPushDown::Unsupported,
4048            vec![],
4049            None,
4050        )?
4051        .project(vec![col("a"), col("b")])?
4052        .filter(
4053            expr.gt(lit(0.1))
4054                .and(col("t.a").gt(lit(5)))
4055                .and(col("t.b").gt(lit(10))),
4056        )?
4057        .build()?;
4058
4059        assert_snapshot!(plan,
4060        @r"
4061        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4062          Projection: a, b
4063            TableScan: test
4064        ",
4065        );
4066        assert_optimized_plan_equal!(
4067            plan,
4068            @r"
4069        Projection: a, b
4070          Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4071            TableScan: test
4072        "
4073        )
4074    }
4075
4076    #[test]
4077    fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4078        // Define a custom user-defined logical node
4079        #[derive(Debug, Hash, Eq, PartialEq)]
4080        struct TestUserNode {
4081            schema: DFSchemaRef,
4082        }
4083
4084        impl PartialOrd for TestUserNode {
4085            fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4086                None
4087            }
4088        }
4089
4090        impl TestUserNode {
4091            fn new() -> Self {
4092                let schema = Arc::new(
4093                    DFSchema::new_with_metadata(
4094                        vec![(None, Field::new("a", DataType::Int64, false).into())],
4095                        Default::default(),
4096                    )
4097                    .unwrap(),
4098                );
4099
4100                Self { schema }
4101            }
4102        }
4103
4104        impl UserDefinedLogicalNodeCore for TestUserNode {
4105            fn name(&self) -> &str {
4106                "test_node"
4107            }
4108
4109            fn inputs(&self) -> Vec<&LogicalPlan> {
4110                vec![]
4111            }
4112
4113            fn schema(&self) -> &DFSchemaRef {
4114                &self.schema
4115            }
4116
4117            fn expressions(&self) -> Vec<Expr> {
4118                vec![]
4119            }
4120
4121            fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4122                write!(f, "TestUserNode")
4123            }
4124
4125            fn with_exprs_and_inputs(
4126                &self,
4127                exprs: Vec<Expr>,
4128                inputs: Vec<LogicalPlan>,
4129            ) -> Result<Self> {
4130                assert!(exprs.is_empty());
4131                assert!(inputs.is_empty());
4132                Ok(Self {
4133                    schema: Arc::clone(&self.schema),
4134                })
4135            }
4136        }
4137
4138        // Create a node and build a plan with a filter
4139        let node = LogicalPlan::Extension(Extension {
4140            node: Arc::new(TestUserNode::new()),
4141        });
4142
4143        let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4144
4145        // Check the original plan format (not part of the test assertions)
4146        assert_snapshot!(plan,
4147        @r"
4148        Filter: Boolean(false)
4149          TestUserNode
4150        ",
4151        );
4152        // Check that the filter is pushed down to the user-defined node
4153        assert_optimized_plan_equal!(
4154            plan,
4155            @r"
4156        Filter: Boolean(false)
4157          TestUserNode
4158        "
4159        )
4160    }
4161}