datafusion_optimizer/
eliminate_filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateFilter`] replaces `where false` or `where null` with an empty relation.
19
20use datafusion_common::tree_node::Transformed;
21use datafusion_common::{Result, ScalarValue};
22use datafusion_expr::{EmptyRelation, Expr, Filter, LogicalPlan};
23use std::sync::Arc;
24
25use crate::optimizer::ApplyOrder;
26use crate::{OptimizerConfig, OptimizerRule};
27
28/// Optimization rule that eliminate the scalar value (true/false/null) filter
29/// with an [LogicalPlan::EmptyRelation]
30///
31/// This saves time in planning and executing the query.
32/// Note that this rule should be applied after simplify expressions optimizer rule.
33#[derive(Default, Debug)]
34pub struct EliminateFilter;
35
36impl EliminateFilter {
37    #[allow(missing_docs)]
38    pub fn new() -> Self {
39        Self {}
40    }
41}
42
43impl OptimizerRule for EliminateFilter {
44    fn name(&self) -> &str {
45        "eliminate_filter"
46    }
47
48    fn apply_order(&self) -> Option<ApplyOrder> {
49        Some(ApplyOrder::TopDown)
50    }
51
52    fn supports_rewrite(&self) -> bool {
53        true
54    }
55
56    fn rewrite(
57        &self,
58        plan: LogicalPlan,
59        _config: &dyn OptimizerConfig,
60    ) -> Result<Transformed<LogicalPlan>> {
61        match plan {
62            LogicalPlan::Filter(Filter {
63                predicate: Expr::Literal(ScalarValue::Boolean(v)),
64                input,
65                ..
66            }) => match v {
67                Some(true) => Ok(Transformed::yes(Arc::unwrap_or_clone(input))),
68                Some(false) | None => Ok(Transformed::yes(LogicalPlan::EmptyRelation(
69                    EmptyRelation {
70                        produce_one_row: false,
71                        schema: Arc::clone(input.schema()),
72                    },
73                ))),
74            },
75            _ => Ok(Transformed::no(plan)),
76        }
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use std::sync::Arc;
83
84    use datafusion_common::{Result, ScalarValue};
85    use datafusion_expr::{
86        col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan,
87    };
88
89    use crate::eliminate_filter::EliminateFilter;
90    use crate::test::*;
91    use datafusion_expr::test::function_stub::sum;
92
93    fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
94        assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected)
95    }
96
97    #[test]
98    fn filter_false() -> Result<()> {
99        let filter_expr = lit(false);
100
101        let table_scan = test_table_scan().unwrap();
102        let plan = LogicalPlanBuilder::from(table_scan)
103            .aggregate(vec![col("a")], vec![sum(col("b"))])?
104            .filter(filter_expr)?
105            .build()?;
106
107        // No aggregate / scan / limit
108        let expected = "EmptyRelation";
109        assert_optimized_plan_equal(plan, expected)
110    }
111
112    #[test]
113    fn filter_null() -> Result<()> {
114        let filter_expr = Expr::Literal(ScalarValue::Boolean(None));
115
116        let table_scan = test_table_scan().unwrap();
117        let plan = LogicalPlanBuilder::from(table_scan)
118            .aggregate(vec![col("a")], vec![sum(col("b"))])?
119            .filter(filter_expr)?
120            .build()?;
121
122        // No aggregate / scan / limit
123        let expected = "EmptyRelation";
124        assert_optimized_plan_equal(plan, expected)
125    }
126
127    #[test]
128    fn filter_false_nested() -> Result<()> {
129        let filter_expr = lit(false);
130
131        let table_scan = test_table_scan()?;
132        let plan1 = LogicalPlanBuilder::from(table_scan.clone())
133            .aggregate(vec![col("a")], vec![sum(col("b"))])?
134            .build()?;
135        let plan = LogicalPlanBuilder::from(table_scan)
136            .aggregate(vec![col("a")], vec![sum(col("b"))])?
137            .filter(filter_expr)?
138            .union(plan1)?
139            .build()?;
140
141        // Left side is removed
142        let expected = "Union\
143            \n  EmptyRelation\
144            \n  Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
145            \n    TableScan: test";
146        assert_optimized_plan_equal(plan, expected)
147    }
148
149    #[test]
150    fn filter_true() -> Result<()> {
151        let filter_expr = lit(true);
152
153        let table_scan = test_table_scan()?;
154        let plan = LogicalPlanBuilder::from(table_scan)
155            .aggregate(vec![col("a")], vec![sum(col("b"))])?
156            .filter(filter_expr)?
157            .build()?;
158
159        let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
160        \n  TableScan: test";
161        assert_optimized_plan_equal(plan, expected)
162    }
163
164    #[test]
165    fn filter_true_nested() -> Result<()> {
166        let filter_expr = lit(true);
167
168        let table_scan = test_table_scan()?;
169        let plan1 = LogicalPlanBuilder::from(table_scan.clone())
170            .aggregate(vec![col("a")], vec![sum(col("b"))])?
171            .build()?;
172        let plan = LogicalPlanBuilder::from(table_scan)
173            .aggregate(vec![col("a")], vec![sum(col("b"))])?
174            .filter(filter_expr)?
175            .union(plan1)?
176            .build()?;
177
178        // Filter is removed
179        let expected = "Union\
180            \n  Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
181            \n    TableScan: test\
182            \n  Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
183            \n    TableScan: test";
184        assert_optimized_plan_equal(plan, expected)
185    }
186
187    #[test]
188    fn filter_from_subquery() -> Result<()> {
189        // SELECT a FROM (SELECT a FROM test WHERE FALSE) WHERE TRUE
190
191        let false_filter = lit(false);
192        let table_scan = test_table_scan()?;
193        let plan1 = LogicalPlanBuilder::from(table_scan)
194            .project(vec![col("a")])?
195            .filter(false_filter)?
196            .build()?;
197
198        let true_filter = lit(true);
199        let plan = LogicalPlanBuilder::from(plan1)
200            .project(vec![col("a")])?
201            .filter(true_filter)?
202            .build()?;
203
204        // Filter is removed
205        let expected = "Projection: test.a\
206            \n  EmptyRelation";
207        assert_optimized_plan_equal(plan, expected)
208    }
209}