datafusion_optimizer/analyzer/
resolve_grouping_function.rs1use std::cmp::Ordering;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25use crate::analyzer::AnalyzerRule;
26
27use arrow::datatypes::DataType;
28use datafusion_common::config::ConfigOptions;
29use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30use datafusion_common::{
31 Column, DFSchema, Result, ScalarValue, internal_datafusion_err, plan_err,
32};
33use datafusion_expr::expr::{AggregateFunction, Alias};
34use datafusion_expr::logical_plan::LogicalPlan;
35use datafusion_expr::utils::grouping_set_to_exprlist;
36use datafusion_expr::{
37 Aggregate, Expr, Projection, bitwise_and, bitwise_or, bitwise_shift_left,
38 bitwise_shift_right, cast,
39};
40use itertools::Itertools;
41
42#[derive(Default, Debug)]
44pub struct ResolveGroupingFunction;
45
46impl ResolveGroupingFunction {
47 pub fn new() -> Self {
48 Self {}
49 }
50}
51
52impl AnalyzerRule for ResolveGroupingFunction {
53 fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
54 plan.transform_up(analyze_internal).data()
55 }
56
57 fn name(&self) -> &str {
58 "resolve_grouping_function"
59 }
60}
61
62fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result<HashMap<&Expr, usize>> {
67 Ok(grouping_set_to_exprlist(group_expr)?
68 .into_iter()
69 .rev()
70 .enumerate()
71 .map(|(idx, v)| (v, idx))
72 .collect::<HashMap<_, _>>())
73}
74
75#[allow(clippy::allow_attributes, clippy::mutable_key_type)] fn replace_grouping_exprs(
77 input: Arc<LogicalPlan>,
78 schema: &DFSchema,
79 group_expr: Vec<Expr>,
80 aggr_expr: Vec<Expr>,
81) -> Result<LogicalPlan> {
82 let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
84 let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?;
85 let columns = schema.columns();
86 let mut new_agg_expr = Vec::new();
87 let mut projection_exprs = Vec::new();
88 let grouping_id_len = if is_grouping_set { 1 } else { 0 };
89 let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len;
90 projection_exprs.extend(
91 columns
92 .iter()
93 .take(group_expr_len)
94 .map(|column| Expr::Column(column.clone())),
95 );
96 for (expr, column) in aggr_expr
97 .into_iter()
98 .zip(columns.into_iter().skip(group_expr_len + grouping_id_len))
99 {
100 let grouping_id_type = is_grouping_set
101 .then(|| {
102 schema
103 .field_with_name(None, Aggregate::INTERNAL_GROUPING_ID)
104 .map(|f| f.data_type().clone())
105 })
106 .transpose()?;
107 match expr {
108 Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => {
109 let grouping_expr = grouping_function_on_id(
110 function,
111 &group_expr_to_bitmap_index,
112 grouping_id_type,
113 )?;
114 projection_exprs.push(Expr::Alias(Alias::new(
115 grouping_expr,
116 column.relation,
117 column.name,
118 )));
119 }
120 Expr::Alias(Alias {
121 ref relation,
122 ref name,
123 ..
124 }) if is_grouping_function(&expr) => {
125 let function = unwrap_alias_to_grouping_function(&expr)?;
126 let grouping_expr = grouping_function_on_id(
127 function,
128 &group_expr_to_bitmap_index,
129 grouping_id_type,
130 )?;
131 projection_exprs.push(Expr::Alias(Alias::new(
133 grouping_expr,
134 relation.clone(),
135 name.clone(),
136 )));
137 }
138 _ => {
139 projection_exprs.push(Expr::Column(column));
140 new_agg_expr.push(expr);
141 }
142 }
143 }
144 let new_aggregate =
146 LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?);
147 let projection = LogicalPlan::Projection(Projection::try_new(
149 projection_exprs,
150 new_aggregate.into(),
151 )?);
152 Ok(projection)
153}
154
155fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
156 let transformed_plan =
158 plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?;
159
160 let transformed_plan = transformed_plan.transform_data(|plan| match plan {
161 LogicalPlan::Aggregate(Aggregate {
162 input,
163 group_expr,
164 aggr_expr,
165 schema,
166 ..
167 }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes(
168 replace_grouping_exprs(input, schema.as_ref(), group_expr, aggr_expr)?,
169 )),
170 _ => Ok(Transformed::no(plan)),
171 })?;
172
173 Ok(transformed_plan)
174}
175
176fn unwrap_alias_to_grouping_function(expr: &Expr) -> Result<&AggregateFunction> {
180 match expr {
181 Expr::AggregateFunction(function) => Ok(function),
182 Expr::Alias(Alias { expr, .. }) => unwrap_alias_to_grouping_function(expr),
183 _ => plan_err!("Expected grouping aggregate function inside alias, got {expr}"),
184 }
185}
186
187fn is_grouping_function(expr: &Expr) -> bool {
188 match expr {
191 Expr::AggregateFunction(AggregateFunction { func, .. }) => {
192 func.name() == "grouping"
193 }
194 Expr::Alias(Alias { expr, .. }) => is_grouping_function(expr),
195 _ => false,
196 }
197}
198
199fn contains_grouping_function(exprs: &[Expr]) -> bool {
200 exprs.iter().any(is_grouping_function)
201}
202
203#[allow(clippy::allow_attributes, clippy::mutable_key_type)] fn validate_args(
206 function: &AggregateFunction,
207 group_by_expr: &HashMap<&Expr, usize>,
208) -> Result<()> {
209 let expr_not_in_group_by = function
210 .params
211 .args
212 .iter()
213 .find(|expr| !group_by_expr.contains_key(expr));
214 if let Some(expr) = expr_not_in_group_by {
215 plan_err!(
216 "Argument {} to grouping function is not in grouping columns {}",
217 expr,
218 group_by_expr.keys().map(|e| e.to_string()).join(", ")
219 )
220 } else {
221 Ok(())
222 }
223}
224
225#[allow(clippy::allow_attributes, clippy::mutable_key_type)] fn grouping_function_on_id(
227 function: &AggregateFunction,
228 group_by_expr: &HashMap<&Expr, usize>,
229 grouping_id_type: Option<DataType>,
231) -> Result<Expr> {
232 validate_args(function, group_by_expr)?;
233 let args = &function.params.args;
234
235 let Some(grouping_id_type) = grouping_id_type else {
238 return Ok(Expr::Literal(ScalarValue::from(0i32), None));
239 };
240
241 let literal = |value: usize| match &grouping_id_type {
246 DataType::UInt8 => Expr::Literal(ScalarValue::from(value as u8), None),
247 DataType::UInt16 => Expr::Literal(ScalarValue::from(value as u16), None),
248 DataType::UInt32 => Expr::Literal(ScalarValue::from(value as u32), None),
249 DataType::UInt64 => Expr::Literal(ScalarValue::from(value as u64), None),
250 other => panic!("unexpected __grouping_id type: {other}"),
251 };
252 let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
253 if args.len() == group_by_expr.len()
254 && args
255 .iter()
256 .rev()
257 .enumerate()
258 .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx))
259 {
260 let n = group_by_expr.len();
261 let semantic_mask: u64 = 1u64.checked_shl(n as u32).map_or(u64::MAX, |m| m - 1);
264 let masked_id =
265 bitwise_and(grouping_id_column.clone(), literal(semantic_mask as usize));
266 return Ok(cast(masked_id, DataType::Int32));
267 }
268
269 args.iter()
270 .rev()
271 .enumerate()
272 .map(|(arg_idx, expr)| {
273 group_by_expr.get(expr).map(|group_by_idx| {
274 let group_by_bit =
275 bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx));
276 match group_by_idx.cmp(&arg_idx) {
277 Ordering::Less => {
278 bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx))
279 }
280 Ordering::Greater => {
281 bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx))
282 }
283 Ordering::Equal => group_by_bit,
284 }
285 })
286 })
287 .collect::<Option<Vec<_>>>()
288 .and_then(|bit_exprs| {
289 bit_exprs
290 .into_iter()
291 .reduce(bitwise_or)
292 .map(|expr| cast(expr, DataType::Int32))
293 })
294 .ok_or_else(|| {
295 internal_datafusion_err!("Grouping sets should contains at least one element")
296 })
297}