use datafusion_common::tree_node::Transformed;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{EmptyRelation, Expr, Filter, LogicalPlan};
use std::sync::Arc;
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
#[derive(Default, Debug)]
pub struct EliminateFilter;
impl EliminateFilter {
#[expect(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
impl OptimizerRule for EliminateFilter {
fn name(&self) -> &str {
"eliminate_filter"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}
fn supports_rewrite(&self) -> bool {
true
}
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
match plan {
LogicalPlan::Filter(Filter {
predicate: Expr::Literal(ScalarValue::Boolean(v), _),
input,
..
}) => match v {
Some(true) => Ok(Transformed::yes(Arc::unwrap_or_clone(input))),
Some(false) | None => Ok(Transformed::yes(LogicalPlan::EmptyRelation(
EmptyRelation {
produce_one_row: false,
schema: Arc::clone(input.schema()),
},
))),
},
_ => Ok(Transformed::no(plan)),
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::OptimizerContext;
use crate::assert_optimized_plan_eq_snapshot;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{Expr, col, lit, logical_plan::builder::LogicalPlanBuilder};
use crate::eliminate_filter::EliminateFilter;
use crate::test::*;
use datafusion_expr::test::function_stub::sum;
macro_rules! assert_optimized_plan_equal {
(
$plan:expr,
@ $expected:literal $(,)?
) => {{
let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(EliminateFilter::new())];
assert_optimized_plan_eq_snapshot!(
optimizer_ctx,
rules,
$plan,
@ $expected,
)
}};
}
#[test]
fn filter_false() -> Result<()> {
let filter_expr = lit(false);
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])?
.filter(filter_expr)?
.build()?;
assert_optimized_plan_equal!(plan, @"EmptyRelation: rows=0")
}
#[test]
fn filter_null() -> Result<()> {
let filter_expr = Expr::Literal(ScalarValue::Boolean(None), None);
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])?
.filter(filter_expr)?
.build()?;
assert_optimized_plan_equal!(plan, @"EmptyRelation: rows=0")
}
#[test]
fn filter_false_nested() -> Result<()> {
let filter_expr = lit(false);
let table_scan = test_table_scan()?;
let plan1 = LogicalPlanBuilder::from(table_scan.clone())
.aggregate(vec![col("a")], vec![sum(col("b"))])?
.build()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])?
.filter(filter_expr)?
.union(plan1)?
.build()?;
assert_optimized_plan_equal!(plan, @r"
Union
EmptyRelation: rows=0
Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
TableScan: test
")
}
#[test]
fn filter_true() -> Result<()> {
let filter_expr = lit(true);
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])?
.filter(filter_expr)?
.build()?;
assert_optimized_plan_equal!(plan, @r"
Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
TableScan: test
")
}
#[test]
fn filter_true_nested() -> Result<()> {
let filter_expr = lit(true);
let table_scan = test_table_scan()?;
let plan1 = LogicalPlanBuilder::from(table_scan.clone())
.aggregate(vec![col("a")], vec![sum(col("b"))])?
.build()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])?
.filter(filter_expr)?
.union(plan1)?
.build()?;
assert_optimized_plan_equal!(plan, @r"
Union
Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
TableScan: test
Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
TableScan: test
")
}
#[test]
fn filter_from_subquery() -> Result<()> {
let false_filter = lit(false);
let table_scan = test_table_scan()?;
let plan1 = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a")])?
.filter(false_filter)?
.build()?;
let true_filter = lit(true);
let plan = LogicalPlanBuilder::from(plan1)
.project(vec![col("a")])?
.filter(true_filter)?
.build()?;
assert_optimized_plan_equal!(plan, @r"
Projection: test.a
EmptyRelation: rows=0
")
}
}