use datafusion_common::HashMap;
use datafusion_expr::expr::AggregateFunctionParams;
use datafusion_expr::{BinaryExpr, Expr};
use datafusion_expr_common::operator::Operator;
const DUPLICATE_THRESHOLD: usize = 2;
pub(super) fn rewrite_multiple_linear_aggregates(
agg_expr: &mut [Expr],
) -> datafusion_common::Result<bool> {
let mut common_args = HashMap::new();
for agg in agg_expr.iter() {
let Expr::AggregateFunction(agg_function) = agg else {
continue;
};
let Some(arg) = candidate_linear_param(&agg_function.params) else {
continue;
};
let Some(expr_literal) = ExprLiteral::try_new(arg) else {
continue;
};
let counter = common_args.entry(expr_literal.expr()).or_insert(0);
*counter += 1;
}
let mut new_aggs = vec![];
for (idx, agg) in agg_expr.iter().enumerate() {
let Expr::AggregateFunction(agg_function) = agg else {
continue;
};
let Some(arg) = candidate_linear_param(&agg_function.params) else {
continue;
};
let Some(expr_literal) = ExprLiteral::try_new(arg) else {
continue;
};
if common_args.get(expr_literal.expr()).unwrap_or(&0) < &DUPLICATE_THRESHOLD {
continue;
}
if let Some(new_agg_function) = agg_function.func.simplify_expr_op_literal(
agg_function,
expr_literal.expr(),
expr_literal.op(),
expr_literal.lit(),
expr_literal.arg_is_left(),
)? {
new_aggs.push((idx, new_agg_function));
}
}
if new_aggs.is_empty() {
return Ok(false);
}
drop(common_args); for (idx, new_agg) in new_aggs {
let orig_name = agg_expr[idx].name_for_alias()?;
agg_expr[idx] = new_agg.alias_if_changed(orig_name)?
}
Ok(true)
}
fn candidate_linear_param(params: &AggregateFunctionParams) -> Option<&Expr> {
let AggregateFunctionParams {
args,
distinct,
filter,
order_by,
null_treatment,
} = params;
if *distinct
|| filter.is_some()
|| !order_by.is_empty()
|| null_treatment.is_some()
|| args.len() != 1
{
return None;
}
let arg = args.first()?;
if arg.is_volatile() {
return None;
};
Some(arg)
}
#[derive(Debug, Clone)]
pub enum ExprLiteral<'a> {
ArgOpLit {
arg: &'a Expr,
op: Operator,
lit: &'a Expr,
},
LitOpArg {
lit: &'a Expr,
op: Operator,
arg: &'a Expr,
},
}
impl<'a> ExprLiteral<'a> {
fn try_new(expr: &'a Expr) -> Option<Self> {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right })
if matches!(left.as_ref(), Expr::Literal(..)) =>
{
Some(Self::LitOpArg {
arg: right,
lit: left,
op: *op,
})
}
Expr::BinaryExpr(BinaryExpr { left, op, right })
if matches!(right.as_ref(), Expr::Literal(..)) =>
{
Some(Self::ArgOpLit {
arg: left,
lit: right,
op: *op,
})
}
_ => None,
}
}
fn expr(&self) -> &'a Expr {
match self {
Self::ArgOpLit { arg, .. } => arg,
Self::LitOpArg { arg, .. } => arg,
}
}
fn lit(&self) -> &'a Expr {
match self {
Self::ArgOpLit { lit, .. } => lit,
Self::LitOpArg { lit, .. } => lit,
}
}
fn op(&self) -> Operator {
match self {
Self::ArgOpLit { op, .. } => *op,
Self::LitOpArg { op, .. } => *op,
}
}
fn arg_is_left(&self) -> bool {
matches!(self, Self::ArgOpLit { .. })
}
}