datafusion_functions_aggregate_common/
utils.rs1use ahash::RandomState;
19use arrow::array::{
20 Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray,
21};
22use arrow::compute::SortOptions;
23use arrow::datatypes::{
24 ArrowNativeType, DataType, DecimalType, Field, FieldRef, ToByteSlice,
25};
26use datafusion_common::cast::{as_list_array, as_primitive_array};
27use datafusion_common::utils::SingleRowListArrayBuilder;
28use datafusion_common::utils::memory::estimate_memory_size;
29use datafusion_common::{
30 HashSet, Result, ScalarValue, exec_err, internal_datafusion_err,
31};
32use datafusion_expr_common::accumulator::Accumulator;
33use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
34use std::sync::Arc;
35
36pub fn get_accum_scalar_values_as_arrays(
38 accum: &mut dyn Accumulator,
39) -> Result<Vec<ArrayRef>> {
40 accum
41 .state()?
42 .iter()
43 .map(|s| s.to_array_of_size(1))
44 .collect()
45}
46
47pub fn ordering_fields(
49 order_bys: &[PhysicalSortExpr],
50 data_types: &[DataType],
52) -> Vec<FieldRef> {
53 order_bys
54 .iter()
55 .zip(data_types.iter())
56 .map(|(sort_expr, dtype)| {
57 Field::new(
58 sort_expr.expr.to_string().as_str(),
59 dtype.clone(),
60 true,
62 )
63 })
64 .map(Arc::new)
65 .collect()
66}
67
68pub fn get_sort_options(ordering_req: &LexOrdering) -> Vec<SortOptions> {
70 ordering_req.iter().map(|item| item.options).collect()
71}
72
73#[derive(Copy, Clone, Debug)]
75pub struct Hashable<T>(pub T);
76
77impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
78 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
79 self.0.to_byte_slice().hash(state)
80 }
81}
82
83impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
84 fn eq(&self, other: &Self) -> bool {
85 self.0.is_eq(other.0)
86 }
87}
88
89impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}
90
91pub struct DecimalAverager<T: DecimalType> {
100 sum_mul: T::Native,
102 target_mul: T::Native,
104 target_precision: u8,
106 target_scale: i8,
108}
109
110impl<T: DecimalType> DecimalAverager<T> {
111 pub fn try_new(
119 sum_scale: i8,
120 target_precision: u8,
121 target_scale: i8,
122 ) -> Result<Self> {
123 let sum_mul = T::Native::from_usize(10_usize)
124 .map(|b| b.pow_wrapping(sum_scale as u32))
125 .ok_or_else(|| {
126 internal_datafusion_err!("Failed to compute sum_mul in DecimalAverager")
127 })?;
128
129 let target_mul = T::Native::from_usize(10_usize)
130 .map(|b| b.pow_wrapping(target_scale as u32))
131 .ok_or_else(|| {
132 internal_datafusion_err!(
133 "Failed to compute target_mul in DecimalAverager"
134 )
135 })?;
136
137 if target_mul >= sum_mul {
138 Ok(Self {
139 sum_mul,
140 target_mul,
141 target_precision,
142 target_scale,
143 })
144 } else {
145 exec_err!("Arithmetic Overflow in AvgAccumulator")
147 }
148 }
149
150 #[inline(always)]
157 pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
158 if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
159 let new_value = value.div_wrapping(count);
160
161 let validate = T::validate_decimal_precision(
162 new_value,
163 self.target_precision,
164 self.target_scale,
165 );
166
167 if validate.is_ok() {
168 Ok(new_value)
169 } else {
170 exec_err!("Arithmetic Overflow in AvgAccumulator")
171 }
172 } else {
173 exec_err!("Arithmetic Overflow in AvgAccumulator")
175 }
176 }
177}
178
179pub struct GenericDistinctBuffer<T: ArrowPrimitiveType> {
186 pub values: HashSet<Hashable<T::Native>, RandomState>,
187 data_type: DataType,
188}
189
190impl<T: ArrowPrimitiveType> std::fmt::Debug for GenericDistinctBuffer<T> {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 write!(
193 f,
194 "GenericDistinctBuffer({}, values={})",
195 self.data_type,
196 self.values.len()
197 )
198 }
199}
200
201impl<T: ArrowPrimitiveType> GenericDistinctBuffer<T> {
202 pub fn new(data_type: DataType) -> Self {
203 Self {
204 values: HashSet::default(),
205 data_type,
206 }
207 }
208
209 pub fn state(&self) -> Result<Vec<ScalarValue>> {
211 let arr = Arc::new(
212 PrimitiveArray::<T>::from_iter_values(self.values.iter().map(|v| v.0))
213 .with_data_type(self.data_type.clone()),
217 );
218 Ok(vec![
219 SingleRowListArrayBuilder::new(arr).build_list_scalar(),
220 ])
221 }
222
223 pub fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
225 if values.is_empty() {
226 return Ok(());
227 }
228
229 debug_assert_eq!(
230 values.len(),
231 1,
232 "DistinctValuesBuffer::update_batch expects only a single input array"
233 );
234
235 let arr = as_primitive_array::<T>(&values[0])?;
236 if arr.null_count() > 0 {
237 self.values.extend(arr.iter().flatten().map(Hashable));
238 } else {
239 self.values
240 .extend(arr.values().iter().cloned().map(Hashable));
241 }
242
243 Ok(())
244 }
245
246 pub fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
248 if states.is_empty() {
249 return Ok(());
250 }
251
252 let array = as_list_array(&states[0])?;
253 for list in array.iter().flatten() {
254 self.update_batch(&[list])?;
255 }
256
257 Ok(())
258 }
259
260 pub fn size(&self) -> usize {
262 let num_elements = self.values.len();
263 let fixed_size = size_of_val(self) + size_of_val(&self.values);
264 estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap()
265 }
266}