datafusion_functions_aggregate_common/utils.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{ArrayRef, ArrowNativeTypeOp};
19use arrow::compute::SortOptions;
20use arrow::datatypes::{
21 ArrowNativeType, DataType, DecimalType, Field, FieldRef, ToByteSlice,
22};
23use datafusion_common::{exec_err, internal_datafusion_err, Result};
24use datafusion_expr_common::accumulator::Accumulator;
25use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
26use std::sync::Arc;
27
28/// Convert scalar values from an accumulator into arrays.
29pub fn get_accum_scalar_values_as_arrays(
30 accum: &mut dyn Accumulator,
31) -> Result<Vec<ArrayRef>> {
32 accum
33 .state()?
34 .iter()
35 .map(|s| s.to_array_of_size(1))
36 .collect()
37}
38
39/// Construct corresponding fields for the expressions in an ORDER BY clause.
40pub fn ordering_fields(
41 order_bys: &[PhysicalSortExpr],
42 // Data type of each expression in the ordering requirement
43 data_types: &[DataType],
44) -> Vec<FieldRef> {
45 order_bys
46 .iter()
47 .zip(data_types.iter())
48 .map(|(sort_expr, dtype)| {
49 Field::new(
50 sort_expr.expr.to_string().as_str(),
51 dtype.clone(),
52 // Multi partitions may be empty hence field should be nullable.
53 true,
54 )
55 })
56 .map(Arc::new)
57 .collect()
58}
59
60/// Selects the sort option attribute from all the given `PhysicalSortExpr`s.
61pub fn get_sort_options(ordering_req: &LexOrdering) -> Vec<SortOptions> {
62 ordering_req.iter().map(|item| item.options).collect()
63}
64
65/// A wrapper around a type to provide hash for floats
66#[derive(Copy, Clone, Debug)]
67pub struct Hashable<T>(pub T);
68
69impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
70 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
71 self.0.to_byte_slice().hash(state)
72 }
73}
74
75impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
76 fn eq(&self, other: &Self) -> bool {
77 self.0.is_eq(other.0)
78 }
79}
80
81impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}
82
83/// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow
84///
85/// This is needed because different precisions for Decimal128/Decimal256 can
86/// store different ranges of values and thus sum/count may not fit in
87/// the target type.
88///
89/// For example, the precision is 3, the max of value is `999` and the min
90/// value is `-999`
91pub struct DecimalAverager<T: DecimalType> {
92 /// scale factor for sum values (10^sum_scale)
93 sum_mul: T::Native,
94 /// scale factor for target (10^target_scale)
95 target_mul: T::Native,
96 /// the output precision
97 target_precision: u8,
98 /// the output scale
99 target_scale: i8,
100}
101
102impl<T: DecimalType> DecimalAverager<T> {
103 /// Create a new `DecimalAverager`:
104 ///
105 /// * sum_scale: the scale of `sum` values passed to [`Self::avg`]
106 /// * target_precision: the output precision
107 /// * target_scale: the output scale
108 ///
109 /// Errors if the resulting data can not be stored
110 pub fn try_new(
111 sum_scale: i8,
112 target_precision: u8,
113 target_scale: i8,
114 ) -> Result<Self> {
115 let sum_mul = T::Native::from_usize(10_usize)
116 .map(|b| b.pow_wrapping(sum_scale as u32))
117 .ok_or_else(|| {
118 internal_datafusion_err!("Failed to compute sum_mul in DecimalAverager")
119 })?;
120
121 let target_mul = T::Native::from_usize(10_usize)
122 .map(|b| b.pow_wrapping(target_scale as u32))
123 .ok_or_else(|| {
124 internal_datafusion_err!(
125 "Failed to compute target_mul in DecimalAverager"
126 )
127 })?;
128
129 if target_mul >= sum_mul {
130 Ok(Self {
131 sum_mul,
132 target_mul,
133 target_precision,
134 target_scale,
135 })
136 } else {
137 // can't convert the lit decimal to the returned data type
138 exec_err!("Arithmetic Overflow in AvgAccumulator")
139 }
140 }
141
142 /// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with
143 /// target_scale and target_precision and reporting overflow.
144 ///
145 /// * sum: The total sum value stored as Decimal128 with sum_scale
146 /// (passed to `Self::try_new`)
147 /// * count: total count, stored as a i128/i256 (*NOT* a Decimal128/Decimal256 value)
148 #[inline(always)]
149 pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
150 if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
151 let new_value = value.div_wrapping(count);
152
153 let validate = T::validate_decimal_precision(
154 new_value,
155 self.target_precision,
156 self.target_scale,
157 );
158
159 if validate.is_ok() {
160 Ok(new_value)
161 } else {
162 exec_err!("Arithmetic Overflow in AvgAccumulator")
163 }
164 } else {
165 // can't convert the lit decimal to the returned data type
166 exec_err!("Arithmetic Overflow in AvgAccumulator")
167 }
168 }
169}