datafusion_optimizer/analyzer/
resolve_grouping_function.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Analyzed rule to replace TableScan references
19//! such as DataFrames and Views and inlines the LogicalPlan.
20
21use std::cmp::Ordering;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25use crate::analyzer::AnalyzerRule;
26
27use arrow::datatypes::DataType;
28use datafusion_common::config::ConfigOptions;
29use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30use datafusion_common::{
31    internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, ScalarValue,
32};
33use datafusion_expr::expr::{AggregateFunction, Alias};
34use datafusion_expr::logical_plan::LogicalPlan;
35use datafusion_expr::utils::grouping_set_to_exprlist;
36use datafusion_expr::{
37    bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate,
38    Expr, Projection,
39};
40use itertools::Itertools;
41
42/// Replaces grouping aggregation function with value derived from internal grouping id
43#[derive(Default, Debug)]
44pub struct ResolveGroupingFunction;
45
46impl ResolveGroupingFunction {
47    pub fn new() -> Self {
48        Self {}
49    }
50}
51
52impl AnalyzerRule for ResolveGroupingFunction {
53    fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
54        plan.transform_up(analyze_internal).data()
55    }
56
57    fn name(&self) -> &str {
58        "resolve_grouping_function"
59    }
60}
61
62/// Create a map from grouping expr to index in the internal grouping id.
63///
64/// For more details on how the grouping id bitmap works the documentation for
65/// [[Aggregate::INTERNAL_GROUPING_ID]]
66fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result<HashMap<&Expr, usize>> {
67    Ok(grouping_set_to_exprlist(group_expr)?
68        .into_iter()
69        .rev()
70        .enumerate()
71        .map(|(idx, v)| (v, idx))
72        .collect::<HashMap<_, _>>())
73}
74
75fn replace_grouping_exprs(
76    input: Arc<LogicalPlan>,
77    schema: DFSchemaRef,
78    group_expr: Vec<Expr>,
79    aggr_expr: Vec<Expr>,
80) -> Result<LogicalPlan> {
81    // Create HashMap from Expr to index in the grouping_id bitmap
82    let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
83    let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?;
84    let columns = schema.columns();
85    let mut new_agg_expr = Vec::new();
86    let mut projection_exprs = Vec::new();
87    let grouping_id_len = if is_grouping_set { 1 } else { 0 };
88    let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len;
89    projection_exprs.extend(
90        columns
91            .iter()
92            .take(group_expr_len)
93            .map(|column| Expr::Column(column.clone())),
94    );
95    for (expr, column) in aggr_expr
96        .into_iter()
97        .zip(columns.into_iter().skip(group_expr_len + grouping_id_len))
98    {
99        match expr {
100            Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => {
101                let grouping_expr = grouping_function_on_id(
102                    function,
103                    &group_expr_to_bitmap_index,
104                    is_grouping_set,
105                )?;
106                projection_exprs.push(Expr::Alias(Alias::new(
107                    grouping_expr,
108                    column.relation,
109                    column.name,
110                )));
111            }
112            _ => {
113                projection_exprs.push(Expr::Column(column));
114                new_agg_expr.push(expr);
115            }
116        }
117    }
118    // Recreate aggregate without grouping functions
119    let new_aggregate =
120        LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?);
121    // Create projection with grouping functions calculations
122    let projection = LogicalPlan::Projection(Projection::try_new(
123        projection_exprs,
124        new_aggregate.into(),
125    )?);
126    Ok(projection)
127}
128
129fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
130    // rewrite any subqueries in the plan first
131    let transformed_plan =
132        plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?;
133
134    let transformed_plan = transformed_plan.transform_data(|plan| match plan {
135        LogicalPlan::Aggregate(Aggregate {
136            input,
137            group_expr,
138            aggr_expr,
139            schema,
140            ..
141        }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes(
142            replace_grouping_exprs(input, schema, group_expr, aggr_expr)?,
143        )),
144        _ => Ok(Transformed::no(plan)),
145    })?;
146
147    Ok(transformed_plan)
148}
149
150fn is_grouping_function(expr: &Expr) -> bool {
151    // TODO: Do something better than name here should grouping be a built
152    // in expression?
153    matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) if func.name() == "grouping")
154}
155
156fn contains_grouping_function(exprs: &[Expr]) -> bool {
157    exprs.iter().any(is_grouping_function)
158}
159
160/// Validate that the arguments to the grouping function are in the group by clause.
161fn validate_args(
162    function: &AggregateFunction,
163    group_by_expr: &HashMap<&Expr, usize>,
164) -> Result<()> {
165    let expr_not_in_group_by = function
166        .params
167        .args
168        .iter()
169        .find(|expr| !group_by_expr.contains_key(expr));
170    if let Some(expr) = expr_not_in_group_by {
171        plan_err!(
172            "Argument {} to grouping function is not in grouping columns {}",
173            expr,
174            group_by_expr.keys().map(|e| e.to_string()).join(", ")
175        )
176    } else {
177        Ok(())
178    }
179}
180
181fn grouping_function_on_id(
182    function: &AggregateFunction,
183    group_by_expr: &HashMap<&Expr, usize>,
184    is_grouping_set: bool,
185) -> Result<Expr> {
186    validate_args(function, group_by_expr)?;
187    let args = &function.params.args;
188
189    // Postgres allows grouping function for group by without grouping sets, the result is then
190    // always 0
191    if !is_grouping_set {
192        return Ok(Expr::Literal(ScalarValue::from(0i32)));
193    }
194
195    let group_by_expr_count = group_by_expr.len();
196    let literal = |value: usize| {
197        if group_by_expr_count < 8 {
198            Expr::Literal(ScalarValue::from(value as u8))
199        } else if group_by_expr_count < 16 {
200            Expr::Literal(ScalarValue::from(value as u16))
201        } else if group_by_expr_count < 32 {
202            Expr::Literal(ScalarValue::from(value as u32))
203        } else {
204            Expr::Literal(ScalarValue::from(value as u64))
205        }
206    };
207
208    let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
209    // The grouping call is exactly our internal grouping id
210    if args.len() == group_by_expr_count
211        && args
212            .iter()
213            .rev()
214            .enumerate()
215            .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx))
216    {
217        return Ok(cast(grouping_id_column, DataType::Int32));
218    }
219
220    args.iter()
221        .rev()
222        .enumerate()
223        .map(|(arg_idx, expr)| {
224            group_by_expr.get(expr).map(|group_by_idx| {
225                let group_by_bit =
226                    bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx));
227                match group_by_idx.cmp(&arg_idx) {
228                    Ordering::Less => {
229                        bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx))
230                    }
231                    Ordering::Greater => {
232                        bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx))
233                    }
234                    Ordering::Equal => group_by_bit,
235                }
236            })
237        })
238        .collect::<Option<Vec<_>>>()
239        .and_then(|bit_exprs| {
240            bit_exprs
241                .into_iter()
242                .reduce(bitwise_or)
243                .map(|expr| cast(expr, DataType::Int32))
244        })
245        .ok_or_else(|| {
246            internal_datafusion_err!("Grouping sets should contains at least one element")
247        })
248}