datafusion_optimizer/
push_down_filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`PushDownFilter`] applies filters as early as possible
19
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26
27use datafusion_common::tree_node::{
28    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29};
30use datafusion_common::{
31    Column, DFSchema, Result, assert_eq_or_internal_err, assert_or_internal_err,
32    internal_err, plan_err, qualified_name,
33};
34use datafusion_expr::expr::WindowFunction;
35use datafusion_expr::expr_rewriter::replace_col;
36use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
37use datafusion_expr::utils::{
38    conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
39};
40use datafusion_expr::{
41    BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, and, or,
42};
43
44use crate::optimizer::ApplyOrder;
45use crate::simplify_expressions::simplify_predicates;
46use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
47use crate::{OptimizerConfig, OptimizerRule};
48
49/// Optimizer rule for pushing (moving) filter expressions down in a plan so
50/// they are applied as early as possible.
51///
52/// # Introduction
53///
54/// The goal of this rule is to improve query performance by eliminating
55/// redundant work.
56///
57/// For example, given a plan that sorts all values where `a > 10`:
58///
59/// ```text
60///  Filter (a > 10)
61///    Sort (a, b)
62/// ```
63///
64/// A better plan is to  filter the data *before* the Sort, which sorts fewer
65/// rows and therefore does less work overall:
66///
67/// ```text
68///  Sort (a, b)
69///    Filter (a > 10)  <-- Filter is moved before the sort
70/// ```
71///
72/// However it is not always possible to push filters down. For example, given a
73/// plan that finds the top 3 values and then keeps only those that are greater
74/// than 10, if the filter is pushed below the limit it would produce a
75/// different result.
76///
77/// ```text
78///  Filter (a > 10)   <-- can not move this Filter before the limit
79///    Limit (fetch=3)
80///      Sort (a, b)
81/// ```
82///
83///
84/// More formally, a filter-commutative operation is an operation `op` that
85/// satisfies `filter(op(data)) = op(filter(data))`.
86///
87/// The filter-commutative property is plan and column-specific. A filter on `a`
88/// can be pushed through a `Aggregate(group_by = [a], agg=[sum(b))`. However, a
89/// filter on  `sum(b)` can not be pushed through the same aggregate.
90///
91/// # Handling Conjunctions
92///
93/// It is possible to only push down **part** of a filter expression if is
94/// connected with `AND`s (more formally if it is a "conjunction").
95///
96/// For example, given the following plan:
97///
98/// ```text
99/// Filter(a > 10 AND sum(b) < 5)
100///   Aggregate(group_by = [a], agg = [sum(b))
101/// ```
102///
103/// The `a > 10` is commutative with the `Aggregate` but  `sum(b) < 5` is not.
104/// Therefore it is possible to only push part of the expression, resulting in:
105///
106/// ```text
107/// Filter(sum(b) < 5)
108///   Aggregate(group_by = [a], agg = [sum(b))
109///     Filter(a > 10)
110/// ```
111///
112/// # Handling Column Aliases
113///
114/// This optimizer must sometimes handle re-writing filter expressions when they
115/// pushed, for example if there is a projection that aliases `a+1` to `"b"`:
116///
117/// ```text
118/// Filter (b > 10)
119///     Projection: [a+1 AS "b"]  <-- changes the name of `a+1` to `b`
120/// ```
121///
122/// To apply the filter prior to the `Projection`, all references to `b` must be
123/// rewritten to `a+1`:
124///
125/// ```text
126/// Projection: a AS "b"
127///     Filter: (a + 1 > 10)  <--- changed from b to a + 1
128/// ```
129/// # Implementation Notes
130///
131/// This implementation performs a single pass through the plan, "pushing" down
132/// filters. When it passes through a filter, it stores that filter, and when it
133/// reaches a plan node that does not commute with that filter, it adds the
134/// filter to that place. When it passes through a projection, it re-writes the
135/// filter's expression taking into account that projection.
136#[derive(Default, Debug)]
137pub struct PushDownFilter {}
138
139/// For a given JOIN type, determine whether each input of the join is preserved
140/// for post-join (`WHERE` clause) filters.
141///
142/// It is only correct to push filters below a join for preserved inputs.
143///
144/// # Return Value
145/// A tuple of booleans - (left_preserved, right_preserved).
146///
147/// # "Preserved" input definition
148///
149/// We say a join side is preserved if the join returns all or a subset of the rows from
150/// the relevant side, such that each row of the output table directly maps to a row of
151/// the preserved input table. If a table is not preserved, it can provide extra null rows.
152/// That is, there may be rows in the output table that don't directly map to a row in the
153/// input table.
154///
155/// For example:
156///   - In an inner join, both sides are preserved, because each row of the output
157///     maps directly to a row from each side.
158///
159///   - In a left join, the left side is preserved (we can push predicates) but
160///     the right is not, because there may be rows in the output that don't
161///     directly map to a row in the right input (due to nulls filling where there
162///     is no match on the right).
163pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
164    match join_type {
165        JoinType::Inner => (true, true),
166        JoinType::Left => (true, false),
167        JoinType::Right => (false, true),
168        JoinType::Full => (false, false),
169        // No columns from the right side of the join can be referenced in output
170        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
171        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
172        // No columns from the left side of the join can be referenced in output
173        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
174        JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
175    }
176}
177
178/// For a given JOIN type, determine whether each input of the join is preserved
179/// for the join condition (`ON` clause filters).
180///
181/// It is only correct to push filters below a join for preserved inputs.
182///
183/// # Return Value
184/// A tuple of booleans - (left_preserved, right_preserved).
185///
186/// See [`lr_is_preserved`] for a definition of "preserved".
187pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
188    match join_type {
189        JoinType::Inner => (true, true),
190        JoinType::Left => (false, true),
191        JoinType::Right => (true, false),
192        JoinType::Full => (false, false),
193        JoinType::LeftSemi | JoinType::RightSemi => (true, true),
194        JoinType::LeftAnti => (false, true),
195        JoinType::RightAnti => (true, false),
196        JoinType::LeftMark => (false, true),
197        JoinType::RightMark => (true, false),
198    }
199}
200
201/// Evaluates the columns referenced in the given expression to see if they refer
202/// only to the left or right columns
203#[derive(Debug)]
204struct ColumnChecker<'a> {
205    /// schema of left join input
206    left_schema: &'a DFSchema,
207    /// columns in left_schema, computed on demand
208    left_columns: Option<HashSet<Column>>,
209    /// schema of right join input
210    right_schema: &'a DFSchema,
211    /// columns in left_schema, computed on demand
212    right_columns: Option<HashSet<Column>>,
213}
214
215impl<'a> ColumnChecker<'a> {
216    fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
217        Self {
218            left_schema,
219            left_columns: None,
220            right_schema,
221            right_columns: None,
222        }
223    }
224
225    /// Return true if the expression references only columns from the left side of the join
226    fn is_left_only(&mut self, predicate: &Expr) -> bool {
227        if self.left_columns.is_none() {
228            self.left_columns = Some(schema_columns(self.left_schema));
229        }
230        has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
231    }
232
233    /// Return true if the expression references only columns from the right side of the join
234    fn is_right_only(&mut self, predicate: &Expr) -> bool {
235        if self.right_columns.is_none() {
236            self.right_columns = Some(schema_columns(self.right_schema));
237        }
238        has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
239    }
240}
241
242/// Returns all columns in the schema
243fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
244    schema
245        .iter()
246        .flat_map(|(qualifier, field)| {
247            [
248                Column::new(qualifier.cloned(), field.name()),
249                // we need to push down filter using unqualified column as well
250                Column::new_unqualified(field.name()),
251            ]
252        })
253        .collect::<HashSet<_>>()
254}
255
256/// Determine whether the predicate can evaluate as the join conditions
257fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
258    let mut is_evaluate = true;
259    predicate.apply(|expr| match expr {
260        Expr::Column(_)
261        | Expr::Literal(_, _)
262        | Expr::Placeholder(_)
263        | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
264        Expr::Exists { .. }
265        | Expr::InSubquery(_)
266        | Expr::ScalarSubquery(_)
267        | Expr::OuterReferenceColumn(_, _)
268        | Expr::Unnest(_) => {
269            is_evaluate = false;
270            Ok(TreeNodeRecursion::Stop)
271        }
272        Expr::Alias(_)
273        | Expr::BinaryExpr(_)
274        | Expr::Like(_)
275        | Expr::SimilarTo(_)
276        | Expr::Not(_)
277        | Expr::IsNotNull(_)
278        | Expr::IsNull(_)
279        | Expr::IsTrue(_)
280        | Expr::IsFalse(_)
281        | Expr::IsUnknown(_)
282        | Expr::IsNotTrue(_)
283        | Expr::IsNotFalse(_)
284        | Expr::IsNotUnknown(_)
285        | Expr::Negative(_)
286        | Expr::Between(_)
287        | Expr::Case(_)
288        | Expr::Cast(_)
289        | Expr::TryCast(_)
290        | Expr::InList { .. }
291        | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
292        // TODO: remove the next line after `Expr::Wildcard` is removed
293        #[expect(deprecated)]
294        Expr::AggregateFunction(_)
295        | Expr::WindowFunction(_)
296        | Expr::Wildcard { .. }
297        | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
298    })?;
299    Ok(is_evaluate)
300}
301
302/// examine OR clause to see if any useful clauses can be extracted and push down.
303/// extract at least one qual from each sub clauses of OR clause, then form the quals
304/// to new OR clause as predicate.
305///
306/// # Example
307/// ```text
308/// Filter: (a = c and a < 20) or (b = d and b > 10)
309///     join/crossjoin:
310///          TableScan: projection=[a, b]
311///          TableScan: projection=[c, d]
312/// ```
313///
314/// is optimized to
315///
316/// ```text
317/// Filter: (a = c and a < 20) or (b = d and b > 10)
318///     join/crossjoin:
319///          Filter: (a < 20) or (b > 10)
320///              TableScan: projection=[a, b]
321///          TableScan: projection=[c, d]
322/// ```
323///
324/// In general, predicates of this form:
325///
326/// ```sql
327/// (A AND B) OR (C AND D)
328/// ```
329///
330/// will be transformed to one of:
331///
332/// * `((A AND B) OR (C AND D)) AND (A OR C)`
333/// * `((A AND B) OR (C AND D)) AND ((A AND B) OR C)`
334/// * do nothing.
335fn extract_or_clauses_for_join<'a>(
336    filters: &'a [Expr],
337    schema: &'a DFSchema,
338) -> impl Iterator<Item = Expr> + 'a {
339    let schema_columns = schema_columns(schema);
340
341    // new formed OR clauses and their column references
342    filters.iter().filter_map(move |expr| {
343        if let Expr::BinaryExpr(BinaryExpr {
344            left,
345            op: Operator::Or,
346            right,
347        }) = expr
348        {
349            let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
350            let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
351
352            // If nothing can be extracted from any sub clauses, do nothing for this OR clause.
353            if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
354                return Some(or(left_expr, right_expr));
355            }
356        }
357        None
358    })
359}
360
361/// extract qual from OR sub-clause.
362///
363/// A qual is extracted if it only contains set of column references in schema_columns.
364///
365/// For AND clause, we extract from both sub-clauses, then make new AND clause by extracted
366/// clauses if both extracted; Otherwise, use the extracted clause from any sub-clauses or None.
367///
368/// For OR clause, we extract from both sub-clauses, then make new OR clause by extracted clauses if both extracted;
369/// Otherwise, return None.
370///
371/// For other clause, apply the rule above to extract clause.
372fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
373    let mut predicate = None;
374
375    match expr {
376        Expr::BinaryExpr(BinaryExpr {
377            left: l_expr,
378            op: Operator::Or,
379            right: r_expr,
380        }) => {
381            let l_expr = extract_or_clause(l_expr, schema_columns);
382            let r_expr = extract_or_clause(r_expr, schema_columns);
383
384            if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
385                predicate = Some(or(l_expr, r_expr));
386            }
387        }
388        Expr::BinaryExpr(BinaryExpr {
389            left: l_expr,
390            op: Operator::And,
391            right: r_expr,
392        }) => {
393            let l_expr = extract_or_clause(l_expr, schema_columns);
394            let r_expr = extract_or_clause(r_expr, schema_columns);
395
396            match (l_expr, r_expr) {
397                (Some(l_expr), Some(r_expr)) => {
398                    predicate = Some(and(l_expr, r_expr));
399                }
400                (Some(l_expr), None) => {
401                    predicate = Some(l_expr);
402                }
403                (None, Some(r_expr)) => {
404                    predicate = Some(r_expr);
405                }
406                (None, None) => {
407                    predicate = None;
408                }
409            }
410        }
411        _ => {
412            if has_all_column_refs(expr, schema_columns) {
413                predicate = Some(expr.clone());
414            }
415        }
416    }
417
418    predicate
419}
420
421/// push down join/cross-join
422fn push_down_all_join(
423    predicates: Vec<Expr>,
424    inferred_join_predicates: Vec<Expr>,
425    mut join: Join,
426    on_filter: Vec<Expr>,
427) -> Result<Transformed<LogicalPlan>> {
428    let is_inner_join = join.join_type == JoinType::Inner;
429    // Get pushable predicates from current optimizer state
430    let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
431
432    // The predicates can be divided to three categories:
433    // 1) can push through join to its children(left or right)
434    // 2) can be converted to join conditions if the join type is Inner
435    // 3) should be kept as filter conditions
436    let left_schema = join.left.schema();
437    let right_schema = join.right.schema();
438    let mut left_push = vec![];
439    let mut right_push = vec![];
440    let mut keep_predicates = vec![];
441    let mut join_conditions = vec![];
442    let mut checker = ColumnChecker::new(left_schema, right_schema);
443    for predicate in predicates {
444        if left_preserved && checker.is_left_only(&predicate) {
445            left_push.push(predicate);
446        } else if right_preserved && checker.is_right_only(&predicate) {
447            right_push.push(predicate);
448        } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
449            // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
450            // and convert to the join on condition
451            join_conditions.push(predicate);
452        } else {
453            keep_predicates.push(predicate);
454        }
455    }
456
457    // For infer predicates, if they can not push through join, just drop them
458    for predicate in inferred_join_predicates {
459        if left_preserved && checker.is_left_only(&predicate) {
460            left_push.push(predicate);
461        } else if right_preserved && checker.is_right_only(&predicate) {
462            right_push.push(predicate);
463        }
464    }
465
466    let mut on_filter_join_conditions = vec![];
467    let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
468
469    if !on_filter.is_empty() {
470        for on in on_filter {
471            if on_left_preserved && checker.is_left_only(&on) {
472                left_push.push(on)
473            } else if on_right_preserved && checker.is_right_only(&on) {
474                right_push.push(on)
475            } else {
476                on_filter_join_conditions.push(on)
477            }
478        }
479    }
480
481    // Extract from OR clause, generate new predicates for both side of join if possible.
482    // We only track the unpushable predicates above.
483    if left_preserved {
484        left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
485        left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
486    }
487    if right_preserved {
488        right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
489        right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
490    }
491
492    // For predicates from join filter, we should check with if a join side is preserved
493    // in term of join filtering.
494    if on_left_preserved {
495        left_push.extend(extract_or_clauses_for_join(
496            &on_filter_join_conditions,
497            left_schema,
498        ));
499    }
500    if on_right_preserved {
501        right_push.extend(extract_or_clauses_for_join(
502            &on_filter_join_conditions,
503            right_schema,
504        ));
505    }
506
507    if let Some(predicate) = conjunction(left_push) {
508        join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
509    }
510    if let Some(predicate) = conjunction(right_push) {
511        join.right =
512            Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
513    }
514
515    // Add any new join conditions as the non join predicates
516    join_conditions.extend(on_filter_join_conditions);
517    join.filter = conjunction(join_conditions);
518
519    // wrap the join on the filter whose predicates must be kept, if any
520    let plan = LogicalPlan::Join(join);
521    let plan = if let Some(predicate) = conjunction(keep_predicates) {
522        LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
523    } else {
524        plan
525    };
526    Ok(Transformed::yes(plan))
527}
528
529fn push_down_join(
530    join: Join,
531    parent_predicate: Option<&Expr>,
532) -> Result<Transformed<LogicalPlan>> {
533    // Split the parent predicate into individual conjunctive parts.
534    let predicates = parent_predicate
535        .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
536
537    // Extract conjunctions from the JOIN's ON filter, if present.
538    let on_filters = join
539        .filter
540        .as_ref()
541        .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
542
543    // Are there any new join predicates that can be inferred from the filter expressions?
544    let inferred_join_predicates =
545        infer_join_predicates(&join, &predicates, &on_filters)?;
546
547    if on_filters.is_empty()
548        && predicates.is_empty()
549        && inferred_join_predicates.is_empty()
550    {
551        return Ok(Transformed::no(LogicalPlan::Join(join)));
552    }
553
554    push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
555}
556
557/// Extracts any equi-join join predicates from the given filter expressions.
558///
559/// Parameters
560/// * `join` the join in question
561///
562/// * `predicates` the pushed down filter expression
563///
564/// * `on_filters` filters from the join ON clause that have not already been
565///   identified as join predicates
566fn infer_join_predicates(
567    join: &Join,
568    predicates: &[Expr],
569    on_filters: &[Expr],
570) -> Result<Vec<Expr>> {
571    // Only allow both side key is column.
572    let join_col_keys = join
573        .on
574        .iter()
575        .filter_map(|(l, r)| {
576            let left_col = l.try_as_col()?;
577            let right_col = r.try_as_col()?;
578            Some((left_col, right_col))
579        })
580        .collect::<Vec<_>>();
581
582    let join_type = join.join_type;
583
584    let mut inferred_predicates = InferredPredicates::new(join_type);
585
586    infer_join_predicates_from_predicates(
587        &join_col_keys,
588        predicates,
589        &mut inferred_predicates,
590    )?;
591
592    infer_join_predicates_from_on_filters(
593        &join_col_keys,
594        join_type,
595        on_filters,
596        &mut inferred_predicates,
597    )?;
598
599    Ok(inferred_predicates.predicates)
600}
601
602/// Inferred predicates collector.
603/// When the JoinType is not Inner, we need to detect whether the inferred predicate can strictly
604/// filter out NULL, otherwise ignore it. e.g.
605/// ```text
606/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL;
607/// ```
608/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate will be pushed down to
609/// the left side, resulting in the wrong result.
610struct InferredPredicates {
611    predicates: Vec<Expr>,
612    is_inner_join: bool,
613}
614
615impl InferredPredicates {
616    fn new(join_type: JoinType) -> Self {
617        Self {
618            predicates: vec![],
619            is_inner_join: matches!(join_type, JoinType::Inner),
620        }
621    }
622
623    fn try_build_predicate(
624        &mut self,
625        predicate: Expr,
626        replace_map: &HashMap<&Column, &Column>,
627    ) -> Result<()> {
628        if self.is_inner_join
629            || matches!(
630                is_restrict_null_predicate(
631                    predicate.clone(),
632                    replace_map.keys().cloned()
633                ),
634                Ok(true)
635            )
636        {
637            self.predicates.push(replace_col(predicate, replace_map)?);
638        }
639
640        Ok(())
641    }
642}
643
644/// Infer predicates from the pushed down predicates.
645///
646/// Parameters
647/// * `join_col_keys` column pairs from the join ON clause
648///
649/// * `predicates` the pushed down predicates
650///
651/// * `inferred_predicates` the inferred results
652fn infer_join_predicates_from_predicates(
653    join_col_keys: &[(&Column, &Column)],
654    predicates: &[Expr],
655    inferred_predicates: &mut InferredPredicates,
656) -> Result<()> {
657    infer_join_predicates_impl::<true, true>(
658        join_col_keys,
659        predicates,
660        inferred_predicates,
661    )
662}
663
664/// Infer predicates from the join filter.
665///
666/// Parameters
667/// * `join_col_keys` column pairs from the join ON clause
668///
669/// * `join_type` the JoinType of Join
670///
671/// * `on_filters` filters from the join ON clause that have not already been
672///   identified as join predicates
673///
674/// * `inferred_predicates` the inferred results
675fn infer_join_predicates_from_on_filters(
676    join_col_keys: &[(&Column, &Column)],
677    join_type: JoinType,
678    on_filters: &[Expr],
679    inferred_predicates: &mut InferredPredicates,
680) -> Result<()> {
681    match join_type {
682        JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
683        JoinType::Inner => infer_join_predicates_impl::<true, true>(
684            join_col_keys,
685            on_filters,
686            inferred_predicates,
687        ),
688        JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
689            infer_join_predicates_impl::<true, false>(
690                join_col_keys,
691                on_filters,
692                inferred_predicates,
693            )
694        }
695        JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
696            infer_join_predicates_impl::<false, true>(
697                join_col_keys,
698                on_filters,
699                inferred_predicates,
700            )
701        }
702    }
703}
704
705/// Infer predicates from the given predicates.
706///
707/// Parameters
708/// * `join_col_keys` column pairs from the join ON clause
709///
710/// * `input_predicates` the given predicates. It can be the pushed down predicates,
711///   or it can be the filters of the Join
712///
713/// * `inferred_predicates` the inferred results
714///
715/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate can
716///   be inferred from the left table related predicate
717///
718/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate can
719///   be inferred from the right table related predicate
720fn infer_join_predicates_impl<
721    const ENABLE_LEFT_TO_RIGHT: bool,
722    const ENABLE_RIGHT_TO_LEFT: bool,
723>(
724    join_col_keys: &[(&Column, &Column)],
725    input_predicates: &[Expr],
726    inferred_predicates: &mut InferredPredicates,
727) -> Result<()> {
728    for predicate in input_predicates {
729        let mut join_cols_to_replace = HashMap::new();
730
731        for &col in &predicate.column_refs() {
732            for (l, r) in join_col_keys.iter() {
733                if ENABLE_LEFT_TO_RIGHT && col == *l {
734                    join_cols_to_replace.insert(col, *r);
735                    break;
736                }
737                if ENABLE_RIGHT_TO_LEFT && col == *r {
738                    join_cols_to_replace.insert(col, *l);
739                    break;
740                }
741            }
742        }
743        if join_cols_to_replace.is_empty() {
744            continue;
745        }
746
747        inferred_predicates
748            .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
749    }
750    Ok(())
751}
752
753impl OptimizerRule for PushDownFilter {
754    fn name(&self) -> &str {
755        "push_down_filter"
756    }
757
758    fn apply_order(&self) -> Option<ApplyOrder> {
759        Some(ApplyOrder::TopDown)
760    }
761
762    fn supports_rewrite(&self) -> bool {
763        true
764    }
765
766    fn rewrite(
767        &self,
768        plan: LogicalPlan,
769        config: &dyn OptimizerConfig,
770    ) -> Result<Transformed<LogicalPlan>> {
771        let _ = config.options();
772        if let LogicalPlan::Join(join) = plan {
773            return push_down_join(join, None);
774        };
775
776        let plan_schema = Arc::clone(plan.schema());
777
778        let LogicalPlan::Filter(mut filter) = plan else {
779            return Ok(Transformed::no(plan));
780        };
781
782        let predicate = split_conjunction_owned(filter.predicate.clone());
783        let old_predicate_len = predicate.len();
784        let new_predicates = simplify_predicates(predicate)?;
785        if old_predicate_len != new_predicates.len() {
786            let Some(new_predicate) = conjunction(new_predicates) else {
787                // new_predicates is empty - remove the filter entirely
788                // Return the child plan without the filter
789                return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
790            };
791            filter.predicate = new_predicate;
792        }
793
794        match Arc::unwrap_or_clone(filter.input) {
795            LogicalPlan::Filter(child_filter) => {
796                let parents_predicates = split_conjunction_owned(filter.predicate);
797
798                // remove duplicated filters
799                let child_predicates = split_conjunction_owned(child_filter.predicate);
800                let new_predicates = parents_predicates
801                    .into_iter()
802                    .chain(child_predicates)
803                    // use IndexSet to remove dupes while preserving predicate order
804                    .collect::<IndexSet<_>>()
805                    .into_iter()
806                    .collect::<Vec<_>>();
807
808                let Some(new_predicate) = conjunction(new_predicates) else {
809                    return plan_err!("at least one expression exists");
810                };
811                let new_filter = LogicalPlan::Filter(Filter::try_new(
812                    new_predicate,
813                    child_filter.input,
814                )?);
815                self.rewrite(new_filter, config)
816            }
817            LogicalPlan::Repartition(repartition) => {
818                let new_filter =
819                    Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
820                        .map(LogicalPlan::Filter)?;
821                insert_below(LogicalPlan::Repartition(repartition), new_filter)
822            }
823            LogicalPlan::Distinct(distinct) => {
824                let new_filter =
825                    Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
826                        .map(LogicalPlan::Filter)?;
827                insert_below(LogicalPlan::Distinct(distinct), new_filter)
828            }
829            LogicalPlan::Sort(sort) => {
830                let new_filter =
831                    Filter::try_new(filter.predicate, Arc::clone(&sort.input))
832                        .map(LogicalPlan::Filter)?;
833                insert_below(LogicalPlan::Sort(sort), new_filter)
834            }
835            LogicalPlan::SubqueryAlias(subquery_alias) => {
836                let mut replace_map = HashMap::new();
837                for (i, (qualifier, field)) in
838                    subquery_alias.input.schema().iter().enumerate()
839                {
840                    let (sub_qualifier, sub_field) =
841                        subquery_alias.schema.qualified_field(i);
842                    replace_map.insert(
843                        qualified_name(sub_qualifier, sub_field.name()),
844                        Expr::Column(Column::new(qualifier.cloned(), field.name())),
845                    );
846                }
847                let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
848
849                let new_filter = LogicalPlan::Filter(Filter::try_new(
850                    new_predicate,
851                    Arc::clone(&subquery_alias.input),
852                )?);
853                insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
854            }
855            LogicalPlan::Projection(projection) => {
856                let predicates = split_conjunction_owned(filter.predicate.clone());
857                let (new_projection, keep_predicate) =
858                    rewrite_projection(predicates, projection)?;
859                if new_projection.transformed {
860                    match keep_predicate {
861                        None => Ok(new_projection),
862                        Some(keep_predicate) => new_projection.map_data(|child_plan| {
863                            Filter::try_new(keep_predicate, Arc::new(child_plan))
864                                .map(LogicalPlan::Filter)
865                        }),
866                    }
867                } else {
868                    filter.input = Arc::new(new_projection.data);
869                    Ok(Transformed::no(LogicalPlan::Filter(filter)))
870                }
871            }
872            LogicalPlan::Unnest(mut unnest) => {
873                let predicates = split_conjunction_owned(filter.predicate.clone());
874                let mut non_unnest_predicates = vec![];
875                let mut unnest_predicates = vec![];
876                let mut unnest_struct_columns = vec![];
877
878                for idx in &unnest.struct_type_columns {
879                    let (sub_qualifier, field) =
880                        unnest.input.schema().qualified_field(*idx);
881                    let field_name = field.name().clone();
882
883                    if let DataType::Struct(children) = field.data_type() {
884                        for child in children {
885                            let child_name = child.name().clone();
886                            unnest_struct_columns.push(Column::new(
887                                sub_qualifier.cloned(),
888                                format!("{field_name}.{child_name}"),
889                            ));
890                        }
891                    }
892                }
893
894                for predicate in predicates {
895                    // collect all the Expr::Column in predicate recursively
896                    let mut accum: HashSet<Column> = HashSet::new();
897                    expr_to_columns(&predicate, &mut accum)?;
898
899                    let contains_list_columns =
900                        unnest.list_type_columns.iter().any(|(_, unnest_list)| {
901                            accum.contains(&unnest_list.output_column)
902                        });
903                    let contains_struct_columns =
904                        unnest_struct_columns.iter().any(|c| accum.contains(c));
905
906                    if contains_list_columns || contains_struct_columns {
907                        unnest_predicates.push(predicate);
908                    } else {
909                        non_unnest_predicates.push(predicate);
910                    }
911                }
912
913                // Unnest predicates should not be pushed down.
914                // If no non-unnest predicates exist, early return
915                if non_unnest_predicates.is_empty() {
916                    filter.input = Arc::new(LogicalPlan::Unnest(unnest));
917                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
918                }
919
920                // Push down non-unnest filter predicate
921                // Unnest
922                //   Unnest Input (Projection)
923                // -> rewritten to
924                // Unnest
925                //   Filter
926                //     Unnest Input (Projection)
927
928                let unnest_input = std::mem::take(&mut unnest.input);
929
930                let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
931                    conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty.
932                    unnest_input,
933                )?);
934
935                // Directly assign new filter plan as the new unnest's input.
936                // The new filter plan will go through another rewrite pass since the rule itself
937                // is applied recursively to all the child from top to down
938                let unnest_plan =
939                    insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
940
941                match conjunction(unnest_predicates) {
942                    None => Ok(unnest_plan),
943                    Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
944                        Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
945                    ))),
946                }
947            }
948            LogicalPlan::Union(ref union) => {
949                let mut inputs = Vec::with_capacity(union.inputs.len());
950                for input in &union.inputs {
951                    let mut replace_map = HashMap::new();
952                    for (i, (qualifier, field)) in input.schema().iter().enumerate() {
953                        let (union_qualifier, union_field) =
954                            union.schema.qualified_field(i);
955                        replace_map.insert(
956                            qualified_name(union_qualifier, union_field.name()),
957                            Expr::Column(Column::new(qualifier.cloned(), field.name())),
958                        );
959                    }
960
961                    let push_predicate =
962                        replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
963                    inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
964                        push_predicate,
965                        Arc::clone(input),
966                    )?)))
967                }
968                Ok(Transformed::yes(LogicalPlan::Union(Union {
969                    inputs,
970                    schema: Arc::clone(&plan_schema),
971                })))
972            }
973            LogicalPlan::Aggregate(agg) => {
974                // We can push down Predicate which in groupby_expr.
975                let group_expr_columns = agg
976                    .group_expr
977                    .iter()
978                    .map(|e| {
979                        let (relation, name) = e.qualified_name();
980                        Column::new(relation, name)
981                    })
982                    .collect::<HashSet<_>>();
983
984                let predicates = split_conjunction_owned(filter.predicate);
985
986                let mut keep_predicates = vec![];
987                let mut push_predicates = vec![];
988                for expr in predicates {
989                    let cols = expr.column_refs();
990                    if cols.iter().all(|c| group_expr_columns.contains(c)) {
991                        push_predicates.push(expr);
992                    } else {
993                        keep_predicates.push(expr);
994                    }
995                }
996
997                // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
998                // After push, we need to replace `a+b` with Column(a)+Column(b)
999                // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
1000                let mut replace_map = HashMap::new();
1001                for expr in &agg.group_expr {
1002                    replace_map.insert(expr.schema_name().to_string(), expr.clone());
1003                }
1004                let replaced_push_predicates = push_predicates
1005                    .into_iter()
1006                    .map(|expr| replace_cols_by_name(expr, &replace_map))
1007                    .collect::<Result<Vec<_>>>()?;
1008
1009                let agg_input = Arc::clone(&agg.input);
1010                Transformed::yes(LogicalPlan::Aggregate(agg))
1011                    .transform_data(|new_plan| {
1012                        // If we have a filter to push, we push it down to the input of the aggregate
1013                        if let Some(predicate) = conjunction(replaced_push_predicates) {
1014                            let new_filter = make_filter(predicate, agg_input)?;
1015                            insert_below(new_plan, new_filter)
1016                        } else {
1017                            Ok(Transformed::no(new_plan))
1018                        }
1019                    })?
1020                    .map_data(|child_plan| {
1021                        // if there are any remaining predicates we can't push, add them
1022                        // back as a filter
1023                        if let Some(predicate) = conjunction(keep_predicates) {
1024                            make_filter(predicate, Arc::new(child_plan))
1025                        } else {
1026                            Ok(child_plan)
1027                        }
1028                    })
1029            }
1030            // Tries to push filters based on the partition key(s) of the window function(s) used.
1031            // Example:
1032            //   Before:
1033            //     Filter: (a > 1) and (b > 1) and (c > 1)
1034            //      Window: func() PARTITION BY [a] ...
1035            //   ---
1036            //   After:
1037            //     Filter: (b > 1) and (c > 1)
1038            //      Window: func() PARTITION BY [a] ...
1039            //        Filter: (a > 1)
1040            LogicalPlan::Window(window) => {
1041                // Retrieve the set of potential partition keys where we can push filters by.
1042                // Unlike aggregations, where there is only one statement per SELECT, there can be
1043                // multiple window functions, each with potentially different partition keys.
1044                // Therefore, we need to ensure that any potential partition key returned is used in
1045                // ALL window functions. Otherwise, filters cannot be pushed by through that column.
1046                let extract_partition_keys = |func: &WindowFunction| {
1047                    func.params
1048                        .partition_by
1049                        .iter()
1050                        .map(|c| {
1051                            let (relation, name) = c.qualified_name();
1052                            Column::new(relation, name)
1053                        })
1054                        .collect::<HashSet<_>>()
1055                };
1056                let potential_partition_keys = window
1057                    .window_expr
1058                    .iter()
1059                    .map(|e| {
1060                        match e {
1061                            Expr::WindowFunction(window_func) => {
1062                                extract_partition_keys(window_func)
1063                            }
1064                            Expr::Alias(alias) => {
1065                                if let Expr::WindowFunction(window_func) =
1066                                    alias.expr.as_ref()
1067                                {
1068                                    extract_partition_keys(window_func)
1069                                } else {
1070                                    // window functions expressions are only Expr::WindowFunction
1071                                    unreachable!()
1072                                }
1073                            }
1074                            _ => {
1075                                // window functions expressions are only Expr::WindowFunction
1076                                unreachable!()
1077                            }
1078                        }
1079                    })
1080                    // performs the set intersection of the partition keys of all window functions,
1081                    // returning only the common ones
1082                    .reduce(|a, b| &a & &b)
1083                    .unwrap_or_default();
1084
1085                let predicates = split_conjunction_owned(filter.predicate);
1086                let mut keep_predicates = vec![];
1087                let mut push_predicates = vec![];
1088                for expr in predicates {
1089                    let cols = expr.column_refs();
1090                    if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1091                        push_predicates.push(expr);
1092                    } else {
1093                        keep_predicates.push(expr);
1094                    }
1095                }
1096
1097                // Unlike with aggregations, there are no cases where we have to replace, e.g.,
1098                // `a+b` with Column(a)+Column(b). This is because partition expressions are not
1099                // available as standalone columns to the user. For example, while an aggregation on
1100                // `a+b` becomes Column(a + b), in a window partition it becomes
1101                // `func() PARTITION BY [a + b] ...`. Thus, filters on expressions always remain in
1102                // place, so we can use `push_predicates` directly. This is consistent with other
1103                // optimizers, such as the one used by Postgres.
1104
1105                let window_input = Arc::clone(&window.input);
1106                Transformed::yes(LogicalPlan::Window(window))
1107                    .transform_data(|new_plan| {
1108                        // If we have a filter to push, we push it down to the input of the window
1109                        if let Some(predicate) = conjunction(push_predicates) {
1110                            let new_filter = make_filter(predicate, window_input)?;
1111                            insert_below(new_plan, new_filter)
1112                        } else {
1113                            Ok(Transformed::no(new_plan))
1114                        }
1115                    })?
1116                    .map_data(|child_plan| {
1117                        // if there are any remaining predicates we can't push, add them
1118                        // back as a filter
1119                        if let Some(predicate) = conjunction(keep_predicates) {
1120                            make_filter(predicate, Arc::new(child_plan))
1121                        } else {
1122                            Ok(child_plan)
1123                        }
1124                    })
1125            }
1126            LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1127            LogicalPlan::TableScan(scan) => {
1128                let filter_predicates = split_conjunction(&filter.predicate);
1129
1130                let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1131                    filter_predicates
1132                        .into_iter()
1133                        .partition(|pred| pred.is_volatile());
1134
1135                // Check which non-volatile filters are supported by source
1136                let supported_filters = scan
1137                    .source
1138                    .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1139                assert_eq_or_internal_err!(
1140                    non_volatile_filters.len(),
1141                    supported_filters.len(),
1142                    "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1143                    supported_filters.len(),
1144                    non_volatile_filters.len()
1145                );
1146
1147                // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type
1148                let zip = non_volatile_filters.into_iter().zip(supported_filters);
1149
1150                let new_scan_filters = zip
1151                    .clone()
1152                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1153                    .map(|(pred, _)| pred);
1154
1155                // Add new scan filters
1156                let new_scan_filters: Vec<Expr> = scan
1157                    .filters
1158                    .iter()
1159                    .chain(new_scan_filters)
1160                    .unique()
1161                    .cloned()
1162                    .collect();
1163
1164                // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
1165                let new_predicate: Vec<Expr> = zip
1166                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1167                    .map(|(pred, _)| pred)
1168                    .chain(volatile_filters)
1169                    .cloned()
1170                    .collect();
1171
1172                let new_scan = LogicalPlan::TableScan(TableScan {
1173                    filters: new_scan_filters,
1174                    ..scan
1175                });
1176
1177                Transformed::yes(new_scan).transform_data(|new_scan| {
1178                    if let Some(predicate) = conjunction(new_predicate) {
1179                        make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1180                    } else {
1181                        Ok(Transformed::no(new_scan))
1182                    }
1183                })
1184            }
1185            LogicalPlan::Extension(extension_plan) => {
1186                // This check prevents the Filter from being removed when the extension node has no children,
1187                // so we return the original Filter unchanged.
1188                if extension_plan.node.inputs().is_empty() {
1189                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1190                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1191                }
1192                let prevent_cols =
1193                    extension_plan.node.prevent_predicate_push_down_columns();
1194
1195                // determine if we can push any predicates down past the extension node
1196
1197                // each element is true for push, false to keep
1198                let predicate_push_or_keep = split_conjunction(&filter.predicate)
1199                    .iter()
1200                    .map(|expr| {
1201                        let cols = expr.column_refs();
1202                        if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1203                            Ok(false) // No push (keep)
1204                        } else {
1205                            Ok(true) // push
1206                        }
1207                    })
1208                    .collect::<Result<Vec<_>>>()?;
1209
1210                // all predicates are kept, no changes needed
1211                if predicate_push_or_keep.iter().all(|&x| !x) {
1212                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1213                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1214                }
1215
1216                // going to push some predicates down, so split the predicates
1217                let mut keep_predicates = vec![];
1218                let mut push_predicates = vec![];
1219                for (push, expr) in predicate_push_or_keep
1220                    .into_iter()
1221                    .zip(split_conjunction_owned(filter.predicate).into_iter())
1222                {
1223                    if !push {
1224                        keep_predicates.push(expr);
1225                    } else {
1226                        push_predicates.push(expr);
1227                    }
1228                }
1229
1230                let new_children = match conjunction(push_predicates) {
1231                    Some(predicate) => extension_plan
1232                        .node
1233                        .inputs()
1234                        .into_iter()
1235                        .map(|child| {
1236                            Ok(LogicalPlan::Filter(Filter::try_new(
1237                                predicate.clone(),
1238                                Arc::new(child.clone()),
1239                            )?))
1240                        })
1241                        .collect::<Result<Vec<_>>>()?,
1242                    None => extension_plan.node.inputs().into_iter().cloned().collect(),
1243                };
1244                // extension with new inputs.
1245                let child_plan = LogicalPlan::Extension(extension_plan);
1246                let new_extension =
1247                    child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1248
1249                let new_plan = match conjunction(keep_predicates) {
1250                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1251                        predicate,
1252                        Arc::new(new_extension),
1253                    )?),
1254                    None => new_extension,
1255                };
1256                Ok(Transformed::yes(new_plan))
1257            }
1258            child => {
1259                filter.input = Arc::new(child);
1260                Ok(Transformed::no(LogicalPlan::Filter(filter)))
1261            }
1262        }
1263    }
1264}
1265
1266/// Attempts to push `predicate` into a `FilterExec` below `projection
1267///
1268/// # Returns
1269/// (plan, remaining_predicate)
1270///
1271/// `plan` is a LogicalPlan for `projection` with possibly a new FilterExec below it.
1272/// `remaining_predicate` is any part of the predicate that could not be pushed down
1273///
1274/// # Args
1275/// - predicates: Split predicates like `[foo=5, bar=6]`
1276/// - projection: The target projection plan to push down the predicates
1277///
1278/// # Example
1279///
1280/// Pushing a predicate like `foo=5 AND bar=6` with an input plan like this:
1281///
1282/// ```text
1283/// Projection(foo, c+d as bar)
1284/// ```
1285///
1286/// Might result in returning `remaining_predicate` of `bar=6` and a plan like
1287///
1288/// ```text
1289/// Projection(foo, c+d as bar)
1290///  Filter(foo=5)
1291///   ...
1292/// ```
1293fn rewrite_projection(
1294    predicates: Vec<Expr>,
1295    mut projection: Projection,
1296) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1297    // A projection is filter-commutable if it do not contain volatile predicates or contain volatile
1298    // predicates that are not used in the filter. However, we should re-writes all predicate expressions.
1299    // collect projection.
1300    let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
1301        .schema
1302        .iter()
1303        .zip(projection.expr.iter())
1304        .map(|((qualifier, field), expr)| {
1305            // strip alias, as they should not be part of filters
1306            let expr = expr.clone().unalias();
1307
1308            (qualified_name(qualifier, field.name()), expr)
1309        })
1310        .partition(|(_, value)| value.is_volatile());
1311
1312    let mut push_predicates = vec![];
1313    let mut keep_predicates = vec![];
1314    for expr in predicates {
1315        if contain(&expr, &volatile_map) {
1316            keep_predicates.push(expr);
1317        } else {
1318            push_predicates.push(expr);
1319        }
1320    }
1321
1322    match conjunction(push_predicates) {
1323        Some(expr) => {
1324            // re-write all filters based on this projection
1325            // E.g. in `Filter: b\n  Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
1326            let new_filter = LogicalPlan::Filter(Filter::try_new(
1327                replace_cols_by_name(expr, &non_volatile_map)?,
1328                std::mem::take(&mut projection.input),
1329            )?);
1330
1331            projection.input = Arc::new(new_filter);
1332
1333            Ok((
1334                Transformed::yes(LogicalPlan::Projection(projection)),
1335                conjunction(keep_predicates),
1336            ))
1337        }
1338        None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
1339    }
1340}
1341
1342/// Creates a new LogicalPlan::Filter node.
1343pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1344    Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1345}
1346
1347/// Replace the existing child of the single input node with `new_child`.
1348///
1349/// Starting:
1350/// ```text
1351/// plan
1352///   child
1353/// ```
1354///
1355/// Ending:
1356/// ```text
1357/// plan
1358///   new_child
1359/// ```
1360fn insert_below(
1361    plan: LogicalPlan,
1362    new_child: LogicalPlan,
1363) -> Result<Transformed<LogicalPlan>> {
1364    let mut new_child = Some(new_child);
1365    let transformed_plan = plan.map_children(|_child| {
1366        if let Some(new_child) = new_child.take() {
1367            Ok(Transformed::yes(new_child))
1368        } else {
1369            // already took the new child
1370            internal_err!("node had more than one input")
1371        }
1372    })?;
1373
1374    // make sure we did the actual replacement
1375    assert_or_internal_err!(new_child.is_none(), "node had no inputs");
1376
1377    Ok(transformed_plan)
1378}
1379
1380impl PushDownFilter {
1381    #[expect(missing_docs)]
1382    pub fn new() -> Self {
1383        Self {}
1384    }
1385}
1386
1387/// replaces columns by its name on the projection.
1388pub fn replace_cols_by_name(
1389    e: Expr,
1390    replace_map: &HashMap<String, Expr>,
1391) -> Result<Expr> {
1392    e.transform_up(|expr| {
1393        Ok(if let Expr::Column(c) = &expr {
1394            match replace_map.get(&c.flat_name()) {
1395                Some(new_c) => Transformed::yes(new_c.clone()),
1396                None => Transformed::no(expr),
1397            }
1398        } else {
1399            Transformed::no(expr)
1400        })
1401    })
1402    .data()
1403}
1404
1405/// check whether the expression uses the columns in `check_map`.
1406fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1407    let mut is_contain = false;
1408    e.apply(|expr| {
1409        Ok(if let Expr::Column(c) = &expr {
1410            match check_map.get(&c.flat_name()) {
1411                Some(_) => {
1412                    is_contain = true;
1413                    TreeNodeRecursion::Stop
1414                }
1415                None => TreeNodeRecursion::Continue,
1416            }
1417        } else {
1418            TreeNodeRecursion::Continue
1419        })
1420    })
1421    .unwrap();
1422    is_contain
1423}
1424
1425#[cfg(test)]
1426mod tests {
1427    use std::any::Any;
1428    use std::cmp::Ordering;
1429    use std::fmt::{Debug, Formatter};
1430
1431    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1432    use async_trait::async_trait;
1433
1434    use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1435    use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1436    use datafusion_expr::logical_plan::table_scan;
1437    use datafusion_expr::{
1438        ColumnarValue, ExprFunctionExt, Extension, LogicalPlanBuilder,
1439        ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType,
1440        UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, col, in_list,
1441        in_subquery, lit,
1442    };
1443
1444    use crate::OptimizerContext;
1445    use crate::assert_optimized_plan_eq_snapshot;
1446    use crate::optimizer::Optimizer;
1447    use crate::simplify_expressions::SimplifyExpressions;
1448    use crate::test::*;
1449    use datafusion_expr::test::function_stub::sum;
1450    use insta::assert_snapshot;
1451
1452    use super::*;
1453
1454    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1455
1456    macro_rules! assert_optimized_plan_equal {
1457        (
1458            $plan:expr,
1459            @ $expected:literal $(,)?
1460        ) => {{
1461            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1462            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1463            assert_optimized_plan_eq_snapshot!(
1464                optimizer_ctx,
1465                rules,
1466                $plan,
1467                @ $expected,
1468            )
1469        }};
1470    }
1471
1472    macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1473        (
1474            $plan:expr,
1475            @ $expected:literal $(,)?
1476        ) => {{
1477            let optimizer = Optimizer::with_rules(vec![
1478                Arc::new(SimplifyExpressions::new()),
1479                Arc::new(PushDownFilter::new()),
1480            ]);
1481            let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1482            assert_snapshot!(optimized_plan, @ $expected);
1483            Ok::<(), DataFusionError>(())
1484        }};
1485    }
1486
1487    #[test]
1488    fn filter_before_projection() -> Result<()> {
1489        let table_scan = test_table_scan()?;
1490        let plan = LogicalPlanBuilder::from(table_scan)
1491            .project(vec![col("a"), col("b")])?
1492            .filter(col("a").eq(lit(1i64)))?
1493            .build()?;
1494        // filter is before projection
1495        assert_optimized_plan_equal!(
1496            plan,
1497            @r"
1498        Projection: test.a, test.b
1499          TableScan: test, full_filters=[test.a = Int64(1)]
1500        "
1501        )
1502    }
1503
1504    #[test]
1505    fn filter_after_limit() -> Result<()> {
1506        let table_scan = test_table_scan()?;
1507        let plan = LogicalPlanBuilder::from(table_scan)
1508            .project(vec![col("a"), col("b")])?
1509            .limit(0, Some(10))?
1510            .filter(col("a").eq(lit(1i64)))?
1511            .build()?;
1512        // filter is before single projection
1513        assert_optimized_plan_equal!(
1514            plan,
1515            @r"
1516        Filter: test.a = Int64(1)
1517          Limit: skip=0, fetch=10
1518            Projection: test.a, test.b
1519              TableScan: test
1520        "
1521        )
1522    }
1523
1524    #[test]
1525    fn filter_no_columns() -> Result<()> {
1526        let table_scan = test_table_scan()?;
1527        let plan = LogicalPlanBuilder::from(table_scan)
1528            .filter(lit(0i64).eq(lit(1i64)))?
1529            .build()?;
1530        assert_optimized_plan_equal!(
1531            plan,
1532            @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1533        )
1534    }
1535
1536    #[test]
1537    fn filter_jump_2_plans() -> Result<()> {
1538        let table_scan = test_table_scan()?;
1539        let plan = LogicalPlanBuilder::from(table_scan)
1540            .project(vec![col("a"), col("b"), col("c")])?
1541            .project(vec![col("c"), col("b")])?
1542            .filter(col("a").eq(lit(1i64)))?
1543            .build()?;
1544        // filter is before double projection
1545        assert_optimized_plan_equal!(
1546            plan,
1547            @r"
1548        Projection: test.c, test.b
1549          Projection: test.a, test.b, test.c
1550            TableScan: test, full_filters=[test.a = Int64(1)]
1551        "
1552        )
1553    }
1554
1555    #[test]
1556    fn filter_move_agg() -> Result<()> {
1557        let table_scan = test_table_scan()?;
1558        let plan = LogicalPlanBuilder::from(table_scan)
1559            .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1560            .filter(col("a").gt(lit(10i64)))?
1561            .build()?;
1562        // filter of key aggregation is commutative
1563        assert_optimized_plan_equal!(
1564            plan,
1565            @r"
1566        Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1567          TableScan: test, full_filters=[test.a > Int64(10)]
1568        "
1569        )
1570    }
1571
1572    /// verifies that filters with unusual column names are pushed down through aggregate operators
1573    #[test]
1574    fn filter_move_agg_special() -> Result<()> {
1575        let schema = Schema::new(vec![
1576            Field::new("$a", DataType::UInt32, false),
1577            Field::new("$b", DataType::UInt32, false),
1578            Field::new("$c", DataType::UInt32, false),
1579        ]);
1580        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1581
1582        let plan = LogicalPlanBuilder::from(table_scan)
1583            .aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])?
1584            .filter(col("$a").gt(lit(10i64)))?
1585            .build()?;
1586        // filter of key aggregation is commutative
1587        assert_optimized_plan_equal!(
1588            plan,
1589            @r"
1590        Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]]
1591          TableScan: test, full_filters=[test.$a > Int64(10)]
1592        "
1593        )
1594    }
1595
1596    #[test]
1597    fn filter_complex_group_by() -> Result<()> {
1598        let table_scan = test_table_scan()?;
1599        let plan = LogicalPlanBuilder::from(table_scan)
1600            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1601            .filter(col("b").gt(lit(10i64)))?
1602            .build()?;
1603        assert_optimized_plan_equal!(
1604            plan,
1605            @r"
1606        Filter: test.b > Int64(10)
1607          Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1608            TableScan: test
1609        "
1610        )
1611    }
1612
1613    #[test]
1614    fn push_agg_need_replace_expr() -> Result<()> {
1615        let plan = LogicalPlanBuilder::from(test_table_scan()?)
1616            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1617            .filter(col("test.b + test.a").gt(lit(10i64)))?
1618            .build()?;
1619        assert_optimized_plan_equal!(
1620            plan,
1621            @r"
1622        Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1623          TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1624        "
1625        )
1626    }
1627
1628    #[test]
1629    fn filter_keep_agg() -> Result<()> {
1630        let table_scan = test_table_scan()?;
1631        let plan = LogicalPlanBuilder::from(table_scan)
1632            .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1633            .filter(col("b").gt(lit(10i64)))?
1634            .build()?;
1635        // filter of aggregate is after aggregation since they are non-commutative
1636        assert_optimized_plan_equal!(
1637            plan,
1638            @r"
1639        Filter: b > Int64(10)
1640          Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1641            TableScan: test
1642        "
1643        )
1644    }
1645
1646    /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed
1647    #[test]
1648    fn filter_move_window() -> Result<()> {
1649        let table_scan = test_table_scan()?;
1650
1651        let window = Expr::from(WindowFunction::new(
1652            WindowFunctionDefinition::WindowUDF(
1653                datafusion_functions_window::rank::rank_udwf(),
1654            ),
1655            vec![],
1656        ))
1657        .partition_by(vec![col("a"), col("b")])
1658        .order_by(vec![col("c").sort(true, true)])
1659        .build()
1660        .unwrap();
1661
1662        let plan = LogicalPlanBuilder::from(table_scan)
1663            .window(vec![window])?
1664            .filter(col("b").gt(lit(10i64)))?
1665            .build()?;
1666
1667        assert_optimized_plan_equal!(
1668            plan,
1669            @r"
1670        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1671          TableScan: test, full_filters=[test.b > Int64(10)]
1672        "
1673        )
1674    }
1675
1676    /// verifies that filters with unusual identifier names are pushed down through window functions
1677    #[test]
1678    fn filter_window_special_identifier() -> Result<()> {
1679        let schema = Schema::new(vec![
1680            Field::new("$a", DataType::UInt32, false),
1681            Field::new("$b", DataType::UInt32, false),
1682            Field::new("$c", DataType::UInt32, false),
1683        ]);
1684        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1685
1686        let window = Expr::from(WindowFunction::new(
1687            WindowFunctionDefinition::WindowUDF(
1688                datafusion_functions_window::rank::rank_udwf(),
1689            ),
1690            vec![],
1691        ))
1692        .partition_by(vec![col("$a"), col("$b")])
1693        .order_by(vec![col("$c").sort(true, true)])
1694        .build()
1695        .unwrap();
1696
1697        let plan = LogicalPlanBuilder::from(table_scan)
1698            .window(vec![window])?
1699            .filter(col("$b").gt(lit(10i64)))?
1700            .build()?;
1701
1702        assert_optimized_plan_equal!(
1703            plan,
1704            @r"
1705        WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1706          TableScan: test, full_filters=[test.$b > Int64(10)]
1707        "
1708        )
1709    }
1710
1711    /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and
1712    /// 'b' are pushed
1713    #[test]
1714    fn filter_move_complex_window() -> Result<()> {
1715        let table_scan = test_table_scan()?;
1716
1717        let window = Expr::from(WindowFunction::new(
1718            WindowFunctionDefinition::WindowUDF(
1719                datafusion_functions_window::rank::rank_udwf(),
1720            ),
1721            vec![],
1722        ))
1723        .partition_by(vec![col("a"), col("b")])
1724        .order_by(vec![col("c").sort(true, true)])
1725        .build()
1726        .unwrap();
1727
1728        let plan = LogicalPlanBuilder::from(table_scan)
1729            .window(vec![window])?
1730            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1731            .build()?;
1732
1733        assert_optimized_plan_equal!(
1734            plan,
1735            @r"
1736        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1737          TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1738        "
1739        )
1740    }
1741
1742    /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed
1743    #[test]
1744    fn filter_move_partial_window() -> Result<()> {
1745        let table_scan = test_table_scan()?;
1746
1747        let window = Expr::from(WindowFunction::new(
1748            WindowFunctionDefinition::WindowUDF(
1749                datafusion_functions_window::rank::rank_udwf(),
1750            ),
1751            vec![],
1752        ))
1753        .partition_by(vec![col("a")])
1754        .order_by(vec![col("c").sort(true, true)])
1755        .build()
1756        .unwrap();
1757
1758        let plan = LogicalPlanBuilder::from(table_scan)
1759            .window(vec![window])?
1760            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1761            .build()?;
1762
1763        assert_optimized_plan_equal!(
1764            plan,
1765            @r"
1766        Filter: test.b = Int64(1)
1767          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1768            TableScan: test, full_filters=[test.a > Int64(10)]
1769        "
1770        )
1771    }
1772
1773    /// verifies that filters on partition expressions are not pushed, as the single expression
1774    /// column is not available to the user, unlike with aggregations
1775    #[test]
1776    fn filter_expression_keep_window() -> Result<()> {
1777        let table_scan = test_table_scan()?;
1778
1779        let window = Expr::from(WindowFunction::new(
1780            WindowFunctionDefinition::WindowUDF(
1781                datafusion_functions_window::rank::rank_udwf(),
1782            ),
1783            vec![],
1784        ))
1785        .partition_by(vec![add(col("a"), col("b"))]) // PARTITION BY a + b
1786        .order_by(vec![col("c").sort(true, true)])
1787        .build()
1788        .unwrap();
1789
1790        let plan = LogicalPlanBuilder::from(table_scan)
1791            .window(vec![window])?
1792            // unlike with aggregations, single partition column "test.a + test.b" is not available
1793            // to the plan, so we use multiple columns when filtering
1794            .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1795            .build()?;
1796
1797        assert_optimized_plan_equal!(
1798            plan,
1799            @r"
1800        Filter: test.a + test.b > Int64(10)
1801          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1802            TableScan: test
1803        "
1804        )
1805    }
1806
1807    /// verifies that filters are not pushed on order by columns (that are not used in partitioning)
1808    #[test]
1809    fn filter_order_keep_window() -> Result<()> {
1810        let table_scan = test_table_scan()?;
1811
1812        let window = Expr::from(WindowFunction::new(
1813            WindowFunctionDefinition::WindowUDF(
1814                datafusion_functions_window::rank::rank_udwf(),
1815            ),
1816            vec![],
1817        ))
1818        .partition_by(vec![col("a")])
1819        .order_by(vec![col("c").sort(true, true)])
1820        .build()
1821        .unwrap();
1822
1823        let plan = LogicalPlanBuilder::from(table_scan)
1824            .window(vec![window])?
1825            .filter(col("c").gt(lit(10i64)))?
1826            .build()?;
1827
1828        assert_optimized_plan_equal!(
1829            plan,
1830            @r"
1831        Filter: test.c > Int64(10)
1832          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1833            TableScan: test
1834        "
1835        )
1836    }
1837
1838    /// verifies that when we use multiple window functions with a common partition key, the filter
1839    /// on that key is pushed
1840    #[test]
1841    fn filter_multiple_windows_common_partitions() -> Result<()> {
1842        let table_scan = test_table_scan()?;
1843
1844        let window1 = Expr::from(WindowFunction::new(
1845            WindowFunctionDefinition::WindowUDF(
1846                datafusion_functions_window::rank::rank_udwf(),
1847            ),
1848            vec![],
1849        ))
1850        .partition_by(vec![col("a")])
1851        .order_by(vec![col("c").sort(true, true)])
1852        .build()
1853        .unwrap();
1854
1855        let window2 = Expr::from(WindowFunction::new(
1856            WindowFunctionDefinition::WindowUDF(
1857                datafusion_functions_window::rank::rank_udwf(),
1858            ),
1859            vec![],
1860        ))
1861        .partition_by(vec![col("b"), col("a")])
1862        .order_by(vec![col("c").sort(true, true)])
1863        .build()
1864        .unwrap();
1865
1866        let plan = LogicalPlanBuilder::from(table_scan)
1867            .window(vec![window1, window2])?
1868            .filter(col("a").gt(lit(10i64)))? // a appears in both window functions
1869            .build()?;
1870
1871        assert_optimized_plan_equal!(
1872            plan,
1873            @r"
1874        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1875          TableScan: test, full_filters=[test.a > Int64(10)]
1876        "
1877        )
1878    }
1879
1880    /// verifies that when we use multiple window functions with different partitions keys, the
1881    /// filter cannot be pushed
1882    #[test]
1883    fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1884        let table_scan = test_table_scan()?;
1885
1886        let window1 = Expr::from(WindowFunction::new(
1887            WindowFunctionDefinition::WindowUDF(
1888                datafusion_functions_window::rank::rank_udwf(),
1889            ),
1890            vec![],
1891        ))
1892        .partition_by(vec![col("a")])
1893        .order_by(vec![col("c").sort(true, true)])
1894        .build()
1895        .unwrap();
1896
1897        let window2 = Expr::from(WindowFunction::new(
1898            WindowFunctionDefinition::WindowUDF(
1899                datafusion_functions_window::rank::rank_udwf(),
1900            ),
1901            vec![],
1902        ))
1903        .partition_by(vec![col("b"), col("a")])
1904        .order_by(vec![col("c").sort(true, true)])
1905        .build()
1906        .unwrap();
1907
1908        let plan = LogicalPlanBuilder::from(table_scan)
1909            .window(vec![window1, window2])?
1910            .filter(col("b").gt(lit(10i64)))? // b only appears in one window function
1911            .build()?;
1912
1913        assert_optimized_plan_equal!(
1914            plan,
1915            @r"
1916        Filter: test.b > Int64(10)
1917          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1918            TableScan: test
1919        "
1920        )
1921    }
1922
1923    /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
1924    #[test]
1925    fn alias() -> Result<()> {
1926        let table_scan = test_table_scan()?;
1927        let plan = LogicalPlanBuilder::from(table_scan)
1928            .project(vec![col("a").alias("b"), col("c")])?
1929            .filter(col("b").eq(lit(1i64)))?
1930            .build()?;
1931        // filter is before projection
1932        assert_optimized_plan_equal!(
1933            plan,
1934            @r"
1935        Projection: test.a AS b, test.c
1936          TableScan: test, full_filters=[test.a = Int64(1)]
1937        "
1938        )
1939    }
1940
1941    fn add(left: Expr, right: Expr) -> Expr {
1942        Expr::BinaryExpr(BinaryExpr::new(
1943            Box::new(left),
1944            Operator::Plus,
1945            Box::new(right),
1946        ))
1947    }
1948
1949    fn multiply(left: Expr, right: Expr) -> Expr {
1950        Expr::BinaryExpr(BinaryExpr::new(
1951            Box::new(left),
1952            Operator::Multiply,
1953            Box::new(right),
1954        ))
1955    }
1956
1957    /// verifies that a filter is pushed to before a projection with a complex expression, the filter expression is correctly re-written
1958    #[test]
1959    fn complex_expression() -> Result<()> {
1960        let table_scan = test_table_scan()?;
1961        let plan = LogicalPlanBuilder::from(table_scan)
1962            .project(vec![
1963                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1964                col("c"),
1965            ])?
1966            .filter(col("b").eq(lit(1i64)))?
1967            .build()?;
1968
1969        // not part of the test, just good to know:
1970        assert_snapshot!(plan,
1971        @r"
1972        Filter: b = Int64(1)
1973          Projection: test.a * Int32(2) + test.c AS b, test.c
1974            TableScan: test
1975        ",
1976        );
1977        // filter is before projection
1978        assert_optimized_plan_equal!(
1979            plan,
1980            @r"
1981        Projection: test.a * Int32(2) + test.c AS b, test.c
1982          TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1983        "
1984        )
1985    }
1986
1987    /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written
1988    #[test]
1989    fn complex_plan() -> Result<()> {
1990        let table_scan = test_table_scan()?;
1991        let plan = LogicalPlanBuilder::from(table_scan)
1992            .project(vec![
1993                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1994                col("c"),
1995            ])?
1996            // second projection where we rename columns, just to make it difficult
1997            .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
1998            .filter(col("a").eq(lit(1i64)))?
1999            .build()?;
2000
2001        // not part of the test, just good to know:
2002        assert_snapshot!(plan,
2003        @r"
2004        Filter: a = Int64(1)
2005          Projection: b * Int32(3) AS a, test.c
2006            Projection: test.a * Int32(2) + test.c AS b, test.c
2007              TableScan: test
2008        ",
2009        );
2010        // filter is before the projections
2011        assert_optimized_plan_equal!(
2012            plan,
2013            @r"
2014        Projection: b * Int32(3) AS a, test.c
2015          Projection: test.a * Int32(2) + test.c AS b, test.c
2016            TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
2017        "
2018        )
2019    }
2020
2021    #[derive(Debug, PartialEq, Eq, Hash)]
2022    struct NoopPlan {
2023        input: Vec<LogicalPlan>,
2024        schema: DFSchemaRef,
2025    }
2026
2027    // Manual implementation needed because of `schema` field. Comparison excludes this field.
2028    impl PartialOrd for NoopPlan {
2029        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2030            self.input
2031                .partial_cmp(&other.input)
2032                // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
2033                .filter(|cmp| *cmp != Ordering::Equal || self == other)
2034        }
2035    }
2036
2037    impl UserDefinedLogicalNodeCore for NoopPlan {
2038        fn name(&self) -> &str {
2039            "NoopPlan"
2040        }
2041
2042        fn inputs(&self) -> Vec<&LogicalPlan> {
2043            self.input.iter().collect()
2044        }
2045
2046        fn schema(&self) -> &DFSchemaRef {
2047            &self.schema
2048        }
2049
2050        fn expressions(&self) -> Vec<Expr> {
2051            self.input
2052                .iter()
2053                .flat_map(|child| child.expressions())
2054                .collect()
2055        }
2056
2057        fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
2058            HashSet::from_iter(vec!["c".to_string()])
2059        }
2060
2061        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
2062            write!(f, "NoopPlan")
2063        }
2064
2065        fn with_exprs_and_inputs(
2066            &self,
2067            _exprs: Vec<Expr>,
2068            inputs: Vec<LogicalPlan>,
2069        ) -> Result<Self> {
2070            Ok(Self {
2071                input: inputs,
2072                schema: Arc::clone(&self.schema),
2073            })
2074        }
2075
2076        fn supports_limit_pushdown(&self) -> bool {
2077            false // Disallow limit push-down by default
2078        }
2079    }
2080
2081    #[test]
2082    fn user_defined_plan() -> Result<()> {
2083        let table_scan = test_table_scan()?;
2084
2085        let custom_plan = LogicalPlan::Extension(Extension {
2086            node: Arc::new(NoopPlan {
2087                input: vec![table_scan.clone()],
2088                schema: Arc::clone(table_scan.schema()),
2089            }),
2090        });
2091        let plan = LogicalPlanBuilder::from(custom_plan)
2092            .filter(col("a").eq(lit(1i64)))?
2093            .build()?;
2094
2095        // Push filter below NoopPlan
2096        assert_optimized_plan_equal!(
2097            plan,
2098            @r"
2099        NoopPlan
2100          TableScan: test, full_filters=[test.a = Int64(1)]
2101        "
2102        )?;
2103
2104        let custom_plan = LogicalPlan::Extension(Extension {
2105            node: Arc::new(NoopPlan {
2106                input: vec![table_scan.clone()],
2107                schema: Arc::clone(table_scan.schema()),
2108            }),
2109        });
2110        let plan = LogicalPlanBuilder::from(custom_plan)
2111            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2112            .build()?;
2113
2114        // Push only predicate on `a` below NoopPlan
2115        assert_optimized_plan_equal!(
2116            plan,
2117            @r"
2118        Filter: test.c = Int64(2)
2119          NoopPlan
2120            TableScan: test, full_filters=[test.a = Int64(1)]
2121        "
2122        )?;
2123
2124        let custom_plan = LogicalPlan::Extension(Extension {
2125            node: Arc::new(NoopPlan {
2126                input: vec![table_scan.clone(), table_scan.clone()],
2127                schema: Arc::clone(table_scan.schema()),
2128            }),
2129        });
2130        let plan = LogicalPlanBuilder::from(custom_plan)
2131            .filter(col("a").eq(lit(1i64)))?
2132            .build()?;
2133
2134        // Push filter below NoopPlan for each child branch
2135        assert_optimized_plan_equal!(
2136            plan,
2137            @r"
2138        NoopPlan
2139          TableScan: test, full_filters=[test.a = Int64(1)]
2140          TableScan: test, full_filters=[test.a = Int64(1)]
2141        "
2142        )?;
2143
2144        let custom_plan = LogicalPlan::Extension(Extension {
2145            node: Arc::new(NoopPlan {
2146                input: vec![table_scan.clone(), table_scan.clone()],
2147                schema: Arc::clone(table_scan.schema()),
2148            }),
2149        });
2150        let plan = LogicalPlanBuilder::from(custom_plan)
2151            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2152            .build()?;
2153
2154        // Push only predicate on `a` below NoopPlan
2155        assert_optimized_plan_equal!(
2156            plan,
2157            @r"
2158        Filter: test.c = Int64(2)
2159          NoopPlan
2160            TableScan: test, full_filters=[test.a = Int64(1)]
2161            TableScan: test, full_filters=[test.a = Int64(1)]
2162        "
2163        )
2164    }
2165
2166    /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed
2167    /// and the other not.
2168    #[test]
2169    fn multi_filter() -> Result<()> {
2170        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2171        let table_scan = test_table_scan()?;
2172        let plan = LogicalPlanBuilder::from(table_scan)
2173            .project(vec![col("a").alias("b"), col("c")])?
2174            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2175            .filter(col("b").gt(lit(10i64)))?
2176            .filter(col("sum(test.c)").gt(lit(10i64)))?
2177            .build()?;
2178
2179        // not part of the test, just good to know:
2180        assert_snapshot!(plan,
2181        @r"
2182        Filter: sum(test.c) > Int64(10)
2183          Filter: b > Int64(10)
2184            Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2185              Projection: test.a AS b, test.c
2186                TableScan: test
2187        ",
2188        );
2189        // filter is before the projections
2190        assert_optimized_plan_equal!(
2191            plan,
2192            @r"
2193        Filter: sum(test.c) > Int64(10)
2194          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2195            Projection: test.a AS b, test.c
2196              TableScan: test, full_filters=[test.a > Int64(10)]
2197        "
2198        )
2199    }
2200
2201    /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed
2202    /// and the other not.
2203    #[test]
2204    fn split_filter() -> Result<()> {
2205        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2206        let table_scan = test_table_scan()?;
2207        let plan = LogicalPlanBuilder::from(table_scan)
2208            .project(vec![col("a").alias("b"), col("c")])?
2209            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2210            .filter(and(
2211                col("sum(test.c)").gt(lit(10i64)),
2212                and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2213            ))?
2214            .build()?;
2215
2216        // not part of the test, just good to know:
2217        assert_snapshot!(plan,
2218        @r"
2219        Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2220          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2221            Projection: test.a AS b, test.c
2222              TableScan: test
2223        ",
2224        );
2225        // filter is before the projections
2226        assert_optimized_plan_equal!(
2227            plan,
2228            @r"
2229        Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2230          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2231            Projection: test.a AS b, test.c
2232              TableScan: test, full_filters=[test.a > Int64(10)]
2233        "
2234        )
2235    }
2236
2237    /// verifies that when two limits are in place, we jump neither
2238    #[test]
2239    fn double_limit() -> Result<()> {
2240        let table_scan = test_table_scan()?;
2241        let plan = LogicalPlanBuilder::from(table_scan)
2242            .project(vec![col("a"), col("b")])?
2243            .limit(0, Some(20))?
2244            .limit(0, Some(10))?
2245            .project(vec![col("a"), col("b")])?
2246            .filter(col("a").eq(lit(1i64)))?
2247            .build()?;
2248        // filter does not just any of the limits
2249        assert_optimized_plan_equal!(
2250            plan,
2251            @r"
2252        Projection: test.a, test.b
2253          Filter: test.a = Int64(1)
2254            Limit: skip=0, fetch=10
2255              Limit: skip=0, fetch=20
2256                Projection: test.a, test.b
2257                  TableScan: test
2258        "
2259        )
2260    }
2261
2262    #[test]
2263    fn union_all() -> Result<()> {
2264        let table_scan = test_table_scan()?;
2265        let table_scan2 = test_table_scan_with_name("test2")?;
2266        let plan = LogicalPlanBuilder::from(table_scan)
2267            .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2268            .filter(col("a").eq(lit(1i64)))?
2269            .build()?;
2270        // filter appears below Union
2271        assert_optimized_plan_equal!(
2272            plan,
2273            @r"
2274        Union
2275          TableScan: test, full_filters=[test.a = Int64(1)]
2276          TableScan: test2, full_filters=[test2.a = Int64(1)]
2277        "
2278        )
2279    }
2280
2281    #[test]
2282    fn union_all_on_projection() -> Result<()> {
2283        let table_scan = test_table_scan()?;
2284        let table = LogicalPlanBuilder::from(table_scan)
2285            .project(vec![col("a").alias("b")])?
2286            .alias("test2")?;
2287
2288        let plan = table
2289            .clone()
2290            .union(table.build()?)?
2291            .filter(col("b").eq(lit(1i64)))?
2292            .build()?;
2293
2294        // filter appears below Union
2295        assert_optimized_plan_equal!(
2296            plan,
2297            @r"
2298        Union
2299          SubqueryAlias: test2
2300            Projection: test.a AS b
2301              TableScan: test, full_filters=[test.a = Int64(1)]
2302          SubqueryAlias: test2
2303            Projection: test.a AS b
2304              TableScan: test, full_filters=[test.a = Int64(1)]
2305        "
2306        )
2307    }
2308
2309    #[test]
2310    fn test_union_different_schema() -> Result<()> {
2311        let left = LogicalPlanBuilder::from(test_table_scan()?)
2312            .project(vec![col("a"), col("b"), col("c")])?
2313            .build()?;
2314
2315        let schema = Schema::new(vec![
2316            Field::new("d", DataType::UInt32, false),
2317            Field::new("e", DataType::UInt32, false),
2318            Field::new("f", DataType::UInt32, false),
2319        ]);
2320        let right = table_scan(Some("test1"), &schema, None)?
2321            .project(vec![col("d"), col("e"), col("f")])?
2322            .build()?;
2323        let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2324        let plan = LogicalPlanBuilder::from(left)
2325            .cross_join(right)?
2326            .project(vec![col("test.a"), col("test1.d")])?
2327            .filter(filter)?
2328            .build()?;
2329
2330        assert_optimized_plan_equal!(
2331            plan,
2332            @r"
2333        Projection: test.a, test1.d
2334          Cross Join: 
2335            Projection: test.a, test.b, test.c
2336              TableScan: test, full_filters=[test.a = Int32(1)]
2337            Projection: test1.d, test1.e, test1.f
2338              TableScan: test1, full_filters=[test1.d > Int32(2)]
2339        "
2340        )
2341    }
2342
2343    #[test]
2344    fn test_project_same_name_different_qualifier() -> Result<()> {
2345        let table_scan = test_table_scan()?;
2346        let left = LogicalPlanBuilder::from(table_scan)
2347            .project(vec![col("a"), col("b"), col("c")])?
2348            .build()?;
2349        let right_table_scan = test_table_scan_with_name("test1")?;
2350        let right = LogicalPlanBuilder::from(right_table_scan)
2351            .project(vec![col("a"), col("b"), col("c")])?
2352            .build()?;
2353        let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2354        let plan = LogicalPlanBuilder::from(left)
2355            .cross_join(right)?
2356            .project(vec![col("test.a"), col("test1.a")])?
2357            .filter(filter)?
2358            .build()?;
2359
2360        assert_optimized_plan_equal!(
2361            plan,
2362            @r"
2363        Projection: test.a, test1.a
2364          Cross Join: 
2365            Projection: test.a, test.b, test.c
2366              TableScan: test, full_filters=[test.a = Int32(1)]
2367            Projection: test1.a, test1.b, test1.c
2368              TableScan: test1, full_filters=[test1.a > Int32(2)]
2369        "
2370        )
2371    }
2372
2373    /// verifies that filters with the same columns are correctly placed
2374    #[test]
2375    fn filter_2_breaks_limits() -> Result<()> {
2376        let table_scan = test_table_scan()?;
2377        let plan = LogicalPlanBuilder::from(table_scan)
2378            .project(vec![col("a")])?
2379            .filter(col("a").lt_eq(lit(1i64)))?
2380            .limit(0, Some(1))?
2381            .project(vec![col("a")])?
2382            .filter(col("a").gt_eq(lit(1i64)))?
2383            .build()?;
2384        // Should be able to move both filters below the projections
2385
2386        // not part of the test
2387        assert_snapshot!(plan,
2388        @r"
2389        Filter: test.a >= Int64(1)
2390          Projection: test.a
2391            Limit: skip=0, fetch=1
2392              Filter: test.a <= Int64(1)
2393                Projection: test.a
2394                  TableScan: test
2395        ",
2396        );
2397        assert_optimized_plan_equal!(
2398            plan,
2399            @r"
2400        Projection: test.a
2401          Filter: test.a >= Int64(1)
2402            Limit: skip=0, fetch=1
2403              Projection: test.a
2404                TableScan: test, full_filters=[test.a <= Int64(1)]
2405        "
2406        )
2407    }
2408
2409    /// verifies that filters to be placed on the same depth are ANDed
2410    #[test]
2411    fn two_filters_on_same_depth() -> Result<()> {
2412        let table_scan = test_table_scan()?;
2413        let plan = LogicalPlanBuilder::from(table_scan)
2414            .limit(0, Some(1))?
2415            .filter(col("a").lt_eq(lit(1i64)))?
2416            .filter(col("a").gt_eq(lit(1i64)))?
2417            .project(vec![col("a")])?
2418            .build()?;
2419
2420        // not part of the test
2421        assert_snapshot!(plan,
2422        @r"
2423        Projection: test.a
2424          Filter: test.a >= Int64(1)
2425            Filter: test.a <= Int64(1)
2426              Limit: skip=0, fetch=1
2427                TableScan: test
2428        ",
2429        );
2430        assert_optimized_plan_equal!(
2431            plan,
2432            @r"
2433        Projection: test.a
2434          Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2435            Limit: skip=0, fetch=1
2436              TableScan: test
2437        "
2438        )
2439    }
2440
2441    /// verifies that filters on a plan with user nodes are not lost
2442    /// (ARROW-10547)
2443    #[test]
2444    fn filters_user_defined_node() -> Result<()> {
2445        let table_scan = test_table_scan()?;
2446        let plan = LogicalPlanBuilder::from(table_scan)
2447            .filter(col("a").lt_eq(lit(1i64)))?
2448            .build()?;
2449
2450        let plan = user_defined::new(plan);
2451
2452        // not part of the test
2453        assert_snapshot!(plan,
2454        @r"
2455        TestUserDefined
2456          Filter: test.a <= Int64(1)
2457            TableScan: test
2458        ",
2459        );
2460        assert_optimized_plan_equal!(
2461            plan,
2462            @r"
2463        TestUserDefined
2464          TableScan: test, full_filters=[test.a <= Int64(1)]
2465        "
2466        )
2467    }
2468
2469    /// post-on-join predicates on a column common to both sides is pushed to both sides
2470    #[test]
2471    fn filter_on_join_on_common_independent() -> Result<()> {
2472        let table_scan = test_table_scan()?;
2473        let left = LogicalPlanBuilder::from(table_scan).build()?;
2474        let right_table_scan = test_table_scan_with_name("test2")?;
2475        let right = LogicalPlanBuilder::from(right_table_scan)
2476            .project(vec![col("a")])?
2477            .build()?;
2478        let plan = LogicalPlanBuilder::from(left)
2479            .join(
2480                right,
2481                JoinType::Inner,
2482                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2483                None,
2484            )?
2485            .filter(col("test.a").lt_eq(lit(1i64)))?
2486            .build()?;
2487
2488        // not part of the test, just good to know:
2489        assert_snapshot!(plan,
2490        @r"
2491        Filter: test.a <= Int64(1)
2492          Inner Join: test.a = test2.a
2493            TableScan: test
2494            Projection: test2.a
2495              TableScan: test2
2496        ",
2497        );
2498        // filter sent to side before the join
2499        assert_optimized_plan_equal!(
2500            plan,
2501            @r"
2502        Inner Join: test.a = test2.a
2503          TableScan: test, full_filters=[test.a <= Int64(1)]
2504          Projection: test2.a
2505            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2506        "
2507        )
2508    }
2509
2510    /// post-using-join predicates on a column common to both sides is pushed to both sides
2511    #[test]
2512    fn filter_using_join_on_common_independent() -> Result<()> {
2513        let table_scan = test_table_scan()?;
2514        let left = LogicalPlanBuilder::from(table_scan).build()?;
2515        let right_table_scan = test_table_scan_with_name("test2")?;
2516        let right = LogicalPlanBuilder::from(right_table_scan)
2517            .project(vec![col("a")])?
2518            .build()?;
2519        let plan = LogicalPlanBuilder::from(left)
2520            .join_using(
2521                right,
2522                JoinType::Inner,
2523                vec![Column::from_name("a".to_string())],
2524            )?
2525            .filter(col("a").lt_eq(lit(1i64)))?
2526            .build()?;
2527
2528        // not part of the test, just good to know:
2529        assert_snapshot!(plan,
2530        @r"
2531        Filter: test.a <= Int64(1)
2532          Inner Join: Using test.a = test2.a
2533            TableScan: test
2534            Projection: test2.a
2535              TableScan: test2
2536        ",
2537        );
2538        // filter sent to side before the join
2539        assert_optimized_plan_equal!(
2540            plan,
2541            @r"
2542        Inner Join: Using test.a = test2.a
2543          TableScan: test, full_filters=[test.a <= Int64(1)]
2544          Projection: test2.a
2545            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2546        "
2547        )
2548    }
2549
2550    /// post-join predicates with columns from both sides are converted to join filters
2551    #[test]
2552    fn filter_join_on_common_dependent() -> Result<()> {
2553        let table_scan = test_table_scan()?;
2554        let left = LogicalPlanBuilder::from(table_scan)
2555            .project(vec![col("a"), col("c")])?
2556            .build()?;
2557        let right_table_scan = test_table_scan_with_name("test2")?;
2558        let right = LogicalPlanBuilder::from(right_table_scan)
2559            .project(vec![col("a"), col("b")])?
2560            .build()?;
2561        let plan = LogicalPlanBuilder::from(left)
2562            .join(
2563                right,
2564                JoinType::Inner,
2565                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2566                None,
2567            )?
2568            .filter(col("c").lt_eq(col("b")))?
2569            .build()?;
2570
2571        // not part of the test, just good to know:
2572        assert_snapshot!(plan,
2573        @r"
2574        Filter: test.c <= test2.b
2575          Inner Join: test.a = test2.a
2576            Projection: test.a, test.c
2577              TableScan: test
2578            Projection: test2.a, test2.b
2579              TableScan: test2
2580        ",
2581        );
2582        // Filter is converted to Join Filter
2583        assert_optimized_plan_equal!(
2584            plan,
2585            @r"
2586        Inner Join: test.a = test2.a Filter: test.c <= test2.b
2587          Projection: test.a, test.c
2588            TableScan: test
2589          Projection: test2.a, test2.b
2590            TableScan: test2
2591        "
2592        )
2593    }
2594
2595    /// post-join predicates with columns from one side of a join are pushed only to that side
2596    #[test]
2597    fn filter_join_on_one_side() -> Result<()> {
2598        let table_scan = test_table_scan()?;
2599        let left = LogicalPlanBuilder::from(table_scan)
2600            .project(vec![col("a"), col("b")])?
2601            .build()?;
2602        let table_scan_right = test_table_scan_with_name("test2")?;
2603        let right = LogicalPlanBuilder::from(table_scan_right)
2604            .project(vec![col("a"), col("c")])?
2605            .build()?;
2606
2607        let plan = LogicalPlanBuilder::from(left)
2608            .join(
2609                right,
2610                JoinType::Inner,
2611                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2612                None,
2613            )?
2614            .filter(col("b").lt_eq(lit(1i64)))?
2615            .build()?;
2616
2617        // not part of the test, just good to know:
2618        assert_snapshot!(plan,
2619        @r"
2620        Filter: test.b <= Int64(1)
2621          Inner Join: test.a = test2.a
2622            Projection: test.a, test.b
2623              TableScan: test
2624            Projection: test2.a, test2.c
2625              TableScan: test2
2626        ",
2627        );
2628        assert_optimized_plan_equal!(
2629            plan,
2630            @r"
2631        Inner Join: test.a = test2.a
2632          Projection: test.a, test.b
2633            TableScan: test, full_filters=[test.b <= Int64(1)]
2634          Projection: test2.a, test2.c
2635            TableScan: test2
2636        "
2637        )
2638    }
2639
2640    /// post-join predicates on the right side of a left join are not duplicated
2641    /// TODO: In this case we can sometimes convert the join to an INNER join
2642    #[test]
2643    fn filter_using_left_join() -> Result<()> {
2644        let table_scan = test_table_scan()?;
2645        let left = LogicalPlanBuilder::from(table_scan).build()?;
2646        let right_table_scan = test_table_scan_with_name("test2")?;
2647        let right = LogicalPlanBuilder::from(right_table_scan)
2648            .project(vec![col("a")])?
2649            .build()?;
2650        let plan = LogicalPlanBuilder::from(left)
2651            .join_using(
2652                right,
2653                JoinType::Left,
2654                vec![Column::from_name("a".to_string())],
2655            )?
2656            .filter(col("test2.a").lt_eq(lit(1i64)))?
2657            .build()?;
2658
2659        // not part of the test, just good to know:
2660        assert_snapshot!(plan,
2661        @r"
2662        Filter: test2.a <= Int64(1)
2663          Left Join: Using test.a = test2.a
2664            TableScan: test
2665            Projection: test2.a
2666              TableScan: test2
2667        ",
2668        );
2669        // filter not duplicated nor pushed down - i.e. noop
2670        assert_optimized_plan_equal!(
2671            plan,
2672            @r"
2673        Filter: test2.a <= Int64(1)
2674          Left Join: Using test.a = test2.a
2675            TableScan: test, full_filters=[test.a <= Int64(1)]
2676            Projection: test2.a
2677              TableScan: test2
2678        "
2679        )
2680    }
2681
2682    /// post-join predicates on the left side of a right join are not duplicated
2683    #[test]
2684    fn filter_using_right_join() -> Result<()> {
2685        let table_scan = test_table_scan()?;
2686        let left = LogicalPlanBuilder::from(table_scan).build()?;
2687        let right_table_scan = test_table_scan_with_name("test2")?;
2688        let right = LogicalPlanBuilder::from(right_table_scan)
2689            .project(vec![col("a")])?
2690            .build()?;
2691        let plan = LogicalPlanBuilder::from(left)
2692            .join_using(
2693                right,
2694                JoinType::Right,
2695                vec![Column::from_name("a".to_string())],
2696            )?
2697            .filter(col("test.a").lt_eq(lit(1i64)))?
2698            .build()?;
2699
2700        // not part of the test, just good to know:
2701        assert_snapshot!(plan,
2702        @r"
2703        Filter: test.a <= Int64(1)
2704          Right Join: Using test.a = test2.a
2705            TableScan: test
2706            Projection: test2.a
2707              TableScan: test2
2708        ",
2709        );
2710        // filter not duplicated nor pushed down - i.e. noop
2711        assert_optimized_plan_equal!(
2712            plan,
2713            @r"
2714        Filter: test.a <= Int64(1)
2715          Right Join: Using test.a = test2.a
2716            TableScan: test
2717            Projection: test2.a
2718              TableScan: test2, full_filters=[test2.a <= Int64(1)]
2719        "
2720        )
2721    }
2722
2723    /// post-left-join predicate on a column common to both sides is only pushed to the left side
2724    /// i.e. - not duplicated to the right side
2725    #[test]
2726    fn filter_using_left_join_on_common() -> Result<()> {
2727        let table_scan = test_table_scan()?;
2728        let left = LogicalPlanBuilder::from(table_scan).build()?;
2729        let right_table_scan = test_table_scan_with_name("test2")?;
2730        let right = LogicalPlanBuilder::from(right_table_scan)
2731            .project(vec![col("a")])?
2732            .build()?;
2733        let plan = LogicalPlanBuilder::from(left)
2734            .join_using(
2735                right,
2736                JoinType::Left,
2737                vec![Column::from_name("a".to_string())],
2738            )?
2739            .filter(col("a").lt_eq(lit(1i64)))?
2740            .build()?;
2741
2742        // not part of the test, just good to know:
2743        assert_snapshot!(plan,
2744        @r"
2745        Filter: test.a <= Int64(1)
2746          Left Join: Using test.a = test2.a
2747            TableScan: test
2748            Projection: test2.a
2749              TableScan: test2
2750        ",
2751        );
2752        // filter sent to left side of the join, not the right
2753        assert_optimized_plan_equal!(
2754            plan,
2755            @r"
2756        Left Join: Using test.a = test2.a
2757          TableScan: test, full_filters=[test.a <= Int64(1)]
2758          Projection: test2.a
2759            TableScan: test2
2760        "
2761        )
2762    }
2763
2764    /// post-right-join predicate on a column common to both sides is only pushed to the right side
2765    /// i.e. - not duplicated to the left side.
2766    #[test]
2767    fn filter_using_right_join_on_common() -> Result<()> {
2768        let table_scan = test_table_scan()?;
2769        let left = LogicalPlanBuilder::from(table_scan).build()?;
2770        let right_table_scan = test_table_scan_with_name("test2")?;
2771        let right = LogicalPlanBuilder::from(right_table_scan)
2772            .project(vec![col("a")])?
2773            .build()?;
2774        let plan = LogicalPlanBuilder::from(left)
2775            .join_using(
2776                right,
2777                JoinType::Right,
2778                vec![Column::from_name("a".to_string())],
2779            )?
2780            .filter(col("test2.a").lt_eq(lit(1i64)))?
2781            .build()?;
2782
2783        // not part of the test, just good to know:
2784        assert_snapshot!(plan,
2785        @r"
2786        Filter: test2.a <= Int64(1)
2787          Right Join: Using test.a = test2.a
2788            TableScan: test
2789            Projection: test2.a
2790              TableScan: test2
2791        ",
2792        );
2793        // filter sent to right side of join, not duplicated to the left
2794        assert_optimized_plan_equal!(
2795            plan,
2796            @r"
2797        Right Join: Using test.a = test2.a
2798          TableScan: test
2799          Projection: test2.a
2800            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2801        "
2802        )
2803    }
2804
2805    /// single table predicate parts of ON condition should be pushed to both inputs
2806    #[test]
2807    fn join_on_with_filter() -> Result<()> {
2808        let table_scan = test_table_scan()?;
2809        let left = LogicalPlanBuilder::from(table_scan)
2810            .project(vec![col("a"), col("b"), col("c")])?
2811            .build()?;
2812        let right_table_scan = test_table_scan_with_name("test2")?;
2813        let right = LogicalPlanBuilder::from(right_table_scan)
2814            .project(vec![col("a"), col("b"), col("c")])?
2815            .build()?;
2816        let filter = col("test.c")
2817            .gt(lit(1u32))
2818            .and(col("test.b").lt(col("test2.b")))
2819            .and(col("test2.c").gt(lit(4u32)));
2820        let plan = LogicalPlanBuilder::from(left)
2821            .join(
2822                right,
2823                JoinType::Inner,
2824                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2825                Some(filter),
2826            )?
2827            .build()?;
2828
2829        // not part of the test, just good to know:
2830        assert_snapshot!(plan,
2831        @r"
2832        Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2833          Projection: test.a, test.b, test.c
2834            TableScan: test
2835          Projection: test2.a, test2.b, test2.c
2836            TableScan: test2
2837        ",
2838        );
2839        assert_optimized_plan_equal!(
2840            plan,
2841            @r"
2842        Inner Join: test.a = test2.a Filter: test.b < test2.b
2843          Projection: test.a, test.b, test.c
2844            TableScan: test, full_filters=[test.c > UInt32(1)]
2845          Projection: test2.a, test2.b, test2.c
2846            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2847        "
2848        )
2849    }
2850
2851    /// join filter should be completely removed after pushdown
2852    #[test]
2853    fn join_filter_removed() -> Result<()> {
2854        let table_scan = test_table_scan()?;
2855        let left = LogicalPlanBuilder::from(table_scan)
2856            .project(vec![col("a"), col("b"), col("c")])?
2857            .build()?;
2858        let right_table_scan = test_table_scan_with_name("test2")?;
2859        let right = LogicalPlanBuilder::from(right_table_scan)
2860            .project(vec![col("a"), col("b"), col("c")])?
2861            .build()?;
2862        let filter = col("test.b")
2863            .gt(lit(1u32))
2864            .and(col("test2.c").gt(lit(4u32)));
2865        let plan = LogicalPlanBuilder::from(left)
2866            .join(
2867                right,
2868                JoinType::Inner,
2869                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2870                Some(filter),
2871            )?
2872            .build()?;
2873
2874        // not part of the test, just good to know:
2875        assert_snapshot!(plan,
2876        @r"
2877        Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2878          Projection: test.a, test.b, test.c
2879            TableScan: test
2880          Projection: test2.a, test2.b, test2.c
2881            TableScan: test2
2882        ",
2883        );
2884        assert_optimized_plan_equal!(
2885            plan,
2886            @r"
2887        Inner Join: test.a = test2.a
2888          Projection: test.a, test.b, test.c
2889            TableScan: test, full_filters=[test.b > UInt32(1)]
2890          Projection: test2.a, test2.b, test2.c
2891            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2892        "
2893        )
2894    }
2895
2896    /// predicate on join key in filter expression should be pushed down to both inputs
2897    #[test]
2898    fn join_filter_on_common() -> Result<()> {
2899        let table_scan = test_table_scan()?;
2900        let left = LogicalPlanBuilder::from(table_scan)
2901            .project(vec![col("a")])?
2902            .build()?;
2903        let right_table_scan = test_table_scan_with_name("test2")?;
2904        let right = LogicalPlanBuilder::from(right_table_scan)
2905            .project(vec![col("b")])?
2906            .build()?;
2907        let filter = col("test.a").gt(lit(1u32));
2908        let plan = LogicalPlanBuilder::from(left)
2909            .join(
2910                right,
2911                JoinType::Inner,
2912                (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2913                Some(filter),
2914            )?
2915            .build()?;
2916
2917        // not part of the test, just good to know:
2918        assert_snapshot!(plan,
2919        @r"
2920        Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2921          Projection: test.a
2922            TableScan: test
2923          Projection: test2.b
2924            TableScan: test2
2925        ",
2926        );
2927        assert_optimized_plan_equal!(
2928            plan,
2929            @r"
2930        Inner Join: test.a = test2.b
2931          Projection: test.a
2932            TableScan: test, full_filters=[test.a > UInt32(1)]
2933          Projection: test2.b
2934            TableScan: test2, full_filters=[test2.b > UInt32(1)]
2935        "
2936        )
2937    }
2938
2939    /// single table predicate parts of ON condition should be pushed to right input
2940    #[test]
2941    fn left_join_on_with_filter() -> Result<()> {
2942        let table_scan = test_table_scan()?;
2943        let left = LogicalPlanBuilder::from(table_scan)
2944            .project(vec![col("a"), col("b"), col("c")])?
2945            .build()?;
2946        let right_table_scan = test_table_scan_with_name("test2")?;
2947        let right = LogicalPlanBuilder::from(right_table_scan)
2948            .project(vec![col("a"), col("b"), col("c")])?
2949            .build()?;
2950        let filter = col("test.a")
2951            .gt(lit(1u32))
2952            .and(col("test.b").lt(col("test2.b")))
2953            .and(col("test2.c").gt(lit(4u32)));
2954        let plan = LogicalPlanBuilder::from(left)
2955            .join(
2956                right,
2957                JoinType::Left,
2958                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2959                Some(filter),
2960            )?
2961            .build()?;
2962
2963        // not part of the test, just good to know:
2964        assert_snapshot!(plan,
2965        @r"
2966        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2967          Projection: test.a, test.b, test.c
2968            TableScan: test
2969          Projection: test2.a, test2.b, test2.c
2970            TableScan: test2
2971        ",
2972        );
2973        assert_optimized_plan_equal!(
2974            plan,
2975            @r"
2976        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2977          Projection: test.a, test.b, test.c
2978            TableScan: test
2979          Projection: test2.a, test2.b, test2.c
2980            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2981        "
2982        )
2983    }
2984
2985    /// single table predicate parts of ON condition should be pushed to left input
2986    #[test]
2987    fn right_join_on_with_filter() -> Result<()> {
2988        let table_scan = test_table_scan()?;
2989        let left = LogicalPlanBuilder::from(table_scan)
2990            .project(vec![col("a"), col("b"), col("c")])?
2991            .build()?;
2992        let right_table_scan = test_table_scan_with_name("test2")?;
2993        let right = LogicalPlanBuilder::from(right_table_scan)
2994            .project(vec![col("a"), col("b"), col("c")])?
2995            .build()?;
2996        let filter = col("test.a")
2997            .gt(lit(1u32))
2998            .and(col("test.b").lt(col("test2.b")))
2999            .and(col("test2.c").gt(lit(4u32)));
3000        let plan = LogicalPlanBuilder::from(left)
3001            .join(
3002                right,
3003                JoinType::Right,
3004                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3005                Some(filter),
3006            )?
3007            .build()?;
3008
3009        // not part of the test, just good to know:
3010        assert_snapshot!(plan,
3011        @r"
3012        Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3013          Projection: test.a, test.b, test.c
3014            TableScan: test
3015          Projection: test2.a, test2.b, test2.c
3016            TableScan: test2
3017        ",
3018        );
3019        assert_optimized_plan_equal!(
3020            plan,
3021            @r"
3022        Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
3023          Projection: test.a, test.b, test.c
3024            TableScan: test, full_filters=[test.a > UInt32(1)]
3025          Projection: test2.a, test2.b, test2.c
3026            TableScan: test2
3027        "
3028        )
3029    }
3030
3031    /// single table predicate parts of ON condition should not be pushed
3032    #[test]
3033    fn full_join_on_with_filter() -> Result<()> {
3034        let table_scan = test_table_scan()?;
3035        let left = LogicalPlanBuilder::from(table_scan)
3036            .project(vec![col("a"), col("b"), col("c")])?
3037            .build()?;
3038        let right_table_scan = test_table_scan_with_name("test2")?;
3039        let right = LogicalPlanBuilder::from(right_table_scan)
3040            .project(vec![col("a"), col("b"), col("c")])?
3041            .build()?;
3042        let filter = col("test.a")
3043            .gt(lit(1u32))
3044            .and(col("test.b").lt(col("test2.b")))
3045            .and(col("test2.c").gt(lit(4u32)));
3046        let plan = LogicalPlanBuilder::from(left)
3047            .join(
3048                right,
3049                JoinType::Full,
3050                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3051                Some(filter),
3052            )?
3053            .build()?;
3054
3055        // not part of the test, just good to know:
3056        assert_snapshot!(plan,
3057        @r"
3058        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3059          Projection: test.a, test.b, test.c
3060            TableScan: test
3061          Projection: test2.a, test2.b, test2.c
3062            TableScan: test2
3063        ",
3064        );
3065        assert_optimized_plan_equal!(
3066            plan,
3067            @r"
3068        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3069          Projection: test.a, test.b, test.c
3070            TableScan: test
3071          Projection: test2.a, test2.b, test2.c
3072            TableScan: test2
3073        "
3074        )
3075    }
3076
3077    struct PushDownProvider {
3078        pub filter_support: TableProviderFilterPushDown,
3079    }
3080
3081    #[async_trait]
3082    impl TableSource for PushDownProvider {
3083        fn schema(&self) -> SchemaRef {
3084            Arc::new(Schema::new(vec![
3085                Field::new("a", DataType::Int32, true),
3086                Field::new("b", DataType::Int32, true),
3087            ]))
3088        }
3089
3090        fn table_type(&self) -> TableType {
3091            TableType::Base
3092        }
3093
3094        fn supports_filters_pushdown(
3095            &self,
3096            filters: &[&Expr],
3097        ) -> Result<Vec<TableProviderFilterPushDown>> {
3098            Ok((0..filters.len())
3099                .map(|_| self.filter_support.clone())
3100                .collect())
3101        }
3102
3103        fn as_any(&self) -> &dyn Any {
3104            self
3105        }
3106    }
3107
3108    fn table_scan_with_pushdown_provider_builder(
3109        filter_support: TableProviderFilterPushDown,
3110        filters: Vec<Expr>,
3111        projection: Option<Vec<usize>>,
3112    ) -> Result<LogicalPlanBuilder> {
3113        let test_provider = PushDownProvider { filter_support };
3114
3115        let table_scan = LogicalPlan::TableScan(TableScan {
3116            table_name: "test".into(),
3117            filters,
3118            projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3119            projection,
3120            source: Arc::new(test_provider),
3121            fetch: None,
3122        });
3123
3124        Ok(LogicalPlanBuilder::from(table_scan))
3125    }
3126
3127    fn table_scan_with_pushdown_provider(
3128        filter_support: TableProviderFilterPushDown,
3129    ) -> Result<LogicalPlan> {
3130        table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3131            .filter(col("a").eq(lit(1i64)))?
3132            .build()
3133    }
3134
3135    #[test]
3136    fn filter_with_table_provider_exact() -> Result<()> {
3137        let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3138
3139        assert_optimized_plan_equal!(
3140            plan,
3141            @"TableScan: test, full_filters=[a = Int64(1)]"
3142        )
3143    }
3144
3145    #[test]
3146    fn filter_with_table_provider_inexact() -> Result<()> {
3147        let plan =
3148            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3149
3150        assert_optimized_plan_equal!(
3151            plan,
3152            @r"
3153        Filter: a = Int64(1)
3154          TableScan: test, partial_filters=[a = Int64(1)]
3155        "
3156        )
3157    }
3158
3159    #[test]
3160    fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3161        let plan =
3162            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3163
3164        let optimized_plan = PushDownFilter::new()
3165            .rewrite(plan, &OptimizerContext::new())
3166            .expect("failed to optimize plan")
3167            .data;
3168
3169        // Optimizing the same plan multiple times should produce the same plan
3170        // each time.
3171        assert_optimized_plan_equal!(
3172            optimized_plan,
3173            @r"
3174        Filter: a = Int64(1)
3175          TableScan: test, partial_filters=[a = Int64(1)]
3176        "
3177        )
3178    }
3179
3180    #[test]
3181    fn filter_with_table_provider_unsupported() -> Result<()> {
3182        let plan =
3183            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3184
3185        assert_optimized_plan_equal!(
3186            plan,
3187            @r"
3188        Filter: a = Int64(1)
3189          TableScan: test
3190        "
3191        )
3192    }
3193
3194    #[test]
3195    fn multi_combined_filter() -> Result<()> {
3196        let plan = table_scan_with_pushdown_provider_builder(
3197            TableProviderFilterPushDown::Inexact,
3198            vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3199            Some(vec![0]),
3200        )?
3201        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3202        .project(vec![col("a"), col("b")])?
3203        .build()?;
3204
3205        assert_optimized_plan_equal!(
3206            plan,
3207            @r"
3208        Projection: a, b
3209          Filter: a = Int64(10) AND b > Int64(11)
3210            TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3211        "
3212        )
3213    }
3214
3215    #[test]
3216    fn multi_combined_filter_exact() -> Result<()> {
3217        let plan = table_scan_with_pushdown_provider_builder(
3218            TableProviderFilterPushDown::Exact,
3219            vec![],
3220            Some(vec![0]),
3221        )?
3222        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3223        .project(vec![col("a"), col("b")])?
3224        .build()?;
3225
3226        assert_optimized_plan_equal!(
3227            plan,
3228            @r"
3229        Projection: a, b
3230          TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3231        "
3232        )
3233    }
3234
3235    #[test]
3236    fn test_filter_with_alias() -> Result<()> {
3237        // in table scan the true col name is 'test.a',
3238        // but we rename it as 'b', and use col 'b' in filter
3239        // we need rewrite filter col before push down.
3240        let table_scan = test_table_scan()?;
3241        let plan = LogicalPlanBuilder::from(table_scan)
3242            .project(vec![col("a").alias("b"), col("c")])?
3243            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3244            .build()?;
3245
3246        // filter on col b
3247        assert_snapshot!(plan,
3248        @r"
3249        Filter: b > Int64(10) AND test.c > Int64(10)
3250          Projection: test.a AS b, test.c
3251            TableScan: test
3252        ",
3253        );
3254        // rewrite filter col b to test.a
3255        assert_optimized_plan_equal!(
3256            plan,
3257            @r"
3258        Projection: test.a AS b, test.c
3259          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3260        "
3261        )
3262    }
3263
3264    #[test]
3265    fn test_filter_with_alias_2() -> Result<()> {
3266        // in table scan the true col name is 'test.a',
3267        // but we rename it as 'b', and use col 'b' in filter
3268        // we need rewrite filter col before push down.
3269        let table_scan = test_table_scan()?;
3270        let plan = LogicalPlanBuilder::from(table_scan)
3271            .project(vec![col("a").alias("b"), col("c")])?
3272            .project(vec![col("b"), col("c")])?
3273            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3274            .build()?;
3275
3276        // filter on col b
3277        assert_snapshot!(plan,
3278        @r"
3279        Filter: b > Int64(10) AND test.c > Int64(10)
3280          Projection: b, test.c
3281            Projection: test.a AS b, test.c
3282              TableScan: test
3283        ",
3284        );
3285        // rewrite filter col b to test.a
3286        assert_optimized_plan_equal!(
3287            plan,
3288            @r"
3289        Projection: b, test.c
3290          Projection: test.a AS b, test.c
3291            TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3292        "
3293        )
3294    }
3295
3296    #[test]
3297    fn test_filter_with_multi_alias() -> Result<()> {
3298        let table_scan = test_table_scan()?;
3299        let plan = LogicalPlanBuilder::from(table_scan)
3300            .project(vec![col("a").alias("b"), col("c").alias("d")])?
3301            .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3302            .build()?;
3303
3304        // filter on col b and d
3305        assert_snapshot!(plan,
3306        @r"
3307        Filter: b > Int64(10) AND d > Int64(10)
3308          Projection: test.a AS b, test.c AS d
3309            TableScan: test
3310        ",
3311        );
3312        // rewrite filter col b to test.a, col d to test.c
3313        assert_optimized_plan_equal!(
3314            plan,
3315            @r"
3316        Projection: test.a AS b, test.c AS d
3317          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3318        "
3319        )
3320    }
3321
3322    /// predicate on join key in filter expression should be pushed down to both inputs
3323    #[test]
3324    fn join_filter_with_alias() -> Result<()> {
3325        let table_scan = test_table_scan()?;
3326        let left = LogicalPlanBuilder::from(table_scan)
3327            .project(vec![col("a").alias("c")])?
3328            .build()?;
3329        let right_table_scan = test_table_scan_with_name("test2")?;
3330        let right = LogicalPlanBuilder::from(right_table_scan)
3331            .project(vec![col("b").alias("d")])?
3332            .build()?;
3333        let filter = col("c").gt(lit(1u32));
3334        let plan = LogicalPlanBuilder::from(left)
3335            .join(
3336                right,
3337                JoinType::Inner,
3338                (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3339                Some(filter),
3340            )?
3341            .build()?;
3342
3343        assert_snapshot!(plan,
3344        @r"
3345        Inner Join: c = d Filter: c > UInt32(1)
3346          Projection: test.a AS c
3347            TableScan: test
3348          Projection: test2.b AS d
3349            TableScan: test2
3350        ",
3351        );
3352        // Change filter on col `c`, 'd' to `test.a`, 'test.b'
3353        assert_optimized_plan_equal!(
3354            plan,
3355            @r"
3356        Inner Join: c = d
3357          Projection: test.a AS c
3358            TableScan: test, full_filters=[test.a > UInt32(1)]
3359          Projection: test2.b AS d
3360            TableScan: test2, full_filters=[test2.b > UInt32(1)]
3361        "
3362        )
3363    }
3364
3365    #[test]
3366    fn test_in_filter_with_alias() -> Result<()> {
3367        // in table scan the true col name is 'test.a',
3368        // but we rename it as 'b', and use col 'b' in filter
3369        // we need rewrite filter col before push down.
3370        let table_scan = test_table_scan()?;
3371        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3372        let plan = LogicalPlanBuilder::from(table_scan)
3373            .project(vec![col("a").alias("b"), col("c")])?
3374            .filter(in_list(col("b"), filter_value, false))?
3375            .build()?;
3376
3377        // filter on col b
3378        assert_snapshot!(plan,
3379        @r"
3380        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3381          Projection: test.a AS b, test.c
3382            TableScan: test
3383        ",
3384        );
3385        // rewrite filter col b to test.a
3386        assert_optimized_plan_equal!(
3387            plan,
3388            @r"
3389        Projection: test.a AS b, test.c
3390          TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3391        "
3392        )
3393    }
3394
3395    #[test]
3396    fn test_in_filter_with_alias_2() -> Result<()> {
3397        // in table scan the true col name is 'test.a',
3398        // but we rename it as 'b', and use col 'b' in filter
3399        // we need rewrite filter col before push down.
3400        let table_scan = test_table_scan()?;
3401        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3402        let plan = LogicalPlanBuilder::from(table_scan)
3403            .project(vec![col("a").alias("b"), col("c")])?
3404            .project(vec![col("b"), col("c")])?
3405            .filter(in_list(col("b"), filter_value, false))?
3406            .build()?;
3407
3408        // filter on col b
3409        assert_snapshot!(plan,
3410        @r"
3411        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3412          Projection: b, test.c
3413            Projection: test.a AS b, test.c
3414              TableScan: test
3415        ",
3416        );
3417        // rewrite filter col b to test.a
3418        assert_optimized_plan_equal!(
3419            plan,
3420            @r"
3421        Projection: b, test.c
3422          Projection: test.a AS b, test.c
3423            TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3424        "
3425        )
3426    }
3427
3428    #[test]
3429    fn test_in_subquery_with_alias() -> Result<()> {
3430        // in table scan the true col name is 'test.a',
3431        // but we rename it as 'b', and use col 'b' in subquery filter
3432        let table_scan = test_table_scan()?;
3433        let table_scan_sq = test_table_scan_with_name("sq")?;
3434        let subplan = Arc::new(
3435            LogicalPlanBuilder::from(table_scan_sq)
3436                .project(vec![col("c")])?
3437                .build()?,
3438        );
3439        let plan = LogicalPlanBuilder::from(table_scan)
3440            .project(vec![col("a").alias("b"), col("c")])?
3441            .filter(in_subquery(col("b"), subplan))?
3442            .build()?;
3443
3444        // filter on col b in subquery
3445        assert_snapshot!(plan,
3446        @r"
3447        Filter: b IN (<subquery>)
3448          Subquery:
3449            Projection: sq.c
3450              TableScan: sq
3451          Projection: test.a AS b, test.c
3452            TableScan: test
3453        ",
3454        );
3455        // rewrite filter col b to test.a
3456        assert_optimized_plan_equal!(
3457            plan,
3458            @r"
3459        Projection: test.a AS b, test.c
3460          TableScan: test, full_filters=[test.a IN (<subquery>)]
3461            Subquery:
3462              Projection: sq.c
3463                TableScan: sq
3464        "
3465        )
3466    }
3467
3468    #[test]
3469    fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3470        // SELECT a FROM (SELECT 1 AS a) b WHERE b.a = 1
3471        let plan = LogicalPlanBuilder::empty(true)
3472            .project(vec![lit(0i64).alias("a")])?
3473            .alias("b")?
3474            .project(vec![col("b.a")])?
3475            .alias("b")?
3476            .filter(col("b.a").eq(lit(1i64)))?
3477            .project(vec![col("b.a")])?
3478            .build()?;
3479
3480        assert_snapshot!(plan,
3481        @r"
3482        Projection: b.a
3483          Filter: b.a = Int64(1)
3484            SubqueryAlias: b
3485              Projection: b.a
3486                SubqueryAlias: b
3487                  Projection: Int64(0) AS a
3488                    EmptyRelation: rows=1
3489        ",
3490        );
3491        // Ensure that the predicate without any columns (0 = 1) is
3492        // still there.
3493        assert_optimized_plan_equal!(
3494            plan,
3495            @r"
3496        Projection: b.a
3497          SubqueryAlias: b
3498            Projection: b.a
3499              SubqueryAlias: b
3500                Projection: Int64(0) AS a
3501                  Filter: Int64(0) = Int64(1)
3502                    EmptyRelation: rows=1
3503        "
3504        )
3505    }
3506
3507    #[test]
3508    fn test_crossjoin_with_or_clause() -> Result<()> {
3509        // select * from test,test1 where (test.a = test1.a and test.b > 1) or (test.b = test1.b and test.c < 10);
3510        let table_scan = test_table_scan()?;
3511        let left = LogicalPlanBuilder::from(table_scan)
3512            .project(vec![col("a"), col("b"), col("c")])?
3513            .build()?;
3514        let right_table_scan = test_table_scan_with_name("test1")?;
3515        let right = LogicalPlanBuilder::from(right_table_scan)
3516            .project(vec![col("a").alias("d"), col("a").alias("e")])?
3517            .build()?;
3518        let filter = or(
3519            and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3520            and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3521        );
3522        let plan = LogicalPlanBuilder::from(left)
3523            .cross_join(right)?
3524            .filter(filter)?
3525            .build()?;
3526
3527        assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3528        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3529          Projection: test.a, test.b, test.c
3530            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3531          Projection: test1.a AS d, test1.a AS e
3532            TableScan: test1
3533        ")?;
3534
3535        // Originally global state which can help to avoid duplicate Filters been generated and pushed down.
3536        // Now the global state is removed. Need to double confirm that avoid duplicate Filters.
3537        let optimized_plan = PushDownFilter::new()
3538            .rewrite(plan, &OptimizerContext::new())
3539            .expect("failed to optimize plan")
3540            .data;
3541        assert_optimized_plan_equal!(
3542            optimized_plan,
3543            @r"
3544        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3545          Projection: test.a, test.b, test.c
3546            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3547          Projection: test1.a AS d, test1.a AS e
3548            TableScan: test1
3549        "
3550        )
3551    }
3552
3553    #[test]
3554    fn left_semi_join() -> Result<()> {
3555        let left = test_table_scan_with_name("test1")?;
3556        let right_table_scan = test_table_scan_with_name("test2")?;
3557        let right = LogicalPlanBuilder::from(right_table_scan)
3558            .project(vec![col("a"), col("b")])?
3559            .build()?;
3560        let plan = LogicalPlanBuilder::from(left)
3561            .join(
3562                right,
3563                JoinType::LeftSemi,
3564                (
3565                    vec![Column::from_qualified_name("test1.a")],
3566                    vec![Column::from_qualified_name("test2.a")],
3567                ),
3568                None,
3569            )?
3570            .filter(col("test2.a").lt_eq(lit(1i64)))?
3571            .build()?;
3572
3573        // not part of the test, just good to know:
3574        assert_snapshot!(plan,
3575        @r"
3576        Filter: test2.a <= Int64(1)
3577          LeftSemi Join: test1.a = test2.a
3578            TableScan: test1
3579            Projection: test2.a, test2.b
3580              TableScan: test2
3581        ",
3582        );
3583        // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side.
3584        assert_optimized_plan_equal!(
3585            plan,
3586            @r"
3587        Filter: test2.a <= Int64(1)
3588          LeftSemi Join: test1.a = test2.a
3589            TableScan: test1, full_filters=[test1.a <= Int64(1)]
3590            Projection: test2.a, test2.b
3591              TableScan: test2
3592        "
3593        )
3594    }
3595
3596    #[test]
3597    fn left_semi_join_with_filters() -> Result<()> {
3598        let left = test_table_scan_with_name("test1")?;
3599        let right_table_scan = test_table_scan_with_name("test2")?;
3600        let right = LogicalPlanBuilder::from(right_table_scan)
3601            .project(vec![col("a"), col("b")])?
3602            .build()?;
3603        let plan = LogicalPlanBuilder::from(left)
3604            .join(
3605                right,
3606                JoinType::LeftSemi,
3607                (
3608                    vec![Column::from_qualified_name("test1.a")],
3609                    vec![Column::from_qualified_name("test2.a")],
3610                ),
3611                Some(
3612                    col("test1.b")
3613                        .gt(lit(1u32))
3614                        .and(col("test2.b").gt(lit(2u32))),
3615                ),
3616            )?
3617            .build()?;
3618
3619        // not part of the test, just good to know:
3620        assert_snapshot!(plan,
3621        @r"
3622        LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3623          TableScan: test1
3624          Projection: test2.a, test2.b
3625            TableScan: test2
3626        ",
3627        );
3628        // Both side will be pushed down.
3629        assert_optimized_plan_equal!(
3630            plan,
3631            @r"
3632        LeftSemi Join: test1.a = test2.a
3633          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3634          Projection: test2.a, test2.b
3635            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3636        "
3637        )
3638    }
3639
3640    #[test]
3641    fn right_semi_join() -> Result<()> {
3642        let left = test_table_scan_with_name("test1")?;
3643        let right_table_scan = test_table_scan_with_name("test2")?;
3644        let right = LogicalPlanBuilder::from(right_table_scan)
3645            .project(vec![col("a"), col("b")])?
3646            .build()?;
3647        let plan = LogicalPlanBuilder::from(left)
3648            .join(
3649                right,
3650                JoinType::RightSemi,
3651                (
3652                    vec![Column::from_qualified_name("test1.a")],
3653                    vec![Column::from_qualified_name("test2.a")],
3654                ),
3655                None,
3656            )?
3657            .filter(col("test1.a").lt_eq(lit(1i64)))?
3658            .build()?;
3659
3660        // not part of the test, just good to know:
3661        assert_snapshot!(plan,
3662        @r"
3663        Filter: test1.a <= Int64(1)
3664          RightSemi Join: test1.a = test2.a
3665            TableScan: test1
3666            Projection: test2.a, test2.b
3667              TableScan: test2
3668        ",
3669        );
3670        // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side.
3671        assert_optimized_plan_equal!(
3672            plan,
3673            @r"
3674        Filter: test1.a <= Int64(1)
3675          RightSemi Join: test1.a = test2.a
3676            TableScan: test1
3677            Projection: test2.a, test2.b
3678              TableScan: test2, full_filters=[test2.a <= Int64(1)]
3679        "
3680        )
3681    }
3682
3683    #[test]
3684    fn right_semi_join_with_filters() -> Result<()> {
3685        let left = test_table_scan_with_name("test1")?;
3686        let right_table_scan = test_table_scan_with_name("test2")?;
3687        let right = LogicalPlanBuilder::from(right_table_scan)
3688            .project(vec![col("a"), col("b")])?
3689            .build()?;
3690        let plan = LogicalPlanBuilder::from(left)
3691            .join(
3692                right,
3693                JoinType::RightSemi,
3694                (
3695                    vec![Column::from_qualified_name("test1.a")],
3696                    vec![Column::from_qualified_name("test2.a")],
3697                ),
3698                Some(
3699                    col("test1.b")
3700                        .gt(lit(1u32))
3701                        .and(col("test2.b").gt(lit(2u32))),
3702                ),
3703            )?
3704            .build()?;
3705
3706        // not part of the test, just good to know:
3707        assert_snapshot!(plan,
3708        @r"
3709        RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3710          TableScan: test1
3711          Projection: test2.a, test2.b
3712            TableScan: test2
3713        ",
3714        );
3715        // Both side will be pushed down.
3716        assert_optimized_plan_equal!(
3717            plan,
3718            @r"
3719        RightSemi Join: test1.a = test2.a
3720          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3721          Projection: test2.a, test2.b
3722            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3723        "
3724        )
3725    }
3726
3727    #[test]
3728    fn left_anti_join() -> Result<()> {
3729        let table_scan = test_table_scan_with_name("test1")?;
3730        let left = LogicalPlanBuilder::from(table_scan)
3731            .project(vec![col("a"), col("b")])?
3732            .build()?;
3733        let right_table_scan = test_table_scan_with_name("test2")?;
3734        let right = LogicalPlanBuilder::from(right_table_scan)
3735            .project(vec![col("a"), col("b")])?
3736            .build()?;
3737        let plan = LogicalPlanBuilder::from(left)
3738            .join(
3739                right,
3740                JoinType::LeftAnti,
3741                (
3742                    vec![Column::from_qualified_name("test1.a")],
3743                    vec![Column::from_qualified_name("test2.a")],
3744                ),
3745                None,
3746            )?
3747            .filter(col("test2.a").gt(lit(2u32)))?
3748            .build()?;
3749
3750        // not part of the test, just good to know:
3751        assert_snapshot!(plan,
3752        @r"
3753        Filter: test2.a > UInt32(2)
3754          LeftAnti Join: test1.a = test2.a
3755            Projection: test1.a, test1.b
3756              TableScan: test1
3757            Projection: test2.a, test2.b
3758              TableScan: test2
3759        ",
3760        );
3761        // For left anti, filter of the right side filter can be pushed down.
3762        assert_optimized_plan_equal!(
3763            plan,
3764            @r"
3765        Filter: test2.a > UInt32(2)
3766          LeftAnti Join: test1.a = test2.a
3767            Projection: test1.a, test1.b
3768              TableScan: test1, full_filters=[test1.a > UInt32(2)]
3769            Projection: test2.a, test2.b
3770              TableScan: test2
3771        "
3772        )
3773    }
3774
3775    #[test]
3776    fn left_anti_join_with_filters() -> Result<()> {
3777        let table_scan = test_table_scan_with_name("test1")?;
3778        let left = LogicalPlanBuilder::from(table_scan)
3779            .project(vec![col("a"), col("b")])?
3780            .build()?;
3781        let right_table_scan = test_table_scan_with_name("test2")?;
3782        let right = LogicalPlanBuilder::from(right_table_scan)
3783            .project(vec![col("a"), col("b")])?
3784            .build()?;
3785        let plan = LogicalPlanBuilder::from(left)
3786            .join(
3787                right,
3788                JoinType::LeftAnti,
3789                (
3790                    vec![Column::from_qualified_name("test1.a")],
3791                    vec![Column::from_qualified_name("test2.a")],
3792                ),
3793                Some(
3794                    col("test1.b")
3795                        .gt(lit(1u32))
3796                        .and(col("test2.b").gt(lit(2u32))),
3797                ),
3798            )?
3799            .build()?;
3800
3801        // not part of the test, just good to know:
3802        assert_snapshot!(plan,
3803        @r"
3804        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3805          Projection: test1.a, test1.b
3806            TableScan: test1
3807          Projection: test2.a, test2.b
3808            TableScan: test2
3809        ",
3810        );
3811        // For left anti, filter of the right side filter can be pushed down.
3812        assert_optimized_plan_equal!(
3813            plan,
3814            @r"
3815        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3816          Projection: test1.a, test1.b
3817            TableScan: test1
3818          Projection: test2.a, test2.b
3819            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3820        "
3821        )
3822    }
3823
3824    #[test]
3825    fn right_anti_join() -> Result<()> {
3826        let table_scan = test_table_scan_with_name("test1")?;
3827        let left = LogicalPlanBuilder::from(table_scan)
3828            .project(vec![col("a"), col("b")])?
3829            .build()?;
3830        let right_table_scan = test_table_scan_with_name("test2")?;
3831        let right = LogicalPlanBuilder::from(right_table_scan)
3832            .project(vec![col("a"), col("b")])?
3833            .build()?;
3834        let plan = LogicalPlanBuilder::from(left)
3835            .join(
3836                right,
3837                JoinType::RightAnti,
3838                (
3839                    vec![Column::from_qualified_name("test1.a")],
3840                    vec![Column::from_qualified_name("test2.a")],
3841                ),
3842                None,
3843            )?
3844            .filter(col("test1.a").gt(lit(2u32)))?
3845            .build()?;
3846
3847        // not part of the test, just good to know:
3848        assert_snapshot!(plan,
3849        @r"
3850        Filter: test1.a > UInt32(2)
3851          RightAnti Join: test1.a = test2.a
3852            Projection: test1.a, test1.b
3853              TableScan: test1
3854            Projection: test2.a, test2.b
3855              TableScan: test2
3856        ",
3857        );
3858        // For right anti, filter of the left side can be pushed down.
3859        assert_optimized_plan_equal!(
3860            plan,
3861            @r"
3862        Filter: test1.a > UInt32(2)
3863          RightAnti Join: test1.a = test2.a
3864            Projection: test1.a, test1.b
3865              TableScan: test1
3866            Projection: test2.a, test2.b
3867              TableScan: test2, full_filters=[test2.a > UInt32(2)]
3868        "
3869        )
3870    }
3871
3872    #[test]
3873    fn right_anti_join_with_filters() -> Result<()> {
3874        let table_scan = test_table_scan_with_name("test1")?;
3875        let left = LogicalPlanBuilder::from(table_scan)
3876            .project(vec![col("a"), col("b")])?
3877            .build()?;
3878        let right_table_scan = test_table_scan_with_name("test2")?;
3879        let right = LogicalPlanBuilder::from(right_table_scan)
3880            .project(vec![col("a"), col("b")])?
3881            .build()?;
3882        let plan = LogicalPlanBuilder::from(left)
3883            .join(
3884                right,
3885                JoinType::RightAnti,
3886                (
3887                    vec![Column::from_qualified_name("test1.a")],
3888                    vec![Column::from_qualified_name("test2.a")],
3889                ),
3890                Some(
3891                    col("test1.b")
3892                        .gt(lit(1u32))
3893                        .and(col("test2.b").gt(lit(2u32))),
3894                ),
3895            )?
3896            .build()?;
3897
3898        // not part of the test, just good to know:
3899        assert_snapshot!(plan,
3900        @r"
3901        RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3902          Projection: test1.a, test1.b
3903            TableScan: test1
3904          Projection: test2.a, test2.b
3905            TableScan: test2
3906        ",
3907        );
3908        // For right anti, filter of the left side can be pushed down.
3909        assert_optimized_plan_equal!(
3910            plan,
3911            @r"
3912        RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3913          Projection: test1.a, test1.b
3914            TableScan: test1, full_filters=[test1.b > UInt32(1)]
3915          Projection: test2.a, test2.b
3916            TableScan: test2
3917        "
3918        )
3919    }
3920
3921    #[derive(Debug, PartialEq, Eq, Hash)]
3922    struct TestScalarUDF {
3923        signature: Signature,
3924    }
3925
3926    impl ScalarUDFImpl for TestScalarUDF {
3927        fn as_any(&self) -> &dyn Any {
3928            self
3929        }
3930        fn name(&self) -> &str {
3931            "TestScalarUDF"
3932        }
3933
3934        fn signature(&self) -> &Signature {
3935            &self.signature
3936        }
3937
3938        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3939            Ok(DataType::Int32)
3940        }
3941
3942        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3943            Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3944        }
3945    }
3946
3947    #[test]
3948    fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3949        // SELECT t.a, t.r FROM (SELECT a, sum(b),  TestScalarUDF()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5;
3950        let table_scan = test_table_scan_with_name("test1")?;
3951        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3952            signature: Signature::exact(vec![], Volatility::Volatile),
3953        });
3954        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3955
3956        let plan = LogicalPlanBuilder::from(table_scan)
3957            .aggregate(vec![col("a")], vec![sum(col("b"))])?
3958            .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3959            .alias("t")?
3960            .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3961            .project(vec![col("t.a"), col("t.r")])?
3962            .build()?;
3963
3964        assert_snapshot!(plan,
3965        @r"
3966        Projection: t.a, t.r
3967          Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3968            SubqueryAlias: t
3969              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3970                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3971                  TableScan: test1
3972        ",
3973        );
3974        assert_optimized_plan_equal!(
3975            plan,
3976            @r"
3977        Projection: t.a, t.r
3978          SubqueryAlias: t
3979            Filter: r > Float64(0.5)
3980              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3981                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3982                  TableScan: test1, full_filters=[test1.a > Int32(5)]
3983        "
3984        )
3985    }
3986
3987    #[test]
3988    fn test_push_down_volatile_function_in_join() -> Result<()> {
3989        // SELECT t.a, t.r FROM (SELECT test1.a AS a, TestScalarUDF() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5;
3990        let table_scan = test_table_scan_with_name("test1")?;
3991        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3992            signature: Signature::exact(vec![], Volatility::Volatile),
3993        });
3994        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3995        let left = LogicalPlanBuilder::from(table_scan).build()?;
3996        let right_table_scan = test_table_scan_with_name("test2")?;
3997        let right = LogicalPlanBuilder::from(right_table_scan).build()?;
3998        let plan = LogicalPlanBuilder::from(left)
3999            .join(
4000                right,
4001                JoinType::Inner,
4002                (
4003                    vec![Column::from_qualified_name("test1.a")],
4004                    vec![Column::from_qualified_name("test2.a")],
4005                ),
4006                None,
4007            )?
4008            .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
4009            .alias("t")?
4010            .filter(col("t.r").gt(lit(0.8)))?
4011            .project(vec![col("t.a"), col("t.r")])?
4012            .build()?;
4013
4014        assert_snapshot!(plan,
4015        @r"
4016        Projection: t.a, t.r
4017          Filter: t.r > Float64(0.8)
4018            SubqueryAlias: t
4019              Projection: test1.a AS a, TestScalarUDF() AS r
4020                Inner Join: test1.a = test2.a
4021                  TableScan: test1
4022                  TableScan: test2
4023        ",
4024        );
4025        assert_optimized_plan_equal!(
4026            plan,
4027            @r"
4028        Projection: t.a, t.r
4029          SubqueryAlias: t
4030            Filter: r > Float64(0.8)
4031              Projection: test1.a AS a, TestScalarUDF() AS r
4032                Inner Join: test1.a = test2.a
4033                  TableScan: test1
4034                  TableScan: test2
4035        "
4036        )
4037    }
4038
4039    #[test]
4040    fn test_push_down_volatile_table_scan() -> Result<()> {
4041        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1;
4042        let table_scan = test_table_scan()?;
4043        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4044            signature: Signature::exact(vec![], Volatility::Volatile),
4045        });
4046        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4047        let plan = LogicalPlanBuilder::from(table_scan)
4048            .project(vec![col("a"), col("b")])?
4049            .filter(expr.gt(lit(0.1)))?
4050            .build()?;
4051
4052        assert_snapshot!(plan,
4053        @r"
4054        Filter: TestScalarUDF() > Float64(0.1)
4055          Projection: test.a, test.b
4056            TableScan: test
4057        ",
4058        );
4059        assert_optimized_plan_equal!(
4060            plan,
4061            @r"
4062        Projection: test.a, test.b
4063          Filter: TestScalarUDF() > Float64(0.1)
4064            TableScan: test
4065        "
4066        )
4067    }
4068
4069    #[test]
4070    fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4071        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4072        let table_scan = test_table_scan()?;
4073        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4074            signature: Signature::exact(vec![], Volatility::Volatile),
4075        });
4076        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4077        let plan = LogicalPlanBuilder::from(table_scan)
4078            .project(vec![col("a"), col("b")])?
4079            .filter(
4080                expr.gt(lit(0.1))
4081                    .and(col("t.a").gt(lit(5)))
4082                    .and(col("t.b").gt(lit(10))),
4083            )?
4084            .build()?;
4085
4086        assert_snapshot!(plan,
4087        @r"
4088        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4089          Projection: test.a, test.b
4090            TableScan: test
4091        ",
4092        );
4093        assert_optimized_plan_equal!(
4094            plan,
4095            @r"
4096        Projection: test.a, test.b
4097          Filter: TestScalarUDF() > Float64(0.1)
4098            TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4099        "
4100        )
4101    }
4102
4103    #[test]
4104    fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4105        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4106        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4107            signature: Signature::exact(vec![], Volatility::Volatile),
4108        });
4109        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4110        let plan = table_scan_with_pushdown_provider_builder(
4111            TableProviderFilterPushDown::Unsupported,
4112            vec![],
4113            None,
4114        )?
4115        .project(vec![col("a"), col("b")])?
4116        .filter(
4117            expr.gt(lit(0.1))
4118                .and(col("t.a").gt(lit(5)))
4119                .and(col("t.b").gt(lit(10))),
4120        )?
4121        .build()?;
4122
4123        assert_snapshot!(plan,
4124        @r"
4125        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4126          Projection: a, b
4127            TableScan: test
4128        ",
4129        );
4130        assert_optimized_plan_equal!(
4131            plan,
4132            @r"
4133        Projection: a, b
4134          Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4135            TableScan: test
4136        "
4137        )
4138    }
4139
4140    #[test]
4141    fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4142        // Define a custom user-defined logical node
4143        #[derive(Debug, Hash, Eq, PartialEq)]
4144        struct TestUserNode {
4145            schema: DFSchemaRef,
4146        }
4147
4148        impl PartialOrd for TestUserNode {
4149            fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4150                None
4151            }
4152        }
4153
4154        impl TestUserNode {
4155            fn new() -> Self {
4156                let schema = Arc::new(
4157                    DFSchema::new_with_metadata(
4158                        vec![(None, Field::new("a", DataType::Int64, false).into())],
4159                        Default::default(),
4160                    )
4161                    .unwrap(),
4162                );
4163
4164                Self { schema }
4165            }
4166        }
4167
4168        impl UserDefinedLogicalNodeCore for TestUserNode {
4169            fn name(&self) -> &str {
4170                "test_node"
4171            }
4172
4173            fn inputs(&self) -> Vec<&LogicalPlan> {
4174                vec![]
4175            }
4176
4177            fn schema(&self) -> &DFSchemaRef {
4178                &self.schema
4179            }
4180
4181            fn expressions(&self) -> Vec<Expr> {
4182                vec![]
4183            }
4184
4185            fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4186                write!(f, "TestUserNode")
4187            }
4188
4189            fn with_exprs_and_inputs(
4190                &self,
4191                exprs: Vec<Expr>,
4192                inputs: Vec<LogicalPlan>,
4193            ) -> Result<Self> {
4194                assert!(exprs.is_empty());
4195                assert!(inputs.is_empty());
4196                Ok(Self {
4197                    schema: Arc::clone(&self.schema),
4198                })
4199            }
4200        }
4201
4202        // Create a node and build a plan with a filter
4203        let node = LogicalPlan::Extension(Extension {
4204            node: Arc::new(TestUserNode::new()),
4205        });
4206
4207        let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4208
4209        // Check the original plan format (not part of the test assertions)
4210        assert_snapshot!(plan,
4211        @r"
4212        Filter: Boolean(false)
4213          TestUserNode
4214        ",
4215        );
4216        // Check that the filter is pushed down to the user-defined node
4217        assert_optimized_plan_equal!(
4218            plan,
4219            @r"
4220        Filter: Boolean(false)
4221          TestUserNode
4222        "
4223        )
4224    }
4225}