datafusion_optimizer/
eliminate_filter.rs1use datafusion_common::tree_node::Transformed;
21use datafusion_common::{Result, ScalarValue};
22use datafusion_expr::{EmptyRelation, Expr, Filter, LogicalPlan};
23use std::sync::Arc;
24
25use crate::optimizer::ApplyOrder;
26use crate::{OptimizerConfig, OptimizerRule};
27
28#[derive(Default, Debug)]
34pub struct EliminateFilter;
35
36impl EliminateFilter {
37 #[allow(missing_docs)]
38 pub fn new() -> Self {
39 Self {}
40 }
41}
42
43impl OptimizerRule for EliminateFilter {
44 fn name(&self) -> &str {
45 "eliminate_filter"
46 }
47
48 fn apply_order(&self) -> Option<ApplyOrder> {
49 Some(ApplyOrder::TopDown)
50 }
51
52 fn supports_rewrite(&self) -> bool {
53 true
54 }
55
56 fn rewrite(
57 &self,
58 plan: LogicalPlan,
59 _config: &dyn OptimizerConfig,
60 ) -> Result<Transformed<LogicalPlan>> {
61 match plan {
62 LogicalPlan::Filter(Filter {
63 predicate: Expr::Literal(ScalarValue::Boolean(v), _),
64 input,
65 ..
66 }) => match v {
67 Some(true) => Ok(Transformed::yes(Arc::unwrap_or_clone(input))),
68 Some(false) | None => Ok(Transformed::yes(LogicalPlan::EmptyRelation(
69 EmptyRelation {
70 produce_one_row: false,
71 schema: Arc::clone(input.schema()),
72 },
73 ))),
74 },
75 _ => Ok(Transformed::no(plan)),
76 }
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use std::sync::Arc;
83
84 use crate::assert_optimized_plan_eq_snapshot;
85 use crate::OptimizerContext;
86 use datafusion_common::{Result, ScalarValue};
87 use datafusion_expr::{col, lit, logical_plan::builder::LogicalPlanBuilder, Expr};
88
89 use crate::eliminate_filter::EliminateFilter;
90 use crate::test::*;
91 use datafusion_expr::test::function_stub::sum;
92
93 macro_rules! assert_optimized_plan_equal {
94 (
95 $plan:expr,
96 @ $expected:literal $(,)?
97 ) => {{
98 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
99 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(EliminateFilter::new())];
100 assert_optimized_plan_eq_snapshot!(
101 optimizer_ctx,
102 rules,
103 $plan,
104 @ $expected,
105 )
106 }};
107 }
108
109 #[test]
110 fn filter_false() -> Result<()> {
111 let filter_expr = lit(false);
112
113 let table_scan = test_table_scan().unwrap();
114 let plan = LogicalPlanBuilder::from(table_scan)
115 .aggregate(vec![col("a")], vec![sum(col("b"))])?
116 .filter(filter_expr)?
117 .build()?;
118
119 assert_optimized_plan_equal!(plan, @"EmptyRelation: rows=0")
121 }
122
123 #[test]
124 fn filter_null() -> Result<()> {
125 let filter_expr = Expr::Literal(ScalarValue::Boolean(None), None);
126
127 let table_scan = test_table_scan().unwrap();
128 let plan = LogicalPlanBuilder::from(table_scan)
129 .aggregate(vec![col("a")], vec![sum(col("b"))])?
130 .filter(filter_expr)?
131 .build()?;
132
133 assert_optimized_plan_equal!(plan, @"EmptyRelation: rows=0")
135 }
136
137 #[test]
138 fn filter_false_nested() -> Result<()> {
139 let filter_expr = lit(false);
140
141 let table_scan = test_table_scan()?;
142 let plan1 = LogicalPlanBuilder::from(table_scan.clone())
143 .aggregate(vec![col("a")], vec![sum(col("b"))])?
144 .build()?;
145 let plan = LogicalPlanBuilder::from(table_scan)
146 .aggregate(vec![col("a")], vec![sum(col("b"))])?
147 .filter(filter_expr)?
148 .union(plan1)?
149 .build()?;
150
151 assert_optimized_plan_equal!(plan, @r"
153 Union
154 EmptyRelation: rows=0
155 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
156 TableScan: test
157 ")
158 }
159
160 #[test]
161 fn filter_true() -> Result<()> {
162 let filter_expr = lit(true);
163
164 let table_scan = test_table_scan()?;
165 let plan = LogicalPlanBuilder::from(table_scan)
166 .aggregate(vec![col("a")], vec![sum(col("b"))])?
167 .filter(filter_expr)?
168 .build()?;
169
170 assert_optimized_plan_equal!(plan, @r"
171 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
172 TableScan: test
173 ")
174 }
175
176 #[test]
177 fn filter_true_nested() -> Result<()> {
178 let filter_expr = lit(true);
179
180 let table_scan = test_table_scan()?;
181 let plan1 = LogicalPlanBuilder::from(table_scan.clone())
182 .aggregate(vec![col("a")], vec![sum(col("b"))])?
183 .build()?;
184 let plan = LogicalPlanBuilder::from(table_scan)
185 .aggregate(vec![col("a")], vec![sum(col("b"))])?
186 .filter(filter_expr)?
187 .union(plan1)?
188 .build()?;
189
190 assert_optimized_plan_equal!(plan, @r"
192 Union
193 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
194 TableScan: test
195 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
196 TableScan: test
197 ")
198 }
199
200 #[test]
201 fn filter_from_subquery() -> Result<()> {
202 let false_filter = lit(false);
205 let table_scan = test_table_scan()?;
206 let plan1 = LogicalPlanBuilder::from(table_scan)
207 .project(vec![col("a")])?
208 .filter(false_filter)?
209 .build()?;
210
211 let true_filter = lit(true);
212 let plan = LogicalPlanBuilder::from(plan1)
213 .project(vec![col("a")])?
214 .filter(true_filter)?
215 .build()?;
216
217 assert_optimized_plan_equal!(plan, @r"
219 Projection: test.a
220 EmptyRelation: rows=0
221 ")
222 }
223}