use crate::expr::{
AggregateFunction, Between, BinaryExpr, Case, Cast, GetIndexedField, GroupingSet,
Like, Sort, TryCast, WindowFunction,
};
use crate::logical_plan::{Aggregate, Projection};
use crate::utils::grouping_set_to_exprlist;
use crate::{Expr, ExprSchemable, LogicalPlan};
use datafusion_common::Result;
use datafusion_common::{Column, DFSchema};
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
pub enum RewriteRecursion {
Continue,
Mutate,
Stop,
Skip,
}
pub trait ExprRewriter<E: ExprRewritable = Expr>: Sized {
fn pre_visit(&mut self, _expr: &E) -> Result<RewriteRecursion> {
Ok(RewriteRecursion::Continue)
}
fn mutate(&mut self, expr: E) -> Result<E>;
}
pub trait ExprRewritable: Sized {
fn rewrite<R: ExprRewriter<Self>>(self, rewriter: &mut R) -> Result<Self>;
}
impl ExprRewritable for Expr {
fn rewrite<R>(self, rewriter: &mut R) -> Result<Self>
where
R: ExprRewriter<Self>,
{
let need_mutate = match rewriter.pre_visit(&self)? {
RewriteRecursion::Mutate => return rewriter.mutate(self),
RewriteRecursion::Stop => return Ok(self),
RewriteRecursion::Continue => true,
RewriteRecursion::Skip => false,
};
let expr = match self {
Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name),
Expr::Column(_) => self.clone(),
Expr::Exists { .. } => self.clone(),
Expr::InSubquery {
expr,
subquery,
negated,
} => Expr::InSubquery {
expr: rewrite_boxed(expr, rewriter)?,
subquery,
negated,
},
Expr::ScalarSubquery(_) => self.clone(),
Expr::ScalarVariable(ty, names) => Expr::ScalarVariable(ty, names),
Expr::Literal(value) => Expr::Literal(value),
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
Expr::BinaryExpr(BinaryExpr::new(
rewrite_boxed(left, rewriter)?,
op,
rewrite_boxed(right, rewriter)?,
))
}
Expr::Like(Like {
negated,
expr,
pattern,
escape_char,
}) => Expr::Like(Like::new(
negated,
rewrite_boxed(expr, rewriter)?,
rewrite_boxed(pattern, rewriter)?,
escape_char,
)),
Expr::ILike(Like {
negated,
expr,
pattern,
escape_char,
}) => Expr::ILike(Like::new(
negated,
rewrite_boxed(expr, rewriter)?,
rewrite_boxed(pattern, rewriter)?,
escape_char,
)),
Expr::SimilarTo(Like {
negated,
expr,
pattern,
escape_char,
}) => Expr::SimilarTo(Like::new(
negated,
rewrite_boxed(expr, rewriter)?,
rewrite_boxed(pattern, rewriter)?,
escape_char,
)),
Expr::Not(expr) => Expr::Not(rewrite_boxed(expr, rewriter)?),
Expr::IsNotNull(expr) => Expr::IsNotNull(rewrite_boxed(expr, rewriter)?),
Expr::IsNull(expr) => Expr::IsNull(rewrite_boxed(expr, rewriter)?),
Expr::IsTrue(expr) => Expr::IsTrue(rewrite_boxed(expr, rewriter)?),
Expr::IsFalse(expr) => Expr::IsFalse(rewrite_boxed(expr, rewriter)?),
Expr::IsUnknown(expr) => Expr::IsUnknown(rewrite_boxed(expr, rewriter)?),
Expr::IsNotTrue(expr) => Expr::IsNotTrue(rewrite_boxed(expr, rewriter)?),
Expr::IsNotFalse(expr) => Expr::IsNotFalse(rewrite_boxed(expr, rewriter)?),
Expr::IsNotUnknown(expr) => {
Expr::IsNotUnknown(rewrite_boxed(expr, rewriter)?)
}
Expr::Negative(expr) => Expr::Negative(rewrite_boxed(expr, rewriter)?),
Expr::Between(Between {
expr,
negated,
low,
high,
}) => Expr::Between(Between::new(
rewrite_boxed(expr, rewriter)?,
negated,
rewrite_boxed(low, rewriter)?,
rewrite_boxed(high, rewriter)?,
)),
Expr::Case(case) => {
let expr = rewrite_option_box(case.expr, rewriter)?;
let when_then_expr = case
.when_then_expr
.into_iter()
.map(|(when, then)| {
Ok((
rewrite_boxed(when, rewriter)?,
rewrite_boxed(then, rewriter)?,
))
})
.collect::<Result<Vec<_>>>()?;
let else_expr = rewrite_option_box(case.else_expr, rewriter)?;
Expr::Case(Case::new(expr, when_then_expr, else_expr))
}
Expr::Cast(Cast { expr, data_type }) => {
Expr::Cast(Cast::new(rewrite_boxed(expr, rewriter)?, data_type))
}
Expr::TryCast(TryCast { expr, data_type }) => {
Expr::TryCast(TryCast::new(rewrite_boxed(expr, rewriter)?, data_type))
}
Expr::Sort(Sort {
expr,
asc,
nulls_first,
}) => Expr::Sort(Sort::new(rewrite_boxed(expr, rewriter)?, asc, nulls_first)),
Expr::ScalarFunction { args, fun } => Expr::ScalarFunction {
args: rewrite_vec(args, rewriter)?,
fun,
},
Expr::ScalarUDF { args, fun } => Expr::ScalarUDF {
args: rewrite_vec(args, rewriter)?,
fun,
},
Expr::WindowFunction(WindowFunction {
args,
fun,
partition_by,
order_by,
window_frame,
}) => Expr::WindowFunction(WindowFunction::new(
fun,
rewrite_vec(args, rewriter)?,
rewrite_vec(partition_by, rewriter)?,
rewrite_vec(order_by, rewriter)?,
window_frame,
)),
Expr::AggregateFunction(AggregateFunction {
args,
fun,
distinct,
filter,
}) => Expr::AggregateFunction(AggregateFunction::new(
fun,
rewrite_vec(args, rewriter)?,
distinct,
filter,
)),
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
Expr::GroupingSet(GroupingSet::Rollup(rewrite_vec(exprs, rewriter)?))
}
GroupingSet::Cube(exprs) => {
Expr::GroupingSet(GroupingSet::Cube(rewrite_vec(exprs, rewriter)?))
}
GroupingSet::GroupingSets(lists_of_exprs) => {
Expr::GroupingSet(GroupingSet::GroupingSets(
lists_of_exprs
.iter()
.map(|exprs| rewrite_vec(exprs.clone(), rewriter))
.collect::<Result<Vec<_>>>()?,
))
}
},
Expr::AggregateUDF { args, fun, filter } => Expr::AggregateUDF {
args: rewrite_vec(args, rewriter)?,
fun,
filter,
},
Expr::InList {
expr,
list,
negated,
} => Expr::InList {
expr: rewrite_boxed(expr, rewriter)?,
list: rewrite_vec(list, rewriter)?,
negated,
},
Expr::Wildcard => Expr::Wildcard,
Expr::QualifiedWildcard { qualifier } => {
Expr::QualifiedWildcard { qualifier }
}
Expr::GetIndexedField(GetIndexedField { key, expr }) => {
Expr::GetIndexedField(GetIndexedField::new(
rewrite_boxed(expr, rewriter)?,
key,
))
}
Expr::Placeholder { id, data_type } => Expr::Placeholder { id, data_type },
};
if need_mutate {
rewriter.mutate(expr)
} else {
Ok(expr)
}
}
}
#[allow(clippy::boxed_local)]
fn rewrite_boxed<R>(boxed_expr: Box<Expr>, rewriter: &mut R) -> Result<Box<Expr>>
where
R: ExprRewriter,
{
let expr: Expr = *boxed_expr;
let rewritten_expr = expr.rewrite(rewriter)?;
Ok(Box::new(rewritten_expr))
}
fn rewrite_option_box<R>(
option_box: Option<Box<Expr>>,
rewriter: &mut R,
) -> Result<Option<Box<Expr>>>
where
R: ExprRewriter,
{
option_box
.map(|expr| rewrite_boxed(expr, rewriter))
.transpose()
}
fn rewrite_vec<R>(v: Vec<Expr>, rewriter: &mut R) -> Result<Vec<Expr>>
where
R: ExprRewriter,
{
v.into_iter().map(|expr| expr.rewrite(rewriter)).collect()
}
pub fn rewrite_sort_cols_by_aggs(
exprs: impl IntoIterator<Item = impl Into<Expr>>,
plan: &LogicalPlan,
) -> Result<Vec<Expr>> {
exprs
.into_iter()
.map(|e| {
let expr = e.into();
match expr {
Expr::Sort(Sort {
expr,
asc,
nulls_first,
}) => {
let sort = Expr::Sort(Sort::new(
Box::new(rewrite_sort_col_by_aggs(*expr, plan)?),
asc,
nulls_first,
));
Ok(sort)
}
expr => Ok(expr),
}
})
.collect()
}
fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
match plan {
LogicalPlan::Aggregate(Aggregate {
input,
aggr_expr,
group_expr,
..
}) => {
struct Rewriter<'a> {
plan: &'a LogicalPlan,
input: &'a LogicalPlan,
aggr_expr: &'a Vec<Expr>,
distinct_group_exprs: &'a Vec<Expr>,
}
impl<'a> ExprRewriter for Rewriter<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
let normalized_expr = normalize_col(expr.clone(), self.plan);
if normalized_expr.is_err() {
return Ok(expr);
}
let normalized_expr = normalized_expr?;
if let Some(found_agg) = self
.aggr_expr
.iter()
.chain(self.distinct_group_exprs)
.find(|a| (**a) == normalized_expr)
{
let agg = normalize_col(found_agg.clone(), self.plan)?;
let col = Expr::Column(
agg.to_field(self.input.schema())
.map(|f| f.qualified_column())?,
);
Ok(col)
} else {
Ok(expr)
}
}
}
let distinct_group_exprs = grouping_set_to_exprlist(group_expr.as_slice())?;
expr.rewrite(&mut Rewriter {
plan,
input,
aggr_expr,
distinct_group_exprs: &distinct_group_exprs,
})
}
LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]),
_ => Ok(expr),
}
}
pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
normalize_col_with_schemas(expr, &plan.all_schemas(), &plan.using_columns()?)
}
pub fn normalize_col_with_schemas(
expr: Expr,
schemas: &[&Arc<DFSchema>],
using_columns: &[HashSet<Column>],
) -> Result<Expr> {
struct ColumnNormalizer<'a> {
schemas: &'a [&'a Arc<DFSchema>],
using_columns: &'a [HashSet<Column>],
}
impl<'a> ExprRewriter for ColumnNormalizer<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
if let Expr::Column(c) = expr {
Ok(Expr::Column(c.normalize_with_schemas(
self.schemas,
self.using_columns,
)?))
} else {
Ok(expr)
}
}
}
expr.rewrite(&mut ColumnNormalizer {
schemas,
using_columns,
})
}
pub fn normalize_cols(
exprs: impl IntoIterator<Item = impl Into<Expr>>,
plan: &LogicalPlan,
) -> Result<Vec<Expr>> {
exprs
.into_iter()
.map(|e| normalize_col(e.into(), plan))
.collect()
}
pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<Expr> {
struct ColumnReplacer<'a> {
replace_map: &'a HashMap<&'a Column, &'a Column>,
}
impl<'a> ExprRewriter for ColumnReplacer<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
if let Expr::Column(c) = &expr {
match self.replace_map.get(c) {
Some(new_c) => Ok(Expr::Column((*new_c).to_owned())),
None => Ok(expr),
}
} else {
Ok(expr)
}
}
}
e.rewrite(&mut ColumnReplacer { replace_map })
}
pub fn unnormalize_col(expr: Expr) -> Expr {
struct RemoveQualifier {}
impl ExprRewriter for RemoveQualifier {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
if let Expr::Column(col) = expr {
Ok(Expr::Column(Column {
relation: None,
name: col.name,
}))
} else {
Ok(expr)
}
}
}
expr.rewrite(&mut RemoveQualifier {})
.expect("Unnormalize is infallable")
}
#[inline]
pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
exprs.into_iter().map(unnormalize_col).collect()
}
pub fn coerce_plan_expr_for_schema(
plan: &LogicalPlan,
schema: &DFSchema,
) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Projection(Projection { expr, input, .. }) => {
let new_exprs =
coerce_exprs_for_schema(expr.clone(), input.schema(), schema)?;
let projection = Projection::try_new(new_exprs, input.clone())?;
Ok(LogicalPlan::Projection(projection))
}
_ => {
let exprs: Vec<Expr> = plan
.schema()
.fields()
.iter()
.map(|field| Expr::Column(field.qualified_column()))
.collect();
let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?;
let add_project = new_exprs.iter().any(|expr| expr.try_into_col().is_err());
if add_project {
let projection = Projection::try_new(new_exprs, Arc::new(plan.clone()))?;
Ok(LogicalPlan::Projection(projection))
} else {
Ok(plan.clone())
}
}
}
}
fn coerce_exprs_for_schema(
exprs: Vec<Expr>,
src_schema: &DFSchema,
dst_schema: &DFSchema,
) -> Result<Vec<Expr>> {
exprs
.into_iter()
.enumerate()
.map(|(idx, expr)| {
let new_type = dst_schema.field(idx).data_type();
if new_type != &expr.get_type(src_schema)? {
match expr {
Expr::Alias(e, alias) => {
Ok(e.cast_to(new_type, src_schema)?.alias(alias))
}
_ => expr.cast_to(new_type, src_schema),
}
} else {
Ok(expr.clone())
}
})
.collect::<Result<_>>()
}
#[cfg(test)]
mod test {
use super::*;
use crate::{col, lit};
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, ScalarValue};
#[derive(Default)]
struct RecordingRewriter {
v: Vec<String>,
}
impl ExprRewriter for RecordingRewriter {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
self.v.push(format!("Mutated {expr:?}"));
Ok(expr)
}
fn pre_visit(&mut self, expr: &Expr) -> Result<RewriteRecursion> {
self.v.push(format!("Previsited {expr:?}"));
Ok(RewriteRecursion::Continue)
}
}
#[test]
fn rewriter_rewrite() {
let mut rewriter = FooBarRewriter {};
let rewritten = col("state").eq(lit("foo")).rewrite(&mut rewriter).unwrap();
assert_eq!(rewritten, col("state").eq(lit("bar")));
let rewritten = col("state").eq(lit("baz")).rewrite(&mut rewriter).unwrap();
assert_eq!(rewritten, col("state").eq(lit("baz")));
}
struct FooBarRewriter {}
impl ExprRewriter for FooBarRewriter {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
match expr {
Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => {
let utf8_val = if utf8_val == "foo" {
"bar".to_string()
} else {
utf8_val
};
Ok(lit(utf8_val))
}
expr => Ok(expr),
}
}
}
#[test]
fn normalize_cols() {
let expr = col("a") + col("b") + col("c");
let schema_a = make_schema_with_empty_metadata(vec![
make_field("tableA", "a"),
make_field("tableA", "aa"),
]);
let schema_c = make_schema_with_empty_metadata(vec![
make_field("tableC", "cc"),
make_field("tableC", "c"),
]);
let schema_b = make_schema_with_empty_metadata(vec![make_field("tableB", "b")]);
let schema_f = make_schema_with_empty_metadata(vec![
make_field("tableC", "f"),
make_field("tableC", "ff"),
]);
let schemas = vec![schema_c, schema_f, schema_b, schema_a]
.into_iter()
.map(Arc::new)
.collect::<Vec<_>>();
let schemas = schemas.iter().collect::<Vec<_>>();
let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap();
assert_eq!(
normalized_expr,
col("tableA.a") + col("tableB.b") + col("tableC.c")
);
}
#[test]
fn normalize_cols_priority() {
let expr = col("a") + col("b");
let schema_a = make_schema_with_empty_metadata(vec![make_field("tableA", "a")]);
let schema_b = make_schema_with_empty_metadata(vec![make_field("tableB", "b")]);
let schema_a2 = make_schema_with_empty_metadata(vec![make_field("tableA2", "a")]);
let schemas = vec![schema_a2, schema_b, schema_a]
.into_iter()
.map(Arc::new)
.collect::<Vec<_>>();
let schemas = schemas.iter().collect::<Vec<_>>();
let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap();
assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b"));
}
#[test]
fn normalize_cols_non_exist() {
let expr = col("a") + col("b");
let schema_a = make_schema_with_empty_metadata(vec![make_field("tableA", "a")]);
let schemas = vec![schema_a].into_iter().map(Arc::new).collect::<Vec<_>>();
let schemas = schemas.iter().collect::<Vec<_>>();
let error = normalize_col_with_schemas(expr, &schemas, &[])
.unwrap_err()
.to_string();
assert_eq!(
error,
"Schema error: No field named 'b'. Valid fields are 'tableA'.'a'."
);
}
#[test]
fn unnormalize_cols() {
let expr = col("tableA.a") + col("tableB.b");
let unnormalized_expr = unnormalize_col(expr);
assert_eq!(unnormalized_expr, col("a") + col("b"));
}
fn make_schema_with_empty_metadata(fields: Vec<DFField>) -> DFSchema {
DFSchema::new_with_metadata(fields, HashMap::new()).unwrap()
}
fn make_field(relation: &str, column: &str) -> DFField {
DFField::new(Some(relation), column, DataType::Int8, false)
}
#[test]
fn rewriter_visit() {
let mut rewriter = RecordingRewriter::default();
col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap();
assert_eq!(
rewriter.v,
vec![
"Previsited state = Utf8(\"CO\")",
"Previsited state",
"Mutated state",
"Previsited Utf8(\"CO\")",
"Mutated Utf8(\"CO\")",
"Mutated state = Utf8(\"CO\")"
]
)
}
}