use crate::error::Result;
use crate::logical_plan::{DFField, DFSchema, DFSchemaRef, LogicalPlan, ToDFSchema};
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
use arrow::datatypes::Schema;
use arrow::error::Result as ArrowResult;
use std::{collections::HashSet, sync::Arc};
use utils::optimize_explain;
pub struct ProjectionPushDown {}
impl OptimizerRule for ProjectionPushDown {
fn optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan> {
let required_columns = plan
.schema()
.fields()
.iter()
.map(|f| f.name().clone())
.collect::<HashSet<String>>();
optimize_plan(self, plan, &required_columns, false)
}
fn name(&self) -> &str {
"projection_push_down"
}
}
impl ProjectionPushDown {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
fn get_projected_schema(
schema: &Schema,
required_columns: &HashSet<String>,
has_projection: bool,
) -> Result<(Vec<usize>, DFSchemaRef)> {
let mut projection: Vec<usize> = required_columns
.iter()
.map(|name| schema.index_of(name))
.filter_map(ArrowResult::ok)
.collect();
if projection.is_empty() {
if has_projection {
projection.push(0);
} else {
projection = schema
.fields()
.iter()
.enumerate()
.map(|(i, _)| i)
.collect::<Vec<usize>>();
}
}
projection.sort_unstable();
let mut projected_fields: Vec<DFField> = Vec::with_capacity(projection.len());
for i in &projection {
projected_fields.push(DFField::from(schema.fields()[*i].clone()));
}
Ok((projection, projected_fields.to_dfschema_ref()?))
}
fn optimize_plan(
optimizer: &ProjectionPushDown,
plan: &LogicalPlan,
required_columns: &HashSet<String>,
has_projection: bool,
) -> Result<LogicalPlan> {
let mut new_required_columns = required_columns.clone();
match plan {
LogicalPlan::Projection {
input,
expr,
schema,
} => {
let mut new_expr = Vec::new();
let mut new_fields = Vec::new();
schema
.fields()
.iter()
.enumerate()
.try_for_each(|(i, field)| {
if required_columns.contains(field.name()) {
new_expr.push(expr[i].clone());
new_fields.push(field.clone());
utils::expr_to_column_names(&expr[i], &mut new_required_columns)
} else {
Ok(())
}
})?;
let new_input =
optimize_plan(optimizer, &input, &new_required_columns, true)?;
if new_fields.is_empty() {
Ok(new_input)
} else {
Ok(LogicalPlan::Projection {
expr: new_expr,
input: Arc::new(new_input),
schema: DFSchemaRef::new(DFSchema::new(new_fields)?),
})
}
}
LogicalPlan::Join {
left,
right,
on,
join_type,
schema,
} => {
for (l, r) in on {
new_required_columns.insert(l.to_owned());
new_required_columns.insert(r.to_owned());
}
Ok(LogicalPlan::Join {
left: Arc::new(optimize_plan(
optimizer,
&left,
&new_required_columns,
true,
)?),
right: Arc::new(optimize_plan(
optimizer,
&right,
&new_required_columns,
true,
)?),
join_type: *join_type,
on: on.clone(),
schema: schema.clone(),
})
}
LogicalPlan::Aggregate {
schema,
input,
group_expr,
aggr_expr,
..
} => {
utils::exprlist_to_column_names(group_expr, &mut new_required_columns)?;
let mut new_aggr_expr = Vec::new();
aggr_expr.iter().try_for_each(|expr| {
let name = &expr.name(&schema)?;
if required_columns.contains(name) {
new_aggr_expr.push(expr.clone());
new_required_columns.insert(name.clone());
utils::expr_to_column_names(expr, &mut new_required_columns)
} else {
Ok(())
}
})?;
let new_schema = DFSchema::new(
schema
.fields()
.iter()
.filter(|x| new_required_columns.contains(x.name()))
.cloned()
.collect(),
)?;
Ok(LogicalPlan::Aggregate {
group_expr: group_expr.clone(),
aggr_expr: new_aggr_expr,
input: Arc::new(optimize_plan(
optimizer,
&input,
&new_required_columns,
true,
)?),
schema: DFSchemaRef::new(new_schema),
})
}
LogicalPlan::TableScan {
table_name,
source,
filters,
limit,
..
} => {
let (projection, projected_schema) =
get_projected_schema(&source.schema(), required_columns, has_projection)?;
Ok(LogicalPlan::TableScan {
table_name: table_name.to_string(),
source: source.clone(),
projection: Some(projection),
projected_schema,
filters: filters.clone(),
limit: *limit,
})
}
LogicalPlan::Explain {
verbose,
plan,
stringified_plans,
schema,
} => {
let schema = schema.as_ref().to_owned().into();
optimize_explain(optimizer, *verbose, &*plan, stringified_plans, &schema)
}
LogicalPlan::Limit { .. }
| LogicalPlan::Filter { .. }
| LogicalPlan::Repartition { .. }
| LogicalPlan::EmptyRelation { .. }
| LogicalPlan::Sort { .. }
| LogicalPlan::CreateExternalTable { .. }
| LogicalPlan::Union { .. }
| LogicalPlan::Extension { .. } => {
let expr = plan.expressions();
utils::exprlist_to_column_names(&expr, &mut new_required_columns)?;
let inputs = plan.inputs();
let new_inputs = inputs
.iter()
.map(|plan| {
optimize_plan(optimizer, plan, &new_required_columns, has_projection)
})
.collect::<Result<Vec<_>>>()?;
utils::from_plan(plan, &expr, &new_inputs)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::logical_plan::{col, lit};
use crate::logical_plan::{max, min, Expr, LogicalPlanBuilder};
use crate::test::*;
use arrow::datatypes::DataType;
#[test]
fn aggregate_no_group_by() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.aggregate(vec![], vec![max(col("b"))])?
.build()?;
let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\
\n TableScan: test projection=Some([1])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn aggregate_group_by() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.aggregate(vec![col("c")], vec![max(col("b"))])?
.build()?;
let expected = "Aggregate: groupBy=[[#c]], aggr=[[MAX(#b)]]\
\n TableScan: test projection=Some([1, 2])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn aggregate_no_group_by_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.filter(col("c"))?
.aggregate(vec![], vec![max(col("b"))])?
.build()?;
let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\
\n Filter: #c\
\n TableScan: test projection=Some([1, 2])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn cast() -> Result<()> {
let table_scan = test_table_scan()?;
let projection = LogicalPlanBuilder::from(&table_scan)
.project(vec![Expr::Cast {
expr: Box::new(col("c")),
data_type: DataType::Float64,
}])?
.build()?;
let expected = "Projection: CAST(#c AS Float64)\
\n TableScan: test projection=Some([2])";
assert_optimized_plan_eq(&projection, expected);
Ok(())
}
#[test]
fn table_scan_projected_schema() -> Result<()> {
let table_scan = test_table_scan()?;
assert_eq!(3, table_scan.schema().fields().len());
assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
assert_fields_eq(&plan, vec!["a", "b"]);
let expected = "Projection: #a, #b\
\n TableScan: test projection=Some([0, 1])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn table_limit() -> Result<()> {
let table_scan = test_table_scan()?;
assert_eq!(3, table_scan.schema().fields().len());
assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("c"), col("a")])?
.limit(5)?
.build()?;
assert_fields_eq(&plan, vec!["c", "a"]);
let expected = "Limit: 5\
\n Projection: #c, #a\
\n TableScan: test projection=Some([0, 2])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn table_scan_without_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan).build()?;
let expected = "TableScan: test projection=Some([0, 1, 2])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn table_scan_with_literal_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![lit(1_i64), lit(2_i64)])?
.build()?;
let expected = "Projection: Int64(1), Int64(2)\
\n TableScan: test projection=Some([0])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn table_unused_column() -> Result<()> {
let table_scan = test_table_scan()?;
assert_eq!(3, table_scan.schema().fields().len());
assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("c"), col("a"), col("b")])?
.filter(col("c").gt(lit(1)))?
.aggregate(vec![col("c")], vec![max(col("a"))])?
.build()?;
assert_fields_eq(&plan, vec!["c", "MAX(a)"]);
let expected = "\
Aggregate: groupBy=[[#c]], aggr=[[MAX(#a)]]\
\n Filter: #c Gt Int32(1)\
\n Projection: #c, #a\
\n TableScan: test projection=Some([0, 2])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn table_unused_projection() -> Result<()> {
let table_scan = test_table_scan()?;
assert_eq!(3, table_scan.schema().fields().len());
assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("b")])?
.project(vec![lit(1).alias("a")])?
.build()?;
assert_fields_eq(&plan, vec!["a"]);
let expected = "\
Projection: Int32(1) AS a\
\n TableScan: test projection=Some([0])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn test_double_optimization() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("b")])?
.project(vec![lit(1).alias("a")])?
.build()?;
let optimized_plan1 = optimize(&plan).expect("failed to optimize plan");
let optimized_plan2 =
optimize(&optimized_plan1).expect("failed to optimize plan");
let formatted_plan1 = format!("{:?}", optimized_plan1);
let formatted_plan2 = format!("{:?}", optimized_plan2);
assert_eq!(formatted_plan1, formatted_plan2);
Ok(())
}
#[test]
fn table_unused_aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
assert_eq!(3, table_scan.schema().fields().len());
assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
let plan = LogicalPlanBuilder::from(&table_scan)
.aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])?
.filter(col("c").gt(lit(1)))?
.project(vec![col("c"), col("a"), col("MAX(b)")])?
.build()?;
assert_fields_eq(&plan, vec!["c", "a", "MAX(b)"]);
let expected = "\
Projection: #c, #a, #MAX(b)\
\n Filter: #c Gt Int32(1)\
\n Aggregate: groupBy=[[#a, #c]], aggr=[[MAX(#b)]]\
\n TableScan: test projection=Some([0, 1, 2])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
let optimized_plan = optimize(plan).expect("failed to optimize plan");
let formatted_plan = format!("{:?}", optimized_plan);
assert_eq!(formatted_plan, expected);
}
fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
let rule = ProjectionPushDown::new();
rule.optimize(plan)
}
}