use crate::error::Result;
use crate::logical_plan::Expr;
use crate::logical_plan::{and, LogicalPlan};
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
use std::{
collections::{BTreeMap, HashMap, HashSet},
sync::Arc,
};
pub struct FilterPushDown {}
impl OptimizerRule for FilterPushDown {
fn name(&self) -> &str {
return "filter_push_down";
}
fn optimize(&mut self, plan: &LogicalPlan) -> Result<LogicalPlan> {
let result = analyze_plan(plan, 0)?;
let break_points = result.break_points.clone();
let max_depth = break_points.keys().max();
if max_depth.is_none() {
return Ok(plan.clone());
}
let max_depth = *max_depth.unwrap();
let mut new_filtersnew_filters: BTreeMap<usize, Expr> = BTreeMap::new();
for (filter_depth, expr) in result.filters {
let mut filter_columns: HashSet<String> = HashSet::new();
utils::expr_to_column_names(&expr, &mut filter_columns)?;
let mut new_depth = filter_depth;
for depth in filter_depth..max_depth {
if let Some(break_columns) = break_points.get(&depth) {
if filter_columns
.intersection(break_columns)
.peekable()
.peek()
.is_none()
{
new_depth += 1
} else {
break;
}
} else {
new_depth += 1
}
}
let mut new_expression = expr.clone();
for depth_i in filter_depth..new_depth {
if let Some(projection) = result.projections.get(&depth_i) {
new_expression = rewrite(&new_expression, projection)?;
}
}
if let Some(existing_expression) = new_filtersnew_filters.get(&new_depth) {
new_expression = and(existing_expression, &new_expression)
}
new_filtersnew_filters.insert(new_depth, new_expression);
}
optimize_plan(plan, &new_filtersnew_filters, 0)
}
}
struct AnalysisResult {
pub break_points: BTreeMap<usize, HashSet<String>>,
pub filters: BTreeMap<usize, Expr>,
pub projections: BTreeMap<usize, HashMap<String, Expr>>,
}
fn analyze_plan(plan: &LogicalPlan, depth: usize) -> Result<AnalysisResult> {
match plan {
LogicalPlan::Filter { input, predicate } => {
let mut result = analyze_plan(&input, depth + 1)?;
result.filters.insert(depth, predicate.clone());
Ok(result)
}
LogicalPlan::Projection {
input,
expr,
schema,
} => {
let mut result = analyze_plan(&input, depth + 1)?;
let mut projection = HashMap::new();
schema.fields().iter().enumerate().for_each(|(i, field)| {
let expr = match &expr[i] {
Expr::Alias(expr, _) => expr.as_ref().clone(),
expr => expr.clone(),
};
projection.insert(field.name().clone(), expr);
});
result.projections.insert(depth, projection);
Ok(result)
}
LogicalPlan::Aggregate {
input, aggr_expr, ..
} => {
let mut result = analyze_plan(&input, depth + 1)?;
let mut agg_columns = HashSet::new();
utils::exprlist_to_column_names(aggr_expr, &mut agg_columns)?;
let mut columns = agg_columns.iter().cloned().collect::<HashSet<_>>();
let agg_columns = aggr_expr
.iter()
.map(|x| x.name(input.schema()))
.collect::<Result<HashSet<_>>>()?;
columns.extend(agg_columns);
result.break_points.insert(depth, columns);
Ok(result)
}
LogicalPlan::Sort { input, .. } => analyze_plan(&input, depth + 1),
LogicalPlan::Limit { input, .. } => {
let mut result = analyze_plan(&input, depth + 1)?;
let columns = input
.schema()
.fields()
.iter()
.map(|f| f.name().clone())
.collect::<HashSet<_>>();
result.break_points.insert(depth, columns);
Ok(result)
}
_ => {
let columns = plan
.schema()
.fields()
.iter()
.map(|f| f.name().clone())
.collect::<HashSet<_>>();
let mut break_points = BTreeMap::new();
break_points.insert(depth, columns);
Ok(AnalysisResult {
break_points,
filters: BTreeMap::new(),
projections: BTreeMap::new(),
})
}
}
}
impl FilterPushDown {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
fn optimize_plan(
plan: &LogicalPlan,
new_filters: &BTreeMap<usize, Expr>,
depth: usize,
) -> Result<LogicalPlan> {
let new_plan = match plan {
LogicalPlan::Filter { input, .. } => {
Ok(optimize_plan(&input, new_filters, depth + 1)?)
}
_ => {
let expr = utils::expressions(plan);
let inputs = utils::inputs(plan);
let new_inputs = inputs
.iter()
.map(|plan| optimize_plan(plan, new_filters, depth + 1))
.collect::<Result<Vec<_>>>()?;
utils::from_plan(plan, &expr, &new_inputs)
}
}?;
if let Some(expr) = new_filters.get(&depth) {
return Ok(LogicalPlan::Filter {
predicate: expr.clone(),
input: Arc::new(new_plan),
});
} else {
Ok(new_plan)
}
}
fn rewrite(expr: &Expr, projection: &HashMap<String, Expr>) -> Result<Expr> {
let expressions = utils::expr_sub_expressions(&expr)?;
let expressions = expressions
.iter()
.map(|e| rewrite(e, &projection))
.collect::<Result<Vec<_>>>()?;
match expr {
Expr::Column(name) => {
if let Some(expr) = projection.get(name) {
return Ok(expr.clone());
}
}
_ => {}
}
utils::rewrite_expression(&expr, &expressions)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::logical_plan::col;
use crate::logical_plan::{lit, sum, Expr, LogicalPlanBuilder, Operator};
use crate::test::*;
fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
let mut rule = FilterPushDown::new();
let optimized_plan = rule.optimize(plan).expect("failed to optimize plan");
let formatted_plan = format!("{:?}", optimized_plan);
assert_eq!(formatted_plan, expected);
}
#[test]
fn filter_before_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("a"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: #a, #b\
\n Filter: #a Eq Int64(1)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_after_limit() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("a"), col("b")])?
.limit(10)?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Filter: #a Eq Int64(1)\
\n Limit: 10\
\n Projection: #a, #b\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_jump_2_plans() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.project(vec![col("c"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: #c, #b\
\n Projection: #a, #b, #c\
\n Filter: #a Eq Int64(1)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_move_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
.filter(col("a").gt(lit(10i64)))?
.build()?;
let expected = "\
Aggregate: groupBy=[[#a]], aggr=[[SUM(#b) AS total_salary]]\
\n Filter: #a Gt Int64(10)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_keep_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
.filter(col("b").gt(lit(10i64)))?
.build()?;
let expected = "\
Filter: #b Gt Int64(10)\
\n Aggregate: groupBy=[[#a]], aggr=[[SUM(#b) AS b]]\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(col("b").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: #a AS b, #c\
\n Filter: #a Eq Int64(1)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
fn add(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr {
left: Box::new(left),
op: Operator::Plus,
right: Box::new(right),
}
}
fn multiply(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr {
left: Box::new(left),
op: Operator::Multiply,
right: Box::new(right),
}
}
#[test]
fn complex_expression() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![
add(multiply(col("a"), lit(2)), col("c")).alias("b"),
col("c"),
])?
.filter(col("b").eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #b Eq Int64(1)\
\n Projection: #a Multiply Int32(2) Plus #c AS b, #c\
\n TableScan: test projection=None"
);
let expected = "\
Projection: #a Multiply Int32(2) Plus #c AS b, #c\
\n Filter: #a Multiply Int32(2) Plus #c Eq Int64(1)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn complex_plan() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![
add(multiply(col("a"), lit(2)), col("c")).alias("b"),
col("c"),
])?
.project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #a Eq Int64(1)\
\n Projection: #b Multiply Int32(3) AS a, #c\
\n Projection: #a Multiply Int32(2) Plus #c AS b, #c\
\n TableScan: test projection=None"
);
let expected = "\
Projection: #b Multiply Int32(3) AS a, #c\
\n Projection: #a Multiply Int32(2) Plus #c AS b, #c\
\n Filter: #a Multiply Int32(2) Plus #c Multiply Int32(3) Eq Int64(1)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn multi_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.aggregate(vec![col("b")], vec![sum(col("c"))])?
.filter(col("b").gt(lit(10i64)))?
.filter(col("SUM(c)").gt(lit(10i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #SUM(c) Gt Int64(10)\
\n Filter: #b Gt Int64(10)\
\n Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\
\n Projection: #a AS b, #c\
\n TableScan: test projection=None"
);
let expected = "\
Filter: #SUM(c) Gt Int64(10)\
\n Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\
\n Projection: #a AS b, #c\
\n Filter: #a Gt Int64(10)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn double_limit() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("a"), col("b")])?
.limit(20)?
.limit(10)?
.project(vec![col("a"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: #a, #b\
\n Filter: #a Eq Int64(1)\
\n Limit: 10\
\n Limit: 20\
\n Projection: #a, #b\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_2_breaks_limits() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("a")])?
.filter(col("a").lt_eq(lit(1i64)))?
.limit(1)?
.project(vec![col("a")])?
.filter(col("a").gt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: #a GtEq Int64(1)\
\n Projection: #a\
\n Limit: 1\
\n Filter: #a LtEq Int64(1)\
\n Projection: #a\
\n TableScan: test projection=None"
);
let expected = "\
Projection: #a\
\n Filter: #a GtEq Int64(1)\
\n Limit: 1\
\n Projection: #a\
\n Filter: #a LtEq Int64(1)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn two_filters_on_same_depth() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.limit(1)?
.filter(col("a").lt_eq(lit(1i64)))?
.filter(col("a").gt_eq(lit(1i64)))?
.project(vec![col("a")])?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Projection: #a\
\n Filter: #a GtEq Int64(1)\
\n Filter: #a LtEq Int64(1)\
\n Limit: 1\
\n TableScan: test projection=None"
);
let expected = "\
Projection: #a\
\n Filter: #a GtEq Int64(1) And #a LtEq Int64(1)\
\n Limit: 1\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
}