datafusion_functions_aggregate_common/utils.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{ArrayRef, ArrowNativeTypeOp};
19use arrow::compute::SortOptions;
20use arrow::datatypes::{
21 ArrowNativeType, DataType, DecimalType, Field, FieldRef, ToByteSlice,
22};
23use datafusion_common::{exec_err, DataFusionError, Result};
24use datafusion_expr_common::accumulator::Accumulator;
25use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
26use std::sync::Arc;
27
28/// Convert scalar values from an accumulator into arrays.
29pub fn get_accum_scalar_values_as_arrays(
30 accum: &mut dyn Accumulator,
31) -> Result<Vec<ArrayRef>> {
32 accum
33 .state()?
34 .iter()
35 .map(|s| s.to_array_of_size(1))
36 .collect()
37}
38
39/// Construct corresponding fields for the expressions in an ORDER BY clause.
40pub fn ordering_fields(
41 order_bys: &[PhysicalSortExpr],
42 // Data type of each expression in the ordering requirement
43 data_types: &[DataType],
44) -> Vec<FieldRef> {
45 order_bys
46 .iter()
47 .zip(data_types.iter())
48 .map(|(sort_expr, dtype)| {
49 Field::new(
50 sort_expr.expr.to_string().as_str(),
51 dtype.clone(),
52 // Multi partitions may be empty hence field should be nullable.
53 true,
54 )
55 })
56 .map(Arc::new)
57 .collect()
58}
59
60/// Selects the sort option attribute from all the given `PhysicalSortExpr`s.
61pub fn get_sort_options(ordering_req: &LexOrdering) -> Vec<SortOptions> {
62 ordering_req.iter().map(|item| item.options).collect()
63}
64
65/// A wrapper around a type to provide hash for floats
66#[derive(Copy, Clone, Debug)]
67pub struct Hashable<T>(pub T);
68
69impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
70 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
71 self.0.to_byte_slice().hash(state)
72 }
73}
74
75impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
76 fn eq(&self, other: &Self) -> bool {
77 self.0.is_eq(other.0)
78 }
79}
80
81impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}
82
83/// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow
84///
85/// This is needed because different precisions for Decimal128/Decimal256 can
86/// store different ranges of values and thus sum/count may not fit in
87/// the target type.
88///
89/// For example, the precision is 3, the max of value is `999` and the min
90/// value is `-999`
91pub struct DecimalAverager<T: DecimalType> {
92 /// scale factor for sum values (10^sum_scale)
93 sum_mul: T::Native,
94 /// scale factor for target (10^target_scale)
95 target_mul: T::Native,
96 /// the output precision
97 target_precision: u8,
98}
99
100impl<T: DecimalType> DecimalAverager<T> {
101 /// Create a new `DecimalAverager`:
102 ///
103 /// * sum_scale: the scale of `sum` values passed to [`Self::avg`]
104 /// * target_precision: the output precision
105 /// * target_scale: the output scale
106 ///
107 /// Errors if the resulting data can not be stored
108 pub fn try_new(
109 sum_scale: i8,
110 target_precision: u8,
111 target_scale: i8,
112 ) -> Result<Self> {
113 let sum_mul = T::Native::from_usize(10_usize)
114 .map(|b| b.pow_wrapping(sum_scale as u32))
115 .ok_or(DataFusionError::Internal(
116 "Failed to compute sum_mul in DecimalAverager".to_string(),
117 ))?;
118
119 let target_mul = T::Native::from_usize(10_usize)
120 .map(|b| b.pow_wrapping(target_scale as u32))
121 .ok_or(DataFusionError::Internal(
122 "Failed to compute target_mul in DecimalAverager".to_string(),
123 ))?;
124
125 if target_mul >= sum_mul {
126 Ok(Self {
127 sum_mul,
128 target_mul,
129 target_precision,
130 })
131 } else {
132 // can't convert the lit decimal to the returned data type
133 exec_err!("Arithmetic Overflow in AvgAccumulator")
134 }
135 }
136
137 /// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with
138 /// target_scale and target_precision and reporting overflow.
139 ///
140 /// * sum: The total sum value stored as Decimal128 with sum_scale
141 /// (passed to `Self::try_new`)
142 /// * count: total count, stored as a i128/i256 (*NOT* a Decimal128/Decimal256 value)
143 #[inline(always)]
144 pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
145 if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
146 let new_value = value.div_wrapping(count);
147
148 let validate =
149 T::validate_decimal_precision(new_value, self.target_precision);
150
151 if validate.is_ok() {
152 Ok(new_value)
153 } else {
154 exec_err!("Arithmetic Overflow in AvgAccumulator")
155 }
156 } else {
157 // can't convert the lit decimal to the returned data type
158 exec_err!("Arithmetic Overflow in AvgAccumulator")
159 }
160 }
161}