datafusion_optimizer/analyzer/
resolve_grouping_function.rs1use std::cmp::Ordering;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25use crate::analyzer::AnalyzerRule;
26
27use arrow::datatypes::DataType;
28use datafusion_common::config::ConfigOptions;
29use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30use datafusion_common::{
31 internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, ScalarValue,
32};
33use datafusion_expr::expr::{AggregateFunction, Alias};
34use datafusion_expr::logical_plan::LogicalPlan;
35use datafusion_expr::utils::grouping_set_to_exprlist;
36use datafusion_expr::{
37 bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate,
38 Expr, Projection,
39};
40use itertools::Itertools;
41
42#[derive(Default, Debug)]
44pub struct ResolveGroupingFunction;
45
46impl ResolveGroupingFunction {
47 pub fn new() -> Self {
48 Self {}
49 }
50}
51
52impl AnalyzerRule for ResolveGroupingFunction {
53 fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
54 plan.transform_up(analyze_internal).data()
55 }
56
57 fn name(&self) -> &str {
58 "resolve_grouping_function"
59 }
60}
61
62fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result<HashMap<&Expr, usize>> {
67 Ok(grouping_set_to_exprlist(group_expr)?
68 .into_iter()
69 .rev()
70 .enumerate()
71 .map(|(idx, v)| (v, idx))
72 .collect::<HashMap<_, _>>())
73}
74
75fn replace_grouping_exprs(
76 input: Arc<LogicalPlan>,
77 schema: DFSchemaRef,
78 group_expr: Vec<Expr>,
79 aggr_expr: Vec<Expr>,
80) -> Result<LogicalPlan> {
81 let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
83 let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?;
84 let columns = schema.columns();
85 let mut new_agg_expr = Vec::new();
86 let mut projection_exprs = Vec::new();
87 let grouping_id_len = if is_grouping_set { 1 } else { 0 };
88 let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len;
89 projection_exprs.extend(
90 columns
91 .iter()
92 .take(group_expr_len)
93 .map(|column| Expr::Column(column.clone())),
94 );
95 for (expr, column) in aggr_expr
96 .into_iter()
97 .zip(columns.into_iter().skip(group_expr_len + grouping_id_len))
98 {
99 match expr {
100 Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => {
101 let grouping_expr = grouping_function_on_id(
102 function,
103 &group_expr_to_bitmap_index,
104 is_grouping_set,
105 )?;
106 projection_exprs.push(Expr::Alias(Alias::new(
107 grouping_expr,
108 column.relation,
109 column.name,
110 )));
111 }
112 _ => {
113 projection_exprs.push(Expr::Column(column));
114 new_agg_expr.push(expr);
115 }
116 }
117 }
118 let new_aggregate =
120 LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?);
121 let projection = LogicalPlan::Projection(Projection::try_new(
123 projection_exprs,
124 new_aggregate.into(),
125 )?);
126 Ok(projection)
127}
128
129fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
130 let transformed_plan =
132 plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?;
133
134 let transformed_plan = transformed_plan.transform_data(|plan| match plan {
135 LogicalPlan::Aggregate(Aggregate {
136 input,
137 group_expr,
138 aggr_expr,
139 schema,
140 ..
141 }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes(
142 replace_grouping_exprs(input, schema, group_expr, aggr_expr)?,
143 )),
144 _ => Ok(Transformed::no(plan)),
145 })?;
146
147 Ok(transformed_plan)
148}
149
150fn is_grouping_function(expr: &Expr) -> bool {
151 matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) if func.name() == "grouping")
154}
155
156fn contains_grouping_function(exprs: &[Expr]) -> bool {
157 exprs.iter().any(is_grouping_function)
158}
159
160fn validate_args(
162 function: &AggregateFunction,
163 group_by_expr: &HashMap<&Expr, usize>,
164) -> Result<()> {
165 let expr_not_in_group_by = function
166 .params
167 .args
168 .iter()
169 .find(|expr| !group_by_expr.contains_key(expr));
170 if let Some(expr) = expr_not_in_group_by {
171 plan_err!(
172 "Argument {} to grouping function is not in grouping columns {}",
173 expr,
174 group_by_expr.keys().map(|e| e.to_string()).join(", ")
175 )
176 } else {
177 Ok(())
178 }
179}
180
181fn grouping_function_on_id(
182 function: &AggregateFunction,
183 group_by_expr: &HashMap<&Expr, usize>,
184 is_grouping_set: bool,
185) -> Result<Expr> {
186 validate_args(function, group_by_expr)?;
187 let args = &function.params.args;
188
189 if !is_grouping_set {
192 return Ok(Expr::Literal(ScalarValue::from(0i32)));
193 }
194
195 let group_by_expr_count = group_by_expr.len();
196 let literal = |value: usize| {
197 if group_by_expr_count < 8 {
198 Expr::Literal(ScalarValue::from(value as u8))
199 } else if group_by_expr_count < 16 {
200 Expr::Literal(ScalarValue::from(value as u16))
201 } else if group_by_expr_count < 32 {
202 Expr::Literal(ScalarValue::from(value as u32))
203 } else {
204 Expr::Literal(ScalarValue::from(value as u64))
205 }
206 };
207
208 let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
209 if args.len() == group_by_expr_count
211 && args
212 .iter()
213 .rev()
214 .enumerate()
215 .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx))
216 {
217 return Ok(cast(grouping_id_column, DataType::Int32));
218 }
219
220 args.iter()
221 .rev()
222 .enumerate()
223 .map(|(arg_idx, expr)| {
224 group_by_expr.get(expr).map(|group_by_idx| {
225 let group_by_bit =
226 bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx));
227 match group_by_idx.cmp(&arg_idx) {
228 Ordering::Less => {
229 bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx))
230 }
231 Ordering::Greater => {
232 bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx))
233 }
234 Ordering::Equal => group_by_bit,
235 }
236 })
237 })
238 .collect::<Option<Vec<_>>>()
239 .and_then(|bit_exprs| {
240 bit_exprs
241 .into_iter()
242 .reduce(bitwise_or)
243 .map(|expr| cast(expr, DataType::Int32))
244 })
245 .ok_or_else(|| {
246 internal_datafusion_err!("Grouping sets should contains at least one element")
247 })
248}