use std::cmp::Ordering;
use std::collections::HashMap;
use std::sync::Arc;
use crate::analyzer::AnalyzerRule;
use arrow::datatypes::DataType;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{
Column, DFSchema, Result, ScalarValue, internal_datafusion_err, plan_err,
};
use datafusion_expr::expr::{AggregateFunction, Alias};
use datafusion_expr::logical_plan::LogicalPlan;
use datafusion_expr::utils::grouping_set_to_exprlist;
use datafusion_expr::{
Aggregate, Expr, Projection, bitwise_and, bitwise_or, bitwise_shift_left,
bitwise_shift_right, cast,
};
use itertools::Itertools;
#[derive(Default, Debug)]
pub struct ResolveGroupingFunction;
impl ResolveGroupingFunction {
pub fn new() -> Self {
Self {}
}
}
impl AnalyzerRule for ResolveGroupingFunction {
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
plan.transform_up(analyze_internal).data()
}
fn name(&self) -> &str {
"resolve_grouping_function"
}
}
fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result<HashMap<&Expr, usize>> {
Ok(grouping_set_to_exprlist(group_expr)?
.into_iter()
.rev()
.enumerate()
.map(|(idx, v)| (v, idx))
.collect::<HashMap<_, _>>())
}
#[allow(clippy::allow_attributes, clippy::mutable_key_type)] fn replace_grouping_exprs(
input: Arc<LogicalPlan>,
schema: &DFSchema,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<LogicalPlan> {
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?;
let columns = schema.columns();
let mut new_agg_expr = Vec::new();
let mut projection_exprs = Vec::new();
let grouping_id_len = if is_grouping_set { 1 } else { 0 };
let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len;
projection_exprs.extend(
columns
.iter()
.take(group_expr_len)
.map(|column| Expr::Column(column.clone())),
);
for (expr, column) in aggr_expr
.into_iter()
.zip(columns.into_iter().skip(group_expr_len + grouping_id_len))
{
let grouping_id_type = is_grouping_set
.then(|| {
schema
.field_with_name(None, Aggregate::INTERNAL_GROUPING_ID)
.map(|f| f.data_type().clone())
})
.transpose()?;
match expr {
Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => {
let grouping_expr = grouping_function_on_id(
function,
&group_expr_to_bitmap_index,
grouping_id_type,
)?;
projection_exprs.push(Expr::Alias(Alias::new(
grouping_expr,
column.relation,
column.name,
)));
}
Expr::Alias(Alias {
ref relation,
ref name,
..
}) if is_grouping_function(&expr) => {
let function = unwrap_alias_to_grouping_function(&expr)?;
let grouping_expr = grouping_function_on_id(
function,
&group_expr_to_bitmap_index,
grouping_id_type,
)?;
projection_exprs.push(Expr::Alias(Alias::new(
grouping_expr,
relation.clone(),
name.clone(),
)));
}
_ => {
projection_exprs.push(Expr::Column(column));
new_agg_expr.push(expr);
}
}
}
let new_aggregate =
LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?);
let projection = LogicalPlan::Projection(Projection::try_new(
projection_exprs,
new_aggregate.into(),
)?);
Ok(projection)
}
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
let transformed_plan =
plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?;
let transformed_plan = transformed_plan.transform_data(|plan| match plan {
LogicalPlan::Aggregate(Aggregate {
input,
group_expr,
aggr_expr,
schema,
..
}) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes(
replace_grouping_exprs(input, schema.as_ref(), group_expr, aggr_expr)?,
)),
_ => Ok(Transformed::no(plan)),
})?;
Ok(transformed_plan)
}
fn unwrap_alias_to_grouping_function(expr: &Expr) -> Result<&AggregateFunction> {
match expr {
Expr::AggregateFunction(function) => Ok(function),
Expr::Alias(Alias { expr, .. }) => unwrap_alias_to_grouping_function(expr),
_ => plan_err!("Expected grouping aggregate function inside alias, got {expr}"),
}
}
fn is_grouping_function(expr: &Expr) -> bool {
match expr {
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
func.name() == "grouping"
}
Expr::Alias(Alias { expr, .. }) => is_grouping_function(expr),
_ => false,
}
}
fn contains_grouping_function(exprs: &[Expr]) -> bool {
exprs.iter().any(is_grouping_function)
}
#[allow(clippy::allow_attributes, clippy::mutable_key_type)] fn validate_args(
function: &AggregateFunction,
group_by_expr: &HashMap<&Expr, usize>,
) -> Result<()> {
let expr_not_in_group_by = function
.params
.args
.iter()
.find(|expr| !group_by_expr.contains_key(expr));
if let Some(expr) = expr_not_in_group_by {
plan_err!(
"Argument {} to grouping function is not in grouping columns {}",
expr,
group_by_expr.keys().map(|e| e.to_string()).join(", ")
)
} else {
Ok(())
}
}
#[allow(clippy::allow_attributes, clippy::mutable_key_type)] fn grouping_function_on_id(
function: &AggregateFunction,
group_by_expr: &HashMap<&Expr, usize>,
grouping_id_type: Option<DataType>,
) -> Result<Expr> {
validate_args(function, group_by_expr)?;
let args = &function.params.args;
let Some(grouping_id_type) = grouping_id_type else {
return Ok(Expr::Literal(ScalarValue::from(0i32), None));
};
let literal = |value: usize| match &grouping_id_type {
DataType::UInt8 => Expr::Literal(ScalarValue::from(value as u8), None),
DataType::UInt16 => Expr::Literal(ScalarValue::from(value as u16), None),
DataType::UInt32 => Expr::Literal(ScalarValue::from(value as u32), None),
DataType::UInt64 => Expr::Literal(ScalarValue::from(value as u64), None),
other => panic!("unexpected __grouping_id type: {other}"),
};
let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
if args.len() == group_by_expr.len()
&& args
.iter()
.rev()
.enumerate()
.all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx))
{
let n = group_by_expr.len();
let semantic_mask: u64 = 1u64.checked_shl(n as u32).map_or(u64::MAX, |m| m - 1);
let masked_id =
bitwise_and(grouping_id_column.clone(), literal(semantic_mask as usize));
return Ok(cast(masked_id, DataType::Int32));
}
args.iter()
.rev()
.enumerate()
.map(|(arg_idx, expr)| {
group_by_expr.get(expr).map(|group_by_idx| {
let group_by_bit =
bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx));
match group_by_idx.cmp(&arg_idx) {
Ordering::Less => {
bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx))
}
Ordering::Greater => {
bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx))
}
Ordering::Equal => group_by_bit,
}
})
})
.collect::<Option<Vec<_>>>()
.and_then(|bit_exprs| {
bit_exprs
.into_iter()
.reduce(bitwise_or)
.map(|expr| cast(expr, DataType::Int32))
})
.ok_or_else(|| {
internal_datafusion_err!("Grouping sets should contains at least one element")
})
}