datafusion_functions_aggregate_common/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use ahash::RandomState;
19use arrow::array::{
20    Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray,
21};
22use arrow::compute::SortOptions;
23use arrow::datatypes::{
24    ArrowNativeType, DataType, DecimalType, Field, FieldRef, ToByteSlice,
25};
26use datafusion_common::cast::{as_list_array, as_primitive_array};
27use datafusion_common::utils::SingleRowListArrayBuilder;
28use datafusion_common::utils::memory::estimate_memory_size;
29use datafusion_common::{
30    HashSet, Result, ScalarValue, exec_err, internal_datafusion_err,
31};
32use datafusion_expr_common::accumulator::Accumulator;
33use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
34use std::sync::Arc;
35
36/// Convert scalar values from an accumulator into arrays.
37pub fn get_accum_scalar_values_as_arrays(
38    accum: &mut dyn Accumulator,
39) -> Result<Vec<ArrayRef>> {
40    accum
41        .state()?
42        .iter()
43        .map(|s| s.to_array_of_size(1))
44        .collect()
45}
46
47/// Construct corresponding fields for the expressions in an ORDER BY clause.
48pub fn ordering_fields(
49    order_bys: &[PhysicalSortExpr],
50    // Data type of each expression in the ordering requirement
51    data_types: &[DataType],
52) -> Vec<FieldRef> {
53    order_bys
54        .iter()
55        .zip(data_types.iter())
56        .map(|(sort_expr, dtype)| {
57            Field::new(
58                sort_expr.expr.to_string().as_str(),
59                dtype.clone(),
60                // Multi partitions may be empty hence field should be nullable.
61                true,
62            )
63        })
64        .map(Arc::new)
65        .collect()
66}
67
68/// Selects the sort option attribute from all the given `PhysicalSortExpr`s.
69pub fn get_sort_options(ordering_req: &LexOrdering) -> Vec<SortOptions> {
70    ordering_req.iter().map(|item| item.options).collect()
71}
72
73/// A wrapper around a type to provide hash for floats
74#[derive(Copy, Clone, Debug)]
75pub struct Hashable<T>(pub T);
76
77impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
78    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
79        self.0.to_byte_slice().hash(state)
80    }
81}
82
83impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
84    fn eq(&self, other: &Self) -> bool {
85        self.0.is_eq(other.0)
86    }
87}
88
89impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}
90
91/// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow
92///
93/// This is needed because different precisions for Decimal128/Decimal256 can
94/// store different ranges of values and thus sum/count may not fit in
95/// the target type.
96///
97/// For example, the precision is 3, the max of value is `999` and the min
98/// value is `-999`
99pub struct DecimalAverager<T: DecimalType> {
100    /// scale factor for sum values (10^sum_scale)
101    sum_mul: T::Native,
102    /// scale factor for target (10^target_scale)
103    target_mul: T::Native,
104    /// the output precision
105    target_precision: u8,
106    /// the output scale
107    target_scale: i8,
108}
109
110impl<T: DecimalType> DecimalAverager<T> {
111    /// Create a new `DecimalAverager`:
112    ///
113    /// * sum_scale: the scale of `sum` values passed to [`Self::avg`]
114    /// * target_precision: the output precision
115    /// * target_scale: the output scale
116    ///
117    /// Errors if the resulting data can not be stored
118    pub fn try_new(
119        sum_scale: i8,
120        target_precision: u8,
121        target_scale: i8,
122    ) -> Result<Self> {
123        let sum_mul = T::Native::from_usize(10_usize)
124            .map(|b| b.pow_wrapping(sum_scale as u32))
125            .ok_or_else(|| {
126                internal_datafusion_err!("Failed to compute sum_mul in DecimalAverager")
127            })?;
128
129        let target_mul = T::Native::from_usize(10_usize)
130            .map(|b| b.pow_wrapping(target_scale as u32))
131            .ok_or_else(|| {
132                internal_datafusion_err!(
133                    "Failed to compute target_mul in DecimalAverager"
134                )
135            })?;
136
137        if target_mul >= sum_mul {
138            Ok(Self {
139                sum_mul,
140                target_mul,
141                target_precision,
142                target_scale,
143            })
144        } else {
145            // can't convert the lit decimal to the returned data type
146            exec_err!("Arithmetic Overflow in AvgAccumulator")
147        }
148    }
149
150    /// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with
151    /// target_scale and target_precision and reporting overflow.
152    ///
153    /// * sum: The total sum value stored as Decimal128 with sum_scale
154    ///   (passed to `Self::try_new`)
155    /// * count: total count, stored as a i128/i256 (*NOT* a Decimal128/Decimal256 value)
156    #[inline(always)]
157    pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
158        if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
159            let new_value = value.div_wrapping(count);
160
161            let validate = T::validate_decimal_precision(
162                new_value,
163                self.target_precision,
164                self.target_scale,
165            );
166
167            if validate.is_ok() {
168                Ok(new_value)
169            } else {
170                exec_err!("Arithmetic Overflow in AvgAccumulator")
171            }
172        } else {
173            // can't convert the lit decimal to the returned data type
174            exec_err!("Arithmetic Overflow in AvgAccumulator")
175        }
176    }
177}
178
179/// Generic way to collect distinct values for accumulators.
180///
181/// The intermediate state is represented as a List of scalar values updated by
182/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values
183/// in the final evaluation step so that we avoid expensive conversions and
184/// allocations during `update_batch`.
185pub struct GenericDistinctBuffer<T: ArrowPrimitiveType> {
186    pub values: HashSet<Hashable<T::Native>, RandomState>,
187    data_type: DataType,
188}
189
190impl<T: ArrowPrimitiveType> std::fmt::Debug for GenericDistinctBuffer<T> {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        write!(
193            f,
194            "GenericDistinctBuffer({}, values={})",
195            self.data_type,
196            self.values.len()
197        )
198    }
199}
200
201impl<T: ArrowPrimitiveType> GenericDistinctBuffer<T> {
202    pub fn new(data_type: DataType) -> Self {
203        Self {
204            values: HashSet::default(),
205            data_type,
206        }
207    }
208
209    /// Mirrors [`Accumulator::state`].
210    pub fn state(&self) -> Result<Vec<ScalarValue>> {
211        let arr = Arc::new(
212            PrimitiveArray::<T>::from_iter_values(self.values.iter().map(|v| v.0))
213                // Ideally we'd just use T::DATA_TYPE but this misses things like
214                // decimal scale/precision and timestamp timezones, which need to
215                // match up with Accumulator::state_fields
216                .with_data_type(self.data_type.clone()),
217        );
218        Ok(vec![
219            SingleRowListArrayBuilder::new(arr).build_list_scalar(),
220        ])
221    }
222
223    /// Mirrors [`Accumulator::update_batch`].
224    pub fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
225        if values.is_empty() {
226            return Ok(());
227        }
228
229        debug_assert_eq!(
230            values.len(),
231            1,
232            "DistinctValuesBuffer::update_batch expects only a single input array"
233        );
234
235        let arr = as_primitive_array::<T>(&values[0])?;
236        if arr.null_count() > 0 {
237            self.values.extend(arr.iter().flatten().map(Hashable));
238        } else {
239            self.values
240                .extend(arr.values().iter().cloned().map(Hashable));
241        }
242
243        Ok(())
244    }
245
246    /// Mirrors [`Accumulator::merge_batch`].
247    pub fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
248        if states.is_empty() {
249            return Ok(());
250        }
251
252        let array = as_list_array(&states[0])?;
253        for list in array.iter().flatten() {
254            self.update_batch(&[list])?;
255        }
256
257        Ok(())
258    }
259
260    /// Mirrors [`Accumulator::size`].
261    pub fn size(&self) -> usize {
262        let num_elements = self.values.len();
263        let fixed_size = size_of_val(self) + size_of_val(&self.values);
264        estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap()
265    }
266}