use crate::aggregates::group_values::GroupValues;
use ahash::RandomState;
use arrow::array::{Array, ArrayRef, ListArray, StructArray};
use arrow::compute::cast;
use arrow::datatypes::{DataType, SchemaRef};
use arrow::row::{RowConverter, Rows, SortField};
use datafusion_common::Result;
use datafusion_common::hash_utils::create_hashes;
use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt};
use datafusion_expr::EmitTo;
use hashbrown::hash_table::HashTable;
use log::debug;
use std::mem::size_of;
use std::sync::Arc;
pub struct GroupValuesRows {
schema: SchemaRef,
row_converter: RowConverter,
map: HashTable<(u64, usize)>,
map_size: usize,
group_values: Option<Rows>,
hashes_buffer: Vec<u64>,
rows_buffer: Rows,
random_state: RandomState,
}
impl GroupValuesRows {
pub fn try_new(schema: SchemaRef) -> Result<Self> {
debug!("Creating GroupValuesRows for schema: {schema}");
let row_converter = RowConverter::new(
schema
.fields()
.iter()
.map(|f| SortField::new(f.data_type().clone()))
.collect(),
)?;
let map = HashTable::with_capacity(0);
let starting_rows_capacity = 1000;
let starting_data_capacity = 64 * starting_rows_capacity;
let rows_buffer =
row_converter.empty_rows(starting_rows_capacity, starting_data_capacity);
Ok(Self {
schema,
row_converter,
map,
map_size: 0,
group_values: None,
hashes_buffer: Default::default(),
rows_buffer,
random_state: crate::aggregates::AGGREGATION_HASH_SEED,
})
}
}
impl GroupValues for GroupValuesRows {
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
let group_rows = &mut self.rows_buffer;
group_rows.clear();
self.row_converter.append(group_rows, cols)?;
let n_rows = group_rows.num_rows();
let mut group_values = match self.group_values.take() {
Some(group_values) => group_values,
None => self.row_converter.empty_rows(0, 0),
};
groups.clear();
let batch_hashes = &mut self.hashes_buffer;
batch_hashes.clear();
batch_hashes.resize(n_rows, 0);
create_hashes(cols, &self.random_state, batch_hashes)?;
for (row, &target_hash) in batch_hashes.iter().enumerate() {
let entry = self.map.find_mut(target_hash, |(exist_hash, group_idx)| {
target_hash == *exist_hash
&& group_rows.row(row) == group_values.row(*group_idx)
});
let group_idx = match entry {
Some((_hash, group_idx)) => *group_idx,
None => {
let group_idx = group_values.num_rows();
group_values.push(group_rows.row(row));
self.map.insert_accounted(
(target_hash, group_idx),
|(hash, _group_index)| *hash,
&mut self.map_size,
);
group_idx
}
};
groups.push(group_idx);
}
self.group_values = Some(group_values);
Ok(())
}
fn size(&self) -> usize {
let group_values_size = self.group_values.as_ref().map(|v| v.size()).unwrap_or(0);
self.row_converter.size()
+ group_values_size
+ self.map_size
+ self.rows_buffer.size()
+ self.hashes_buffer.allocated_size()
}
fn is_empty(&self) -> bool {
self.len() == 0
}
fn len(&self) -> usize {
self.group_values
.as_ref()
.map(|group_values| group_values.num_rows())
.unwrap_or(0)
}
fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let mut group_values = self
.group_values
.take()
.expect("Can not emit from empty rows");
let mut output = match emit_to {
EmitTo::All => {
let output = self.row_converter.convert_rows(&group_values)?;
group_values.clear();
self.map.clear();
output
}
EmitTo::First(n) => {
let groups_rows = group_values.iter().take(n);
let output = self.row_converter.convert_rows(groups_rows)?;
let mut new_group_values = self.row_converter.empty_rows(0, 0);
for row in group_values.iter().skip(n) {
new_group_values.push(row);
}
std::mem::swap(&mut new_group_values, &mut group_values);
self.map.retain(|(_exists_hash, group_idx)| {
match group_idx.checked_sub(n) {
Some(sub) => {
*group_idx = sub;
true
}
None => false,
}
});
output
}
};
for (field, array) in self.schema.fields.iter().zip(&mut output) {
let expected = field.data_type();
*array = dictionary_encode_if_necessary(array, expected)?;
}
self.group_values = Some(group_values);
Ok(output)
}
fn clear_shrink(&mut self, num_rows: usize) {
self.group_values = self.group_values.take().map(|mut rows| {
rows.clear();
rows
});
self.map.clear();
self.map.shrink_to(num_rows, |_| 0); self.map_size = self.map.capacity() * size_of::<(u64, usize)>();
self.hashes_buffer.clear();
self.hashes_buffer.shrink_to(num_rows);
}
}
fn dictionary_encode_if_necessary(
array: &ArrayRef,
expected: &DataType,
) -> Result<ArrayRef> {
match (expected, array.data_type()) {
(DataType::Struct(expected_fields), _) => {
let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
let arrays = expected_fields
.iter()
.zip(struct_array.columns())
.map(|(expected_field, column)| {
dictionary_encode_if_necessary(column, expected_field.data_type())
})
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(StructArray::try_new(
expected_fields.clone(),
arrays,
struct_array.nulls().cloned(),
)?))
}
(DataType::List(expected_field), &DataType::List(_)) => {
let list = array.as_any().downcast_ref::<ListArray>().unwrap();
Ok(Arc::new(ListArray::try_new(
Arc::<arrow::datatypes::Field>::clone(expected_field),
list.offsets().clone(),
dictionary_encode_if_necessary(
list.values(),
expected_field.data_type(),
)?,
list.nulls().cloned(),
)?))
}
(DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?),
(_, _) => Ok(Arc::<dyn Array>::clone(array)),
}
}