Skip to main content

datafusion_optimizer/analyzer/
resolve_grouping_function.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Analyzed rule to replace TableScan references
19//! such as DataFrames and Views and inlines the LogicalPlan.
20
21use std::cmp::Ordering;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25use crate::analyzer::AnalyzerRule;
26
27use arrow::datatypes::DataType;
28use datafusion_common::config::ConfigOptions;
29use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30use datafusion_common::{
31    Column, DFSchema, Result, ScalarValue, internal_datafusion_err, plan_err,
32};
33use datafusion_expr::expr::{AggregateFunction, Alias};
34use datafusion_expr::logical_plan::LogicalPlan;
35use datafusion_expr::utils::grouping_set_to_exprlist;
36use datafusion_expr::{
37    Aggregate, Expr, Projection, bitwise_and, bitwise_or, bitwise_shift_left,
38    bitwise_shift_right, cast,
39};
40use itertools::Itertools;
41
42/// Replaces grouping aggregation function with value derived from internal grouping id
43#[derive(Default, Debug)]
44pub struct ResolveGroupingFunction;
45
46impl ResolveGroupingFunction {
47    pub fn new() -> Self {
48        Self {}
49    }
50}
51
52impl AnalyzerRule for ResolveGroupingFunction {
53    fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
54        plan.transform_up(analyze_internal).data()
55    }
56
57    fn name(&self) -> &str {
58        "resolve_grouping_function"
59    }
60}
61
62/// Create a map from grouping expr to index in the internal grouping id.
63///
64/// For more details on how the grouping id bitmap works the documentation for
65/// [[Aggregate::INTERNAL_GROUPING_ID]]
66fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result<HashMap<&Expr, usize>> {
67    Ok(grouping_set_to_exprlist(group_expr)?
68        .into_iter()
69        .rev()
70        .enumerate()
71        .map(|(idx, v)| (v, idx))
72        .collect::<HashMap<_, _>>())
73}
74
75#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key
76fn replace_grouping_exprs(
77    input: Arc<LogicalPlan>,
78    schema: &DFSchema,
79    group_expr: Vec<Expr>,
80    aggr_expr: Vec<Expr>,
81) -> Result<LogicalPlan> {
82    // Create HashMap from Expr to index in the grouping_id bitmap
83    let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
84    let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?;
85    let columns = schema.columns();
86    let mut new_agg_expr = Vec::new();
87    let mut projection_exprs = Vec::new();
88    let grouping_id_len = if is_grouping_set { 1 } else { 0 };
89    let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len;
90    projection_exprs.extend(
91        columns
92            .iter()
93            .take(group_expr_len)
94            .map(|column| Expr::Column(column.clone())),
95    );
96    for (expr, column) in aggr_expr
97        .into_iter()
98        .zip(columns.into_iter().skip(group_expr_len + grouping_id_len))
99    {
100        let grouping_id_type = is_grouping_set
101            .then(|| {
102                schema
103                    .field_with_name(None, Aggregate::INTERNAL_GROUPING_ID)
104                    .map(|f| f.data_type().clone())
105            })
106            .transpose()?;
107        match expr {
108            Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => {
109                let grouping_expr = grouping_function_on_id(
110                    function,
111                    &group_expr_to_bitmap_index,
112                    grouping_id_type,
113                )?;
114                projection_exprs.push(Expr::Alias(Alias::new(
115                    grouping_expr,
116                    column.relation,
117                    column.name,
118                )));
119            }
120            Expr::Alias(Alias {
121                ref relation,
122                ref name,
123                ..
124            }) if is_grouping_function(&expr) => {
125                let function = unwrap_alias_to_grouping_function(&expr)?;
126                let grouping_expr = grouping_function_on_id(
127                    function,
128                    &group_expr_to_bitmap_index,
129                    grouping_id_type,
130                )?;
131                // Preserve the outermost user-provided alias
132                projection_exprs.push(Expr::Alias(Alias::new(
133                    grouping_expr,
134                    relation.clone(),
135                    name.clone(),
136                )));
137            }
138            _ => {
139                projection_exprs.push(Expr::Column(column));
140                new_agg_expr.push(expr);
141            }
142        }
143    }
144    // Recreate aggregate without grouping functions
145    let new_aggregate =
146        LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?);
147    // Create projection with grouping functions calculations
148    let projection = LogicalPlan::Projection(Projection::try_new(
149        projection_exprs,
150        new_aggregate.into(),
151    )?);
152    Ok(projection)
153}
154
155fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
156    // rewrite any subqueries in the plan first
157    let transformed_plan =
158        plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?;
159
160    let transformed_plan = transformed_plan.transform_data(|plan| match plan {
161        LogicalPlan::Aggregate(Aggregate {
162            input,
163            group_expr,
164            aggr_expr,
165            schema,
166            ..
167        }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes(
168            replace_grouping_exprs(input, schema.as_ref(), group_expr, aggr_expr)?,
169        )),
170        _ => Ok(Transformed::no(plan)),
171    })?;
172
173    Ok(transformed_plan)
174}
175
176/// Recursively unwrap `Expr::Alias` nodes to reach the inner `AggregateFunction`.
177/// Returns an error if the innermost expression is not an `AggregateFunction`,
178/// which should not happen if `is_grouping_function` returned true.
179fn unwrap_alias_to_grouping_function(expr: &Expr) -> Result<&AggregateFunction> {
180    match expr {
181        Expr::AggregateFunction(function) => Ok(function),
182        Expr::Alias(Alias { expr, .. }) => unwrap_alias_to_grouping_function(expr),
183        _ => plan_err!("Expected grouping aggregate function inside alias, got {expr}"),
184    }
185}
186
187fn is_grouping_function(expr: &Expr) -> bool {
188    // TODO: Do something better than name here should grouping be a built
189    // in expression?
190    match expr {
191        Expr::AggregateFunction(AggregateFunction { func, .. }) => {
192            func.name() == "grouping"
193        }
194        Expr::Alias(Alias { expr, .. }) => is_grouping_function(expr),
195        _ => false,
196    }
197}
198
199fn contains_grouping_function(exprs: &[Expr]) -> bool {
200    exprs.iter().any(is_grouping_function)
201}
202
203/// Validate that the arguments to the grouping function are in the group by clause.
204#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key
205fn validate_args(
206    function: &AggregateFunction,
207    group_by_expr: &HashMap<&Expr, usize>,
208) -> Result<()> {
209    let expr_not_in_group_by = function
210        .params
211        .args
212        .iter()
213        .find(|expr| !group_by_expr.contains_key(expr));
214    if let Some(expr) = expr_not_in_group_by {
215        plan_err!(
216            "Argument {} to grouping function is not in grouping columns {}",
217            expr,
218            group_by_expr.keys().map(|e| e.to_string()).join(", ")
219        )
220    } else {
221        Ok(())
222    }
223}
224
225#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key
226fn grouping_function_on_id(
227    function: &AggregateFunction,
228    group_by_expr: &HashMap<&Expr, usize>,
229    // None means not a grouping set (result is always 0).
230    grouping_id_type: Option<DataType>,
231) -> Result<Expr> {
232    validate_args(function, group_by_expr)?;
233    let args = &function.params.args;
234
235    // Postgres allows grouping function for group by without grouping sets, the result is then
236    // always 0
237    let Some(grouping_id_type) = grouping_id_type else {
238        return Ok(Expr::Literal(ScalarValue::from(0i32), None));
239    };
240
241    // Use the actual __grouping_id column type to size literals correctly. This
242    // accounts for duplicate-ordinal bits that `Aggregate::grouping_id_type`
243    // packs into the high bits of the column, which a simple count of grouping
244    // expressions would miss.
245    let literal = |value: usize| match &grouping_id_type {
246        DataType::UInt8 => Expr::Literal(ScalarValue::from(value as u8), None),
247        DataType::UInt16 => Expr::Literal(ScalarValue::from(value as u16), None),
248        DataType::UInt32 => Expr::Literal(ScalarValue::from(value as u32), None),
249        DataType::UInt64 => Expr::Literal(ScalarValue::from(value as u64), None),
250        other => panic!("unexpected __grouping_id type: {other}"),
251    };
252    let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
253    if args.len() == group_by_expr.len()
254        && args
255            .iter()
256            .rev()
257            .enumerate()
258            .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx))
259    {
260        let n = group_by_expr.len();
261        // Mask the ordinal bits above position `n` so only the semantic bitmask is visible.
262        // checked_shl returns None when n >= 64 (all bits are semantic), mapping to u64::MAX.
263        let semantic_mask: u64 = 1u64.checked_shl(n as u32).map_or(u64::MAX, |m| m - 1);
264        let masked_id =
265            bitwise_and(grouping_id_column.clone(), literal(semantic_mask as usize));
266        return Ok(cast(masked_id, DataType::Int32));
267    }
268
269    args.iter()
270        .rev()
271        .enumerate()
272        .map(|(arg_idx, expr)| {
273            group_by_expr.get(expr).map(|group_by_idx| {
274                let group_by_bit =
275                    bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx));
276                match group_by_idx.cmp(&arg_idx) {
277                    Ordering::Less => {
278                        bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx))
279                    }
280                    Ordering::Greater => {
281                        bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx))
282                    }
283                    Ordering::Equal => group_by_bit,
284                }
285            })
286        })
287        .collect::<Option<Vec<_>>>()
288        .and_then(|bit_exprs| {
289            bit_exprs
290                .into_iter()
291                .reduce(bitwise_or)
292                .map(|expr| cast(expr, DataType::Int32))
293        })
294        .ok_or_else(|| {
295            internal_datafusion_err!("Grouping sets should contains at least one element")
296        })
297}