use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, Result};
use datafusion_expr::{
col,
expr::GroupingSet,
expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion},
expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion},
logical_plan::{Aggregate, Filter, LogicalPlan, Projection, Sort, Window},
utils::from_plan,
Expr, ExprSchemable,
};
use std::collections::{HashMap, HashSet};
use std::fmt::Write;
use std::sync::Arc;
type ExprSet = HashMap<Identifier, (Expr, usize, DataType)>;
type Identifier = String;
pub struct CommonSubexprEliminate {}
impl OptimizerRule for CommonSubexprEliminate {
fn optimize(
&self,
plan: &LogicalPlan,
optimizer_config: &mut OptimizerConfig,
) -> Result<LogicalPlan> {
optimize(plan, optimizer_config)
}
fn name(&self) -> &str {
"common_sub_expression_eliminate"
}
}
impl Default for CommonSubexprEliminate {
fn default() -> Self {
Self::new()
}
}
impl CommonSubexprEliminate {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
fn optimize(
plan: &LogicalPlan,
optimizer_config: &OptimizerConfig,
) -> Result<LogicalPlan> {
let mut expr_set = ExprSet::new();
match plan {
LogicalPlan::Projection(Projection {
expr,
input,
schema,
alias,
}) => {
let arrays = to_arrays(expr, input, &mut expr_set)?;
let (mut new_expr, new_input) = rewrite_expr(
&[expr],
&[&arrays],
input,
&mut expr_set,
schema,
optimizer_config,
)?;
Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
new_expr.pop().unwrap(),
Arc::new(new_input),
schema.clone(),
alias.clone(),
)?))
}
LogicalPlan::Filter(Filter { predicate, input }) => {
let schema = plan.schema().as_ref().clone();
let data_type = if let Ok(data_type) = predicate.get_type(&schema) {
data_type
} else {
let schemas = plan.all_schemas();
let all_schema =
schemas.into_iter().fold(DFSchema::empty(), |mut lhs, rhs| {
lhs.merge(rhs);
lhs
});
predicate.get_type(&all_schema)?
};
let mut id_array = vec![];
expr_to_identifier(predicate, &mut expr_set, &mut id_array, data_type)?;
let (mut new_expr, new_input) = rewrite_expr(
&[&[predicate.clone()]],
&[&[id_array]],
input,
&mut expr_set,
input.schema(),
optimizer_config,
)?;
Ok(LogicalPlan::Filter(Filter {
predicate: new_expr.pop().unwrap().pop().unwrap(),
input: Arc::new(new_input),
}))
}
LogicalPlan::Window(Window {
input,
window_expr,
schema,
}) => {
let arrays = to_arrays(window_expr, input, &mut expr_set)?;
let (mut new_expr, new_input) = rewrite_expr(
&[window_expr],
&[&arrays],
input,
&mut expr_set,
schema,
optimizer_config,
)?;
Ok(LogicalPlan::Window(Window {
input: Arc::new(new_input),
window_expr: new_expr.pop().unwrap(),
schema: schema.clone(),
}))
}
LogicalPlan::Aggregate(Aggregate {
group_expr,
aggr_expr,
input,
schema,
}) => {
let group_arrays = to_arrays(group_expr, input, &mut expr_set)?;
let aggr_arrays = to_arrays(aggr_expr, input, &mut expr_set)?;
let (mut new_expr, new_input) = rewrite_expr(
&[group_expr, aggr_expr],
&[&group_arrays, &aggr_arrays],
input,
&mut expr_set,
schema,
optimizer_config,
)?;
let new_aggr_expr = new_expr.pop().unwrap();
let new_group_expr = new_expr.pop().unwrap();
Ok(LogicalPlan::Aggregate(Aggregate {
input: Arc::new(new_input),
group_expr: new_group_expr,
aggr_expr: new_aggr_expr,
schema: schema.clone(),
}))
}
LogicalPlan::Sort(Sort { expr, input }) => {
let arrays = to_arrays(expr, input, &mut expr_set)?;
let (mut new_expr, new_input) = rewrite_expr(
&[expr],
&[&arrays],
input,
&mut expr_set,
input.schema(),
optimizer_config,
)?;
Ok(LogicalPlan::Sort(Sort {
expr: new_expr.pop().unwrap(),
input: Arc::new(new_input),
}))
}
LogicalPlan::Join { .. }
| LogicalPlan::CrossJoin(_)
| LogicalPlan::Repartition(_)
| LogicalPlan::Union(_)
| LogicalPlan::TableScan { .. }
| LogicalPlan::Values(_)
| LogicalPlan::EmptyRelation(_)
| LogicalPlan::Subquery(_)
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::Limit(_)
| LogicalPlan::CreateExternalTable(_)
| LogicalPlan::Explain { .. }
| LogicalPlan::Analyze { .. }
| LogicalPlan::CreateMemoryTable(_)
| LogicalPlan::CreateView(_)
| LogicalPlan::CreateCatalogSchema(_)
| LogicalPlan::CreateCatalog(_)
| LogicalPlan::DropTable(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Extension { .. } => {
let expr = plan.expressions();
let inputs = plan.inputs();
let new_inputs = inputs
.iter()
.map(|input_plan| optimize(input_plan, optimizer_config))
.collect::<Result<Vec<_>>>()?;
from_plan(plan, &expr, &new_inputs)
}
}
}
fn to_arrays(
expr: &[Expr],
input: &LogicalPlan,
expr_set: &mut ExprSet,
) -> Result<Vec<Vec<(usize, String)>>> {
expr.iter()
.map(|e| {
let data_type = e.get_type(input.schema())?;
let mut id_array = vec![];
expr_to_identifier(e, expr_set, &mut id_array, data_type)?;
Ok(id_array)
})
.collect::<Result<Vec<_>>>()
}
fn build_project_plan(
input: LogicalPlan,
affected_id: HashSet<Identifier>,
expr_set: &ExprSet,
) -> Result<LogicalPlan> {
let mut project_exprs = vec![];
let mut fields = vec![];
let mut fields_set = HashSet::new();
for id in affected_id {
let (expr, _, data_type) = expr_set.get(&id).unwrap();
let field = DFField::new(None, &id, data_type.clone(), true);
fields_set.insert(field.name().to_owned());
fields.push(field);
project_exprs.push(expr.clone().alias(&id));
}
for field in input.schema().fields() {
if fields_set.insert(field.qualified_name()) {
fields.push(field.clone());
project_exprs.push(Expr::Column(field.qualified_column()));
}
}
let schema = DFSchema::new_with_metadata(fields, HashMap::new())?;
Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
project_exprs,
Arc::new(input),
Arc::new(schema),
None,
)?))
}
#[inline]
fn rewrite_expr(
exprs_list: &[&[Expr]],
arrays_list: &[&[Vec<(usize, String)>]],
input: &LogicalPlan,
expr_set: &mut ExprSet,
schema: &DFSchema,
optimizer_config: &OptimizerConfig,
) -> Result<(Vec<Vec<Expr>>, LogicalPlan)> {
let mut affected_id = HashSet::<Identifier>::new();
let rewrote_exprs = exprs_list
.iter()
.zip(arrays_list.iter())
.map(|(exprs, arrays)| {
exprs
.iter()
.cloned()
.zip(arrays.iter())
.map(|(expr, id_array)| {
replace_common_expr(
expr,
id_array,
expr_set,
&mut affected_id,
schema,
)
})
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<_>>>()?;
let mut new_input = optimize(input, optimizer_config)?;
if !affected_id.is_empty() {
new_input = build_project_plan(new_input, affected_id, expr_set)?;
}
Ok((rewrote_exprs, new_input))
}
struct ExprIdentifierVisitor<'a> {
expr_set: &'a mut ExprSet,
id_array: &'a mut Vec<(usize, Identifier)>,
data_type: DataType,
visit_stack: Vec<VisitRecord>,
node_count: usize,
series_number: usize,
}
enum VisitRecord {
EnterMark(usize),
ExprItem(Identifier),
}
impl ExprIdentifierVisitor<'_> {
fn desc_expr(expr: &Expr) -> String {
let mut desc = String::new();
match expr {
Expr::Column(column) => {
desc.push_str("Column-");
desc.push_str(&column.flat_name());
}
Expr::ScalarVariable(_, var_names) => {
desc.push_str("ScalarVariable-");
desc.push_str(&var_names.join("."));
}
Expr::Alias(_, alias) => {
desc.push_str("Alias-");
desc.push_str(alias);
}
Expr::Literal(value) => {
desc.push_str("Literal");
desc.push_str(&value.to_string());
}
Expr::BinaryExpr { op, .. } => {
desc.push_str("BinaryExpr-");
desc.push_str(&op.to_string());
}
Expr::Not(_) => {
desc.push_str("Not-");
}
Expr::IsNotNull(_) => {
desc.push_str("IsNotNull-");
}
Expr::IsNull(_) => {
desc.push_str("IsNull-");
}
Expr::Negative(_) => {
desc.push_str("Negative-");
}
Expr::Between { negated, .. } => {
desc.push_str("Between-");
desc.push_str(&negated.to_string());
}
Expr::Case { .. } => {
desc.push_str("Case-");
}
Expr::Cast { data_type, .. } => {
desc.push_str("Cast-");
let _ = write!(desc, "{:?}", data_type);
}
Expr::TryCast { data_type, .. } => {
desc.push_str("TryCast-");
let _ = write!(desc, "{:?}", data_type);
}
Expr::Sort {
asc, nulls_first, ..
} => {
desc.push_str("Sort-");
let _ = write!(desc, "{}{}", asc, nulls_first);
}
Expr::ScalarFunction { fun, .. } => {
desc.push_str("ScalarFunction-");
desc.push_str(&fun.to_string());
}
Expr::ScalarUDF { fun, .. } => {
desc.push_str("ScalarUDF-");
desc.push_str(&fun.name);
}
Expr::WindowFunction {
fun, window_frame, ..
} => {
desc.push_str("WindowFunction-");
desc.push_str(&fun.to_string());
let _ = write!(desc, "{:?}", window_frame);
}
Expr::AggregateFunction { fun, distinct, .. } => {
desc.push_str("AggregateFunction-");
desc.push_str(&fun.to_string());
desc.push_str(&distinct.to_string());
}
Expr::AggregateUDF { fun, .. } => {
desc.push_str("AggregateUDF-");
desc.push_str(&fun.name);
}
Expr::InList { negated, .. } => {
desc.push_str("InList-");
desc.push_str(&negated.to_string());
}
Expr::Exists { negated, .. } => {
desc.push_str("Exists-");
desc.push_str(&negated.to_string());
}
Expr::InSubquery { negated, .. } => {
desc.push_str("InSubquery-");
desc.push_str(&negated.to_string());
}
Expr::ScalarSubquery(_) => {
desc.push_str("ScalarSubquery-");
}
Expr::Wildcard => {
desc.push_str("Wildcard-");
}
Expr::QualifiedWildcard { qualifier } => {
desc.push_str("QualifiedWildcard-");
desc.push_str(qualifier);
}
Expr::GetIndexedField { key, .. } => {
desc.push_str("GetIndexedField-");
desc.push_str(&key.to_string());
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
desc.push_str("Rollup");
for expr in exprs {
desc.push('-');
desc.push_str(&Self::desc_expr(expr));
}
}
GroupingSet::Cube(exprs) => {
desc.push_str("Cube");
for expr in exprs {
desc.push('-');
desc.push_str(&Self::desc_expr(expr));
}
}
GroupingSet::GroupingSets(lists_of_exprs) => {
desc.push_str("GroupingSets");
for exprs in lists_of_exprs {
desc.push('(');
for expr in exprs {
desc.push('-');
desc.push_str(&Self::desc_expr(expr));
}
desc.push(')');
}
}
},
}
desc
}
fn pop_enter_mark(&mut self) -> (usize, Identifier) {
let mut desc = String::new();
while let Some(item) = self.visit_stack.pop() {
match item {
VisitRecord::EnterMark(idx) => {
return (idx, desc);
}
VisitRecord::ExprItem(s) => {
desc.push_str(&s);
}
}
}
unreachable!("Enter mark should paired with node number");
}
}
impl ExpressionVisitor for ExprIdentifierVisitor<'_> {
fn pre_visit(mut self, _expr: &Expr) -> Result<Recursion<Self>> {
self.visit_stack
.push(VisitRecord::EnterMark(self.node_count));
self.node_count += 1;
self.id_array.push((0, "".to_string()));
Ok(Recursion::Continue(self))
}
fn post_visit(mut self, expr: &Expr) -> Result<Self> {
self.series_number += 1;
let (idx, sub_expr_desc) = self.pop_enter_mark();
if matches!(
expr,
Expr::Literal(..)
| Expr::Column(..)
| Expr::ScalarVariable(..)
| Expr::Alias(..)
| Expr::Sort { .. }
| Expr::Wildcard
) {
self.id_array[idx].0 = self.series_number;
let desc = Self::desc_expr(expr);
self.visit_stack.push(VisitRecord::ExprItem(desc));
return Ok(self);
}
let mut desc = Self::desc_expr(expr);
desc.push_str(&sub_expr_desc);
self.id_array[idx] = (self.series_number, desc.clone());
self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));
let data_type = self.data_type.clone();
self.expr_set
.entry(desc)
.or_insert_with(|| (expr.clone(), 0, data_type))
.1 += 1;
Ok(self)
}
}
fn expr_to_identifier(
expr: &Expr,
expr_set: &mut ExprSet,
id_array: &mut Vec<(usize, Identifier)>,
data_type: DataType,
) -> Result<()> {
expr.accept(ExprIdentifierVisitor {
expr_set,
id_array,
data_type,
visit_stack: vec![],
node_count: 0,
series_number: 0,
})?;
Ok(())
}
struct CommonSubexprRewriter<'a> {
expr_set: &'a mut ExprSet,
id_array: &'a [(usize, Identifier)],
affected_id: &'a mut HashSet<Identifier>,
schema: &'a DFSchema,
max_series_number: usize,
curr_index: usize,
}
impl ExprRewriter for CommonSubexprRewriter<'_> {
fn pre_visit(&mut self, _: &Expr) -> Result<RewriteRecursion> {
if self.curr_index >= self.id_array.len()
|| self.max_series_number > self.id_array[self.curr_index].0
{
return Ok(RewriteRecursion::Stop);
}
let curr_id = &self.id_array[self.curr_index].1;
if curr_id.is_empty() {
self.curr_index += 1;
return Ok(RewriteRecursion::Skip);
}
let (_, counter, _) = self.expr_set.get(curr_id).unwrap();
if *counter > 1 {
self.affected_id.insert(curr_id.clone());
Ok(RewriteRecursion::Mutate)
} else {
self.curr_index += 1;
Ok(RewriteRecursion::Skip)
}
}
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
if self.curr_index >= self.id_array.len() {
return Ok(expr);
}
let (series_number, id) = &self.id_array[self.curr_index];
self.curr_index += 1;
if *series_number < self.max_series_number
|| id.is_empty()
|| self.expr_set.get(id).unwrap().1 <= 1
{
return Ok(expr);
}
self.max_series_number = *series_number;
while self.curr_index < self.id_array.len()
&& *series_number > self.id_array[self.curr_index].0
{
self.curr_index += 1;
}
let expr_name = expr.name(self.schema)?;
Ok(col(id).alias(&expr_name))
}
}
fn replace_common_expr(
expr: Expr,
id_array: &[(usize, Identifier)],
expr_set: &mut ExprSet,
affected_id: &mut HashSet<Identifier>,
schema: &DFSchema,
) -> Result<Expr> {
expr.rewrite(&mut CommonSubexprRewriter {
expr_set,
id_array,
affected_id,
schema,
max_series_number: 0,
curr_index: 0,
})
}
#[cfg(test)]
mod test {
use super::*;
use crate::test::*;
use datafusion_expr::logical_plan::JoinType;
use datafusion_expr::{
avg, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, sum,
Operator,
};
use std::iter;
fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
let optimizer = CommonSubexprEliminate {};
let optimized_plan = optimizer
.optimize(plan, &mut OptimizerConfig::new())
.expect("failed to optimize plan");
let formatted_plan = format!("{:?}", optimized_plan);
assert_eq!(formatted_plan, expected);
}
#[test]
fn id_array_visitor() -> Result<()> {
let expr = binary_expr(
binary_expr(
sum(binary_expr(col("a"), Operator::Plus, lit("1"))),
Operator::Minus,
avg(col("c")),
),
Operator::Multiply,
lit(2),
);
let mut id_array = vec![];
expr_to_identifier(&expr, &mut HashMap::new(), &mut id_array, DataType::Int64)?;
let expected = vec![
(9, "BinaryExpr-*Literal2BinaryExpr--AggregateFunction-AVGfalseColumn-cAggregateFunction-SUMfalseBinaryExpr-+Literal1Column-a"),
(7, "BinaryExpr--AggregateFunction-AVGfalseColumn-cAggregateFunction-SUMfalseBinaryExpr-+Literal1Column-a"),
(4, "AggregateFunction-SUMfalseBinaryExpr-+Literal1Column-a"), (3, "BinaryExpr-+Literal1Column-a"),
(1, ""),
(2, ""),
(6, "AggregateFunction-AVGfalseColumn-c"),
(5, ""),
(8, ""),
]
.into_iter()
.map(|(number, id)| (number, id.into()))
.collect::<Vec<_>>();
assert_eq!(id_array, expected);
Ok(())
}
#[test]
fn tpch_q1_simplified() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
iter::empty::<Expr>(),
vec![
sum(binary_expr(
col("a"),
Operator::Multiply,
binary_expr(lit(1), Operator::Minus, col("b")),
)),
sum(binary_expr(
binary_expr(
col("a"),
Operator::Multiply,
binary_expr(lit(1), Operator::Minus, col("b")),
),
Operator::Multiply,
binary_expr(lit(1), Operator::Plus, col("c")),
)),
],
)?
.build()?;
let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(#BinaryExpr-*BinaryExpr--Column-test.bLiteral1Column-test.a AS test.a * Int32(1) - test.b), SUM(#BinaryExpr-*BinaryExpr--Column-test.bLiteral1Column-test.a AS test.a * Int32(1) - test.b * Int32(1) + #test.c)]]\
\n Projection: #test.a * Int32(1) - #test.b AS BinaryExpr-*BinaryExpr--Column-test.bLiteral1Column-test.a, #test.a, #test.b, #test.c\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
iter::empty::<Expr>(),
vec![
binary_expr(lit(1), Operator::Plus, avg(col("a"))),
binary_expr(lit(1), Operator::Minus, avg(col("a"))),
],
)?
.build()?;
let expected = "Aggregate: groupBy=[[]], aggr=[[Int32(1) + #AggregateFunction-AVGfalseColumn-test.a AS AVG(test.a), Int32(1) - #AggregateFunction-AVGfalseColumn-test.a AS AVG(test.a)]]\
\n Projection: AVG(#test.a) AS AggregateFunction-AVGfalseColumn-test.a, #test.a, #test.b, #test.c\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn subexpr_in_same_order() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
binary_expr(lit(1), Operator::Plus, col("a")).alias("first"),
binary_expr(lit(1), Operator::Plus, col("a")).alias("second"),
])?
.build()?;
let expected = "Projection: #BinaryExpr-+Column-test.aLiteral1 AS Int32(1) + test.a AS first, #BinaryExpr-+Column-test.aLiteral1 AS Int32(1) + test.a AS second\
\n Projection: Int32(1) + #test.a AS BinaryExpr-+Column-test.aLiteral1, #test.a, #test.b, #test.c\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn subexpr_in_different_order() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
binary_expr(lit(1), Operator::Plus, col("a")),
binary_expr(col("a"), Operator::Plus, lit(1)),
])?
.build()?;
let expected = "Projection: Int32(1) + #test.a, #test.a + Int32(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn cross_plans_subexpr() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
.project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
.build()?;
let expected = "Projection: #Int32(1) + test.a\
\n Projection: Int32(1) + #test.a\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn redundant_project_fields() {
let table_scan = test_table_scan().unwrap();
let affected_id: HashSet<Identifier> =
["c+a".to_string(), "d+a".to_string()].into_iter().collect();
let expr_set = [
("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
("d+a".to_string(), (col("d+a"), 1, DataType::UInt32)),
]
.into_iter()
.collect();
let project =
build_project_plan(table_scan, affected_id.clone(), &expr_set).unwrap();
let project_2 = build_project_plan(project, affected_id, &expr_set).unwrap();
let mut field_set = HashSet::new();
for field in project_2.schema().fields() {
assert!(field_set.insert(field.qualified_name()));
}
}
#[test]
fn redundant_project_fields_join_input() {
let table_scan_1 = test_table_scan_with_name("test1").unwrap();
let table_scan_2 = test_table_scan_with_name("test2").unwrap();
let join = LogicalPlanBuilder::from(table_scan_1)
.join(&table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
.unwrap()
.build()
.unwrap();
let affected_id: HashSet<Identifier> =
["c+a".to_string(), "d+a".to_string()].into_iter().collect();
let expr_set = [
("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
("d+a".to_string(), (col("d+a"), 1, DataType::UInt32)),
]
.into_iter()
.collect();
let project = build_project_plan(join, affected_id.clone(), &expr_set).unwrap();
let project_2 = build_project_plan(project, affected_id, &expr_set).unwrap();
let mut field_set = HashSet::new();
for field in project_2.schema().fields() {
assert!(field_set.insert(field.qualified_name()));
}
}
}