datafusion_functions_aggregate_common/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{ArrayRef, ArrowNativeTypeOp};
19use arrow::compute::SortOptions;
20use arrow::datatypes::{
21    ArrowNativeType, DataType, DecimalType, Field, FieldRef, ToByteSlice,
22};
23use datafusion_common::{exec_err, DataFusionError, Result};
24use datafusion_expr_common::accumulator::Accumulator;
25use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
26use std::sync::Arc;
27
28/// Convert scalar values from an accumulator into arrays.
29pub fn get_accum_scalar_values_as_arrays(
30    accum: &mut dyn Accumulator,
31) -> Result<Vec<ArrayRef>> {
32    accum
33        .state()?
34        .iter()
35        .map(|s| s.to_array_of_size(1))
36        .collect()
37}
38
39/// Construct corresponding fields for the expressions in an ORDER BY clause.
40pub fn ordering_fields(
41    order_bys: &[PhysicalSortExpr],
42    // Data type of each expression in the ordering requirement
43    data_types: &[DataType],
44) -> Vec<FieldRef> {
45    order_bys
46        .iter()
47        .zip(data_types.iter())
48        .map(|(sort_expr, dtype)| {
49            Field::new(
50                sort_expr.expr.to_string().as_str(),
51                dtype.clone(),
52                // Multi partitions may be empty hence field should be nullable.
53                true,
54            )
55        })
56        .map(Arc::new)
57        .collect()
58}
59
60/// Selects the sort option attribute from all the given `PhysicalSortExpr`s.
61pub fn get_sort_options(ordering_req: &LexOrdering) -> Vec<SortOptions> {
62    ordering_req.iter().map(|item| item.options).collect()
63}
64
65/// A wrapper around a type to provide hash for floats
66#[derive(Copy, Clone, Debug)]
67pub struct Hashable<T>(pub T);
68
69impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
70    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
71        self.0.to_byte_slice().hash(state)
72    }
73}
74
75impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
76    fn eq(&self, other: &Self) -> bool {
77        self.0.is_eq(other.0)
78    }
79}
80
81impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}
82
83/// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow
84///
85/// This is needed because different precisions for Decimal128/Decimal256 can
86/// store different ranges of values and thus sum/count may not fit in
87/// the target type.
88///
89/// For example, the precision is 3, the max of value is `999` and the min
90/// value is `-999`
91pub struct DecimalAverager<T: DecimalType> {
92    /// scale factor for sum values (10^sum_scale)
93    sum_mul: T::Native,
94    /// scale factor for target (10^target_scale)
95    target_mul: T::Native,
96    /// the output precision
97    target_precision: u8,
98}
99
100impl<T: DecimalType> DecimalAverager<T> {
101    /// Create a new `DecimalAverager`:
102    ///
103    /// * sum_scale: the scale of `sum` values passed to [`Self::avg`]
104    /// * target_precision: the output precision
105    /// * target_scale: the output scale
106    ///
107    /// Errors if the resulting data can not be stored
108    pub fn try_new(
109        sum_scale: i8,
110        target_precision: u8,
111        target_scale: i8,
112    ) -> Result<Self> {
113        let sum_mul = T::Native::from_usize(10_usize)
114            .map(|b| b.pow_wrapping(sum_scale as u32))
115            .ok_or(DataFusionError::Internal(
116                "Failed to compute sum_mul in DecimalAverager".to_string(),
117            ))?;
118
119        let target_mul = T::Native::from_usize(10_usize)
120            .map(|b| b.pow_wrapping(target_scale as u32))
121            .ok_or(DataFusionError::Internal(
122                "Failed to compute target_mul in DecimalAverager".to_string(),
123            ))?;
124
125        if target_mul >= sum_mul {
126            Ok(Self {
127                sum_mul,
128                target_mul,
129                target_precision,
130            })
131        } else {
132            // can't convert the lit decimal to the returned data type
133            exec_err!("Arithmetic Overflow in AvgAccumulator")
134        }
135    }
136
137    /// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with
138    /// target_scale and target_precision and reporting overflow.
139    ///
140    /// * sum: The total sum value stored as Decimal128 with sum_scale
141    ///   (passed to `Self::try_new`)
142    /// * count: total count, stored as a i128/i256 (*NOT* a Decimal128/Decimal256 value)
143    #[inline(always)]
144    pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
145        if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
146            let new_value = value.div_wrapping(count);
147
148            let validate =
149                T::validate_decimal_precision(new_value, self.target_precision);
150
151            if validate.is_ok() {
152                Ok(new_value)
153            } else {
154                exec_err!("Arithmetic Overflow in AvgAccumulator")
155            }
156        } else {
157            // can't convert the lit decimal to the returned data type
158            exec_err!("Arithmetic Overflow in AvgAccumulator")
159        }
160    }
161}