use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
logical_plan::{EmptyRelation, Filter, LogicalPlan},
utils::from_plan,
Expr,
};
#[derive(Default)]
pub struct EliminateFilter;
impl EliminateFilter {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
impl OptimizerRule for EliminateFilter {
fn optimize(
&self,
plan: &LogicalPlan,
optimizer_config: &OptimizerConfig,
) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Filter(Filter {
predicate: Expr::Literal(ScalarValue::Boolean(Some(v))),
input,
}) => {
if !*v {
Ok(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: input.schema().clone(),
}))
} else {
Ok((**input).clone())
}
}
_ => {
let inputs = plan.inputs();
let new_inputs = inputs
.iter()
.map(|plan| self.optimize(plan, optimizer_config))
.collect::<Result<Vec<_>>>()?;
from_plan(plan, &plan.expressions(), &new_inputs)
}
}
}
fn name(&self) -> &str {
"eliminate_filter"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test::*;
use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder, sum};
fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
let rule = EliminateFilter::new();
let optimized_plan = rule
.optimize(plan, &OptimizerConfig::new())
.expect("failed to optimize plan");
let formatted_plan = format!("{:?}", optimized_plan);
assert_eq!(formatted_plan, expected);
assert_eq!(plan.schema(), optimized_plan.schema());
}
#[test]
fn fliter_false() {
let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(false)));
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])
.unwrap()
.filter(filter_expr)
.unwrap()
.build()
.unwrap();
let expected = "EmptyRelation";
assert_optimized_plan_eq(&plan, expected);
}
#[test]
fn fliter_false_nested() {
let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(false)));
let table_scan = test_table_scan().unwrap();
let plan1 = LogicalPlanBuilder::from(table_scan.clone())
.aggregate(vec![col("a")], vec![sum(col("b"))])
.unwrap()
.build()
.unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])
.unwrap()
.filter(filter_expr)
.unwrap()
.union(plan1)
.unwrap()
.build()
.unwrap();
let expected = "Union\
\n EmptyRelation\
\n Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b)]]\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
}
#[test]
fn fliter_true() {
let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(true)));
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])
.unwrap()
.filter(filter_expr)
.unwrap()
.build()
.unwrap();
let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b)]]\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
}
#[test]
fn fliter_true_nested() {
let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(true)));
let table_scan = test_table_scan().unwrap();
let plan1 = LogicalPlanBuilder::from(table_scan.clone())
.aggregate(vec![col("a")], vec![sum(col("b"))])
.unwrap()
.build()
.unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])
.unwrap()
.filter(filter_expr)
.unwrap()
.union(plan1)
.unwrap()
.build()
.unwrap();
let expected = "Union\
\n Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b)]]\
\n TableScan: test\
\n Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b)]]\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
}
}