datafusion_optimizer/
eliminate_filter.rs1use datafusion_common::tree_node::Transformed;
21use datafusion_common::{Result, ScalarValue};
22use datafusion_expr::{EmptyRelation, Expr, Filter, LogicalPlan};
23use std::sync::Arc;
24
25use crate::optimizer::ApplyOrder;
26use crate::{OptimizerConfig, OptimizerRule};
27
28#[derive(Default, Debug)]
34pub struct EliminateFilter;
35
36impl EliminateFilter {
37 #[allow(missing_docs)]
38 pub fn new() -> Self {
39 Self {}
40 }
41}
42
43impl OptimizerRule for EliminateFilter {
44 fn name(&self) -> &str {
45 "eliminate_filter"
46 }
47
48 fn apply_order(&self) -> Option<ApplyOrder> {
49 Some(ApplyOrder::TopDown)
50 }
51
52 fn supports_rewrite(&self) -> bool {
53 true
54 }
55
56 fn rewrite(
57 &self,
58 plan: LogicalPlan,
59 _config: &dyn OptimizerConfig,
60 ) -> Result<Transformed<LogicalPlan>> {
61 match plan {
62 LogicalPlan::Filter(Filter {
63 predicate: Expr::Literal(ScalarValue::Boolean(v)),
64 input,
65 ..
66 }) => match v {
67 Some(true) => Ok(Transformed::yes(Arc::unwrap_or_clone(input))),
68 Some(false) | None => Ok(Transformed::yes(LogicalPlan::EmptyRelation(
69 EmptyRelation {
70 produce_one_row: false,
71 schema: Arc::clone(input.schema()),
72 },
73 ))),
74 },
75 _ => Ok(Transformed::no(plan)),
76 }
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use std::sync::Arc;
83
84 use datafusion_common::{Result, ScalarValue};
85 use datafusion_expr::{
86 col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan,
87 };
88
89 use crate::eliminate_filter::EliminateFilter;
90 use crate::test::*;
91 use datafusion_expr::test::function_stub::sum;
92
93 fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
94 assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected)
95 }
96
97 #[test]
98 fn filter_false() -> Result<()> {
99 let filter_expr = lit(false);
100
101 let table_scan = test_table_scan().unwrap();
102 let plan = LogicalPlanBuilder::from(table_scan)
103 .aggregate(vec![col("a")], vec![sum(col("b"))])?
104 .filter(filter_expr)?
105 .build()?;
106
107 let expected = "EmptyRelation";
109 assert_optimized_plan_equal(plan, expected)
110 }
111
112 #[test]
113 fn filter_null() -> Result<()> {
114 let filter_expr = Expr::Literal(ScalarValue::Boolean(None));
115
116 let table_scan = test_table_scan().unwrap();
117 let plan = LogicalPlanBuilder::from(table_scan)
118 .aggregate(vec![col("a")], vec![sum(col("b"))])?
119 .filter(filter_expr)?
120 .build()?;
121
122 let expected = "EmptyRelation";
124 assert_optimized_plan_equal(plan, expected)
125 }
126
127 #[test]
128 fn filter_false_nested() -> Result<()> {
129 let filter_expr = lit(false);
130
131 let table_scan = test_table_scan()?;
132 let plan1 = LogicalPlanBuilder::from(table_scan.clone())
133 .aggregate(vec![col("a")], vec![sum(col("b"))])?
134 .build()?;
135 let plan = LogicalPlanBuilder::from(table_scan)
136 .aggregate(vec![col("a")], vec![sum(col("b"))])?
137 .filter(filter_expr)?
138 .union(plan1)?
139 .build()?;
140
141 let expected = "Union\
143 \n EmptyRelation\
144 \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
145 \n TableScan: test";
146 assert_optimized_plan_equal(plan, expected)
147 }
148
149 #[test]
150 fn filter_true() -> Result<()> {
151 let filter_expr = lit(true);
152
153 let table_scan = test_table_scan()?;
154 let plan = LogicalPlanBuilder::from(table_scan)
155 .aggregate(vec![col("a")], vec![sum(col("b"))])?
156 .filter(filter_expr)?
157 .build()?;
158
159 let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
160 \n TableScan: test";
161 assert_optimized_plan_equal(plan, expected)
162 }
163
164 #[test]
165 fn filter_true_nested() -> Result<()> {
166 let filter_expr = lit(true);
167
168 let table_scan = test_table_scan()?;
169 let plan1 = LogicalPlanBuilder::from(table_scan.clone())
170 .aggregate(vec![col("a")], vec![sum(col("b"))])?
171 .build()?;
172 let plan = LogicalPlanBuilder::from(table_scan)
173 .aggregate(vec![col("a")], vec![sum(col("b"))])?
174 .filter(filter_expr)?
175 .union(plan1)?
176 .build()?;
177
178 let expected = "Union\
180 \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
181 \n TableScan: test\
182 \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
183 \n TableScan: test";
184 assert_optimized_plan_equal(plan, expected)
185 }
186
187 #[test]
188 fn filter_from_subquery() -> Result<()> {
189 let false_filter = lit(false);
192 let table_scan = test_table_scan()?;
193 let plan1 = LogicalPlanBuilder::from(table_scan)
194 .project(vec![col("a")])?
195 .filter(false_filter)?
196 .build()?;
197
198 let true_filter = lit(true);
199 let plan = LogicalPlanBuilder::from(plan1)
200 .project(vec![col("a")])?
201 .filter(true_filter)?
202 .build()?;
203
204 let expected = "Projection: test.a\
206 \n EmptyRelation";
207 assert_optimized_plan_equal(plan, expected)
208 }
209}