use crate::expr::{
AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case,
Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder,
ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, Unnest, WindowFunction,
};
use crate::{Expr, GetFieldAccess};
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_common::{internal_err, DataFusionError, Result};
impl TreeNode for Expr {
fn apply_children<F: FnMut(&Self) -> Result<VisitRecursion>>(
&self,
op: &mut F,
) -> Result<VisitRecursion> {
let children = match self {
Expr::Alias(Alias{expr,..})
| Expr::Not(expr)
| Expr::IsNotNull(expr)
| Expr::IsTrue(expr)
| Expr::IsFalse(expr)
| Expr::IsUnknown(expr)
| Expr::IsNotTrue(expr)
| Expr::IsNotFalse(expr)
| Expr::IsNotUnknown(expr)
| Expr::IsNull(expr)
| Expr::Negative(expr)
| Expr::Cast(Cast { expr, .. })
| Expr::TryCast(TryCast { expr, .. })
| Expr::Sort(Sort { expr, .. })
| Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref()],
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
let expr = expr.as_ref();
match field {
GetFieldAccess::ListIndex {key} => vec![key.as_ref(), expr],
GetFieldAccess::ListRange {start, stop, stride} => {
vec![start.as_ref(), stop.as_ref(),stride.as_ref(), expr]
}
GetFieldAccess::NamedStructField { .. } => vec![expr],
}
}
Expr::Unnest(Unnest { exprs }) |
Expr::GroupingSet(GroupingSet::Rollup(exprs))
| Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().collect(),
Expr::ScalarFunction (ScalarFunction{ args, .. } ) => {
args.iter().collect()
}
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
lists_of_exprs.iter().flatten().collect()
}
Expr::Column(_)
| Expr::OuterReferenceColumn(_, _)
| Expr::ScalarVariable(_, _)
| Expr::Literal(_)
| Expr::Exists {..}
| Expr::ScalarSubquery(_)
| Expr::Wildcard {..}
| Expr::Placeholder (_) => vec![],
Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
vec![left.as_ref(), right.as_ref()]
}
Expr::Like(Like { expr, pattern, .. })
| Expr::SimilarTo(Like { expr, pattern, .. }) => {
vec![expr.as_ref(), pattern.as_ref()]
}
Expr::Between(Between {
expr, low, high, ..
}) => vec![expr.as_ref(), low.as_ref(), high.as_ref()],
Expr::Case(case) => {
let mut expr_vec = vec![];
if let Some(expr) = case.expr.as_ref() {
expr_vec.push(expr.as_ref());
};
for (when, then) in case.when_then_expr.iter() {
expr_vec.push(when.as_ref());
expr_vec.push(then.as_ref());
}
if let Some(else_expr) = case.else_expr.as_ref() {
expr_vec.push(else_expr.as_ref());
}
expr_vec
}
Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. })
=> {
let mut expr_vec = args.iter().collect::<Vec<_>>();
if let Some(f) = filter {
expr_vec.push(f.as_ref());
}
if let Some(order_by) = order_by {
expr_vec.extend(order_by);
}
expr_vec
}
Expr::WindowFunction(WindowFunction {
args,
partition_by,
order_by,
..
}) => {
let mut expr_vec = args.iter().collect::<Vec<_>>();
expr_vec.extend(partition_by);
expr_vec.extend(order_by);
expr_vec
}
Expr::InList(InList { expr, list, .. }) => {
let mut expr_vec = vec![expr.as_ref()];
expr_vec.extend(list);
expr_vec
}
};
for child in children {
match op(child)? {
VisitRecursion::Continue => {}
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
}
}
Ok(VisitRecursion::Continue)
}
fn map_children<F: FnMut(Self) -> Result<Self>>(
self,
mut transform: F,
) -> Result<Self> {
Ok(match self {
Expr::Column(_)
| Expr::Wildcard { .. }
| Expr::Placeholder(Placeholder { .. })
| Expr::OuterReferenceColumn(_, _)
| Expr::Exists { .. }
| Expr::ScalarSubquery(_)
| Expr::ScalarVariable(_, _)
| Expr::Unnest(_)
| Expr::Literal(_) => self,
Expr::Alias(Alias {
expr,
relation,
name,
}) => Expr::Alias(Alias::new(transform(*expr)?, relation, name)),
Expr::InSubquery(InSubquery {
expr,
subquery,
negated,
}) => Expr::InSubquery(InSubquery::new(
transform_boxed(expr, &mut transform)?,
subquery,
negated,
)),
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
Expr::BinaryExpr(BinaryExpr::new(
transform_boxed(left, &mut transform)?,
op,
transform_boxed(right, &mut transform)?,
))
}
Expr::Like(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}) => Expr::Like(Like::new(
negated,
transform_boxed(expr, &mut transform)?,
transform_boxed(pattern, &mut transform)?,
escape_char,
case_insensitive,
)),
Expr::SimilarTo(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}) => Expr::SimilarTo(Like::new(
negated,
transform_boxed(expr, &mut transform)?,
transform_boxed(pattern, &mut transform)?,
escape_char,
case_insensitive,
)),
Expr::Not(expr) => Expr::Not(transform_boxed(expr, &mut transform)?),
Expr::IsNotNull(expr) => {
Expr::IsNotNull(transform_boxed(expr, &mut transform)?)
}
Expr::IsNull(expr) => Expr::IsNull(transform_boxed(expr, &mut transform)?),
Expr::IsTrue(expr) => Expr::IsTrue(transform_boxed(expr, &mut transform)?),
Expr::IsFalse(expr) => Expr::IsFalse(transform_boxed(expr, &mut transform)?),
Expr::IsUnknown(expr) => {
Expr::IsUnknown(transform_boxed(expr, &mut transform)?)
}
Expr::IsNotTrue(expr) => {
Expr::IsNotTrue(transform_boxed(expr, &mut transform)?)
}
Expr::IsNotFalse(expr) => {
Expr::IsNotFalse(transform_boxed(expr, &mut transform)?)
}
Expr::IsNotUnknown(expr) => {
Expr::IsNotUnknown(transform_boxed(expr, &mut transform)?)
}
Expr::Negative(expr) => {
Expr::Negative(transform_boxed(expr, &mut transform)?)
}
Expr::Between(Between {
expr,
negated,
low,
high,
}) => Expr::Between(Between::new(
transform_boxed(expr, &mut transform)?,
negated,
transform_boxed(low, &mut transform)?,
transform_boxed(high, &mut transform)?,
)),
Expr::Case(case) => {
let expr = transform_option_box(case.expr, &mut transform)?;
let when_then_expr = case
.when_then_expr
.into_iter()
.map(|(when, then)| {
Ok((
transform_boxed(when, &mut transform)?,
transform_boxed(then, &mut transform)?,
))
})
.collect::<Result<Vec<_>>>()?;
let else_expr = transform_option_box(case.else_expr, &mut transform)?;
Expr::Case(Case::new(expr, when_then_expr, else_expr))
}
Expr::Cast(Cast { expr, data_type }) => {
Expr::Cast(Cast::new(transform_boxed(expr, &mut transform)?, data_type))
}
Expr::TryCast(TryCast { expr, data_type }) => Expr::TryCast(TryCast::new(
transform_boxed(expr, &mut transform)?,
data_type,
)),
Expr::Sort(Sort {
expr,
asc,
nulls_first,
}) => Expr::Sort(Sort::new(
transform_boxed(expr, &mut transform)?,
asc,
nulls_first,
)),
Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def {
ScalarFunctionDefinition::BuiltIn(fun) => Expr::ScalarFunction(
ScalarFunction::new(fun, transform_vec(args, &mut transform)?),
),
ScalarFunctionDefinition::UDF(fun) => Expr::ScalarFunction(
ScalarFunction::new_udf(fun, transform_vec(args, &mut transform)?),
),
ScalarFunctionDefinition::Name(_) => {
return internal_err!("Function `Expr` with name should be resolved.")
}
},
Expr::WindowFunction(WindowFunction {
args,
fun,
partition_by,
order_by,
window_frame,
}) => Expr::WindowFunction(WindowFunction::new(
fun,
transform_vec(args, &mut transform)?,
transform_vec(partition_by, &mut transform)?,
transform_vec(order_by, &mut transform)?,
window_frame,
)),
Expr::AggregateFunction(AggregateFunction {
args,
func_def,
distinct,
filter,
order_by,
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
Expr::AggregateFunction(AggregateFunction::new(
fun,
transform_vec(args, &mut transform)?,
distinct,
transform_option_box(filter, &mut transform)?,
transform_option_vec(order_by, &mut transform)?,
))
}
AggregateFunctionDefinition::UDF(fun) => {
let order_by = order_by
.map(|order_by| transform_vec(order_by, &mut transform))
.transpose()?;
Expr::AggregateFunction(AggregateFunction::new_udf(
fun,
transform_vec(args, &mut transform)?,
false,
transform_option_box(filter, &mut transform)?,
transform_option_vec(order_by, &mut transform)?,
))
}
AggregateFunctionDefinition::Name(_) => {
return internal_err!("Function `Expr` with name should be resolved.")
}
},
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup(
transform_vec(exprs, &mut transform)?,
)),
GroupingSet::Cube(exprs) => Expr::GroupingSet(GroupingSet::Cube(
transform_vec(exprs, &mut transform)?,
)),
GroupingSet::GroupingSets(lists_of_exprs) => {
Expr::GroupingSet(GroupingSet::GroupingSets(
lists_of_exprs
.into_iter()
.map(|exprs| transform_vec(exprs, &mut transform))
.collect::<Result<Vec<_>>>()?,
))
}
},
Expr::InList(InList {
expr,
list,
negated,
}) => Expr::InList(InList::new(
transform_boxed(expr, &mut transform)?,
transform_vec(list, &mut transform)?,
negated,
)),
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
Expr::GetIndexedField(GetIndexedField::new(
transform_boxed(expr, &mut transform)?,
field,
))
}
})
}
}
fn transform_boxed<F: FnMut(Expr) -> Result<Expr>>(
boxed_expr: Box<Expr>,
transform: &mut F,
) -> Result<Box<Expr>> {
transform(*boxed_expr).map(Box::new)
}
fn transform_option_box<F: FnMut(Expr) -> Result<Expr>>(
option_box: Option<Box<Expr>>,
transform: &mut F,
) -> Result<Option<Box<Expr>>> {
option_box
.map(|expr| transform_boxed(expr, transform))
.transpose()
}
fn transform_option_vec<F: FnMut(Expr) -> Result<Expr>>(
option_box: Option<Vec<Expr>>,
transform: &mut F,
) -> Result<Option<Vec<Expr>>> {
option_box
.map(|exprs| transform_vec(exprs, transform))
.transpose()
}
fn transform_vec<F: FnMut(Expr) -> Result<Expr>>(
v: Vec<Expr>,
transform: &mut F,
) -> Result<Vec<Expr>> {
v.into_iter().map(transform).collect()
}