use crate::expr::{AggregateFunction, Cast, Sort, WindowFunction};
use crate::{
expr::{BinaryExpr, GroupingSet, TryCast},
Between, Expr, GetIndexedField, Like,
};
use datafusion_common::Result;
pub enum Recursion<V: ExpressionVisitor> {
Continue(V),
Stop(V),
}
pub trait ExpressionVisitor<E: ExprVisitable = Expr>: Sized {
fn pre_visit(self, expr: &E) -> Result<Recursion<Self>>
where
Self: ExpressionVisitor;
fn post_visit(self, _expr: &E) -> Result<Self> {
Ok(self)
}
}
pub trait ExprVisitable: Sized {
fn accept<V: ExpressionVisitor<Self>>(&self, visitor: V) -> Result<V>;
}
impl ExprVisitable for Expr {
fn accept<V: ExpressionVisitor>(&self, visitor: V) -> Result<V> {
let visitor = match visitor.pre_visit(self)? {
Recursion::Continue(visitor) => visitor,
Recursion::Stop(visitor) => return Ok(visitor),
};
let visitor = match self {
Expr::Alias(expr, _)
| Expr::Not(expr)
| Expr::IsNotNull(expr)
| Expr::IsTrue(expr)
| Expr::IsFalse(expr)
| Expr::IsUnknown(expr)
| Expr::IsNotTrue(expr)
| Expr::IsNotFalse(expr)
| Expr::IsNotUnknown(expr)
| Expr::IsNull(expr)
| Expr::Negative(expr)
| Expr::Cast(Cast { expr, .. })
| Expr::TryCast(TryCast { expr, .. })
| Expr::Sort(Sort { expr, .. })
| Expr::InSubquery { expr, .. } => expr.accept(visitor),
Expr::GetIndexedField(GetIndexedField { expr, .. }) => expr.accept(visitor),
Expr::GroupingSet(GroupingSet::Rollup(exprs)) => exprs
.iter()
.fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))),
Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs
.iter()
.fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))),
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
lists_of_exprs.iter().fold(Ok(visitor), |v, exprs| {
v.and_then(|v| {
exprs.iter().fold(Ok(v), |v, e| v.and_then(|v| e.accept(v)))
})
})
}
Expr::Column(_)
| Expr::ScalarVariable(_, _)
| Expr::Literal(_)
| Expr::Exists { .. }
| Expr::ScalarSubquery(_)
| Expr::Wildcard
| Expr::QualifiedWildcard { .. }
| Expr::Placeholder { .. } => Ok(visitor),
Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
let visitor = left.accept(visitor)?;
right.accept(visitor)
}
Expr::Like(Like { expr, pattern, .. }) => {
let visitor = expr.accept(visitor)?;
pattern.accept(visitor)
}
Expr::ILike(Like { expr, pattern, .. }) => {
let visitor = expr.accept(visitor)?;
pattern.accept(visitor)
}
Expr::SimilarTo(Like { expr, pattern, .. }) => {
let visitor = expr.accept(visitor)?;
pattern.accept(visitor)
}
Expr::Between(Between {
expr, low, high, ..
}) => {
let visitor = expr.accept(visitor)?;
let visitor = low.accept(visitor)?;
high.accept(visitor)
}
Expr::Case(case) => {
let visitor = if let Some(expr) = case.expr.as_ref() {
expr.accept(visitor)
} else {
Ok(visitor)
}?;
let visitor = case.when_then_expr.iter().try_fold(
visitor,
|visitor, (when, then)| {
let visitor = when.accept(visitor)?;
then.accept(visitor)
},
)?;
if let Some(else_expr) = case.else_expr.as_ref() {
else_expr.accept(visitor)
} else {
Ok(visitor)
}
}
Expr::ScalarFunction { args, .. } | Expr::ScalarUDF { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
Expr::AggregateFunction(AggregateFunction { args, filter, .. })
| Expr::AggregateUDF { args, filter, .. } => {
if let Some(f) = filter {
let mut aggr_exprs = args.clone();
aggr_exprs.push(f.as_ref().clone());
aggr_exprs
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))
} else {
args.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))
}
}
Expr::WindowFunction(WindowFunction {
args,
partition_by,
order_by,
..
}) => {
let visitor = args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
let visitor = partition_by
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
let visitor = order_by
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
Ok(visitor)
}
Expr::InList { expr, list, .. } => {
let visitor = expr.accept(visitor)?;
list.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))
}
}?;
visitor.post_visit(self)
}
}
struct VisitorAdapter<F, E> {
f: F,
err: Result<(), E>,
}
impl<F, E> ExpressionVisitor for VisitorAdapter<F, E>
where
F: FnMut(&Expr) -> Result<(), E>,
{
fn pre_visit(mut self, expr: &Expr) -> Result<Recursion<Self>> {
if let Err(e) = (self.f)(expr) {
self.err = Err(e);
Ok(Recursion::Stop(self))
} else {
Ok(Recursion::Continue(self))
}
}
}
pub fn inspect_expr_pre<F, E>(expr: &Expr, f: F) -> Result<(), E>
where
F: FnMut(&Expr) -> Result<(), E>,
{
let adapter = expr.accept(VisitorAdapter { f, err: Ok(()) }).unwrap();
adapter.err
}