Skip to main content

datafusion_optimizer/
push_down_filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`PushDownFilter`] applies filters as early as possible
19
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26use log::{Level, debug, log_enabled};
27
28use datafusion_common::instant::Instant;
29use datafusion_common::tree_node::{
30    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
31};
32use datafusion_common::{
33    Column, DFSchema, Result, assert_eq_or_internal_err, assert_or_internal_err,
34    internal_err, plan_err, qualified_name,
35};
36use datafusion_expr::expr::WindowFunction;
37use datafusion_expr::expr_rewriter::replace_col;
38use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
39use datafusion_expr::utils::{
40    conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
41};
42use datafusion_expr::{
43    BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, and, or,
44};
45
46use crate::optimizer::ApplyOrder;
47use crate::simplify_expressions::simplify_predicates;
48use crate::utils::{
49    ColumnReference, has_all_column_refs, is_restrict_null_predicate, schema_columns,
50};
51use crate::{OptimizerConfig, OptimizerRule};
52use datafusion_expr::ExpressionPlacement;
53
54/// Optimizer rule for pushing (moving) filter expressions down in a plan so
55/// they are applied as early as possible.
56///
57/// # Introduction
58///
59/// The goal of this rule is to improve query performance by eliminating
60/// redundant work.
61///
62/// For example, given a plan that sorts all values where `a > 10`:
63///
64/// ```text
65///  Filter (a > 10)
66///    Sort (a, b)
67/// ```
68///
69/// A better plan is to filter the data *before* the Sort, which sorts fewer
70/// rows and therefore does less work overall:
71///
72/// ```text
73///  Sort (a, b)
74///    Filter (a > 10)  <-- Filter is moved before the sort
75/// ```
76///
77/// However it is not always possible to push filters down. For example, given a
78/// plan that finds the top 3 values and then keeps only those that are greater
79/// than 10, if the filter is pushed below the limit it would produce a
80/// different result.
81///
82/// ```text
83///  Filter (a > 10)   <-- cannot move this Filter before the limit
84///    Limit (fetch=3)
85///      Sort (a, b)
86/// ```
87///
88///
89/// More formally, a filter-commutative operation is an operation `op` that
90/// satisfies `filter(op(data)) = op(filter(data))`.
91///
92/// The filter-commutative property is plan and column-specific. A filter on `a`
93/// can be pushed through a `Aggregate(group_by = [a], agg=[sum(b)])`. However, a
94/// filter on `sum(b)` cannot be pushed through the same aggregate.
95///
96/// # Handling Conjunctions
97///
98/// It is possible to only push down **part** of a filter expression if it is
99/// connected with `AND`s (more formally if it is a "conjunction").
100///
101/// For example, given the following plan:
102///
103/// ```text
104/// Filter(a > 10 AND sum(b) < 5)
105///   Aggregate(group_by = [a], agg = [sum(b)])
106/// ```
107///
108/// The `a > 10` is commutative with the `Aggregate` but `sum(b) < 5` is not.
109/// Therefore it is possible to only push down part of the expression, resulting in:
110///
111/// ```text
112/// Filter(sum(b) < 5)
113///   Aggregate(group_by = [a], agg = [sum(b)])
114///     Filter(a > 10)
115/// ```
116///
117/// # Handling Column Aliases
118///
119/// This optimizer must sometimes handle rewriting filter expressions when they are
120/// pushed. For example, consider a projection that aliases `a+1` to `"b"`:
121///
122/// ```text
123/// Filter (b > 10)
124///     Projection: [a+1 AS "b"]  <-- changes the name of `a+1` to `b`
125/// ```
126///
127/// To push this filter below the `Projection`, all references to `b` must be
128/// rewritten to `a+1`:
129///
130/// ```text
131/// Projection: [a+1 AS "b"]
132///     Filter: (a+1 > 10)  <--- changed from b to a+1
133/// ```
134/// # Implementation Notes
135///
136/// This implementation performs a single pass through the plan, "pushing" down
137/// filters. When it passes through a filter, it stores that filter, and when it
138/// reaches a plan node that does not commute with that filter, it adds the
139/// filter to that place. When it passes through a projection, it re-writes the
140/// filter's expression taking into account that projection.
141#[derive(Default, Debug)]
142pub struct PushDownFilter {}
143
144/// For a given JOIN type, determine whether each input of the join is preserved
145/// for post-join (`WHERE` clause) filters.
146///
147/// It is only correct to push filters below a join for preserved inputs.
148///
149/// # Return Value
150/// A tuple of booleans - (left_preserved, right_preserved).
151///
152/// # "Preserved" input definition
153///
154/// We say a join side is preserved if the join returns all or a subset of the rows from
155/// the relevant side, such that each row of the output table directly maps to a row of
156/// the preserved input table. If a table is not preserved, it can provide extra null rows.
157/// That is, there may be rows in the output table that don't directly map to a row in the
158/// input table.
159///
160/// For example:
161///   - In an inner join, both sides are preserved, because each row of the output
162///     maps directly to a row from each side.
163///
164///   - In a left join, the left side is preserved (we can push predicates) but
165///     the right is not, because there may be rows in the output that don't
166///     directly map to a row in the right input (due to nulls filling where there
167///     is no match on the right).
168pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
169    match join_type {
170        JoinType::Inner => (true, true),
171        JoinType::Left => (true, false),
172        JoinType::Right => (false, true),
173        JoinType::Full => (false, false),
174        // No columns from the right side of the join can be referenced in output
175        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
176        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
177        // No columns from the left side of the join can be referenced in output
178        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
179        JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
180    }
181}
182
183/// See [`JoinType::on_lr_is_preserved`] for details.
184pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
185    join_type.on_lr_is_preserved()
186}
187
188/// Evaluates the columns referenced in the given expression to see if they refer
189/// only to the left or right columns
190#[derive(Debug)]
191struct ColumnChecker<'a> {
192    /// schema of left join input
193    left_schema: &'a DFSchema,
194    /// columns in left_schema, computed on demand
195    left_columns: Option<HashSet<ColumnReference<'a>>>,
196    /// schema of right join input
197    right_schema: &'a DFSchema,
198    /// columns in left_schema, computed on demand
199    right_columns: Option<HashSet<ColumnReference<'a>>>,
200}
201
202impl<'a> ColumnChecker<'a> {
203    fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
204        Self {
205            left_schema,
206            left_columns: None,
207            right_schema,
208            right_columns: None,
209        }
210    }
211
212    /// Return true if the expression references only columns from the left side of the join
213    fn is_left_only(&mut self, predicate: &Expr) -> bool {
214        if self.left_columns.is_none() {
215            self.left_columns = Some(schema_columns(self.left_schema));
216        }
217        has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
218    }
219
220    /// Return true if the expression references only columns from the right side of the join
221    fn is_right_only(&mut self, predicate: &Expr) -> bool {
222        if self.right_columns.is_none() {
223            self.right_columns = Some(schema_columns(self.right_schema));
224        }
225        has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
226    }
227}
228
229/// Determine whether the predicate can evaluate as the join conditions
230fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
231    let mut is_evaluate = true;
232    predicate.apply(|expr| match expr {
233        Expr::Column(_)
234        | Expr::Literal(_, _)
235        | Expr::Placeholder(_)
236        | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
237        Expr::Exists { .. }
238        | Expr::InSubquery(_)
239        | Expr::SetComparison(_)
240        | Expr::ScalarSubquery(_)
241        | Expr::OuterReferenceColumn(_, _)
242        | Expr::Unnest(_) => {
243            is_evaluate = false;
244            Ok(TreeNodeRecursion::Stop)
245        }
246        Expr::Alias(_)
247        | Expr::BinaryExpr(_)
248        | Expr::Like(_)
249        | Expr::SimilarTo(_)
250        | Expr::Not(_)
251        | Expr::IsNotNull(_)
252        | Expr::IsNull(_)
253        | Expr::IsTrue(_)
254        | Expr::IsFalse(_)
255        | Expr::IsUnknown(_)
256        | Expr::IsNotTrue(_)
257        | Expr::IsNotFalse(_)
258        | Expr::IsNotUnknown(_)
259        | Expr::Negative(_)
260        | Expr::Between(_)
261        | Expr::Case(_)
262        | Expr::Cast(_)
263        | Expr::TryCast(_)
264        | Expr::InList { .. }
265        | Expr::ScalarFunction(_)
266        | Expr::HigherOrderFunction(_)
267        | Expr::Lambda(_)
268        | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue),
269        // TODO: remove the next line after `Expr::Wildcard` is removed
270        #[expect(deprecated)]
271        Expr::AggregateFunction(_)
272        | Expr::WindowFunction(_)
273        | Expr::Wildcard { .. }
274        | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
275    })?;
276    Ok(is_evaluate)
277}
278
279/// examine OR clause to see if any useful clauses can be extracted and push down.
280/// extract at least one qual from each sub clauses of OR clause, then form the quals
281/// to new OR clause as predicate.
282///
283/// # Example
284/// ```text
285/// Filter: (a = c and a < 20) or (b = d and b > 10)
286///     join/crossjoin:
287///          TableScan: projection=[a, b]
288///          TableScan: projection=[c, d]
289/// ```
290///
291/// is optimized to
292///
293/// ```text
294/// Filter: (a = c and a < 20) or (b = d and b > 10)
295///     join/crossjoin:
296///          Filter: (a < 20) or (b > 10)
297///              TableScan: projection=[a, b]
298///          TableScan: projection=[c, d]
299/// ```
300///
301/// In general, predicates of this form:
302///
303/// ```sql
304/// (A AND B) OR (C AND D)
305/// ```
306///
307/// will be transformed to one of:
308///
309/// * `((A AND B) OR (C AND D)) AND (A OR C)`
310/// * `((A AND B) OR (C AND D)) AND ((A AND B) OR C)`
311/// * do nothing.
312fn extract_or_clauses_for_join<'a>(
313    filters: &'a [Expr],
314    schema_cols: &'a HashSet<ColumnReference>,
315) -> impl Iterator<Item = Expr> + 'a {
316    // new formed OR clauses and their column references
317    filters.iter().filter_map(move |expr| {
318        if let Expr::BinaryExpr(BinaryExpr {
319            left,
320            op: Operator::Or,
321            right,
322        }) = expr
323        {
324            let left_expr = extract_or_clause(left.as_ref(), schema_cols);
325            let right_expr = extract_or_clause(right.as_ref(), schema_cols);
326
327            // If nothing can be extracted from any sub clauses, do nothing for this OR clause.
328            if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
329                return Some(or(left_expr, right_expr));
330            }
331        }
332        None
333    })
334}
335
336/// extract qual from OR sub-clause.
337///
338/// A qual is extracted if it only contains set of column references in schema_columns.
339///
340/// For AND clause, we extract from both sub-clauses, then make new AND clause by extracted
341/// clauses if both extracted; Otherwise, use the extracted clause from any sub-clauses or None.
342///
343/// For OR clause, we extract from both sub-clauses, then make new OR clause by extracted clauses if both extracted;
344/// Otherwise, return None.
345///
346/// For other clause, apply the rule above to extract clause.
347fn extract_or_clause(
348    expr: &Expr,
349    schema_columns: &HashSet<ColumnReference>,
350) -> Option<Expr> {
351    let mut predicate = None;
352
353    match expr {
354        Expr::BinaryExpr(BinaryExpr {
355            left: l_expr,
356            op: Operator::Or,
357            right: r_expr,
358        }) => {
359            let l_expr = extract_or_clause(l_expr, schema_columns);
360            let r_expr = extract_or_clause(r_expr, schema_columns);
361
362            if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
363                predicate = Some(or(l_expr, r_expr));
364            }
365        }
366        Expr::BinaryExpr(BinaryExpr {
367            left: l_expr,
368            op: Operator::And,
369            right: r_expr,
370        }) => {
371            let l_expr = extract_or_clause(l_expr, schema_columns);
372            let r_expr = extract_or_clause(r_expr, schema_columns);
373
374            match (l_expr, r_expr) {
375                (Some(l_expr), Some(r_expr)) => {
376                    predicate = Some(and(l_expr, r_expr));
377                }
378                (Some(l_expr), None) => {
379                    predicate = Some(l_expr);
380                }
381                (None, Some(r_expr)) => {
382                    predicate = Some(r_expr);
383                }
384                (None, None) => {
385                    predicate = None;
386                }
387            }
388        }
389        _ => {
390            if has_all_column_refs(expr, schema_columns) {
391                predicate = Some(expr.clone());
392            }
393        }
394    }
395
396    predicate
397}
398
399/// push down join/cross-join
400fn push_down_all_join(
401    predicates: Vec<Expr>,
402    inferred_join_predicates: Vec<Expr>,
403    mut join: Join,
404    on_filter: Vec<Expr>,
405) -> Result<Transformed<LogicalPlan>> {
406    let is_inner_join = join.join_type == JoinType::Inner;
407    // Get pushable predicates from current optimizer state
408    let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
409
410    // The predicates can be divided to three categories:
411    // 1) can push through join to its children(left or right)
412    // 2) can be converted to join conditions if the join type is Inner
413    // 3) should be kept as filter conditions
414    let left_schema = join.left.schema();
415    let right_schema = join.right.schema();
416
417    let left_schema_columns = schema_columns(left_schema.as_ref());
418    let right_schema_columns = schema_columns(right_schema.as_ref());
419
420    let mut left_push = vec![];
421    let mut right_push = vec![];
422    let mut keep_predicates = vec![];
423    let mut join_conditions = vec![];
424    let mut checker = ColumnChecker::new(left_schema, right_schema);
425    for predicate in predicates {
426        if left_preserved && checker.is_left_only(&predicate) {
427            left_push.push(predicate);
428        } else if right_preserved && checker.is_right_only(&predicate) {
429            right_push.push(predicate);
430        } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
431            // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
432            // and convert to the join on condition
433            join_conditions.push(predicate);
434        } else {
435            keep_predicates.push(predicate);
436        }
437    }
438
439    // Push predicates inferred from the join expression
440    for predicate in inferred_join_predicates {
441        if checker.is_left_only(&predicate) {
442            left_push.push(predicate);
443        } else if checker.is_right_only(&predicate) {
444            right_push.push(predicate);
445        }
446    }
447
448    let mut on_filter_join_conditions = vec![];
449    let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
450
451    if !on_filter.is_empty() {
452        for on in on_filter {
453            if on_left_preserved && checker.is_left_only(&on) {
454                left_push.push(on)
455            } else if on_right_preserved && checker.is_right_only(&on) {
456                right_push.push(on)
457            } else {
458                on_filter_join_conditions.push(on)
459            }
460        }
461    }
462
463    // Extract from OR clause, generate new predicates for both side of join if possible.
464    // We only track the unpushable predicates above.
465    if left_preserved {
466        left_push.extend(extract_or_clauses_for_join(
467            &keep_predicates,
468            &left_schema_columns,
469        ));
470        left_push.extend(extract_or_clauses_for_join(
471            &join_conditions,
472            &left_schema_columns,
473        ));
474    }
475    if right_preserved {
476        right_push.extend(extract_or_clauses_for_join(
477            &keep_predicates,
478            &right_schema_columns,
479        ));
480        right_push.extend(extract_or_clauses_for_join(
481            &join_conditions,
482            &right_schema_columns,
483        ));
484    }
485
486    // For predicates from join filter, we should check with if a join side is preserved
487    // in term of join filtering.
488    if on_left_preserved {
489        left_push.extend(extract_or_clauses_for_join(
490            &on_filter_join_conditions,
491            &left_schema_columns,
492        ));
493    }
494    if on_right_preserved {
495        right_push.extend(extract_or_clauses_for_join(
496            &on_filter_join_conditions,
497            &right_schema_columns,
498        ));
499    }
500
501    if let Some(predicate) = conjunction(left_push) {
502        join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
503    }
504    if let Some(predicate) = conjunction(right_push) {
505        join.right =
506            Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
507    }
508
509    // Add any new join conditions as the non join predicates
510    join_conditions.extend(on_filter_join_conditions);
511    join.filter = conjunction(join_conditions);
512
513    // wrap the join on the filter whose predicates must be kept, if any
514    let plan = LogicalPlan::Join(join);
515    let plan = if let Some(predicate) = conjunction(keep_predicates) {
516        LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
517    } else {
518        plan
519    };
520    Ok(Transformed::yes(plan))
521}
522
523fn push_down_join(
524    join: Join,
525    parent_predicate: Option<&Expr>,
526) -> Result<Transformed<LogicalPlan>> {
527    // Split the parent predicate into individual conjunctive parts.
528    let predicates = parent_predicate
529        .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
530
531    // Extract conjunctions from the JOIN's ON filter, if present.
532    let on_filters = join
533        .filter
534        .as_ref()
535        .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
536
537    // Are there any new join predicates that can be inferred from the filter expressions?
538    let inferred_join_predicates = with_debug_timing("infer_join_predicates", || {
539        infer_join_predicates(&join, &predicates, &on_filters)
540    })?;
541
542    if log_enabled!(Level::Debug) {
543        debug!(
544            "push_down_filter: join_type={:?}, parent_predicates={}, on_filters={}, inferred_join_predicates={}",
545            join.join_type,
546            predicates.len(),
547            on_filters.len(),
548            inferred_join_predicates.len()
549        );
550    }
551
552    if on_filters.is_empty()
553        && predicates.is_empty()
554        && inferred_join_predicates.is_empty()
555    {
556        return Ok(Transformed::no(LogicalPlan::Join(join)));
557    }
558
559    push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
560}
561
562/// Extracts any equi-join join predicates from the given filter expressions.
563///
564/// Parameters
565/// * `join` the join in question
566///
567/// * `predicates` the pushed down filter expression
568///
569/// * `on_filters` filters from the join ON clause that have not already been
570///   identified as join predicates
571fn infer_join_predicates(
572    join: &Join,
573    predicates: &[Expr],
574    on_filters: &[Expr],
575) -> Result<Vec<Expr>> {
576    // Only allow both side key is column.
577    let join_col_keys = join
578        .on
579        .iter()
580        .filter_map(|(l, r)| {
581            let left_col = l.try_as_col()?;
582            let right_col = r.try_as_col()?;
583            Some((left_col, right_col))
584        })
585        .collect::<Vec<_>>();
586
587    let join_type = join.join_type;
588
589    let mut inferred_predicates = InferredPredicates::new(join_type);
590
591    infer_join_predicates_from_predicates(
592        &join_col_keys,
593        predicates,
594        &mut inferred_predicates,
595    )?;
596
597    infer_join_predicates_from_on_filters(
598        &join_col_keys,
599        join_type,
600        on_filters,
601        &mut inferred_predicates,
602    )?;
603
604    Ok(inferred_predicates.predicates)
605}
606
607/// Inferred predicates collector.
608/// When the JoinType is not Inner, we need to detect whether the inferred predicate can strictly
609/// filter out NULL, otherwise ignore it. e.g.
610/// ```text
611/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL;
612/// ```
613/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate will be pushed down to
614/// the left side, resulting in the wrong result.
615struct InferredPredicates {
616    predicates: Vec<Expr>,
617    is_inner_join: bool,
618}
619
620impl InferredPredicates {
621    fn new(join_type: JoinType) -> Self {
622        Self {
623            predicates: vec![],
624            is_inner_join: join_type == JoinType::Inner,
625        }
626    }
627
628    fn try_build_predicate(
629        &mut self,
630        predicate: Expr,
631        replace_map: &HashMap<&Column, &Column>,
632    ) -> Result<()> {
633        if self.is_inner_join
634            || matches!(
635                is_restrict_null_predicate(
636                    predicate.clone(),
637                    replace_map.keys().cloned()
638                ),
639                Ok(true)
640            )
641        {
642            self.predicates.push(replace_col(predicate, replace_map)?);
643        }
644
645        Ok(())
646    }
647}
648
649/// Infer predicates from the pushed down predicates.
650///
651/// Parameters
652/// * `join_col_keys` column pairs from the join ON clause
653///
654/// * `predicates` the pushed down predicates
655///
656/// * `inferred_predicates` the inferred results
657fn infer_join_predicates_from_predicates(
658    join_col_keys: &[(&Column, &Column)],
659    predicates: &[Expr],
660    inferred_predicates: &mut InferredPredicates,
661) -> Result<()> {
662    infer_join_predicates_impl::<true, true>(
663        join_col_keys,
664        predicates,
665        inferred_predicates,
666    )
667}
668
669/// Infer predicates from the join filter.
670///
671/// Parameters
672/// * `join_col_keys` column pairs from the join ON clause
673///
674/// * `join_type` the JoinType of Join
675///
676/// * `on_filters` filters from the join ON clause that have not already been
677///   identified as join predicates
678///
679/// * `inferred_predicates` the inferred results
680fn infer_join_predicates_from_on_filters(
681    join_col_keys: &[(&Column, &Column)],
682    join_type: JoinType,
683    on_filters: &[Expr],
684    inferred_predicates: &mut InferredPredicates,
685) -> Result<()> {
686    match join_type {
687        JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
688        JoinType::Inner => infer_join_predicates_impl::<true, true>(
689            join_col_keys,
690            on_filters,
691            inferred_predicates,
692        ),
693        JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
694            infer_join_predicates_impl::<true, false>(
695                join_col_keys,
696                on_filters,
697                inferred_predicates,
698            )
699        }
700        JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
701            infer_join_predicates_impl::<false, true>(
702                join_col_keys,
703                on_filters,
704                inferred_predicates,
705            )
706        }
707    }
708}
709
710/// Infer predicates from the given predicates.
711///
712/// Parameters
713/// * `join_col_keys` column pairs from the join ON clause
714///
715/// * `input_predicates` the given predicates. It can be the pushed down predicates,
716///   or it can be the filters of the Join
717///
718/// * `inferred_predicates` the inferred results
719///
720/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate can
721///   be inferred from the left table related predicate
722///
723/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate can
724///   be inferred from the right table related predicate
725fn infer_join_predicates_impl<
726    const ENABLE_LEFT_TO_RIGHT: bool,
727    const ENABLE_RIGHT_TO_LEFT: bool,
728>(
729    join_col_keys: &[(&Column, &Column)],
730    input_predicates: &[Expr],
731    inferred_predicates: &mut InferredPredicates,
732) -> Result<()> {
733    for predicate in input_predicates {
734        let mut join_cols_to_replace = HashMap::new();
735
736        for &col in &predicate.column_refs() {
737            for (l, r) in join_col_keys.iter() {
738                if ENABLE_LEFT_TO_RIGHT && col == *l {
739                    join_cols_to_replace.insert(col, *r);
740                    break;
741                }
742                if ENABLE_RIGHT_TO_LEFT && col == *r {
743                    join_cols_to_replace.insert(col, *l);
744                    break;
745                }
746            }
747        }
748        if join_cols_to_replace.is_empty() {
749            continue;
750        }
751
752        inferred_predicates
753            .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
754    }
755    Ok(())
756}
757
758impl OptimizerRule for PushDownFilter {
759    fn name(&self) -> &str {
760        "push_down_filter"
761    }
762
763    fn apply_order(&self) -> Option<ApplyOrder> {
764        Some(ApplyOrder::TopDown)
765    }
766
767    fn supports_rewrite(&self) -> bool {
768        true
769    }
770
771    fn rewrite(
772        &self,
773        plan: LogicalPlan,
774        config: &dyn OptimizerConfig,
775    ) -> Result<Transformed<LogicalPlan>> {
776        let _ = config.options();
777        if let LogicalPlan::Join(join) = plan {
778            return push_down_join(join, None);
779        };
780
781        let plan_schema = Arc::clone(plan.schema());
782
783        let LogicalPlan::Filter(mut filter) = plan else {
784            return Ok(Transformed::no(plan));
785        };
786
787        let predicate = split_conjunction_owned(filter.predicate.clone());
788        let old_predicate_len = predicate.len();
789        let new_predicates =
790            with_debug_timing("simplify_predicates", || simplify_predicates(predicate))?;
791        if log_enabled!(Level::Debug) {
792            debug!(
793                "push_down_filter: simplify_predicates old_count={}, new_count={}",
794                old_predicate_len,
795                new_predicates.len()
796            );
797        }
798        if old_predicate_len != new_predicates.len() {
799            let Some(new_predicate) = conjunction(new_predicates) else {
800                // new_predicates is empty - remove the filter entirely
801                // Return the child plan without the filter
802                return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
803            };
804            filter.predicate = new_predicate;
805        }
806
807        // If the child has a fetch (limit) or skip (offset), pushing a filter
808        // below it would change semantics: the limit/offset should apply before
809        // the filter, not after.
810        if filter.input.fetch()?.is_some() || filter.input.skip()?.is_some() {
811            return Ok(Transformed::no(LogicalPlan::Filter(filter)));
812        }
813
814        match Arc::unwrap_or_clone(filter.input) {
815            LogicalPlan::Filter(child_filter) => {
816                // child filters first to preserve execution order
817                let new_predicates = split_conjunction_owned(child_filter.predicate)
818                    .into_iter()
819                    .chain(split_conjunction_owned(filter.predicate))
820                    // use IndexSet to remove duplicates while preserving predicate order
821                    .collect::<IndexSet<_>>();
822
823                let Some(new_predicate) = conjunction(new_predicates) else {
824                    return plan_err!("at least one expression exists");
825                };
826
827                let new_filter = LogicalPlan::Filter(Filter::try_new(
828                    new_predicate,
829                    child_filter.input,
830                )?);
831
832                self.rewrite(new_filter, config)
833            }
834            LogicalPlan::Repartition(repartition) => {
835                let new_filter =
836                    Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
837                        .map(LogicalPlan::Filter)?;
838                insert_below(LogicalPlan::Repartition(repartition), new_filter)
839            }
840            LogicalPlan::Distinct(distinct) => {
841                let new_filter =
842                    Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
843                        .map(LogicalPlan::Filter)?;
844                insert_below(LogicalPlan::Distinct(distinct), new_filter)
845            }
846            LogicalPlan::Sort(sort) => {
847                let new_filter =
848                    Filter::try_new(filter.predicate, Arc::clone(&sort.input))
849                        .map(LogicalPlan::Filter)?;
850                insert_below(LogicalPlan::Sort(sort), new_filter)
851            }
852            LogicalPlan::SubqueryAlias(subquery_alias) => {
853                let mut replace_map = HashMap::new();
854                for (i, (qualifier, field)) in
855                    subquery_alias.input.schema().iter().enumerate()
856                {
857                    let (sub_qualifier, sub_field) =
858                        subquery_alias.schema.qualified_field(i);
859                    replace_map.insert(
860                        qualified_name(sub_qualifier, sub_field.name()),
861                        Expr::Column(Column::new(qualifier.cloned(), field.name())),
862                    );
863                }
864                let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
865
866                let new_filter = LogicalPlan::Filter(Filter::try_new(
867                    new_predicate,
868                    Arc::clone(&subquery_alias.input),
869                )?);
870                insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
871            }
872            LogicalPlan::Projection(projection) => {
873                let predicates = split_conjunction_owned(filter.predicate.clone());
874                let (new_projection, keep_predicate) =
875                    rewrite_projection(predicates, projection)?;
876                if new_projection.transformed {
877                    match keep_predicate {
878                        None => Ok(new_projection),
879                        Some(keep_predicate) => new_projection.map_data(|child_plan| {
880                            Filter::try_new(keep_predicate, Arc::new(child_plan))
881                                .map(LogicalPlan::Filter)
882                        }),
883                    }
884                } else {
885                    filter.input = Arc::new(new_projection.data);
886                    Ok(Transformed::no(LogicalPlan::Filter(filter)))
887                }
888            }
889            LogicalPlan::Unnest(mut unnest) => {
890                let predicates = split_conjunction_owned(filter.predicate.clone());
891                let mut non_unnest_predicates = vec![];
892                let mut unnest_predicates = vec![];
893                let mut unnest_struct_columns = vec![];
894
895                for idx in &unnest.struct_type_columns {
896                    let (sub_qualifier, field) =
897                        unnest.input.schema().qualified_field(*idx);
898                    let field_name = field.name().clone();
899
900                    if let DataType::Struct(children) = field.data_type() {
901                        for child in children {
902                            let child_name = child.name().clone();
903                            unnest_struct_columns.push(Column::new(
904                                sub_qualifier.cloned(),
905                                format!("{field_name}.{child_name}"),
906                            ));
907                        }
908                    }
909                }
910
911                for predicate in predicates {
912                    // collect all the Expr::Column in predicate recursively
913                    let mut accum: HashSet<Column> = HashSet::new();
914                    expr_to_columns(&predicate, &mut accum)?;
915
916                    let contains_list_columns =
917                        unnest.list_type_columns.iter().any(|(_, unnest_list)| {
918                            accum.contains(&unnest_list.output_column)
919                        });
920                    let contains_struct_columns =
921                        unnest_struct_columns.iter().any(|c| accum.contains(c));
922
923                    if contains_list_columns || contains_struct_columns {
924                        unnest_predicates.push(predicate);
925                    } else {
926                        non_unnest_predicates.push(predicate);
927                    }
928                }
929
930                // Unnest predicates should not be pushed down.
931                // If no non-unnest predicates exist, early return
932                if non_unnest_predicates.is_empty() {
933                    filter.input = Arc::new(LogicalPlan::Unnest(unnest));
934                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
935                }
936
937                // Push down non-unnest filter predicate
938                // Unnest
939                //   Unnest Input (Projection)
940                // -> rewritten to
941                // Unnest
942                //   Filter
943                //     Unnest Input (Projection)
944
945                let unnest_input = std::mem::take(&mut unnest.input);
946
947                let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
948                    conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty.
949                    unnest_input,
950                )?);
951
952                // Directly assign new filter plan as the new unnest's input.
953                // The new filter plan will go through another rewrite pass since the rule itself
954                // is applied recursively to all the child from top to down
955                let unnest_plan =
956                    insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
957
958                match conjunction(unnest_predicates) {
959                    None => Ok(unnest_plan),
960                    Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
961                        Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
962                    ))),
963                }
964            }
965            LogicalPlan::Union(ref union) => {
966                let mut inputs = Vec::with_capacity(union.inputs.len());
967                for input in &union.inputs {
968                    let mut replace_map = HashMap::new();
969                    for (i, (qualifier, field)) in input.schema().iter().enumerate() {
970                        let (union_qualifier, union_field) =
971                            union.schema.qualified_field(i);
972                        replace_map.insert(
973                            qualified_name(union_qualifier, union_field.name()),
974                            Expr::Column(Column::new(qualifier.cloned(), field.name())),
975                        );
976                    }
977
978                    let push_predicate =
979                        replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
980                    inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
981                        push_predicate,
982                        Arc::clone(input),
983                    )?)))
984                }
985                Ok(Transformed::yes(LogicalPlan::Union(Union {
986                    inputs,
987                    schema: Arc::clone(&plan_schema),
988                })))
989            }
990            LogicalPlan::Aggregate(agg) => {
991                // We can push down Predicate which in groupby_expr.
992                let group_expr_columns = agg
993                    .group_expr
994                    .iter()
995                    .map(|e| {
996                        let (relation, name) = e.qualified_name();
997                        Column::new(relation, name)
998                    })
999                    .collect::<HashSet<_>>();
1000
1001                let predicates = split_conjunction_owned(filter.predicate);
1002
1003                let mut keep_predicates = vec![];
1004                let mut push_predicates = vec![];
1005                for expr in predicates {
1006                    let cols = expr.column_refs();
1007                    if cols.iter().all(|c| group_expr_columns.contains(c)) {
1008                        push_predicates.push(expr);
1009                    } else {
1010                        keep_predicates.push(expr);
1011                    }
1012                }
1013
1014                // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
1015                // After push, we need to replace `a+b` with Column(a)+Column(b)
1016                // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
1017                let mut replace_map = HashMap::new();
1018                for expr in &agg.group_expr {
1019                    replace_map.insert(expr.schema_name().to_string(), expr.clone());
1020                }
1021                let replaced_push_predicates = push_predicates
1022                    .into_iter()
1023                    .map(|expr| replace_cols_by_name(expr, &replace_map))
1024                    .collect::<Result<Vec<_>>>()?;
1025
1026                let agg_input = Arc::clone(&agg.input);
1027                Transformed::yes(LogicalPlan::Aggregate(agg))
1028                    .transform_data(|new_plan| {
1029                        // If we have a filter to push, we push it down to the input of the aggregate
1030                        if let Some(predicate) = conjunction(replaced_push_predicates) {
1031                            let new_filter = make_filter(predicate, agg_input)?;
1032                            insert_below(new_plan, new_filter)
1033                        } else {
1034                            Ok(Transformed::no(new_plan))
1035                        }
1036                    })?
1037                    .map_data(|child_plan| {
1038                        // if there are any remaining predicates we can't push, add them
1039                        // back as a filter
1040                        if let Some(predicate) = conjunction(keep_predicates) {
1041                            make_filter(predicate, Arc::new(child_plan))
1042                        } else {
1043                            Ok(child_plan)
1044                        }
1045                    })
1046            }
1047            // Tries to push filters based on the partition key(s) of the window function(s) used.
1048            // Example:
1049            //   Before:
1050            //     Filter: (a > 1) and (b > 1) and (c > 1)
1051            //      Window: func() PARTITION BY [a] ...
1052            //   ---
1053            //   After:
1054            //     Filter: (b > 1) and (c > 1)
1055            //      Window: func() PARTITION BY [a] ...
1056            //        Filter: (a > 1)
1057            LogicalPlan::Window(window) => {
1058                // Retrieve the set of potential partition keys where we can push filters by.
1059                // Unlike aggregations, where there is only one statement per SELECT, there can be
1060                // multiple window functions, each with potentially different partition keys.
1061                // Therefore, we need to ensure that any potential partition key returned is used in
1062                // ALL window functions. Otherwise, filters cannot be pushed by through that column.
1063                let extract_partition_keys = |func: &WindowFunction| {
1064                    func.params
1065                        .partition_by
1066                        .iter()
1067                        .map(|c| {
1068                            let (relation, name) = c.qualified_name();
1069                            Column::new(relation, name)
1070                        })
1071                        .collect::<HashSet<_>>()
1072                };
1073                let potential_partition_keys = window
1074                    .window_expr
1075                    .iter()
1076                    .map(|e| {
1077                        match e {
1078                            Expr::WindowFunction(window_func) => {
1079                                extract_partition_keys(window_func)
1080                            }
1081                            Expr::Alias(alias) => {
1082                                if let Expr::WindowFunction(window_func) =
1083                                    alias.expr.as_ref()
1084                                {
1085                                    extract_partition_keys(window_func)
1086                                } else {
1087                                    // window functions expressions are only Expr::WindowFunction
1088                                    unreachable!()
1089                                }
1090                            }
1091                            _ => {
1092                                // window functions expressions are only Expr::WindowFunction
1093                                unreachable!()
1094                            }
1095                        }
1096                    })
1097                    // performs the set intersection of the partition keys of all window functions,
1098                    // returning only the common ones
1099                    .reduce(|a, b| &a & &b)
1100                    .unwrap_or_default();
1101
1102                let predicates = split_conjunction_owned(filter.predicate);
1103                let mut keep_predicates = vec![];
1104                let mut push_predicates = vec![];
1105                for expr in predicates {
1106                    let cols = expr.column_refs();
1107                    if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1108                        push_predicates.push(expr);
1109                    } else {
1110                        keep_predicates.push(expr);
1111                    }
1112                }
1113
1114                // Unlike with aggregations, there are no cases where we have to replace, e.g.,
1115                // `a+b` with Column(a)+Column(b). This is because partition expressions are not
1116                // available as standalone columns to the user. For example, while an aggregation on
1117                // `a+b` becomes Column(a + b), in a window partition it becomes
1118                // `func() PARTITION BY [a + b] ...`. Thus, filters on expressions always remain in
1119                // place, so we can use `push_predicates` directly. This is consistent with other
1120                // optimizers, such as the one used by Postgres.
1121
1122                let window_input = Arc::clone(&window.input);
1123                Transformed::yes(LogicalPlan::Window(window))
1124                    .transform_data(|new_plan| {
1125                        // If we have a filter to push, we push it down to the input of the window
1126                        if let Some(predicate) = conjunction(push_predicates) {
1127                            let new_filter = make_filter(predicate, window_input)?;
1128                            insert_below(new_plan, new_filter)
1129                        } else {
1130                            Ok(Transformed::no(new_plan))
1131                        }
1132                    })?
1133                    .map_data(|child_plan| {
1134                        // if there are any remaining predicates we can't push, add them
1135                        // back as a filter
1136                        if let Some(predicate) = conjunction(keep_predicates) {
1137                            make_filter(predicate, Arc::new(child_plan))
1138                        } else {
1139                            Ok(child_plan)
1140                        }
1141                    })
1142            }
1143            LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1144            LogicalPlan::TableScan(scan) => {
1145                let filter_predicates = split_conjunction(&filter.predicate);
1146
1147                // Filters containing scalar subqueries cannot be pushed to
1148                // providers because the subquery result is not available
1149                // until execution time.
1150                let (subquery_filters, pushdown_candidates): (Vec<&Expr>, Vec<&Expr>) =
1151                    filter_predicates
1152                        .into_iter()
1153                        .partition(|pred| pred.contains_scalar_subquery());
1154
1155                let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1156                    pushdown_candidates
1157                        .into_iter()
1158                        .partition(|pred| pred.is_volatile());
1159
1160                // Check which non-volatile filters are supported by source
1161                let supported_filters = scan
1162                    .source
1163                    .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1164                assert_eq_or_internal_err!(
1165                    non_volatile_filters.len(),
1166                    supported_filters.len(),
1167                    "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1168                    supported_filters.len(),
1169                    non_volatile_filters.len()
1170                );
1171
1172                // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type
1173                let zip = non_volatile_filters.into_iter().zip(supported_filters);
1174
1175                let new_scan_filters = zip
1176                    .clone()
1177                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1178                    .map(|(pred, _)| pred);
1179
1180                // Add new scan filters
1181                let new_scan_filters: Vec<Expr> = scan
1182                    .filters
1183                    .iter()
1184                    .chain(new_scan_filters)
1185                    .unique()
1186                    .cloned()
1187                    .collect();
1188
1189                // Compose predicates to be of `Unsupported` or `Inexact` pushdown type,
1190                // and also include volatile and subquery-containing filters
1191                let new_predicate: Vec<Expr> = zip
1192                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1193                    .map(|(pred, _)| pred)
1194                    .chain(volatile_filters)
1195                    .chain(subquery_filters)
1196                    .cloned()
1197                    .collect();
1198
1199                let new_scan = LogicalPlan::TableScan(TableScan {
1200                    filters: new_scan_filters,
1201                    ..scan
1202                });
1203
1204                Transformed::yes(new_scan).transform_data(|new_scan| {
1205                    if let Some(predicate) = conjunction(new_predicate) {
1206                        make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1207                    } else {
1208                        Ok(Transformed::no(new_scan))
1209                    }
1210                })
1211            }
1212            LogicalPlan::Extension(extension_plan) => {
1213                // This check prevents the Filter from being removed when the extension node has no children,
1214                // so we return the original Filter unchanged.
1215                if extension_plan.node.inputs().is_empty() {
1216                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1217                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1218                }
1219                let prevent_cols =
1220                    extension_plan.node.prevent_predicate_push_down_columns();
1221
1222                // determine if we can push any predicates down past the extension node
1223
1224                // each element is true for push, false to keep
1225                let predicate_push_or_keep = split_conjunction(&filter.predicate)
1226                    .iter()
1227                    .map(|expr| {
1228                        let cols = expr.column_refs();
1229                        if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1230                            Ok(false) // No push (keep)
1231                        } else {
1232                            Ok(true) // push
1233                        }
1234                    })
1235                    .collect::<Result<Vec<_>>>()?;
1236
1237                // all predicates are kept, no changes needed
1238                if predicate_push_or_keep.iter().all(|&x| !x) {
1239                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1240                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1241                }
1242
1243                // going to push some predicates down, so split the predicates
1244                let mut keep_predicates = vec![];
1245                let mut push_predicates = vec![];
1246                for (push, expr) in predicate_push_or_keep
1247                    .into_iter()
1248                    .zip(split_conjunction_owned(filter.predicate))
1249                {
1250                    if !push {
1251                        keep_predicates.push(expr);
1252                    } else {
1253                        push_predicates.push(expr);
1254                    }
1255                }
1256
1257                let new_children = match conjunction(push_predicates) {
1258                    Some(predicate) => extension_plan
1259                        .node
1260                        .inputs()
1261                        .into_iter()
1262                        .map(|child| {
1263                            Ok(LogicalPlan::Filter(Filter::try_new(
1264                                predicate.clone(),
1265                                Arc::new(child.clone()),
1266                            )?))
1267                        })
1268                        .collect::<Result<Vec<_>>>()?,
1269                    None => extension_plan.node.inputs().into_iter().cloned().collect(),
1270                };
1271                // extension with new inputs.
1272                let child_plan = LogicalPlan::Extension(extension_plan);
1273                let new_extension =
1274                    child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1275
1276                let new_plan = match conjunction(keep_predicates) {
1277                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1278                        predicate,
1279                        Arc::new(new_extension),
1280                    )?),
1281                    None => new_extension,
1282                };
1283                Ok(Transformed::yes(new_plan))
1284            }
1285            child => {
1286                filter.input = Arc::new(child);
1287                Ok(Transformed::no(LogicalPlan::Filter(filter)))
1288            }
1289        }
1290    }
1291}
1292
1293/// Attempts to push `predicate` into a `FilterExec` below `projection
1294///
1295/// # Returns
1296/// (plan, remaining_predicate)
1297///
1298/// `plan` is a LogicalPlan for `projection` with possibly a new FilterExec below it.
1299/// `remaining_predicate` is any part of the predicate that could not be pushed down
1300///
1301/// # Args
1302/// - predicates: Split predicates like `[foo=5, bar=6]`
1303/// - projection: The target projection plan to push down the predicates
1304///
1305/// # Example
1306///
1307/// Pushing a predicate like `foo=5 AND bar=6` with an input plan like this:
1308///
1309/// ```text
1310/// Projection(foo, c+d as bar)
1311/// ```
1312///
1313/// Might result in returning `remaining_predicate` of `bar=6` and a plan like
1314///
1315/// ```text
1316/// Projection(foo, c+d as bar)
1317///  Filter(foo=5)
1318///   ...
1319/// ```
1320fn rewrite_projection(
1321    predicates: Vec<Expr>,
1322    mut projection: Projection,
1323) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1324    // Partition projection expressions into non-pushable vs pushable.
1325    // Non-pushable expressions are volatile (must not be duplicated) or
1326    // MoveTowardsLeafNodes (cheap expressions like get_field where re-inlining
1327    // into a filter causes optimizer instability — ExtractLeafExpressions will
1328    // undo the push-down, creating an infinite loop that runs until the
1329    // iteration limit is hit).
1330    let (non_pushable_map, pushable_map): (HashMap<_, _>, HashMap<_, _>) = projection
1331        .schema
1332        .iter()
1333        .zip(projection.expr.iter())
1334        .map(|((qualifier, field), expr)| {
1335            // strip alias, as they should not be part of filters
1336            let expr = expr.clone().unalias();
1337
1338            (qualified_name(qualifier, field.name()), expr)
1339        })
1340        .partition(|(_, value)| {
1341            value.is_volatile()
1342                || value.placement() == ExpressionPlacement::MoveTowardsLeafNodes
1343        });
1344
1345    let mut push_predicates = vec![];
1346    let mut keep_predicates = vec![];
1347    for expr in predicates {
1348        if contain(&expr, &non_pushable_map) {
1349            keep_predicates.push(expr);
1350        } else {
1351            push_predicates.push(expr);
1352        }
1353    }
1354
1355    match conjunction(push_predicates) {
1356        Some(expr) => {
1357            // re-write all filters based on this projection
1358            // E.g. in `Filter: b\n  Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
1359            let new_filter = LogicalPlan::Filter(Filter::try_new(
1360                replace_cols_by_name(expr, &pushable_map)?,
1361                std::mem::take(&mut projection.input),
1362            )?);
1363
1364            projection.input = Arc::new(new_filter);
1365
1366            Ok((
1367                Transformed::yes(LogicalPlan::Projection(projection)),
1368                conjunction(keep_predicates),
1369            ))
1370        }
1371        None => Ok((
1372            Transformed::no(LogicalPlan::Projection(projection)),
1373            conjunction(keep_predicates),
1374        )),
1375    }
1376}
1377
1378/// Creates a new LogicalPlan::Filter node.
1379pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1380    Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1381}
1382
1383/// Replace the existing child of the single input node with `new_child`.
1384///
1385/// Starting:
1386/// ```text
1387/// plan
1388///   child
1389/// ```
1390///
1391/// Ending:
1392/// ```text
1393/// plan
1394///   new_child
1395/// ```
1396fn insert_below(
1397    plan: LogicalPlan,
1398    new_child: LogicalPlan,
1399) -> Result<Transformed<LogicalPlan>> {
1400    let mut new_child = Some(new_child);
1401    let transformed_plan = plan.map_children(|_child| {
1402        if let Some(new_child) = new_child.take() {
1403            Ok(Transformed::yes(new_child))
1404        } else {
1405            // already took the new child
1406            internal_err!("node had more than one input")
1407        }
1408    })?;
1409
1410    // make sure we did the actual replacement
1411    assert_or_internal_err!(new_child.is_none(), "node had no inputs");
1412
1413    Ok(transformed_plan)
1414}
1415
1416impl PushDownFilter {
1417    #[expect(missing_docs)]
1418    pub fn new() -> Self {
1419        Self {}
1420    }
1421}
1422
1423fn with_debug_timing<T, F>(label: &'static str, f: F) -> Result<T>
1424where
1425    F: FnOnce() -> Result<T>,
1426{
1427    if !log_enabled!(Level::Debug) {
1428        return f();
1429    }
1430    let start = Instant::now();
1431    let result = f();
1432    debug!(
1433        "push_down_filter_timing: section={label}, elapsed_us={}",
1434        start.elapsed().as_micros()
1435    );
1436    result
1437}
1438
1439/// replaces columns by its name on the projection.
1440pub fn replace_cols_by_name(
1441    e: Expr,
1442    replace_map: &HashMap<String, Expr>,
1443) -> Result<Expr> {
1444    e.transform_up(|expr| {
1445        Ok(if let Expr::Column(c) = &expr {
1446            match replace_map.get(&c.flat_name()) {
1447                Some(new_c) => Transformed::yes(new_c.clone()),
1448                None => Transformed::no(expr),
1449            }
1450        } else {
1451            Transformed::no(expr)
1452        })
1453    })
1454    .data()
1455}
1456
1457/// check whether the expression uses the columns in `check_map`.
1458fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1459    let mut is_contain = false;
1460    e.apply(|expr| {
1461        Ok(if let Expr::Column(c) = &expr {
1462            match check_map.get(&c.flat_name()) {
1463                Some(_) => {
1464                    is_contain = true;
1465                    TreeNodeRecursion::Stop
1466                }
1467                None => TreeNodeRecursion::Continue,
1468            }
1469        } else {
1470            TreeNodeRecursion::Continue
1471        })
1472    })
1473    .unwrap();
1474    is_contain
1475}
1476
1477#[cfg(test)]
1478mod tests {
1479    use std::cmp::Ordering;
1480    use std::fmt::{Debug, Formatter};
1481
1482    use arrow::datatypes::{Field, Schema, SchemaRef};
1483    use async_trait::async_trait;
1484
1485    use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1486    use datafusion_expr::expr::ScalarFunction;
1487    use datafusion_expr::logical_plan::table_scan;
1488    use datafusion_expr::{
1489        ColumnarValue, ExprFunctionExt, Extension, LogicalPlanBuilder,
1490        ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType,
1491        UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, col, in_list,
1492        in_subquery, lit,
1493    };
1494
1495    use crate::OptimizerContext;
1496    use crate::assert_optimized_plan_eq_snapshot;
1497    use crate::optimizer::Optimizer;
1498    use crate::simplify_expressions::SimplifyExpressions;
1499    use crate::test::udfs::leaf_udf_expr;
1500    use crate::test::*;
1501    use datafusion_expr::test::function_stub::sum;
1502    use insta::assert_snapshot;
1503
1504    use super::*;
1505
1506    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1507
1508    macro_rules! assert_optimized_plan_equal {
1509        (
1510            $plan:expr,
1511            @ $expected:literal $(,)?
1512        ) => {{
1513            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1514            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1515            assert_optimized_plan_eq_snapshot!(
1516                optimizer_ctx,
1517                rules,
1518                $plan,
1519                @ $expected,
1520            )
1521        }};
1522    }
1523
1524    macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1525        (
1526            $plan:expr,
1527            @ $expected:literal $(,)?
1528        ) => {{
1529            let optimizer = Optimizer::with_rules(vec![
1530                Arc::new(SimplifyExpressions::new()),
1531                Arc::new(PushDownFilter::new()),
1532            ]);
1533            let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1534            assert_snapshot!(optimized_plan, @ $expected);
1535            Ok::<(), DataFusionError>(())
1536        }};
1537    }
1538
1539    #[test]
1540    fn filter_before_projection() -> Result<()> {
1541        let table_scan = test_table_scan()?;
1542        let plan = LogicalPlanBuilder::from(table_scan)
1543            .project(vec![col("a"), col("b")])?
1544            .filter(col("a").eq(lit(1i64)))?
1545            .build()?;
1546        // filter is before projection
1547        assert_optimized_plan_equal!(
1548            plan,
1549            @r"
1550        Projection: test.a, test.b
1551          TableScan: test, full_filters=[test.a = Int64(1)]
1552        "
1553        )
1554    }
1555
1556    #[test]
1557    fn filter_after_limit() -> Result<()> {
1558        let table_scan = test_table_scan()?;
1559        let plan = LogicalPlanBuilder::from(table_scan)
1560            .project(vec![col("a"), col("b")])?
1561            .limit(0, Some(10))?
1562            .filter(col("a").eq(lit(1i64)))?
1563            .build()?;
1564        // filter is before single projection
1565        assert_optimized_plan_equal!(
1566            plan,
1567            @r"
1568        Filter: test.a = Int64(1)
1569          Limit: skip=0, fetch=10
1570            Projection: test.a, test.b
1571              TableScan: test
1572        "
1573        )
1574    }
1575
1576    #[test]
1577    fn filter_no_columns() -> Result<()> {
1578        let table_scan = test_table_scan()?;
1579        let plan = LogicalPlanBuilder::from(table_scan)
1580            .filter(lit(0i64).eq(lit(1i64)))?
1581            .build()?;
1582        assert_optimized_plan_equal!(
1583            plan,
1584            @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1585        )
1586    }
1587
1588    #[test]
1589    fn filter_jump_2_plans() -> Result<()> {
1590        let table_scan = test_table_scan()?;
1591        let plan = LogicalPlanBuilder::from(table_scan)
1592            .project(vec![col("a"), col("b"), col("c")])?
1593            .project(vec![col("c"), col("b")])?
1594            .filter(col("a").eq(lit(1i64)))?
1595            .build()?;
1596        // filter is before double projection
1597        assert_optimized_plan_equal!(
1598            plan,
1599            @r"
1600        Projection: test.c, test.b
1601          Projection: test.a, test.b, test.c
1602            TableScan: test, full_filters=[test.a = Int64(1)]
1603        "
1604        )
1605    }
1606
1607    #[test]
1608    fn filter_move_agg() -> Result<()> {
1609        let table_scan = test_table_scan()?;
1610        let plan = LogicalPlanBuilder::from(table_scan)
1611            .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1612            .filter(col("a").gt(lit(10i64)))?
1613            .build()?;
1614        // filter of key aggregation is commutative
1615        assert_optimized_plan_equal!(
1616            plan,
1617            @r"
1618        Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1619          TableScan: test, full_filters=[test.a > Int64(10)]
1620        "
1621        )
1622    }
1623
1624    /// verifies that filters with unusual column names are pushed down through aggregate operators
1625    #[test]
1626    fn filter_move_agg_special() -> Result<()> {
1627        let schema = Schema::new(vec![
1628            Field::new("$a", DataType::UInt32, false),
1629            Field::new("$b", DataType::UInt32, false),
1630            Field::new("$c", DataType::UInt32, false),
1631        ]);
1632        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1633
1634        let plan = LogicalPlanBuilder::from(table_scan)
1635            .aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])?
1636            .filter(col("$a").gt(lit(10i64)))?
1637            .build()?;
1638        // filter of key aggregation is commutative
1639        assert_optimized_plan_equal!(
1640            plan,
1641            @r"
1642        Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]]
1643          TableScan: test, full_filters=[test.$a > Int64(10)]
1644        "
1645        )
1646    }
1647
1648    #[test]
1649    fn filter_complex_group_by() -> Result<()> {
1650        let table_scan = test_table_scan()?;
1651        let plan = LogicalPlanBuilder::from(table_scan)
1652            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1653            .filter(col("b").gt(lit(10i64)))?
1654            .build()?;
1655        assert_optimized_plan_equal!(
1656            plan,
1657            @r"
1658        Filter: test.b > Int64(10)
1659          Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1660            TableScan: test
1661        "
1662        )
1663    }
1664
1665    #[test]
1666    fn push_agg_need_replace_expr() -> Result<()> {
1667        let plan = LogicalPlanBuilder::from(test_table_scan()?)
1668            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1669            .filter(col("test.b + test.a").gt(lit(10i64)))?
1670            .build()?;
1671        assert_optimized_plan_equal!(
1672            plan,
1673            @r"
1674        Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1675          TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1676        "
1677        )
1678    }
1679
1680    #[test]
1681    fn filter_keep_agg() -> Result<()> {
1682        let table_scan = test_table_scan()?;
1683        let plan = LogicalPlanBuilder::from(table_scan)
1684            .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1685            .filter(col("b").gt(lit(10i64)))?
1686            .build()?;
1687        // filter of aggregate is after aggregation since they are non-commutative
1688        assert_optimized_plan_equal!(
1689            plan,
1690            @r"
1691        Filter: b > Int64(10)
1692          Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1693            TableScan: test
1694        "
1695        )
1696    }
1697
1698    /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed
1699    #[test]
1700    fn filter_move_window() -> Result<()> {
1701        let table_scan = test_table_scan()?;
1702
1703        let window = Expr::from(WindowFunction::new(
1704            WindowFunctionDefinition::WindowUDF(
1705                datafusion_functions_window::rank::rank_udwf(),
1706            ),
1707            vec![],
1708        ))
1709        .partition_by(vec![col("a"), col("b")])
1710        .order_by(vec![col("c").sort(true, true)])
1711        .build()
1712        .unwrap();
1713
1714        let plan = LogicalPlanBuilder::from(table_scan)
1715            .window(vec![window])?
1716            .filter(col("b").gt(lit(10i64)))?
1717            .build()?;
1718
1719        assert_optimized_plan_equal!(
1720            plan,
1721            @r"
1722        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1723          TableScan: test, full_filters=[test.b > Int64(10)]
1724        "
1725        )
1726    }
1727
1728    /// verifies that filters with unusual identifier names are pushed down through window functions
1729    #[test]
1730    fn filter_window_special_identifier() -> Result<()> {
1731        let schema = Schema::new(vec![
1732            Field::new("$a", DataType::UInt32, false),
1733            Field::new("$b", DataType::UInt32, false),
1734            Field::new("$c", DataType::UInt32, false),
1735        ]);
1736        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1737
1738        let window = Expr::from(WindowFunction::new(
1739            WindowFunctionDefinition::WindowUDF(
1740                datafusion_functions_window::rank::rank_udwf(),
1741            ),
1742            vec![],
1743        ))
1744        .partition_by(vec![col("$a"), col("$b")])
1745        .order_by(vec![col("$c").sort(true, true)])
1746        .build()
1747        .unwrap();
1748
1749        let plan = LogicalPlanBuilder::from(table_scan)
1750            .window(vec![window])?
1751            .filter(col("$b").gt(lit(10i64)))?
1752            .build()?;
1753
1754        assert_optimized_plan_equal!(
1755            plan,
1756            @r"
1757        WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1758          TableScan: test, full_filters=[test.$b > Int64(10)]
1759        "
1760        )
1761    }
1762
1763    /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and
1764    /// 'b' are pushed
1765    #[test]
1766    fn filter_move_complex_window() -> Result<()> {
1767        let table_scan = test_table_scan()?;
1768
1769        let window = Expr::from(WindowFunction::new(
1770            WindowFunctionDefinition::WindowUDF(
1771                datafusion_functions_window::rank::rank_udwf(),
1772            ),
1773            vec![],
1774        ))
1775        .partition_by(vec![col("a"), col("b")])
1776        .order_by(vec![col("c").sort(true, true)])
1777        .build()
1778        .unwrap();
1779
1780        let plan = LogicalPlanBuilder::from(table_scan)
1781            .window(vec![window])?
1782            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1783            .build()?;
1784
1785        assert_optimized_plan_equal!(
1786            plan,
1787            @r"
1788        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1789          TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1790        "
1791        )
1792    }
1793
1794    /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed
1795    #[test]
1796    fn filter_move_partial_window() -> Result<()> {
1797        let table_scan = test_table_scan()?;
1798
1799        let window = Expr::from(WindowFunction::new(
1800            WindowFunctionDefinition::WindowUDF(
1801                datafusion_functions_window::rank::rank_udwf(),
1802            ),
1803            vec![],
1804        ))
1805        .partition_by(vec![col("a")])
1806        .order_by(vec![col("c").sort(true, true)])
1807        .build()
1808        .unwrap();
1809
1810        let plan = LogicalPlanBuilder::from(table_scan)
1811            .window(vec![window])?
1812            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1813            .build()?;
1814
1815        assert_optimized_plan_equal!(
1816            plan,
1817            @r"
1818        Filter: test.b = Int64(1)
1819          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1820            TableScan: test, full_filters=[test.a > Int64(10)]
1821        "
1822        )
1823    }
1824
1825    /// verifies that filters on partition expressions are not pushed, as the single expression
1826    /// column is not available to the user, unlike with aggregations
1827    #[test]
1828    fn filter_expression_keep_window() -> Result<()> {
1829        let table_scan = test_table_scan()?;
1830
1831        let window = Expr::from(WindowFunction::new(
1832            WindowFunctionDefinition::WindowUDF(
1833                datafusion_functions_window::rank::rank_udwf(),
1834            ),
1835            vec![],
1836        ))
1837        .partition_by(vec![add(col("a"), col("b"))]) // PARTITION BY a + b
1838        .order_by(vec![col("c").sort(true, true)])
1839        .build()
1840        .unwrap();
1841
1842        let plan = LogicalPlanBuilder::from(table_scan)
1843            .window(vec![window])?
1844            // unlike with aggregations, single partition column "test.a + test.b" is not available
1845            // to the plan, so we use multiple columns when filtering
1846            .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1847            .build()?;
1848
1849        assert_optimized_plan_equal!(
1850            plan,
1851            @r"
1852        Filter: test.a + test.b > Int64(10)
1853          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1854            TableScan: test
1855        "
1856        )
1857    }
1858
1859    /// verifies that filters are not pushed on order by columns (that are not used in partitioning)
1860    #[test]
1861    fn filter_order_keep_window() -> Result<()> {
1862        let table_scan = test_table_scan()?;
1863
1864        let window = Expr::from(WindowFunction::new(
1865            WindowFunctionDefinition::WindowUDF(
1866                datafusion_functions_window::rank::rank_udwf(),
1867            ),
1868            vec![],
1869        ))
1870        .partition_by(vec![col("a")])
1871        .order_by(vec![col("c").sort(true, true)])
1872        .build()
1873        .unwrap();
1874
1875        let plan = LogicalPlanBuilder::from(table_scan)
1876            .window(vec![window])?
1877            .filter(col("c").gt(lit(10i64)))?
1878            .build()?;
1879
1880        assert_optimized_plan_equal!(
1881            plan,
1882            @r"
1883        Filter: test.c > Int64(10)
1884          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1885            TableScan: test
1886        "
1887        )
1888    }
1889
1890    /// verifies that when we use multiple window functions with a common partition key, the filter
1891    /// on that key is pushed
1892    #[test]
1893    fn filter_multiple_windows_common_partitions() -> Result<()> {
1894        let table_scan = test_table_scan()?;
1895
1896        let window1 = Expr::from(WindowFunction::new(
1897            WindowFunctionDefinition::WindowUDF(
1898                datafusion_functions_window::rank::rank_udwf(),
1899            ),
1900            vec![],
1901        ))
1902        .partition_by(vec![col("a")])
1903        .order_by(vec![col("c").sort(true, true)])
1904        .build()
1905        .unwrap();
1906
1907        let window2 = Expr::from(WindowFunction::new(
1908            WindowFunctionDefinition::WindowUDF(
1909                datafusion_functions_window::rank::rank_udwf(),
1910            ),
1911            vec![],
1912        ))
1913        .partition_by(vec![col("b"), col("a")])
1914        .order_by(vec![col("c").sort(true, true)])
1915        .build()
1916        .unwrap();
1917
1918        let plan = LogicalPlanBuilder::from(table_scan)
1919            .window(vec![window1, window2])?
1920            .filter(col("a").gt(lit(10i64)))? // a appears in both window functions
1921            .build()?;
1922
1923        assert_optimized_plan_equal!(
1924            plan,
1925            @r"
1926        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1927          TableScan: test, full_filters=[test.a > Int64(10)]
1928        "
1929        )
1930    }
1931
1932    /// verifies that when we use multiple window functions with different partitions keys, the
1933    /// filter cannot be pushed
1934    #[test]
1935    fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1936        let table_scan = test_table_scan()?;
1937
1938        let window1 = Expr::from(WindowFunction::new(
1939            WindowFunctionDefinition::WindowUDF(
1940                datafusion_functions_window::rank::rank_udwf(),
1941            ),
1942            vec![],
1943        ))
1944        .partition_by(vec![col("a")])
1945        .order_by(vec![col("c").sort(true, true)])
1946        .build()
1947        .unwrap();
1948
1949        let window2 = Expr::from(WindowFunction::new(
1950            WindowFunctionDefinition::WindowUDF(
1951                datafusion_functions_window::rank::rank_udwf(),
1952            ),
1953            vec![],
1954        ))
1955        .partition_by(vec![col("b"), col("a")])
1956        .order_by(vec![col("c").sort(true, true)])
1957        .build()
1958        .unwrap();
1959
1960        let plan = LogicalPlanBuilder::from(table_scan)
1961            .window(vec![window1, window2])?
1962            .filter(col("b").gt(lit(10i64)))? // b only appears in one window function
1963            .build()?;
1964
1965        assert_optimized_plan_equal!(
1966            plan,
1967            @r"
1968        Filter: test.b > Int64(10)
1969          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1970            TableScan: test
1971        "
1972        )
1973    }
1974
1975    /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
1976    #[test]
1977    fn alias() -> Result<()> {
1978        let table_scan = test_table_scan()?;
1979        let plan = LogicalPlanBuilder::from(table_scan)
1980            .project(vec![col("a").alias("b"), col("c")])?
1981            .filter(col("b").eq(lit(1i64)))?
1982            .build()?;
1983        // filter is before projection
1984        assert_optimized_plan_equal!(
1985            plan,
1986            @r"
1987        Projection: test.a AS b, test.c
1988          TableScan: test, full_filters=[test.a = Int64(1)]
1989        "
1990        )
1991    }
1992
1993    fn add(left: Expr, right: Expr) -> Expr {
1994        Expr::BinaryExpr(BinaryExpr::new(
1995            Box::new(left),
1996            Operator::Plus,
1997            Box::new(right),
1998        ))
1999    }
2000
2001    fn multiply(left: Expr, right: Expr) -> Expr {
2002        Expr::BinaryExpr(BinaryExpr::new(
2003            Box::new(left),
2004            Operator::Multiply,
2005            Box::new(right),
2006        ))
2007    }
2008
2009    /// verifies that a filter is pushed to before a projection with a complex expression, the filter expression is correctly re-written
2010    #[test]
2011    fn complex_expression() -> Result<()> {
2012        let table_scan = test_table_scan()?;
2013        let plan = LogicalPlanBuilder::from(table_scan)
2014            .project(vec![
2015                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
2016                col("c"),
2017            ])?
2018            .filter(col("b").eq(lit(1i64)))?
2019            .build()?;
2020
2021        // not part of the test, just good to know:
2022        assert_snapshot!(plan,
2023        @r"
2024        Filter: b = Int64(1)
2025          Projection: test.a * Int32(2) + test.c AS b, test.c
2026            TableScan: test
2027        ",
2028        );
2029        // filter is before projection
2030        assert_optimized_plan_equal!(
2031            plan,
2032            @r"
2033        Projection: test.a * Int32(2) + test.c AS b, test.c
2034          TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
2035        "
2036        )
2037    }
2038
2039    /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written
2040    #[test]
2041    fn complex_plan() -> Result<()> {
2042        let table_scan = test_table_scan()?;
2043        let plan = LogicalPlanBuilder::from(table_scan)
2044            .project(vec![
2045                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
2046                col("c"),
2047            ])?
2048            // second projection where we rename columns, just to make it difficult
2049            .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
2050            .filter(col("a").eq(lit(1i64)))?
2051            .build()?;
2052
2053        // not part of the test, just good to know:
2054        assert_snapshot!(plan,
2055        @r"
2056        Filter: a = Int64(1)
2057          Projection: b * Int32(3) AS a, test.c
2058            Projection: test.a * Int32(2) + test.c AS b, test.c
2059              TableScan: test
2060        ",
2061        );
2062        // filter is before the projections
2063        assert_optimized_plan_equal!(
2064            plan,
2065            @r"
2066        Projection: b * Int32(3) AS a, test.c
2067          Projection: test.a * Int32(2) + test.c AS b, test.c
2068            TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
2069        "
2070        )
2071    }
2072
2073    #[derive(Debug, PartialEq, Eq, Hash)]
2074    struct NoopPlan {
2075        input: Vec<LogicalPlan>,
2076        schema: DFSchemaRef,
2077    }
2078
2079    // Manual implementation needed because of `schema` field. Comparison excludes this field.
2080    impl PartialOrd for NoopPlan {
2081        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2082            self.input
2083                .partial_cmp(&other.input)
2084                // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
2085                .filter(|cmp| *cmp != Ordering::Equal || self == other)
2086        }
2087    }
2088
2089    impl UserDefinedLogicalNodeCore for NoopPlan {
2090        fn name(&self) -> &str {
2091            "NoopPlan"
2092        }
2093
2094        fn inputs(&self) -> Vec<&LogicalPlan> {
2095            self.input.iter().collect()
2096        }
2097
2098        fn schema(&self) -> &DFSchemaRef {
2099            &self.schema
2100        }
2101
2102        fn expressions(&self) -> Vec<Expr> {
2103            self.input
2104                .iter()
2105                .flat_map(|child| child.expressions())
2106                .collect()
2107        }
2108
2109        fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
2110            HashSet::from_iter(vec!["c".to_string()])
2111        }
2112
2113        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
2114            write!(f, "NoopPlan")
2115        }
2116
2117        fn with_exprs_and_inputs(
2118            &self,
2119            _exprs: Vec<Expr>,
2120            inputs: Vec<LogicalPlan>,
2121        ) -> Result<Self> {
2122            Ok(Self {
2123                input: inputs,
2124                schema: Arc::clone(&self.schema),
2125            })
2126        }
2127
2128        fn supports_limit_pushdown(&self) -> bool {
2129            false // Disallow limit push-down by default
2130        }
2131    }
2132
2133    #[test]
2134    fn user_defined_plan() -> Result<()> {
2135        let table_scan = test_table_scan()?;
2136
2137        let custom_plan = LogicalPlan::Extension(Extension {
2138            node: Arc::new(NoopPlan {
2139                input: vec![table_scan.clone()],
2140                schema: Arc::clone(table_scan.schema()),
2141            }),
2142        });
2143        let plan = LogicalPlanBuilder::from(custom_plan)
2144            .filter(col("a").eq(lit(1i64)))?
2145            .build()?;
2146
2147        // Push filter below NoopPlan
2148        assert_optimized_plan_equal!(
2149            plan,
2150            @r"
2151        NoopPlan
2152          TableScan: test, full_filters=[test.a = Int64(1)]
2153        "
2154        )?;
2155
2156        let custom_plan = LogicalPlan::Extension(Extension {
2157            node: Arc::new(NoopPlan {
2158                input: vec![table_scan.clone()],
2159                schema: Arc::clone(table_scan.schema()),
2160            }),
2161        });
2162        let plan = LogicalPlanBuilder::from(custom_plan)
2163            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2164            .build()?;
2165
2166        // Push only predicate on `a` below NoopPlan
2167        assert_optimized_plan_equal!(
2168            plan,
2169            @r"
2170        Filter: test.c = Int64(2)
2171          NoopPlan
2172            TableScan: test, full_filters=[test.a = Int64(1)]
2173        "
2174        )?;
2175
2176        let custom_plan = LogicalPlan::Extension(Extension {
2177            node: Arc::new(NoopPlan {
2178                input: vec![table_scan.clone(), table_scan.clone()],
2179                schema: Arc::clone(table_scan.schema()),
2180            }),
2181        });
2182        let plan = LogicalPlanBuilder::from(custom_plan)
2183            .filter(col("a").eq(lit(1i64)))?
2184            .build()?;
2185
2186        // Push filter below NoopPlan for each child branch
2187        assert_optimized_plan_equal!(
2188            plan,
2189            @r"
2190        NoopPlan
2191          TableScan: test, full_filters=[test.a = Int64(1)]
2192          TableScan: test, full_filters=[test.a = Int64(1)]
2193        "
2194        )?;
2195
2196        let custom_plan = LogicalPlan::Extension(Extension {
2197            node: Arc::new(NoopPlan {
2198                input: vec![table_scan.clone(), table_scan.clone()],
2199                schema: Arc::clone(table_scan.schema()),
2200            }),
2201        });
2202        let plan = LogicalPlanBuilder::from(custom_plan)
2203            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2204            .build()?;
2205
2206        // Push only predicate on `a` below NoopPlan
2207        assert_optimized_plan_equal!(
2208            plan,
2209            @r"
2210        Filter: test.c = Int64(2)
2211          NoopPlan
2212            TableScan: test, full_filters=[test.a = Int64(1)]
2213            TableScan: test, full_filters=[test.a = Int64(1)]
2214        "
2215        )
2216    }
2217
2218    /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed
2219    /// and the other not.
2220    #[test]
2221    fn multi_filter() -> Result<()> {
2222        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2223        let table_scan = test_table_scan()?;
2224        let plan = LogicalPlanBuilder::from(table_scan)
2225            .project(vec![col("a").alias("b"), col("c")])?
2226            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2227            .filter(col("b").gt(lit(10i64)))?
2228            .filter(col("sum(test.c)").gt(lit(10i64)))?
2229            .build()?;
2230
2231        // not part of the test, just good to know:
2232        assert_snapshot!(plan,
2233        @r"
2234        Filter: sum(test.c) > Int64(10)
2235          Filter: b > Int64(10)
2236            Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2237              Projection: test.a AS b, test.c
2238                TableScan: test
2239        ",
2240        );
2241        // filter is before the projections
2242        assert_optimized_plan_equal!(
2243            plan,
2244            @r"
2245        Filter: sum(test.c) > Int64(10)
2246          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2247            Projection: test.a AS b, test.c
2248              TableScan: test, full_filters=[test.a > Int64(10)]
2249        "
2250        )
2251    }
2252
2253    /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed
2254    /// and the other not.
2255    #[test]
2256    fn split_filter() -> Result<()> {
2257        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2258        let table_scan = test_table_scan()?;
2259        let plan = LogicalPlanBuilder::from(table_scan)
2260            .project(vec![col("a").alias("b"), col("c")])?
2261            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2262            .filter(and(
2263                col("sum(test.c)").gt(lit(10i64)),
2264                and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2265            ))?
2266            .build()?;
2267
2268        // not part of the test, just good to know:
2269        assert_snapshot!(plan,
2270        @r"
2271        Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2272          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2273            Projection: test.a AS b, test.c
2274              TableScan: test
2275        ",
2276        );
2277        // filter is before the projections
2278        assert_optimized_plan_equal!(
2279            plan,
2280            @r"
2281        Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2282          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2283            Projection: test.a AS b, test.c
2284              TableScan: test, full_filters=[test.a > Int64(10)]
2285        "
2286        )
2287    }
2288
2289    /// verifies that when two limits are in place, we jump neither
2290    #[test]
2291    fn double_limit() -> Result<()> {
2292        let table_scan = test_table_scan()?;
2293        let plan = LogicalPlanBuilder::from(table_scan)
2294            .project(vec![col("a"), col("b")])?
2295            .limit(0, Some(20))?
2296            .limit(0, Some(10))?
2297            .project(vec![col("a"), col("b")])?
2298            .filter(col("a").eq(lit(1i64)))?
2299            .build()?;
2300        // filter does not just any of the limits
2301        assert_optimized_plan_equal!(
2302            plan,
2303            @r"
2304        Projection: test.a, test.b
2305          Filter: test.a = Int64(1)
2306            Limit: skip=0, fetch=10
2307              Limit: skip=0, fetch=20
2308                Projection: test.a, test.b
2309                  TableScan: test
2310        "
2311        )
2312    }
2313
2314    #[test]
2315    fn union_all() -> Result<()> {
2316        let table_scan = test_table_scan()?;
2317        let table_scan2 = test_table_scan_with_name("test2")?;
2318        let plan = LogicalPlanBuilder::from(table_scan)
2319            .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2320            .filter(col("a").eq(lit(1i64)))?
2321            .build()?;
2322        // filter appears below Union
2323        assert_optimized_plan_equal!(
2324            plan,
2325            @r"
2326        Union
2327          TableScan: test, full_filters=[test.a = Int64(1)]
2328          TableScan: test2, full_filters=[test2.a = Int64(1)]
2329        "
2330        )
2331    }
2332
2333    #[test]
2334    fn union_all_on_projection() -> Result<()> {
2335        let table_scan = test_table_scan()?;
2336        let table = LogicalPlanBuilder::from(table_scan)
2337            .project(vec![col("a").alias("b")])?
2338            .alias("test2")?;
2339
2340        let plan = table
2341            .clone()
2342            .union(table.build()?)?
2343            .filter(col("b").eq(lit(1i64)))?
2344            .build()?;
2345
2346        // filter appears below Union
2347        assert_optimized_plan_equal!(
2348            plan,
2349            @r"
2350        Union
2351          SubqueryAlias: test2
2352            Projection: test.a AS b
2353              TableScan: test, full_filters=[test.a = Int64(1)]
2354          SubqueryAlias: test2
2355            Projection: test.a AS b
2356              TableScan: test, full_filters=[test.a = Int64(1)]
2357        "
2358        )
2359    }
2360
2361    #[test]
2362    fn test_union_different_schema() -> Result<()> {
2363        let left = LogicalPlanBuilder::from(test_table_scan()?)
2364            .project(vec![col("a"), col("b"), col("c")])?
2365            .build()?;
2366
2367        let schema = Schema::new(vec![
2368            Field::new("d", DataType::UInt32, false),
2369            Field::new("e", DataType::UInt32, false),
2370            Field::new("f", DataType::UInt32, false),
2371        ]);
2372        let right = table_scan(Some("test1"), &schema, None)?
2373            .project(vec![col("d"), col("e"), col("f")])?
2374            .build()?;
2375        let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2376        let plan = LogicalPlanBuilder::from(left)
2377            .cross_join(right)?
2378            .project(vec![col("test.a"), col("test1.d")])?
2379            .filter(filter)?
2380            .build()?;
2381
2382        assert_optimized_plan_equal!(
2383            plan,
2384            @r"
2385        Projection: test.a, test1.d
2386          Cross Join:
2387            Projection: test.a, test.b, test.c
2388              TableScan: test, full_filters=[test.a = Int32(1)]
2389            Projection: test1.d, test1.e, test1.f
2390              TableScan: test1, full_filters=[test1.d > Int32(2)]
2391        "
2392        )
2393    }
2394
2395    #[test]
2396    fn test_project_same_name_different_qualifier() -> Result<()> {
2397        let table_scan = test_table_scan()?;
2398        let left = LogicalPlanBuilder::from(table_scan)
2399            .project(vec![col("a"), col("b"), col("c")])?
2400            .build()?;
2401        let right_table_scan = test_table_scan_with_name("test1")?;
2402        let right = LogicalPlanBuilder::from(right_table_scan)
2403            .project(vec![col("a"), col("b"), col("c")])?
2404            .build()?;
2405        let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2406        let plan = LogicalPlanBuilder::from(left)
2407            .cross_join(right)?
2408            .project(vec![col("test.a"), col("test1.a")])?
2409            .filter(filter)?
2410            .build()?;
2411
2412        assert_optimized_plan_equal!(
2413            plan,
2414            @r"
2415        Projection: test.a, test1.a
2416          Cross Join:
2417            Projection: test.a, test.b, test.c
2418              TableScan: test, full_filters=[test.a = Int32(1)]
2419            Projection: test1.a, test1.b, test1.c
2420              TableScan: test1, full_filters=[test1.a > Int32(2)]
2421        "
2422        )
2423    }
2424
2425    /// verifies that filters with the same columns are correctly placed
2426    #[test]
2427    fn filter_2_breaks_limits() -> Result<()> {
2428        let table_scan = test_table_scan()?;
2429        let plan = LogicalPlanBuilder::from(table_scan)
2430            .project(vec![col("a")])?
2431            .filter(col("a").lt_eq(lit(1i64)))?
2432            .limit(0, Some(1))?
2433            .project(vec![col("a")])?
2434            .filter(col("a").gt_eq(lit(1i64)))?
2435            .build()?;
2436        // Should be able to move both filters below the projections
2437
2438        // not part of the test
2439        assert_snapshot!(plan,
2440        @r"
2441        Filter: test.a >= Int64(1)
2442          Projection: test.a
2443            Limit: skip=0, fetch=1
2444              Filter: test.a <= Int64(1)
2445                Projection: test.a
2446                  TableScan: test
2447        ",
2448        );
2449        assert_optimized_plan_equal!(
2450            plan,
2451            @r"
2452        Projection: test.a
2453          Filter: test.a >= Int64(1)
2454            Limit: skip=0, fetch=1
2455              Projection: test.a
2456                TableScan: test, full_filters=[test.a <= Int64(1)]
2457        "
2458        )
2459    }
2460
2461    /// verifies that filters to be placed on the same depth are ANDed
2462    #[test]
2463    fn two_filters_on_same_depth() -> Result<()> {
2464        let table_scan = test_table_scan()?;
2465        let plan = LogicalPlanBuilder::from(table_scan)
2466            .limit(0, Some(1))?
2467            .filter(col("a").lt_eq(lit(1i64)))?
2468            .filter(col("a").gt_eq(lit(1i64)))?
2469            .project(vec![col("a")])?
2470            .build()?;
2471
2472        // not part of the test
2473        assert_snapshot!(plan,
2474        @r"
2475        Projection: test.a
2476          Filter: test.a >= Int64(1)
2477            Filter: test.a <= Int64(1)
2478              Limit: skip=0, fetch=1
2479                TableScan: test
2480        ",
2481        );
2482        assert_optimized_plan_equal!(
2483            plan,
2484            @r"
2485        Projection: test.a
2486          Filter: test.a <= Int64(1) AND test.a >= Int64(1)
2487            Limit: skip=0, fetch=1
2488              TableScan: test
2489        "
2490        )
2491    }
2492
2493    /// verifies that filters on a plan with user nodes are not lost
2494    /// (ARROW-10547)
2495    #[test]
2496    fn filters_user_defined_node() -> Result<()> {
2497        let table_scan = test_table_scan()?;
2498        let plan = LogicalPlanBuilder::from(table_scan)
2499            .filter(col("a").lt_eq(lit(1i64)))?
2500            .build()?;
2501
2502        let plan = user_defined::new(plan);
2503
2504        // not part of the test
2505        assert_snapshot!(plan,
2506        @r"
2507        TestUserDefined
2508          Filter: test.a <= Int64(1)
2509            TableScan: test
2510        ",
2511        );
2512        assert_optimized_plan_equal!(
2513            plan,
2514            @r"
2515        TestUserDefined
2516          TableScan: test, full_filters=[test.a <= Int64(1)]
2517        "
2518        )
2519    }
2520
2521    /// post-on-join predicates on a column common to both sides is pushed to both sides
2522    #[test]
2523    fn filter_on_join_on_common_independent() -> Result<()> {
2524        let table_scan = test_table_scan()?;
2525        let left = LogicalPlanBuilder::from(table_scan).build()?;
2526        let right_table_scan = test_table_scan_with_name("test2")?;
2527        let right = LogicalPlanBuilder::from(right_table_scan)
2528            .project(vec![col("a")])?
2529            .build()?;
2530        let plan = LogicalPlanBuilder::from(left)
2531            .join(
2532                right,
2533                JoinType::Inner,
2534                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2535                None,
2536            )?
2537            .filter(col("test.a").lt_eq(lit(1i64)))?
2538            .build()?;
2539
2540        // not part of the test, just good to know:
2541        assert_snapshot!(plan,
2542        @r"
2543        Filter: test.a <= Int64(1)
2544          Inner Join: test.a = test2.a
2545            TableScan: test
2546            Projection: test2.a
2547              TableScan: test2
2548        ",
2549        );
2550        // filter sent to side before the join
2551        assert_optimized_plan_equal!(
2552            plan,
2553            @r"
2554        Inner Join: test.a = test2.a
2555          TableScan: test, full_filters=[test.a <= Int64(1)]
2556          Projection: test2.a
2557            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2558        "
2559        )
2560    }
2561
2562    /// post-using-join predicates on a column common to both sides is pushed to both sides
2563    #[test]
2564    fn filter_using_join_on_common_independent() -> Result<()> {
2565        let table_scan = test_table_scan()?;
2566        let left = LogicalPlanBuilder::from(table_scan).build()?;
2567        let right_table_scan = test_table_scan_with_name("test2")?;
2568        let right = LogicalPlanBuilder::from(right_table_scan)
2569            .project(vec![col("a")])?
2570            .build()?;
2571        let plan = LogicalPlanBuilder::from(left)
2572            .join_using(
2573                right,
2574                JoinType::Inner,
2575                vec![Column::from_name("a".to_string())],
2576            )?
2577            .filter(col("a").lt_eq(lit(1i64)))?
2578            .build()?;
2579
2580        // not part of the test, just good to know:
2581        assert_snapshot!(plan,
2582        @r"
2583        Filter: test.a <= Int64(1)
2584          Inner Join: Using test.a = test2.a
2585            TableScan: test
2586            Projection: test2.a
2587              TableScan: test2
2588        ",
2589        );
2590        // filter sent to side before the join
2591        assert_optimized_plan_equal!(
2592            plan,
2593            @r"
2594        Inner Join: Using test.a = test2.a
2595          TableScan: test, full_filters=[test.a <= Int64(1)]
2596          Projection: test2.a
2597            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2598        "
2599        )
2600    }
2601
2602    /// post-join predicates with columns from both sides are converted to join filters
2603    #[test]
2604    fn filter_join_on_common_dependent() -> Result<()> {
2605        let table_scan = test_table_scan()?;
2606        let left = LogicalPlanBuilder::from(table_scan)
2607            .project(vec![col("a"), col("c")])?
2608            .build()?;
2609        let right_table_scan = test_table_scan_with_name("test2")?;
2610        let right = LogicalPlanBuilder::from(right_table_scan)
2611            .project(vec![col("a"), col("b")])?
2612            .build()?;
2613        let plan = LogicalPlanBuilder::from(left)
2614            .join(
2615                right,
2616                JoinType::Inner,
2617                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2618                None,
2619            )?
2620            .filter(col("c").lt_eq(col("b")))?
2621            .build()?;
2622
2623        // not part of the test, just good to know:
2624        assert_snapshot!(plan,
2625        @r"
2626        Filter: test.c <= test2.b
2627          Inner Join: test.a = test2.a
2628            Projection: test.a, test.c
2629              TableScan: test
2630            Projection: test2.a, test2.b
2631              TableScan: test2
2632        ",
2633        );
2634        // Filter is converted to Join Filter
2635        assert_optimized_plan_equal!(
2636            plan,
2637            @r"
2638        Inner Join: test.a = test2.a Filter: test.c <= test2.b
2639          Projection: test.a, test.c
2640            TableScan: test
2641          Projection: test2.a, test2.b
2642            TableScan: test2
2643        "
2644        )
2645    }
2646
2647    /// post-join predicates with columns from one side of a join are pushed only to that side
2648    #[test]
2649    fn filter_join_on_one_side() -> Result<()> {
2650        let table_scan = test_table_scan()?;
2651        let left = LogicalPlanBuilder::from(table_scan)
2652            .project(vec![col("a"), col("b")])?
2653            .build()?;
2654        let table_scan_right = test_table_scan_with_name("test2")?;
2655        let right = LogicalPlanBuilder::from(table_scan_right)
2656            .project(vec![col("a"), col("c")])?
2657            .build()?;
2658
2659        let plan = LogicalPlanBuilder::from(left)
2660            .join(
2661                right,
2662                JoinType::Inner,
2663                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2664                None,
2665            )?
2666            .filter(col("b").lt_eq(lit(1i64)))?
2667            .build()?;
2668
2669        // not part of the test, just good to know:
2670        assert_snapshot!(plan,
2671        @r"
2672        Filter: test.b <= Int64(1)
2673          Inner Join: test.a = test2.a
2674            Projection: test.a, test.b
2675              TableScan: test
2676            Projection: test2.a, test2.c
2677              TableScan: test2
2678        ",
2679        );
2680        assert_optimized_plan_equal!(
2681            plan,
2682            @r"
2683        Inner Join: test.a = test2.a
2684          Projection: test.a, test.b
2685            TableScan: test, full_filters=[test.b <= Int64(1)]
2686          Projection: test2.a, test2.c
2687            TableScan: test2
2688        "
2689        )
2690    }
2691
2692    /// post-join predicates on the right side of a left join are not duplicated
2693    /// TODO: In this case we can sometimes convert the join to an INNER join
2694    #[test]
2695    fn filter_using_left_join() -> Result<()> {
2696        let table_scan = test_table_scan()?;
2697        let left = LogicalPlanBuilder::from(table_scan).build()?;
2698        let right_table_scan = test_table_scan_with_name("test2")?;
2699        let right = LogicalPlanBuilder::from(right_table_scan)
2700            .project(vec![col("a")])?
2701            .build()?;
2702        let plan = LogicalPlanBuilder::from(left)
2703            .join_using(
2704                right,
2705                JoinType::Left,
2706                vec![Column::from_name("a".to_string())],
2707            )?
2708            .filter(col("test2.a").lt_eq(lit(1i64)))?
2709            .build()?;
2710
2711        // not part of the test, just good to know:
2712        assert_snapshot!(plan,
2713        @r"
2714        Filter: test2.a <= Int64(1)
2715          Left Join: Using test.a = test2.a
2716            TableScan: test
2717            Projection: test2.a
2718              TableScan: test2
2719        ",
2720        );
2721        // filter not duplicated nor pushed down - i.e. noop
2722        assert_optimized_plan_equal!(
2723            plan,
2724            @r"
2725        Filter: test2.a <= Int64(1)
2726          Left Join: Using test.a = test2.a
2727            TableScan: test, full_filters=[test.a <= Int64(1)]
2728            Projection: test2.a
2729              TableScan: test2
2730        "
2731        )
2732    }
2733
2734    /// post-join predicates on the left side of a right join are not duplicated
2735    #[test]
2736    fn filter_using_right_join() -> Result<()> {
2737        let table_scan = test_table_scan()?;
2738        let left = LogicalPlanBuilder::from(table_scan).build()?;
2739        let right_table_scan = test_table_scan_with_name("test2")?;
2740        let right = LogicalPlanBuilder::from(right_table_scan)
2741            .project(vec![col("a")])?
2742            .build()?;
2743        let plan = LogicalPlanBuilder::from(left)
2744            .join_using(
2745                right,
2746                JoinType::Right,
2747                vec![Column::from_name("a".to_string())],
2748            )?
2749            .filter(col("test.a").lt_eq(lit(1i64)))?
2750            .build()?;
2751
2752        // not part of the test, just good to know:
2753        assert_snapshot!(plan,
2754        @r"
2755        Filter: test.a <= Int64(1)
2756          Right Join: Using test.a = test2.a
2757            TableScan: test
2758            Projection: test2.a
2759              TableScan: test2
2760        ",
2761        );
2762        // filter not duplicated nor pushed down - i.e. noop
2763        assert_optimized_plan_equal!(
2764            plan,
2765            @r"
2766        Filter: test.a <= Int64(1)
2767          Right Join: Using test.a = test2.a
2768            TableScan: test
2769            Projection: test2.a
2770              TableScan: test2, full_filters=[test2.a <= Int64(1)]
2771        "
2772        )
2773    }
2774
2775    /// post-left-join predicate on a column common to both sides is pushed to both sides
2776    #[test]
2777    fn filter_using_left_join_on_common() -> Result<()> {
2778        let table_scan = test_table_scan()?;
2779        let left = LogicalPlanBuilder::from(table_scan).build()?;
2780        let right_table_scan = test_table_scan_with_name("test2")?;
2781        let right = LogicalPlanBuilder::from(right_table_scan)
2782            .project(vec![col("a")])?
2783            .build()?;
2784        let plan = LogicalPlanBuilder::from(left)
2785            .join_using(
2786                right,
2787                JoinType::Left,
2788                vec![Column::from_name("a".to_string())],
2789            )?
2790            .filter(col("a").lt_eq(lit(1i64)))?
2791            .build()?;
2792
2793        // not part of the test, just good to know:
2794        assert_snapshot!(plan,
2795        @r"
2796        Filter: test.a <= Int64(1)
2797          Left Join: Using test.a = test2.a
2798            TableScan: test
2799            Projection: test2.a
2800              TableScan: test2
2801        ",
2802        );
2803        // filter sent to left side of the join and to the right
2804        assert_optimized_plan_equal!(
2805            plan,
2806            @r"
2807        Left Join: Using test.a = test2.a
2808          TableScan: test, full_filters=[test.a <= Int64(1)]
2809          Projection: test2.a
2810            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2811        "
2812        )
2813    }
2814
2815    /// post-right-join predicate on a column common to both sides is pushed to both sides
2816    #[test]
2817    fn filter_using_right_join_on_common() -> Result<()> {
2818        let table_scan = test_table_scan()?;
2819        let left = LogicalPlanBuilder::from(table_scan).build()?;
2820        let right_table_scan = test_table_scan_with_name("test2")?;
2821        let right = LogicalPlanBuilder::from(right_table_scan)
2822            .project(vec![col("a")])?
2823            .build()?;
2824        let plan = LogicalPlanBuilder::from(left)
2825            .join_using(
2826                right,
2827                JoinType::Right,
2828                vec![Column::from_name("a".to_string())],
2829            )?
2830            .filter(col("test2.a").lt_eq(lit(1i64)))?
2831            .build()?;
2832
2833        // not part of the test, just good to know:
2834        assert_snapshot!(plan,
2835        @r"
2836        Filter: test2.a <= Int64(1)
2837          Right Join: Using test.a = test2.a
2838            TableScan: test
2839            Projection: test2.a
2840              TableScan: test2
2841        ",
2842        );
2843        // filter sent to right side of join, sent to the left as well
2844        assert_optimized_plan_equal!(
2845            plan,
2846            @r"
2847        Right Join: Using test.a = test2.a
2848          TableScan: test, full_filters=[test.a <= Int64(1)]
2849          Projection: test2.a
2850            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2851        "
2852        )
2853    }
2854
2855    /// single table predicate parts of ON condition should be pushed to both inputs
2856    #[test]
2857    fn join_on_with_filter() -> Result<()> {
2858        let table_scan = test_table_scan()?;
2859        let left = LogicalPlanBuilder::from(table_scan)
2860            .project(vec![col("a"), col("b"), col("c")])?
2861            .build()?;
2862        let right_table_scan = test_table_scan_with_name("test2")?;
2863        let right = LogicalPlanBuilder::from(right_table_scan)
2864            .project(vec![col("a"), col("b"), col("c")])?
2865            .build()?;
2866        let filter = col("test.c")
2867            .gt(lit(1u32))
2868            .and(col("test.b").lt(col("test2.b")))
2869            .and(col("test2.c").gt(lit(4u32)));
2870        let plan = LogicalPlanBuilder::from(left)
2871            .join(
2872                right,
2873                JoinType::Inner,
2874                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2875                Some(filter),
2876            )?
2877            .build()?;
2878
2879        // not part of the test, just good to know:
2880        assert_snapshot!(plan,
2881        @r"
2882        Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2883          Projection: test.a, test.b, test.c
2884            TableScan: test
2885          Projection: test2.a, test2.b, test2.c
2886            TableScan: test2
2887        ",
2888        );
2889        assert_optimized_plan_equal!(
2890            plan,
2891            @r"
2892        Inner Join: test.a = test2.a Filter: test.b < test2.b
2893          Projection: test.a, test.b, test.c
2894            TableScan: test, full_filters=[test.c > UInt32(1)]
2895          Projection: test2.a, test2.b, test2.c
2896            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2897        "
2898        )
2899    }
2900
2901    /// join filter should be completely removed after pushdown
2902    #[test]
2903    fn join_filter_removed() -> Result<()> {
2904        let table_scan = test_table_scan()?;
2905        let left = LogicalPlanBuilder::from(table_scan)
2906            .project(vec![col("a"), col("b"), col("c")])?
2907            .build()?;
2908        let right_table_scan = test_table_scan_with_name("test2")?;
2909        let right = LogicalPlanBuilder::from(right_table_scan)
2910            .project(vec![col("a"), col("b"), col("c")])?
2911            .build()?;
2912        let filter = col("test.b")
2913            .gt(lit(1u32))
2914            .and(col("test2.c").gt(lit(4u32)));
2915        let plan = LogicalPlanBuilder::from(left)
2916            .join(
2917                right,
2918                JoinType::Inner,
2919                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2920                Some(filter),
2921            )?
2922            .build()?;
2923
2924        // not part of the test, just good to know:
2925        assert_snapshot!(plan,
2926        @r"
2927        Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2928          Projection: test.a, test.b, test.c
2929            TableScan: test
2930          Projection: test2.a, test2.b, test2.c
2931            TableScan: test2
2932        ",
2933        );
2934        assert_optimized_plan_equal!(
2935            plan,
2936            @r"
2937        Inner Join: test.a = test2.a
2938          Projection: test.a, test.b, test.c
2939            TableScan: test, full_filters=[test.b > UInt32(1)]
2940          Projection: test2.a, test2.b, test2.c
2941            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2942        "
2943        )
2944    }
2945
2946    /// predicate on join key in filter expression should be pushed down to both inputs
2947    #[test]
2948    fn join_filter_on_common() -> Result<()> {
2949        let table_scan = test_table_scan()?;
2950        let left = LogicalPlanBuilder::from(table_scan)
2951            .project(vec![col("a")])?
2952            .build()?;
2953        let right_table_scan = test_table_scan_with_name("test2")?;
2954        let right = LogicalPlanBuilder::from(right_table_scan)
2955            .project(vec![col("b")])?
2956            .build()?;
2957        let filter = col("test.a").gt(lit(1u32));
2958        let plan = LogicalPlanBuilder::from(left)
2959            .join(
2960                right,
2961                JoinType::Inner,
2962                (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2963                Some(filter),
2964            )?
2965            .build()?;
2966
2967        // not part of the test, just good to know:
2968        assert_snapshot!(plan,
2969        @r"
2970        Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2971          Projection: test.a
2972            TableScan: test
2973          Projection: test2.b
2974            TableScan: test2
2975        ",
2976        );
2977        assert_optimized_plan_equal!(
2978            plan,
2979            @r"
2980        Inner Join: test.a = test2.b
2981          Projection: test.a
2982            TableScan: test, full_filters=[test.a > UInt32(1)]
2983          Projection: test2.b
2984            TableScan: test2, full_filters=[test2.b > UInt32(1)]
2985        "
2986        )
2987    }
2988
2989    /// single table predicate parts of ON condition should be pushed to right input
2990    #[test]
2991    fn left_join_on_with_filter() -> Result<()> {
2992        let table_scan = test_table_scan()?;
2993        let left = LogicalPlanBuilder::from(table_scan)
2994            .project(vec![col("a"), col("b"), col("c")])?
2995            .build()?;
2996        let right_table_scan = test_table_scan_with_name("test2")?;
2997        let right = LogicalPlanBuilder::from(right_table_scan)
2998            .project(vec![col("a"), col("b"), col("c")])?
2999            .build()?;
3000        let filter = col("test.a")
3001            .gt(lit(1u32))
3002            .and(col("test.b").lt(col("test2.b")))
3003            .and(col("test2.c").gt(lit(4u32)));
3004        let plan = LogicalPlanBuilder::from(left)
3005            .join(
3006                right,
3007                JoinType::Left,
3008                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3009                Some(filter),
3010            )?
3011            .build()?;
3012
3013        // not part of the test, just good to know:
3014        assert_snapshot!(plan,
3015        @r"
3016        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3017          Projection: test.a, test.b, test.c
3018            TableScan: test
3019          Projection: test2.a, test2.b, test2.c
3020            TableScan: test2
3021        ",
3022        );
3023        assert_optimized_plan_equal!(
3024            plan,
3025            @r"
3026        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
3027          Projection: test.a, test.b, test.c
3028            TableScan: test
3029          Projection: test2.a, test2.b, test2.c
3030            TableScan: test2, full_filters=[test2.a > UInt32(1), test2.c > UInt32(4)]
3031        "
3032        )
3033    }
3034
3035    /// single table predicate parts of ON condition should be pushed to left input
3036    #[test]
3037    fn right_join_on_with_filter() -> Result<()> {
3038        let table_scan = test_table_scan()?;
3039        let left = LogicalPlanBuilder::from(table_scan)
3040            .project(vec![col("a"), col("b"), col("c")])?
3041            .build()?;
3042        let right_table_scan = test_table_scan_with_name("test2")?;
3043        let right = LogicalPlanBuilder::from(right_table_scan)
3044            .project(vec![col("a"), col("b"), col("c")])?
3045            .build()?;
3046        let filter = col("test.a")
3047            .gt(lit(1u32))
3048            .and(col("test.b").lt(col("test2.b")))
3049            .and(col("test2.c").gt(lit(4u32)));
3050        let plan = LogicalPlanBuilder::from(left)
3051            .join(
3052                right,
3053                JoinType::Right,
3054                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3055                Some(filter),
3056            )?
3057            .build()?;
3058
3059        // not part of the test, just good to know:
3060        assert_snapshot!(plan,
3061        @r"
3062        Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3063          Projection: test.a, test.b, test.c
3064            TableScan: test
3065          Projection: test2.a, test2.b, test2.c
3066            TableScan: test2
3067        ",
3068        );
3069        assert_optimized_plan_equal!(
3070            plan,
3071            @r"
3072        Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
3073          Projection: test.a, test.b, test.c
3074            TableScan: test, full_filters=[test.a > UInt32(1)]
3075          Projection: test2.a, test2.b, test2.c
3076            TableScan: test2
3077        "
3078        )
3079    }
3080
3081    /// single table predicate parts of ON condition should not be pushed
3082    #[test]
3083    fn full_join_on_with_filter() -> Result<()> {
3084        let table_scan = test_table_scan()?;
3085        let left = LogicalPlanBuilder::from(table_scan)
3086            .project(vec![col("a"), col("b"), col("c")])?
3087            .build()?;
3088        let right_table_scan = test_table_scan_with_name("test2")?;
3089        let right = LogicalPlanBuilder::from(right_table_scan)
3090            .project(vec![col("a"), col("b"), col("c")])?
3091            .build()?;
3092        let filter = col("test.a")
3093            .gt(lit(1u32))
3094            .and(col("test.b").lt(col("test2.b")))
3095            .and(col("test2.c").gt(lit(4u32)));
3096        let plan = LogicalPlanBuilder::from(left)
3097            .join(
3098                right,
3099                JoinType::Full,
3100                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3101                Some(filter),
3102            )?
3103            .build()?;
3104
3105        // not part of the test, just good to know:
3106        assert_snapshot!(plan,
3107        @r"
3108        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3109          Projection: test.a, test.b, test.c
3110            TableScan: test
3111          Projection: test2.a, test2.b, test2.c
3112            TableScan: test2
3113        ",
3114        );
3115        assert_optimized_plan_equal!(
3116            plan,
3117            @r"
3118        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3119          Projection: test.a, test.b, test.c
3120            TableScan: test
3121          Projection: test2.a, test2.b, test2.c
3122            TableScan: test2
3123        "
3124        )
3125    }
3126
3127    struct PushDownProvider {
3128        pub filter_support: TableProviderFilterPushDown,
3129    }
3130
3131    #[async_trait]
3132    impl TableSource for PushDownProvider {
3133        fn schema(&self) -> SchemaRef {
3134            Arc::new(Schema::new(vec![
3135                Field::new("a", DataType::Int32, true),
3136                Field::new("b", DataType::Int32, true),
3137            ]))
3138        }
3139
3140        fn table_type(&self) -> TableType {
3141            TableType::Base
3142        }
3143
3144        fn supports_filters_pushdown(
3145            &self,
3146            filters: &[&Expr],
3147        ) -> Result<Vec<TableProviderFilterPushDown>> {
3148            Ok((0..filters.len())
3149                .map(|_| self.filter_support.clone())
3150                .collect())
3151        }
3152    }
3153
3154    fn table_scan_with_pushdown_provider_builder(
3155        filter_support: TableProviderFilterPushDown,
3156        filters: Vec<Expr>,
3157        projection: Option<Vec<usize>>,
3158    ) -> Result<LogicalPlanBuilder> {
3159        let test_provider = PushDownProvider { filter_support };
3160
3161        let table_scan = LogicalPlan::TableScan(TableScan {
3162            table_name: "test".into(),
3163            filters,
3164            projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3165            projection,
3166            source: Arc::new(test_provider),
3167            fetch: None,
3168        });
3169
3170        Ok(LogicalPlanBuilder::from(table_scan))
3171    }
3172
3173    fn table_scan_with_pushdown_provider(
3174        filter_support: TableProviderFilterPushDown,
3175    ) -> Result<LogicalPlan> {
3176        table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3177            .filter(col("a").eq(lit(1i64)))?
3178            .build()
3179    }
3180
3181    #[test]
3182    fn filter_with_table_provider_exact() -> Result<()> {
3183        let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3184
3185        assert_optimized_plan_equal!(
3186            plan,
3187            @"TableScan: test, full_filters=[a = Int64(1)]"
3188        )
3189    }
3190
3191    #[test]
3192    fn filter_with_table_provider_inexact() -> Result<()> {
3193        let plan =
3194            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3195
3196        assert_optimized_plan_equal!(
3197            plan,
3198            @r"
3199        Filter: a = Int64(1)
3200          TableScan: test, partial_filters=[a = Int64(1)]
3201        "
3202        )
3203    }
3204
3205    #[test]
3206    fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3207        let plan =
3208            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3209
3210        let optimized_plan = PushDownFilter::new()
3211            .rewrite(plan, &OptimizerContext::new())
3212            .expect("failed to optimize plan")
3213            .data;
3214
3215        // Optimizing the same plan multiple times should produce the same plan
3216        // each time.
3217        assert_optimized_plan_equal!(
3218            optimized_plan,
3219            @r"
3220        Filter: a = Int64(1)
3221          TableScan: test, partial_filters=[a = Int64(1)]
3222        "
3223        )
3224    }
3225
3226    #[test]
3227    fn filter_with_table_provider_unsupported() -> Result<()> {
3228        let plan =
3229            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3230
3231        assert_optimized_plan_equal!(
3232            plan,
3233            @r"
3234        Filter: a = Int64(1)
3235          TableScan: test
3236        "
3237        )
3238    }
3239
3240    #[test]
3241    fn multi_combined_filter() -> Result<()> {
3242        let plan = table_scan_with_pushdown_provider_builder(
3243            TableProviderFilterPushDown::Inexact,
3244            vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3245            Some(vec![0]),
3246        )?
3247        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3248        .project(vec![col("a"), col("b")])?
3249        .build()?;
3250
3251        assert_optimized_plan_equal!(
3252            plan,
3253            @r"
3254        Projection: a, b
3255          Filter: a = Int64(10) AND b > Int64(11)
3256            TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3257        "
3258        )
3259    }
3260
3261    #[test]
3262    fn multi_combined_two_filters() -> Result<()> {
3263        let plan = table_scan_with_pushdown_provider_builder(
3264            TableProviderFilterPushDown::Inexact,
3265            vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3266            Some(vec![0]),
3267        )?
3268        .filter(col("a").eq(lit(10i64)))?
3269        .filter(col("b").gt(lit(11i64)))?
3270        .project(vec![col("a"), col("b")])?
3271        .build()?;
3272
3273        assert_optimized_plan_equal!(
3274            plan,
3275            @r"
3276        Projection: a, b
3277          Filter: a = Int64(10) AND b > Int64(11)
3278            TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3279        "
3280        )
3281    }
3282
3283    #[test]
3284    fn multi_combined_filter_exact() -> Result<()> {
3285        let plan = table_scan_with_pushdown_provider_builder(
3286            TableProviderFilterPushDown::Exact,
3287            vec![],
3288            Some(vec![0]),
3289        )?
3290        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3291        .project(vec![col("a"), col("b")])?
3292        .build()?;
3293
3294        assert_optimized_plan_equal!(
3295            plan,
3296            @r"
3297        Projection: a, b
3298          TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3299        "
3300        )
3301    }
3302
3303    #[test]
3304    fn multi_combined_two_filters_exact() -> Result<()> {
3305        let plan = table_scan_with_pushdown_provider_builder(
3306            TableProviderFilterPushDown::Exact,
3307            vec![],
3308            Some(vec![0]),
3309        )?
3310        .filter(col("a").eq(lit(10i64)))?
3311        .filter(col("b").gt(lit(11i64)))?
3312        .project(vec![col("a"), col("b")])?
3313        .build()?;
3314
3315        assert_optimized_plan_equal!(
3316            plan,
3317            @r"
3318        Projection: a, b
3319          TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3320        "
3321        )
3322    }
3323
3324    #[test]
3325    fn test_filter_with_alias() -> Result<()> {
3326        // in table scan the true col name is 'test.a',
3327        // but we rename it as 'b', and use col 'b' in filter
3328        // we need rewrite filter col before push down.
3329        let table_scan = test_table_scan()?;
3330        let plan = LogicalPlanBuilder::from(table_scan)
3331            .project(vec![col("a").alias("b"), col("c")])?
3332            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3333            .build()?;
3334
3335        // filter on col b
3336        assert_snapshot!(plan,
3337        @r"
3338        Filter: b > Int64(10) AND test.c > Int64(10)
3339          Projection: test.a AS b, test.c
3340            TableScan: test
3341        ",
3342        );
3343        // rewrite filter col b to test.a
3344        assert_optimized_plan_equal!(
3345            plan,
3346            @r"
3347        Projection: test.a AS b, test.c
3348          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3349        "
3350        )
3351    }
3352
3353    #[test]
3354    fn test_filter_with_alias_2() -> Result<()> {
3355        // in table scan the true col name is 'test.a',
3356        // but we rename it as 'b', and use col 'b' in filter
3357        // we need rewrite filter col before push down.
3358        let table_scan = test_table_scan()?;
3359        let plan = LogicalPlanBuilder::from(table_scan)
3360            .project(vec![col("a").alias("b"), col("c")])?
3361            .project(vec![col("b"), col("c")])?
3362            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3363            .build()?;
3364
3365        // filter on col b
3366        assert_snapshot!(plan,
3367        @r"
3368        Filter: b > Int64(10) AND test.c > Int64(10)
3369          Projection: b, test.c
3370            Projection: test.a AS b, test.c
3371              TableScan: test
3372        ",
3373        );
3374        // rewrite filter col b to test.a
3375        assert_optimized_plan_equal!(
3376            plan,
3377            @r"
3378        Projection: b, test.c
3379          Projection: test.a AS b, test.c
3380            TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3381        "
3382        )
3383    }
3384
3385    #[test]
3386    fn test_filter_with_multi_alias() -> Result<()> {
3387        let table_scan = test_table_scan()?;
3388        let plan = LogicalPlanBuilder::from(table_scan)
3389            .project(vec![col("a").alias("b"), col("c").alias("d")])?
3390            .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3391            .build()?;
3392
3393        // filter on col b and d
3394        assert_snapshot!(plan,
3395        @r"
3396        Filter: b > Int64(10) AND d > Int64(10)
3397          Projection: test.a AS b, test.c AS d
3398            TableScan: test
3399        ",
3400        );
3401        // rewrite filter col b to test.a, col d to test.c
3402        assert_optimized_plan_equal!(
3403            plan,
3404            @r"
3405        Projection: test.a AS b, test.c AS d
3406          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3407        "
3408        )
3409    }
3410
3411    /// predicate on join key in filter expression should be pushed down to both inputs
3412    #[test]
3413    fn join_filter_with_alias() -> Result<()> {
3414        let table_scan = test_table_scan()?;
3415        let left = LogicalPlanBuilder::from(table_scan)
3416            .project(vec![col("a").alias("c")])?
3417            .build()?;
3418        let right_table_scan = test_table_scan_with_name("test2")?;
3419        let right = LogicalPlanBuilder::from(right_table_scan)
3420            .project(vec![col("b").alias("d")])?
3421            .build()?;
3422        let filter = col("c").gt(lit(1u32));
3423        let plan = LogicalPlanBuilder::from(left)
3424            .join(
3425                right,
3426                JoinType::Inner,
3427                (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3428                Some(filter),
3429            )?
3430            .build()?;
3431
3432        assert_snapshot!(plan,
3433        @r"
3434        Inner Join: c = d Filter: c > UInt32(1)
3435          Projection: test.a AS c
3436            TableScan: test
3437          Projection: test2.b AS d
3438            TableScan: test2
3439        ",
3440        );
3441        // Change filter on col `c`, 'd' to `test.a`, 'test.b'
3442        assert_optimized_plan_equal!(
3443            plan,
3444            @r"
3445        Inner Join: c = d
3446          Projection: test.a AS c
3447            TableScan: test, full_filters=[test.a > UInt32(1)]
3448          Projection: test2.b AS d
3449            TableScan: test2, full_filters=[test2.b > UInt32(1)]
3450        "
3451        )
3452    }
3453
3454    #[test]
3455    fn test_in_filter_with_alias() -> Result<()> {
3456        // in table scan the true col name is 'test.a',
3457        // but we rename it as 'b', and use col 'b' in filter
3458        // we need rewrite filter col before push down.
3459        let table_scan = test_table_scan()?;
3460        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3461        let plan = LogicalPlanBuilder::from(table_scan)
3462            .project(vec![col("a").alias("b"), col("c")])?
3463            .filter(in_list(col("b"), filter_value, false))?
3464            .build()?;
3465
3466        // filter on col b
3467        assert_snapshot!(plan,
3468        @r"
3469        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3470          Projection: test.a AS b, test.c
3471            TableScan: test
3472        ",
3473        );
3474        // rewrite filter col b to test.a
3475        assert_optimized_plan_equal!(
3476            plan,
3477            @r"
3478        Projection: test.a AS b, test.c
3479          TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3480        "
3481        )
3482    }
3483
3484    #[test]
3485    fn test_in_filter_with_alias_2() -> Result<()> {
3486        // in table scan the true col name is 'test.a',
3487        // but we rename it as 'b', and use col 'b' in filter
3488        // we need rewrite filter col before push down.
3489        let table_scan = test_table_scan()?;
3490        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3491        let plan = LogicalPlanBuilder::from(table_scan)
3492            .project(vec![col("a").alias("b"), col("c")])?
3493            .project(vec![col("b"), col("c")])?
3494            .filter(in_list(col("b"), filter_value, false))?
3495            .build()?;
3496
3497        // filter on col b
3498        assert_snapshot!(plan,
3499        @r"
3500        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3501          Projection: b, test.c
3502            Projection: test.a AS b, test.c
3503              TableScan: test
3504        ",
3505        );
3506        // rewrite filter col b to test.a
3507        assert_optimized_plan_equal!(
3508            plan,
3509            @r"
3510        Projection: b, test.c
3511          Projection: test.a AS b, test.c
3512            TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3513        "
3514        )
3515    }
3516
3517    #[test]
3518    fn test_in_subquery_with_alias() -> Result<()> {
3519        // in table scan the true col name is 'test.a',
3520        // but we rename it as 'b', and use col 'b' in subquery filter
3521        let table_scan = test_table_scan()?;
3522        let table_scan_sq = test_table_scan_with_name("sq")?;
3523        let subplan = Arc::new(
3524            LogicalPlanBuilder::from(table_scan_sq)
3525                .project(vec![col("c")])?
3526                .build()?,
3527        );
3528        let plan = LogicalPlanBuilder::from(table_scan)
3529            .project(vec![col("a").alias("b"), col("c")])?
3530            .filter(in_subquery(col("b"), subplan))?
3531            .build()?;
3532
3533        // filter on col b in subquery
3534        assert_snapshot!(plan,
3535        @r"
3536        Filter: b IN (<subquery>)
3537          Subquery:
3538            Projection: sq.c
3539              TableScan: sq
3540          Projection: test.a AS b, test.c
3541            TableScan: test
3542        ",
3543        );
3544        // rewrite filter col b to test.a
3545        assert_optimized_plan_equal!(
3546            plan,
3547            @r"
3548        Projection: test.a AS b, test.c
3549          TableScan: test, full_filters=[test.a IN (<subquery>)]
3550            Subquery:
3551              Projection: sq.c
3552                TableScan: sq
3553        "
3554        )
3555    }
3556
3557    #[test]
3558    fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3559        // SELECT a FROM (SELECT 1 AS a) b WHERE b.a = 1
3560        let plan = LogicalPlanBuilder::empty(true)
3561            .project(vec![lit(0i64).alias("a")])?
3562            .alias("b")?
3563            .project(vec![col("b.a")])?
3564            .alias("b")?
3565            .filter(col("b.a").eq(lit(1i64)))?
3566            .project(vec![col("b.a")])?
3567            .build()?;
3568
3569        assert_snapshot!(plan,
3570        @r"
3571        Projection: b.a
3572          Filter: b.a = Int64(1)
3573            SubqueryAlias: b
3574              Projection: b.a
3575                SubqueryAlias: b
3576                  Projection: Int64(0) AS a
3577                    EmptyRelation: rows=1
3578        ",
3579        );
3580        // Ensure that the predicate without any columns (0 = 1) is
3581        // still there.
3582        assert_optimized_plan_equal!(
3583            plan,
3584            @r"
3585        Projection: b.a
3586          SubqueryAlias: b
3587            Projection: b.a
3588              SubqueryAlias: b
3589                Projection: Int64(0) AS a
3590                  Filter: Int64(0) = Int64(1)
3591                    EmptyRelation: rows=1
3592        "
3593        )
3594    }
3595
3596    #[test]
3597    fn test_crossjoin_with_or_clause() -> Result<()> {
3598        // select * from test,test1 where (test.a = test1.a and test.b > 1) or (test.b = test1.b and test.c < 10);
3599        let table_scan = test_table_scan()?;
3600        let left = LogicalPlanBuilder::from(table_scan)
3601            .project(vec![col("a"), col("b"), col("c")])?
3602            .build()?;
3603        let right_table_scan = test_table_scan_with_name("test1")?;
3604        let right = LogicalPlanBuilder::from(right_table_scan)
3605            .project(vec![col("a").alias("d"), col("a").alias("e")])?
3606            .build()?;
3607        let filter = or(
3608            and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3609            and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3610        );
3611        let plan = LogicalPlanBuilder::from(left)
3612            .cross_join(right)?
3613            .filter(filter)?
3614            .build()?;
3615
3616        assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3617        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3618          Projection: test.a, test.b, test.c
3619            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3620          Projection: test1.a AS d, test1.a AS e
3621            TableScan: test1
3622        ")?;
3623
3624        // Originally global state which can help to avoid duplicate Filters been generated and pushed down.
3625        // Now the global state is removed. Need to double confirm that avoid duplicate Filters.
3626        let optimized_plan = PushDownFilter::new()
3627            .rewrite(plan, &OptimizerContext::new())
3628            .expect("failed to optimize plan")
3629            .data;
3630        assert_optimized_plan_equal!(
3631            optimized_plan,
3632            @r"
3633        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3634          Projection: test.a, test.b, test.c
3635            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3636          Projection: test1.a AS d, test1.a AS e
3637            TableScan: test1
3638        "
3639        )
3640    }
3641
3642    #[test]
3643    fn left_semi_join() -> Result<()> {
3644        let left = test_table_scan_with_name("test1")?;
3645        let right_table_scan = test_table_scan_with_name("test2")?;
3646        let right = LogicalPlanBuilder::from(right_table_scan)
3647            .project(vec![col("a"), col("b")])?
3648            .build()?;
3649        let plan = LogicalPlanBuilder::from(left)
3650            .join(
3651                right,
3652                JoinType::LeftSemi,
3653                (
3654                    vec![Column::from_qualified_name("test1.a")],
3655                    vec![Column::from_qualified_name("test2.a")],
3656                ),
3657                None,
3658            )?
3659            .filter(col("test2.a").lt_eq(lit(1i64)))?
3660            .build()?;
3661
3662        // not part of the test, just good to know:
3663        assert_snapshot!(plan,
3664        @r"
3665        Filter: test2.a <= Int64(1)
3666          LeftSemi Join: test1.a = test2.a
3667            TableScan: test1
3668            Projection: test2.a, test2.b
3669              TableScan: test2
3670        ",
3671        );
3672        // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side.
3673        assert_optimized_plan_equal!(
3674            plan,
3675            @r"
3676        Filter: test2.a <= Int64(1)
3677          LeftSemi Join: test1.a = test2.a
3678            TableScan: test1, full_filters=[test1.a <= Int64(1)]
3679            Projection: test2.a, test2.b
3680              TableScan: test2
3681        "
3682        )
3683    }
3684
3685    #[test]
3686    fn left_semi_join_with_filters() -> Result<()> {
3687        let left = test_table_scan_with_name("test1")?;
3688        let right_table_scan = test_table_scan_with_name("test2")?;
3689        let right = LogicalPlanBuilder::from(right_table_scan)
3690            .project(vec![col("a"), col("b")])?
3691            .build()?;
3692        let plan = LogicalPlanBuilder::from(left)
3693            .join(
3694                right,
3695                JoinType::LeftSemi,
3696                (
3697                    vec![Column::from_qualified_name("test1.a")],
3698                    vec![Column::from_qualified_name("test2.a")],
3699                ),
3700                Some(
3701                    col("test1.b")
3702                        .gt(lit(1u32))
3703                        .and(col("test2.b").gt(lit(2u32))),
3704                ),
3705            )?
3706            .build()?;
3707
3708        // not part of the test, just good to know:
3709        assert_snapshot!(plan,
3710        @r"
3711        LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3712          TableScan: test1
3713          Projection: test2.a, test2.b
3714            TableScan: test2
3715        ",
3716        );
3717        // Both side will be pushed down.
3718        assert_optimized_plan_equal!(
3719            plan,
3720            @r"
3721        LeftSemi Join: test1.a = test2.a
3722          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3723          Projection: test2.a, test2.b
3724            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3725        "
3726        )
3727    }
3728
3729    #[test]
3730    fn right_semi_join() -> Result<()> {
3731        let left = test_table_scan_with_name("test1")?;
3732        let right_table_scan = test_table_scan_with_name("test2")?;
3733        let right = LogicalPlanBuilder::from(right_table_scan)
3734            .project(vec![col("a"), col("b")])?
3735            .build()?;
3736        let plan = LogicalPlanBuilder::from(left)
3737            .join(
3738                right,
3739                JoinType::RightSemi,
3740                (
3741                    vec![Column::from_qualified_name("test1.a")],
3742                    vec![Column::from_qualified_name("test2.a")],
3743                ),
3744                None,
3745            )?
3746            .filter(col("test1.a").lt_eq(lit(1i64)))?
3747            .build()?;
3748
3749        // not part of the test, just good to know:
3750        assert_snapshot!(plan,
3751        @r"
3752        Filter: test1.a <= Int64(1)
3753          RightSemi Join: test1.a = test2.a
3754            TableScan: test1
3755            Projection: test2.a, test2.b
3756              TableScan: test2
3757        ",
3758        );
3759        // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side.
3760        assert_optimized_plan_equal!(
3761            plan,
3762            @r"
3763        Filter: test1.a <= Int64(1)
3764          RightSemi Join: test1.a = test2.a
3765            TableScan: test1
3766            Projection: test2.a, test2.b
3767              TableScan: test2, full_filters=[test2.a <= Int64(1)]
3768        "
3769        )
3770    }
3771
3772    #[test]
3773    fn right_semi_join_with_filters() -> Result<()> {
3774        let left = test_table_scan_with_name("test1")?;
3775        let right_table_scan = test_table_scan_with_name("test2")?;
3776        let right = LogicalPlanBuilder::from(right_table_scan)
3777            .project(vec![col("a"), col("b")])?
3778            .build()?;
3779        let plan = LogicalPlanBuilder::from(left)
3780            .join(
3781                right,
3782                JoinType::RightSemi,
3783                (
3784                    vec![Column::from_qualified_name("test1.a")],
3785                    vec![Column::from_qualified_name("test2.a")],
3786                ),
3787                Some(
3788                    col("test1.b")
3789                        .gt(lit(1u32))
3790                        .and(col("test2.b").gt(lit(2u32))),
3791                ),
3792            )?
3793            .build()?;
3794
3795        // not part of the test, just good to know:
3796        assert_snapshot!(plan,
3797        @r"
3798        RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3799          TableScan: test1
3800          Projection: test2.a, test2.b
3801            TableScan: test2
3802        ",
3803        );
3804        // Both side will be pushed down.
3805        assert_optimized_plan_equal!(
3806            plan,
3807            @r"
3808        RightSemi Join: test1.a = test2.a
3809          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3810          Projection: test2.a, test2.b
3811            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3812        "
3813        )
3814    }
3815
3816    #[test]
3817    fn left_anti_join() -> Result<()> {
3818        let table_scan = test_table_scan_with_name("test1")?;
3819        let left = LogicalPlanBuilder::from(table_scan)
3820            .project(vec![col("a"), col("b")])?
3821            .build()?;
3822        let right_table_scan = test_table_scan_with_name("test2")?;
3823        let right = LogicalPlanBuilder::from(right_table_scan)
3824            .project(vec![col("a"), col("b")])?
3825            .build()?;
3826        let plan = LogicalPlanBuilder::from(left)
3827            .join(
3828                right,
3829                JoinType::LeftAnti,
3830                (
3831                    vec![Column::from_qualified_name("test1.a")],
3832                    vec![Column::from_qualified_name("test2.a")],
3833                ),
3834                None,
3835            )?
3836            .filter(col("test2.a").gt(lit(2u32)))?
3837            .build()?;
3838
3839        // not part of the test, just good to know:
3840        assert_snapshot!(plan,
3841        @r"
3842        Filter: test2.a > UInt32(2)
3843          LeftAnti Join: test1.a = test2.a
3844            Projection: test1.a, test1.b
3845              TableScan: test1
3846            Projection: test2.a, test2.b
3847              TableScan: test2
3848        ",
3849        );
3850        // For left anti, filter of the right side filter can be pushed down.
3851        assert_optimized_plan_equal!(
3852            plan,
3853            @r"
3854        Filter: test2.a > UInt32(2)
3855          LeftAnti Join: test1.a = test2.a
3856            Projection: test1.a, test1.b
3857              TableScan: test1, full_filters=[test1.a > UInt32(2)]
3858            Projection: test2.a, test2.b
3859              TableScan: test2
3860        "
3861        )
3862    }
3863
3864    #[test]
3865    fn left_anti_join_with_filters() -> Result<()> {
3866        let table_scan = test_table_scan_with_name("test1")?;
3867        let left = LogicalPlanBuilder::from(table_scan)
3868            .project(vec![col("a"), col("b")])?
3869            .build()?;
3870        let right_table_scan = test_table_scan_with_name("test2")?;
3871        let right = LogicalPlanBuilder::from(right_table_scan)
3872            .project(vec![col("a"), col("b")])?
3873            .build()?;
3874        let plan = LogicalPlanBuilder::from(left)
3875            .join(
3876                right,
3877                JoinType::LeftAnti,
3878                (
3879                    vec![Column::from_qualified_name("test1.a")],
3880                    vec![Column::from_qualified_name("test2.a")],
3881                ),
3882                Some(
3883                    col("test1.b")
3884                        .gt(lit(1u32))
3885                        .and(col("test2.b").gt(lit(2u32))),
3886                ),
3887            )?
3888            .build()?;
3889
3890        // not part of the test, just good to know:
3891        assert_snapshot!(plan,
3892        @r"
3893        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3894          Projection: test1.a, test1.b
3895            TableScan: test1
3896          Projection: test2.a, test2.b
3897            TableScan: test2
3898        ",
3899        );
3900        // For left anti, filter of the right side filter can be pushed down.
3901        assert_optimized_plan_equal!(
3902            plan,
3903            @r"
3904        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3905          Projection: test1.a, test1.b
3906            TableScan: test1
3907          Projection: test2.a, test2.b
3908            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3909        "
3910        )
3911    }
3912
3913    #[test]
3914    fn right_anti_join() -> Result<()> {
3915        let table_scan = test_table_scan_with_name("test1")?;
3916        let left = LogicalPlanBuilder::from(table_scan)
3917            .project(vec![col("a"), col("b")])?
3918            .build()?;
3919        let right_table_scan = test_table_scan_with_name("test2")?;
3920        let right = LogicalPlanBuilder::from(right_table_scan)
3921            .project(vec![col("a"), col("b")])?
3922            .build()?;
3923        let plan = LogicalPlanBuilder::from(left)
3924            .join(
3925                right,
3926                JoinType::RightAnti,
3927                (
3928                    vec![Column::from_qualified_name("test1.a")],
3929                    vec![Column::from_qualified_name("test2.a")],
3930                ),
3931                None,
3932            )?
3933            .filter(col("test1.a").gt(lit(2u32)))?
3934            .build()?;
3935
3936        // not part of the test, just good to know:
3937        assert_snapshot!(plan,
3938        @r"
3939        Filter: test1.a > UInt32(2)
3940          RightAnti Join: test1.a = test2.a
3941            Projection: test1.a, test1.b
3942              TableScan: test1
3943            Projection: test2.a, test2.b
3944              TableScan: test2
3945        ",
3946        );
3947        // For right anti, filter of the left side can be pushed down.
3948        assert_optimized_plan_equal!(
3949            plan,
3950            @r"
3951        Filter: test1.a > UInt32(2)
3952          RightAnti Join: test1.a = test2.a
3953            Projection: test1.a, test1.b
3954              TableScan: test1
3955            Projection: test2.a, test2.b
3956              TableScan: test2, full_filters=[test2.a > UInt32(2)]
3957        "
3958        )
3959    }
3960
3961    #[test]
3962    fn right_anti_join_with_filters() -> Result<()> {
3963        let table_scan = test_table_scan_with_name("test1")?;
3964        let left = LogicalPlanBuilder::from(table_scan)
3965            .project(vec![col("a"), col("b")])?
3966            .build()?;
3967        let right_table_scan = test_table_scan_with_name("test2")?;
3968        let right = LogicalPlanBuilder::from(right_table_scan)
3969            .project(vec![col("a"), col("b")])?
3970            .build()?;
3971        let plan = LogicalPlanBuilder::from(left)
3972            .join(
3973                right,
3974                JoinType::RightAnti,
3975                (
3976                    vec![Column::from_qualified_name("test1.a")],
3977                    vec![Column::from_qualified_name("test2.a")],
3978                ),
3979                Some(
3980                    col("test1.b")
3981                        .gt(lit(1u32))
3982                        .and(col("test2.b").gt(lit(2u32))),
3983                ),
3984            )?
3985            .build()?;
3986
3987        // not part of the test, just good to know:
3988        assert_snapshot!(plan,
3989        @r"
3990        RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3991          Projection: test1.a, test1.b
3992            TableScan: test1
3993          Projection: test2.a, test2.b
3994            TableScan: test2
3995        ",
3996        );
3997        // For right anti, filter of the left side can be pushed down.
3998        assert_optimized_plan_equal!(
3999            plan,
4000            @r"
4001        RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
4002          Projection: test1.a, test1.b
4003            TableScan: test1, full_filters=[test1.b > UInt32(1)]
4004          Projection: test2.a, test2.b
4005            TableScan: test2
4006        "
4007        )
4008    }
4009
4010    #[derive(Debug, PartialEq, Eq, Hash)]
4011    struct TestScalarUDF {
4012        signature: Signature,
4013    }
4014
4015    impl ScalarUDFImpl for TestScalarUDF {
4016        fn name(&self) -> &str {
4017            "TestScalarUDF"
4018        }
4019
4020        fn signature(&self) -> &Signature {
4021            &self.signature
4022        }
4023
4024        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4025            Ok(DataType::Int32)
4026        }
4027
4028        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
4029            Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
4030        }
4031    }
4032
4033    #[test]
4034    fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
4035        // SELECT t.a, t.r FROM (SELECT a, sum(b),  TestScalarUDF()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5;
4036        let table_scan = test_table_scan_with_name("test1")?;
4037        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4038            signature: Signature::exact(vec![], Volatility::Volatile),
4039        });
4040        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4041
4042        let plan = LogicalPlanBuilder::from(table_scan)
4043            .aggregate(vec![col("a")], vec![sum(col("b"))])?
4044            .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
4045            .alias("t")?
4046            .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
4047            .project(vec![col("t.a"), col("t.r")])?
4048            .build()?;
4049
4050        assert_snapshot!(plan,
4051        @r"
4052        Projection: t.a, t.r
4053          Filter: t.a > Int32(5) AND t.r > Float64(0.5)
4054            SubqueryAlias: t
4055              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
4056                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
4057                  TableScan: test1
4058        ",
4059        );
4060        assert_optimized_plan_equal!(
4061            plan,
4062            @r"
4063        Projection: t.a, t.r
4064          SubqueryAlias: t
4065            Filter: r > Float64(0.5)
4066              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
4067                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
4068                  TableScan: test1, full_filters=[test1.a > Int32(5)]
4069        "
4070        )
4071    }
4072
4073    #[test]
4074    fn test_push_down_volatile_function_in_join() -> Result<()> {
4075        // SELECT t.a, t.r FROM (SELECT test1.a AS a, TestScalarUDF() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5;
4076        let table_scan = test_table_scan_with_name("test1")?;
4077        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4078            signature: Signature::exact(vec![], Volatility::Volatile),
4079        });
4080        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4081        let left = LogicalPlanBuilder::from(table_scan).build()?;
4082        let right_table_scan = test_table_scan_with_name("test2")?;
4083        let right = LogicalPlanBuilder::from(right_table_scan).build()?;
4084        let plan = LogicalPlanBuilder::from(left)
4085            .join(
4086                right,
4087                JoinType::Inner,
4088                (
4089                    vec![Column::from_qualified_name("test1.a")],
4090                    vec![Column::from_qualified_name("test2.a")],
4091                ),
4092                None,
4093            )?
4094            .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
4095            .alias("t")?
4096            .filter(col("t.r").gt(lit(0.8)))?
4097            .project(vec![col("t.a"), col("t.r")])?
4098            .build()?;
4099
4100        assert_snapshot!(plan,
4101        @r"
4102        Projection: t.a, t.r
4103          Filter: t.r > Float64(0.8)
4104            SubqueryAlias: t
4105              Projection: test1.a AS a, TestScalarUDF() AS r
4106                Inner Join: test1.a = test2.a
4107                  TableScan: test1
4108                  TableScan: test2
4109        ",
4110        );
4111        assert_optimized_plan_equal!(
4112            plan,
4113            @r"
4114        Projection: t.a, t.r
4115          SubqueryAlias: t
4116            Filter: r > Float64(0.8)
4117              Projection: test1.a AS a, TestScalarUDF() AS r
4118                Inner Join: test1.a = test2.a
4119                  TableScan: test1
4120                  TableScan: test2
4121        "
4122        )
4123    }
4124
4125    #[test]
4126    fn test_push_down_volatile_table_scan() -> Result<()> {
4127        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1;
4128        let table_scan = test_table_scan()?;
4129        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4130            signature: Signature::exact(vec![], Volatility::Volatile),
4131        });
4132        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4133        let plan = LogicalPlanBuilder::from(table_scan)
4134            .project(vec![col("a"), col("b")])?
4135            .filter(expr.gt(lit(0.1)))?
4136            .build()?;
4137
4138        assert_snapshot!(plan,
4139        @r"
4140        Filter: TestScalarUDF() > Float64(0.1)
4141          Projection: test.a, test.b
4142            TableScan: test
4143        ",
4144        );
4145        assert_optimized_plan_equal!(
4146            plan,
4147            @r"
4148        Projection: test.a, test.b
4149          Filter: TestScalarUDF() > Float64(0.1)
4150            TableScan: test
4151        "
4152        )
4153    }
4154
4155    #[test]
4156    fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4157        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4158        let table_scan = test_table_scan()?;
4159        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4160            signature: Signature::exact(vec![], Volatility::Volatile),
4161        });
4162        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4163        let plan = LogicalPlanBuilder::from(table_scan)
4164            .project(vec![col("a"), col("b")])?
4165            .filter(
4166                expr.gt(lit(0.1))
4167                    .and(col("t.a").gt(lit(5)))
4168                    .and(col("t.b").gt(lit(10))),
4169            )?
4170            .build()?;
4171
4172        assert_snapshot!(plan,
4173        @r"
4174        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4175          Projection: test.a, test.b
4176            TableScan: test
4177        ",
4178        );
4179        assert_optimized_plan_equal!(
4180            plan,
4181            @r"
4182        Projection: test.a, test.b
4183          Filter: TestScalarUDF() > Float64(0.1)
4184            TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4185        "
4186        )
4187    }
4188
4189    #[test]
4190    fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4191        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4192        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4193            signature: Signature::exact(vec![], Volatility::Volatile),
4194        });
4195        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4196        let plan = table_scan_with_pushdown_provider_builder(
4197            TableProviderFilterPushDown::Unsupported,
4198            vec![],
4199            None,
4200        )?
4201        .project(vec![col("a"), col("b")])?
4202        .filter(
4203            expr.gt(lit(0.1))
4204                .and(col("t.a").gt(lit(5)))
4205                .and(col("t.b").gt(lit(10))),
4206        )?
4207        .build()?;
4208
4209        assert_snapshot!(plan,
4210        @r"
4211        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4212          Projection: a, b
4213            TableScan: test
4214        ",
4215        );
4216        assert_optimized_plan_equal!(
4217            plan,
4218            @r"
4219        Projection: a, b
4220          Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4221            TableScan: test
4222        "
4223        )
4224    }
4225
4226    #[test]
4227    fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4228        // Define a custom user-defined logical node
4229        #[derive(Debug, Hash, Eq, PartialEq)]
4230        struct TestUserNode {
4231            schema: DFSchemaRef,
4232        }
4233
4234        impl PartialOrd for TestUserNode {
4235            fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4236                None
4237            }
4238        }
4239
4240        impl TestUserNode {
4241            fn new() -> Self {
4242                let schema = Arc::new(
4243                    DFSchema::new_with_metadata(
4244                        vec![(None, Field::new("a", DataType::Int64, false).into())],
4245                        Default::default(),
4246                    )
4247                    .unwrap(),
4248                );
4249
4250                Self { schema }
4251            }
4252        }
4253
4254        impl UserDefinedLogicalNodeCore for TestUserNode {
4255            fn name(&self) -> &str {
4256                "test_node"
4257            }
4258
4259            fn inputs(&self) -> Vec<&LogicalPlan> {
4260                vec![]
4261            }
4262
4263            fn schema(&self) -> &DFSchemaRef {
4264                &self.schema
4265            }
4266
4267            fn expressions(&self) -> Vec<Expr> {
4268                vec![]
4269            }
4270
4271            fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4272                write!(f, "TestUserNode")
4273            }
4274
4275            fn with_exprs_and_inputs(
4276                &self,
4277                exprs: Vec<Expr>,
4278                inputs: Vec<LogicalPlan>,
4279            ) -> Result<Self> {
4280                assert!(exprs.is_empty());
4281                assert!(inputs.is_empty());
4282                Ok(Self {
4283                    schema: Arc::clone(&self.schema),
4284                })
4285            }
4286        }
4287
4288        // Create a node and build a plan with a filter
4289        let node = LogicalPlan::Extension(Extension {
4290            node: Arc::new(TestUserNode::new()),
4291        });
4292
4293        let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4294
4295        // Check the original plan format (not part of the test assertions)
4296        assert_snapshot!(plan,
4297        @r"
4298        Filter: Boolean(false)
4299          TestUserNode
4300        ",
4301        );
4302        // Check that the filter is pushed down to the user-defined node
4303        assert_optimized_plan_equal!(
4304            plan,
4305            @r"
4306        Filter: Boolean(false)
4307          TestUserNode
4308        "
4309        )
4310    }
4311
4312    /// Test that filters are NOT pushed through MoveTowardsLeafNodes projections.
4313    /// These are cheap expressions (like get_field) where re-inlining into a filter
4314    /// has no benefit and causes optimizer instability — ExtractLeafExpressions will
4315    /// undo the push-down, creating an infinite loop that runs until the iteration
4316    /// limit is hit.
4317    #[test]
4318    fn filter_not_pushed_through_move_towards_leaves_projection() -> Result<()> {
4319        let table_scan = test_table_scan()?;
4320
4321        // Create a projection with a MoveTowardsLeafNodes expression
4322        let proj = LogicalPlanBuilder::from(table_scan)
4323            .project(vec![
4324                leaf_udf_expr(col("a")).alias("val"),
4325                col("b"),
4326                col("c"),
4327            ])?
4328            .build()?;
4329
4330        // Put a filter on the MoveTowardsLeafNodes column
4331        let plan = LogicalPlanBuilder::from(proj)
4332            .filter(col("val").gt(lit(150i64)))?
4333            .build()?;
4334
4335        // Filter should NOT be pushed through — val maps to a MoveTowardsLeafNodes expr
4336        assert_optimized_plan_equal!(
4337            plan,
4338            @r"
4339        Filter: val > Int64(150)
4340          Projection: leaf_udf(test.a) AS val, test.b, test.c
4341            TableScan: test
4342        "
4343        )
4344    }
4345
4346    /// Test mixed predicates: Column predicate pushed, MoveTowardsLeafNodes kept.
4347    #[test]
4348    fn filter_mixed_predicates_partial_push() -> Result<()> {
4349        let table_scan = test_table_scan()?;
4350
4351        // Create a projection with both MoveTowardsLeafNodes and Column expressions
4352        let proj = LogicalPlanBuilder::from(table_scan)
4353            .project(vec![
4354                leaf_udf_expr(col("a")).alias("val"),
4355                col("b"),
4356                col("c"),
4357            ])?
4358            .build()?;
4359
4360        // Filter with both: val > 150 (MoveTowardsLeafNodes) AND b > 5 (Column)
4361        let plan = LogicalPlanBuilder::from(proj)
4362            .filter(col("val").gt(lit(150i64)).and(col("b").gt(lit(5i64))))?
4363            .build()?;
4364
4365        // val > 150 should be kept above, b > 5 should be pushed through
4366        assert_optimized_plan_equal!(
4367            plan,
4368            @r"
4369        Filter: val > Int64(150)
4370          Projection: leaf_udf(test.a) AS val, test.b, test.c
4371            TableScan: test, full_filters=[test.b > Int64(5)]
4372        "
4373        )
4374    }
4375
4376    #[test]
4377    fn filter_not_pushed_down_through_table_scan_with_fetch() -> Result<()> {
4378        let scan = test_table_scan()?;
4379        let scan_with_fetch = match scan {
4380            LogicalPlan::TableScan(scan) => LogicalPlan::TableScan(TableScan {
4381                fetch: Some(10),
4382                ..scan
4383            }),
4384            _ => unreachable!(),
4385        };
4386        let plan = LogicalPlanBuilder::from(scan_with_fetch)
4387            .filter(col("a").gt(lit(10i64)))?
4388            .build()?;
4389        // Filter must NOT be pushed into the table scan when it has a fetch (limit)
4390        assert_optimized_plan_equal!(
4391            plan,
4392            @r"
4393        Filter: test.a > Int64(10)
4394          TableScan: test, fetch=10
4395        "
4396        )
4397    }
4398
4399    #[test]
4400    fn filter_push_down_through_sort_without_fetch() -> Result<()> {
4401        let table_scan = test_table_scan()?;
4402        let plan = LogicalPlanBuilder::from(table_scan)
4403            .sort(vec![col("a").sort(true, true)])?
4404            .filter(col("a").gt(lit(10i64)))?
4405            .build()?;
4406        // Filter should be pushed below the sort
4407        assert_optimized_plan_equal!(
4408            plan,
4409            @r"
4410        Sort: test.a ASC NULLS FIRST
4411          TableScan: test, full_filters=[test.a > Int64(10)]
4412        "
4413        )
4414    }
4415
4416    #[test]
4417    fn filter_not_pushed_down_through_sort_with_fetch() -> Result<()> {
4418        let table_scan = test_table_scan()?;
4419        let plan = LogicalPlanBuilder::from(table_scan)
4420            .sort_with_limit(vec![col("a").sort(true, true)], Some(5))?
4421            .filter(col("a").gt(lit(10i64)))?
4422            .build()?;
4423        // Filter must NOT be pushed below the sort when it has a fetch (limit),
4424        // because the limit should apply before the filter.
4425        assert_optimized_plan_equal!(
4426            plan,
4427            @r"
4428        Filter: test.a > Int64(10)
4429          Sort: test.a ASC NULLS FIRST, fetch=5
4430            TableScan: test
4431        "
4432        )
4433    }
4434}