use crate::aggregates::group_values::GroupValues;
use ahash::RandomState;
use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano};
use arrow::array::{
ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, PrimitiveArray,
cast::AsArray,
};
use arrow::datatypes::{DataType, i256};
use datafusion_common::Result;
use datafusion_execution::memory_pool::proxy::VecAllocExt;
use datafusion_expr::EmitTo;
use half::f16;
use hashbrown::hash_table::HashTable;
use std::mem::size_of;
use std::sync::Arc;
pub(crate) trait HashValue {
fn hash(&self, state: &RandomState) -> u64;
}
macro_rules! hash_integer {
($($t:ty),+) => {
$(impl HashValue for $t {
#[cfg(not(feature = "force_hash_collisions"))]
fn hash(&self, state: &RandomState) -> u64 {
state.hash_one(self)
}
#[cfg(feature = "force_hash_collisions")]
fn hash(&self, _state: &RandomState) -> u64 {
0
}
})+
};
}
hash_integer!(i8, i16, i32, i64, i128, i256);
hash_integer!(u8, u16, u32, u64);
hash_integer!(IntervalDayTime, IntervalMonthDayNano);
macro_rules! hash_float {
($($t:ty),+) => {
$(impl HashValue for $t {
#[cfg(not(feature = "force_hash_collisions"))]
fn hash(&self, state: &RandomState) -> u64 {
state.hash_one(self.to_bits())
}
#[cfg(feature = "force_hash_collisions")]
fn hash(&self, _state: &RandomState) -> u64 {
0
}
})+
};
}
hash_float!(f16, f32, f64);
pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> {
data_type: DataType,
map: HashTable<(usize, u64)>,
null_group: Option<usize>,
values: Vec<T::Native>,
random_state: RandomState,
}
impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T> {
pub fn new(data_type: DataType) -> Self {
assert!(PrimitiveArray::<T>::is_compatible(&data_type));
Self {
data_type,
map: HashTable::with_capacity(128),
values: Vec::with_capacity(128),
null_group: None,
random_state: crate::aggregates::AGGREGATION_HASH_SEED,
}
}
}
impl<T: ArrowPrimitiveType> GroupValues for GroupValuesPrimitive<T>
where
T::Native: HashValue,
{
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
assert_eq!(cols.len(), 1);
groups.clear();
for v in cols[0].as_primitive::<T>() {
let group_id = match v {
None => *self.null_group.get_or_insert_with(|| {
let group_id = self.values.len();
self.values.push(Default::default());
group_id
}),
Some(key) => {
let state = &self.random_state;
let hash = key.hash(state);
let insert = self.map.entry(
hash,
|&(g, h)| unsafe {
hash == h && self.values.get_unchecked(g).is_eq(key)
},
|&(_, h)| h,
);
match insert {
hashbrown::hash_table::Entry::Occupied(o) => o.get().0,
hashbrown::hash_table::Entry::Vacant(v) => {
let g = self.values.len();
v.insert((g, hash));
self.values.push(key);
g
}
}
}
};
groups.push(group_id)
}
Ok(())
}
fn size(&self) -> usize {
self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size()
}
fn is_empty(&self) -> bool {
self.values.is_empty()
}
fn len(&self) -> usize {
self.values.len()
}
fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
fn build_primitive<T: ArrowPrimitiveType>(
values: Vec<T::Native>,
null_idx: Option<usize>,
) -> PrimitiveArray<T> {
let nulls = null_idx.map(|null_idx| {
let mut buffer = NullBufferBuilder::new(values.len());
buffer.append_n_non_nulls(null_idx);
buffer.append_null();
buffer.append_n_non_nulls(values.len() - null_idx - 1);
buffer.finish().unwrap()
});
PrimitiveArray::<T>::new(values.into(), nulls)
}
let array: PrimitiveArray<T> = match emit_to {
EmitTo::All => {
self.map.clear();
build_primitive(std::mem::take(&mut self.values), self.null_group.take())
}
EmitTo::First(n) => {
self.map.retain(|entry| {
let group_idx = entry.0;
match group_idx.checked_sub(n) {
Some(sub) => {
entry.0 = sub;
true
}
None => false,
}
});
let null_group = match &mut self.null_group {
Some(v) if *v >= n => {
*v -= n;
None
}
Some(_) => self.null_group.take(),
None => None,
};
let mut split = self.values.split_off(n);
std::mem::swap(&mut self.values, &mut split);
build_primitive(split, null_group)
}
};
Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))])
}
fn clear_shrink(&mut self, num_rows: usize) {
self.values.clear();
self.values.shrink_to(num_rows);
self.map.clear();
self.map.shrink_to(num_rows, |_| 0); }
}