datafusion_functions_aggregate_common/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{ArrayRef, AsArray};
21use arrow::datatypes::{ArrowNativeType, FieldRef};
22use arrow::{
23    array::ArrowNativeTypeOp,
24    compute::SortOptions,
25    datatypes::{
26        DataType, Decimal128Type, DecimalType, Field, TimeUnit, TimestampMicrosecondType,
27        TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
28        ToByteSlice,
29    },
30};
31use datafusion_common::{exec_err, DataFusionError, Result};
32use datafusion_expr_common::accumulator::Accumulator;
33use datafusion_physical_expr_common::sort_expr::LexOrdering;
34
35/// Convert scalar values from an accumulator into arrays.
36pub fn get_accum_scalar_values_as_arrays(
37    accum: &mut dyn Accumulator,
38) -> Result<Vec<ArrayRef>> {
39    accum
40        .state()?
41        .iter()
42        .map(|s| s.to_array_of_size(1))
43        .collect()
44}
45
46/// Adjust array type metadata if needed
47///
48/// Since `Decimal128Arrays` created from `Vec<NativeType>` have
49/// default precision and scale, this function adjusts the output to
50/// match `data_type`, if necessary
51#[deprecated(since = "44.0.0", note = "use PrimitiveArray::with_datatype")]
52pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result<ArrayRef> {
53    let array = match data_type {
54        DataType::Decimal128(p, s) => Arc::new(
55            array
56                .as_primitive::<Decimal128Type>()
57                .clone()
58                .with_precision_and_scale(*p, *s)?,
59        ) as ArrayRef,
60        DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new(
61            array
62                .as_primitive::<TimestampNanosecondType>()
63                .clone()
64                .with_timezone_opt(tz.clone()),
65        ),
66        DataType::Timestamp(TimeUnit::Microsecond, tz) => Arc::new(
67            array
68                .as_primitive::<TimestampMicrosecondType>()
69                .clone()
70                .with_timezone_opt(tz.clone()),
71        ),
72        DataType::Timestamp(TimeUnit::Millisecond, tz) => Arc::new(
73            array
74                .as_primitive::<TimestampMillisecondType>()
75                .clone()
76                .with_timezone_opt(tz.clone()),
77        ),
78        DataType::Timestamp(TimeUnit::Second, tz) => Arc::new(
79            array
80                .as_primitive::<TimestampSecondType>()
81                .clone()
82                .with_timezone_opt(tz.clone()),
83        ),
84        // no adjustment needed for other arrays
85        _ => array,
86    };
87    Ok(array)
88}
89
90/// Construct corresponding fields for lexicographical ordering requirement expression
91pub fn ordering_fields(
92    ordering_req: &LexOrdering,
93    // Data type of each expression in the ordering requirement
94    data_types: &[DataType],
95) -> Vec<FieldRef> {
96    ordering_req
97        .iter()
98        .zip(data_types.iter())
99        .map(|(sort_expr, dtype)| {
100            Field::new(
101                sort_expr.expr.to_string().as_str(),
102                dtype.clone(),
103                // Multi partitions may be empty hence field should be nullable.
104                true,
105            )
106        })
107        .map(Arc::new)
108        .collect()
109}
110
111/// Selects the sort option attribute from all the given `PhysicalSortExpr`s.
112pub fn get_sort_options(ordering_req: &LexOrdering) -> Vec<SortOptions> {
113    ordering_req.iter().map(|item| item.options).collect()
114}
115
116/// A wrapper around a type to provide hash for floats
117#[derive(Copy, Clone, Debug)]
118pub struct Hashable<T>(pub T);
119
120impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
121    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
122        self.0.to_byte_slice().hash(state)
123    }
124}
125
126impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
127    fn eq(&self, other: &Self) -> bool {
128        self.0.is_eq(other.0)
129    }
130}
131
132impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}
133
134/// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow
135///
136/// This is needed because different precisions for Decimal128/Decimal256 can
137/// store different ranges of values and thus sum/count may not fit in
138/// the target type.
139///
140/// For example, the precision is 3, the max of value is `999` and the min
141/// value is `-999`
142pub struct DecimalAverager<T: DecimalType> {
143    /// scale factor for sum values (10^sum_scale)
144    sum_mul: T::Native,
145    /// scale factor for target (10^target_scale)
146    target_mul: T::Native,
147    /// the output precision
148    target_precision: u8,
149}
150
151impl<T: DecimalType> DecimalAverager<T> {
152    /// Create a new `DecimalAverager`:
153    ///
154    /// * sum_scale: the scale of `sum` values passed to [`Self::avg`]
155    /// * target_precision: the output precision
156    /// * target_scale: the output scale
157    ///
158    /// Errors if the resulting data can not be stored
159    pub fn try_new(
160        sum_scale: i8,
161        target_precision: u8,
162        target_scale: i8,
163    ) -> Result<Self> {
164        let sum_mul = T::Native::from_usize(10_usize)
165            .map(|b| b.pow_wrapping(sum_scale as u32))
166            .ok_or(DataFusionError::Internal(
167                "Failed to compute sum_mul in DecimalAverager".to_string(),
168            ))?;
169
170        let target_mul = T::Native::from_usize(10_usize)
171            .map(|b| b.pow_wrapping(target_scale as u32))
172            .ok_or(DataFusionError::Internal(
173                "Failed to compute target_mul in DecimalAverager".to_string(),
174            ))?;
175
176        if target_mul >= sum_mul {
177            Ok(Self {
178                sum_mul,
179                target_mul,
180                target_precision,
181            })
182        } else {
183            // can't convert the lit decimal to the returned data type
184            exec_err!("Arithmetic Overflow in AvgAccumulator")
185        }
186    }
187
188    /// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with
189    /// target_scale and target_precision and reporting overflow.
190    ///
191    /// * sum: The total sum value stored as Decimal128 with sum_scale
192    ///   (passed to `Self::try_new`)
193    /// * count: total count, stored as a i128/i256 (*NOT* a Decimal128/Decimal256 value)
194    #[inline(always)]
195    pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
196        if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
197            let new_value = value.div_wrapping(count);
198
199            let validate =
200                T::validate_decimal_precision(new_value, self.target_precision);
201
202            if validate.is_ok() {
203                Ok(new_value)
204            } else {
205                exec_err!("Arithmetic Overflow in AvgAccumulator")
206            }
207        } else {
208            // can't convert the lit decimal to the returned data type
209            exec_err!("Arithmetic Overflow in AvgAccumulator")
210        }
211    }
212}