datafusion_functions_aggregate_common/
utils.rs1use std::sync::Arc;
19
20use arrow::array::{ArrayRef, AsArray};
21use arrow::datatypes::{ArrowNativeType, FieldRef};
22use arrow::{
23 array::ArrowNativeTypeOp,
24 compute::SortOptions,
25 datatypes::{
26 DataType, Decimal128Type, DecimalType, Field, TimeUnit, TimestampMicrosecondType,
27 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
28 ToByteSlice,
29 },
30};
31use datafusion_common::{exec_err, DataFusionError, Result};
32use datafusion_expr_common::accumulator::Accumulator;
33use datafusion_physical_expr_common::sort_expr::LexOrdering;
34
35pub fn get_accum_scalar_values_as_arrays(
37 accum: &mut dyn Accumulator,
38) -> Result<Vec<ArrayRef>> {
39 accum
40 .state()?
41 .iter()
42 .map(|s| s.to_array_of_size(1))
43 .collect()
44}
45
46#[deprecated(since = "44.0.0", note = "use PrimitiveArray::with_datatype")]
52pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result<ArrayRef> {
53 let array = match data_type {
54 DataType::Decimal128(p, s) => Arc::new(
55 array
56 .as_primitive::<Decimal128Type>()
57 .clone()
58 .with_precision_and_scale(*p, *s)?,
59 ) as ArrayRef,
60 DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new(
61 array
62 .as_primitive::<TimestampNanosecondType>()
63 .clone()
64 .with_timezone_opt(tz.clone()),
65 ),
66 DataType::Timestamp(TimeUnit::Microsecond, tz) => Arc::new(
67 array
68 .as_primitive::<TimestampMicrosecondType>()
69 .clone()
70 .with_timezone_opt(tz.clone()),
71 ),
72 DataType::Timestamp(TimeUnit::Millisecond, tz) => Arc::new(
73 array
74 .as_primitive::<TimestampMillisecondType>()
75 .clone()
76 .with_timezone_opt(tz.clone()),
77 ),
78 DataType::Timestamp(TimeUnit::Second, tz) => Arc::new(
79 array
80 .as_primitive::<TimestampSecondType>()
81 .clone()
82 .with_timezone_opt(tz.clone()),
83 ),
84 _ => array,
86 };
87 Ok(array)
88}
89
90pub fn ordering_fields(
92 ordering_req: &LexOrdering,
93 data_types: &[DataType],
95) -> Vec<FieldRef> {
96 ordering_req
97 .iter()
98 .zip(data_types.iter())
99 .map(|(sort_expr, dtype)| {
100 Field::new(
101 sort_expr.expr.to_string().as_str(),
102 dtype.clone(),
103 true,
105 )
106 })
107 .map(Arc::new)
108 .collect()
109}
110
111pub fn get_sort_options(ordering_req: &LexOrdering) -> Vec<SortOptions> {
113 ordering_req.iter().map(|item| item.options).collect()
114}
115
116#[derive(Copy, Clone, Debug)]
118pub struct Hashable<T>(pub T);
119
120impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
121 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
122 self.0.to_byte_slice().hash(state)
123 }
124}
125
126impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
127 fn eq(&self, other: &Self) -> bool {
128 self.0.is_eq(other.0)
129 }
130}
131
132impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}
133
134pub struct DecimalAverager<T: DecimalType> {
143 sum_mul: T::Native,
145 target_mul: T::Native,
147 target_precision: u8,
149}
150
151impl<T: DecimalType> DecimalAverager<T> {
152 pub fn try_new(
160 sum_scale: i8,
161 target_precision: u8,
162 target_scale: i8,
163 ) -> Result<Self> {
164 let sum_mul = T::Native::from_usize(10_usize)
165 .map(|b| b.pow_wrapping(sum_scale as u32))
166 .ok_or(DataFusionError::Internal(
167 "Failed to compute sum_mul in DecimalAverager".to_string(),
168 ))?;
169
170 let target_mul = T::Native::from_usize(10_usize)
171 .map(|b| b.pow_wrapping(target_scale as u32))
172 .ok_or(DataFusionError::Internal(
173 "Failed to compute target_mul in DecimalAverager".to_string(),
174 ))?;
175
176 if target_mul >= sum_mul {
177 Ok(Self {
178 sum_mul,
179 target_mul,
180 target_precision,
181 })
182 } else {
183 exec_err!("Arithmetic Overflow in AvgAccumulator")
185 }
186 }
187
188 #[inline(always)]
195 pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
196 if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
197 let new_value = value.div_wrapping(count);
198
199 let validate =
200 T::validate_decimal_precision(new_value, self.target_precision);
201
202 if validate.is_ok() {
203 Ok(new_value)
204 } else {
205 exec_err!("Arithmetic Overflow in AvgAccumulator")
206 }
207 } else {
208 exec_err!("Arithmetic Overflow in AvgAccumulator")
210 }
211 }
212}