use ahash::RandomState;
use arrow::array::{
Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray,
};
use arrow::compute::SortOptions;
use arrow::datatypes::{
ArrowNativeType, DataType, DecimalType, Field, FieldRef, ToByteSlice,
};
use datafusion_common::cast::{as_list_array, as_primitive_array};
use datafusion_common::utils::SingleRowListArrayBuilder;
use datafusion_common::utils::memory::estimate_memory_size;
use datafusion_common::{
HashSet, Result, ScalarValue, exec_err, internal_datafusion_err,
};
use datafusion_expr_common::accumulator::Accumulator;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use std::sync::Arc;
pub fn get_accum_scalar_values_as_arrays(
accum: &mut dyn Accumulator,
) -> Result<Vec<ArrayRef>> {
accum
.state()?
.iter()
.map(|s| s.to_array_of_size(1))
.collect()
}
pub fn ordering_fields(
order_bys: &[PhysicalSortExpr],
data_types: &[DataType],
) -> Vec<FieldRef> {
order_bys
.iter()
.zip(data_types.iter())
.map(|(sort_expr, dtype)| {
Field::new(
sort_expr.expr.to_string().as_str(),
dtype.clone(),
true,
)
})
.map(Arc::new)
.collect()
}
pub fn get_sort_options(ordering_req: &LexOrdering) -> Vec<SortOptions> {
ordering_req.iter().map(|item| item.options).collect()
}
#[derive(Copy, Clone, Debug)]
pub struct Hashable<T>(pub T);
impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.to_byte_slice().hash(state)
}
}
impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
fn eq(&self, other: &Self) -> bool {
self.0.is_eq(other.0)
}
}
impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}
pub struct DecimalAverager<T: DecimalType> {
sum_mul: T::Native,
target_mul: T::Native,
target_precision: u8,
target_scale: i8,
}
impl<T: DecimalType> DecimalAverager<T> {
pub fn try_new(
sum_scale: i8,
target_precision: u8,
target_scale: i8,
) -> Result<Self> {
let sum_mul = T::Native::from_usize(10_usize)
.map(|b| b.pow_wrapping(sum_scale as u32))
.ok_or_else(|| {
internal_datafusion_err!("Failed to compute sum_mul in DecimalAverager")
})?;
let target_mul = T::Native::from_usize(10_usize)
.map(|b| b.pow_wrapping(target_scale as u32))
.ok_or_else(|| {
internal_datafusion_err!(
"Failed to compute target_mul in DecimalAverager"
)
})?;
if target_mul >= sum_mul {
Ok(Self {
sum_mul,
target_mul,
target_precision,
target_scale,
})
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
}
#[inline(always)]
pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
let new_value = value.div_wrapping(count);
let validate = T::validate_decimal_precision(
new_value,
self.target_precision,
self.target_scale,
);
if validate.is_ok() {
Ok(new_value)
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
}
}
pub struct GenericDistinctBuffer<T: ArrowPrimitiveType> {
pub values: HashSet<Hashable<T::Native>, RandomState>,
data_type: DataType,
}
impl<T: ArrowPrimitiveType> std::fmt::Debug for GenericDistinctBuffer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"GenericDistinctBuffer({}, values={})",
self.data_type,
self.values.len()
)
}
}
impl<T: ArrowPrimitiveType> GenericDistinctBuffer<T> {
pub fn new(data_type: DataType) -> Self {
Self {
values: HashSet::default(),
data_type,
}
}
pub fn state(&self) -> Result<Vec<ScalarValue>> {
let arr = Arc::new(
PrimitiveArray::<T>::from_iter_values(self.values.iter().map(|v| v.0))
.with_data_type(self.data_type.clone()),
);
Ok(vec![
SingleRowListArrayBuilder::new(arr).build_list_scalar(),
])
}
pub fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
debug_assert_eq!(
values.len(),
1,
"DistinctValuesBuffer::update_batch expects only a single input array"
);
let arr = as_primitive_array::<T>(&values[0])?;
if arr.null_count() > 0 {
self.values.extend(arr.iter().flatten().map(Hashable));
} else {
self.values
.extend(arr.values().iter().cloned().map(Hashable));
}
Ok(())
}
pub fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
let array = as_list_array(&states[0])?;
for list in array.iter().flatten() {
self.update_batch(&[list])?;
}
Ok(())
}
pub fn size(&self) -> usize {
let num_elements = self.values.len();
let fixed_size = size_of_val(self) + size_of_val(&self.values);
estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap()
}
}