use crate::aggregates::group_values::multi_group_by::{
GroupColumn, Nulls, nulls_equal_to,
};
use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder;
use arrow::array::ArrowNativeTypeOp;
use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray, cast::AsArray};
use arrow::buffer::ScalarBuffer;
use arrow::datatypes::DataType;
use datafusion_common::Result;
use datafusion_execution::memory_pool::proxy::VecAllocExt;
use itertools::izip;
use std::iter;
use std::sync::Arc;
#[derive(Debug)]
pub struct PrimitiveGroupValueBuilder<T: ArrowPrimitiveType, const NULLABLE: bool> {
data_type: DataType,
group_values: Vec<T::Native>,
nulls: MaybeNullBufferBuilder,
}
impl<T, const NULLABLE: bool> PrimitiveGroupValueBuilder<T, NULLABLE>
where
T: ArrowPrimitiveType,
{
pub fn new(data_type: DataType) -> Self {
Self {
data_type,
group_values: vec![],
nulls: MaybeNullBufferBuilder::new(),
}
}
fn vectorized_equal_to_non_nullable(
&self,
lhs_rows: &[usize],
array: &ArrayRef,
rhs_rows: &[usize],
equal_to_results: &mut [bool],
) {
assert!(
!NULLABLE || (array.null_count() == 0 && !self.nulls.might_have_nulls()),
"called with nullable input"
);
let array_values = array.as_primitive::<T>().values();
let iter = izip!(
lhs_rows.iter(),
rhs_rows.iter(),
equal_to_results.iter_mut(),
);
for (&lhs_row, &rhs_row, equal_to_result) in iter {
let result = {
let left = if cfg!(debug_assertions) {
self.group_values[lhs_row]
} else {
unsafe { *self.group_values.get_unchecked(lhs_row) }
};
let right = if cfg!(debug_assertions) {
array_values[rhs_row]
} else {
unsafe { *array_values.get_unchecked(rhs_row) }
};
left.is_eq(right)
};
*equal_to_result = result && *equal_to_result;
}
}
pub fn vectorized_equal_nullable(
&self,
lhs_rows: &[usize],
array: &ArrayRef,
rhs_rows: &[usize],
equal_to_results: &mut [bool],
) {
assert!(NULLABLE, "called with non-nullable input");
let array = array.as_primitive::<T>();
let iter = izip!(
lhs_rows.iter(),
rhs_rows.iter(),
equal_to_results.iter_mut(),
);
for (&lhs_row, &rhs_row, equal_to_result) in iter {
if !*equal_to_result {
continue;
}
let exist_null = self.nulls.is_null(lhs_row);
let input_null = array.is_null(rhs_row);
if let Some(result) = nulls_equal_to(exist_null, input_null) {
*equal_to_result = result;
continue;
}
*equal_to_result = self.group_values[lhs_row].is_eq(array.value(rhs_row));
}
}
}
impl<T: ArrowPrimitiveType, const NULLABLE: bool> GroupColumn
for PrimitiveGroupValueBuilder<T, NULLABLE>
{
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
if NULLABLE {
let exist_null = self.nulls.is_null(lhs_row);
let input_null = array.is_null(rhs_row);
if let Some(result) = nulls_equal_to(exist_null, input_null) {
return result;
}
}
self.group_values[lhs_row].is_eq(array.as_primitive::<T>().value(rhs_row))
}
fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()> {
if NULLABLE {
if array.is_null(row) {
self.nulls.append(true);
self.group_values.push(T::default_value());
} else {
self.nulls.append(false);
self.group_values.push(array.as_primitive::<T>().value(row));
}
} else {
self.group_values.push(array.as_primitive::<T>().value(row));
}
Ok(())
}
fn vectorized_equal_to(
&self,
lhs_rows: &[usize],
array: &ArrayRef,
rhs_rows: &[usize],
equal_to_results: &mut [bool],
) {
if !NULLABLE || (array.null_count() == 0 && !self.nulls.might_have_nulls()) {
self.vectorized_equal_to_non_nullable(
lhs_rows,
array,
rhs_rows,
equal_to_results,
);
} else {
self.vectorized_equal_nullable(lhs_rows, array, rhs_rows, equal_to_results);
}
}
fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()> {
let arr = array.as_primitive::<T>();
let null_count = array.null_count();
let num_rows = array.len();
let all_null_or_non_null = if null_count == 0 {
Nulls::None
} else if null_count == num_rows {
Nulls::All
} else {
Nulls::Some
};
match (NULLABLE, all_null_or_non_null) {
(true, Nulls::Some) => {
for &row in rows {
if array.is_null(row) {
self.nulls.append(true);
self.group_values.push(T::default_value());
} else {
self.nulls.append(false);
self.group_values.push(arr.value(row));
}
}
}
(true, Nulls::None) => {
self.nulls.append_n(rows.len(), false);
for &row in rows {
self.group_values.push(arr.value(row));
}
}
(true, Nulls::All) => {
self.nulls.append_n(rows.len(), true);
self.group_values
.extend(iter::repeat_n(T::default_value(), rows.len()));
}
(false, _) => {
for &row in rows {
self.group_values.push(arr.value(row));
}
}
}
Ok(())
}
fn len(&self) -> usize {
self.group_values.len()
}
fn size(&self) -> usize {
self.group_values.allocated_size() + self.nulls.allocated_size()
}
fn build(self: Box<Self>) -> ArrayRef {
let Self {
data_type,
group_values,
nulls,
} = *self;
let nulls = nulls.build();
if !NULLABLE {
assert!(nulls.is_none(), "unexpected nulls in non nullable input");
}
let arr = PrimitiveArray::<T>::new(ScalarBuffer::from(group_values), nulls);
Arc::new(arr.with_data_type(data_type))
}
fn take_n(&mut self, n: usize) -> ArrayRef {
let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None };
Arc::new(
PrimitiveArray::<T>::new(ScalarBuffer::from(first_n), first_n_nulls)
.with_data_type(self.data_type.clone()),
)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::aggregates::group_values::multi_group_by::primitive::PrimitiveGroupValueBuilder;
use arrow::array::{ArrayRef, Float32Array, Int64Array, NullBufferBuilder};
use arrow::datatypes::{DataType, Float32Type, Int64Type};
use super::GroupColumn;
#[test]
fn test_nullable_primitive_equal_to() {
let append = |builder: &mut PrimitiveGroupValueBuilder<Float32Type, true>,
builder_array: &ArrayRef,
append_rows: &[usize]| {
for &index in append_rows {
builder.append_val(builder_array, index).unwrap();
}
};
let equal_to = |builder: &PrimitiveGroupValueBuilder<Float32Type, true>,
lhs_rows: &[usize],
input_array: &ArrayRef,
rhs_rows: &[usize],
equal_to_results: &mut Vec<bool>| {
let iter = lhs_rows.iter().zip(rhs_rows.iter());
for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() {
equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row);
}
};
test_nullable_primitive_equal_to_internal(append, equal_to);
}
#[test]
fn test_nullable_primitive_vectorized_equal_to() {
let append = |builder: &mut PrimitiveGroupValueBuilder<Float32Type, true>,
builder_array: &ArrayRef,
append_rows: &[usize]| {
builder
.vectorized_append(builder_array, append_rows)
.unwrap();
};
let equal_to = |builder: &PrimitiveGroupValueBuilder<Float32Type, true>,
lhs_rows: &[usize],
input_array: &ArrayRef,
rhs_rows: &[usize],
equal_to_results: &mut Vec<bool>| {
builder.vectorized_equal_to(
lhs_rows,
input_array,
rhs_rows,
equal_to_results,
);
};
test_nullable_primitive_equal_to_internal(append, equal_to);
}
fn test_nullable_primitive_equal_to_internal<A, E>(mut append: A, mut equal_to: E)
where
A: FnMut(&mut PrimitiveGroupValueBuilder<Float32Type, true>, &ArrayRef, &[usize]),
E: FnMut(
&PrimitiveGroupValueBuilder<Float32Type, true>,
&[usize],
&ArrayRef,
&[usize],
&mut Vec<bool>,
),
{
let mut builder =
PrimitiveGroupValueBuilder::<Float32Type, true>::new(DataType::Float32);
let builder_array = Arc::new(Float32Array::from(vec![
None,
None,
None,
Some(1.0),
Some(2.0),
Some(f32::NAN),
Some(3.0),
])) as ArrayRef;
append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5, 6]);
let (_, values, _nulls) = Float32Array::from(vec![
Some(1.0),
Some(2.0),
None,
Some(1.0),
None,
Some(f32::NAN),
None,
])
.into_parts();
let mut nulls = NullBufferBuilder::new(6);
nulls.append_non_null();
nulls.append_null(); nulls.append_null();
nulls.append_non_null();
nulls.append_null();
nulls.append_non_null();
nulls.append_null();
let input_array = Arc::new(Float32Array::new(values, nulls.finish())) as ArrayRef;
let mut equal_to_results = vec![true; builder.len()];
equal_to(
&builder,
&[0, 1, 2, 3, 4, 5, 6],
&input_array,
&[0, 1, 2, 3, 4, 5, 6],
&mut equal_to_results,
);
assert!(!equal_to_results[0]);
assert!(equal_to_results[1]);
assert!(equal_to_results[2]);
assert!(equal_to_results[3]);
assert!(!equal_to_results[4]);
assert!(equal_to_results[5]);
assert!(!equal_to_results[6]);
}
#[test]
fn test_not_nullable_primitive_equal_to() {
let append = |builder: &mut PrimitiveGroupValueBuilder<Int64Type, false>,
builder_array: &ArrayRef,
append_rows: &[usize]| {
for &index in append_rows {
builder.append_val(builder_array, index).unwrap();
}
};
let equal_to = |builder: &PrimitiveGroupValueBuilder<Int64Type, false>,
lhs_rows: &[usize],
input_array: &ArrayRef,
rhs_rows: &[usize],
equal_to_results: &mut Vec<bool>| {
let iter = lhs_rows.iter().zip(rhs_rows.iter());
for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() {
equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row);
}
};
test_not_nullable_primitive_equal_to_internal(append, equal_to);
}
#[test]
fn test_not_nullable_primitive_vectorized_equal_to() {
let append = |builder: &mut PrimitiveGroupValueBuilder<Int64Type, false>,
builder_array: &ArrayRef,
append_rows: &[usize]| {
builder
.vectorized_append(builder_array, append_rows)
.unwrap();
};
let equal_to = |builder: &PrimitiveGroupValueBuilder<Int64Type, false>,
lhs_rows: &[usize],
input_array: &ArrayRef,
rhs_rows: &[usize],
equal_to_results: &mut Vec<bool>| {
builder.vectorized_equal_to(
lhs_rows,
input_array,
rhs_rows,
equal_to_results,
);
};
test_not_nullable_primitive_equal_to_internal(append, equal_to);
}
fn test_not_nullable_primitive_equal_to_internal<A, E>(mut append: A, mut equal_to: E)
where
A: FnMut(&mut PrimitiveGroupValueBuilder<Int64Type, false>, &ArrayRef, &[usize]),
E: FnMut(
&PrimitiveGroupValueBuilder<Int64Type, false>,
&[usize],
&ArrayRef,
&[usize],
&mut Vec<bool>,
),
{
let mut builder =
PrimitiveGroupValueBuilder::<Int64Type, false>::new(DataType::Int64);
let builder_array =
Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef;
append(&mut builder, &builder_array, &[0, 1]);
let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)])) as ArrayRef;
let mut equal_to_results = vec![true; builder.len()];
equal_to(
&builder,
&[0, 1],
&input_array,
&[0, 1],
&mut equal_to_results,
);
assert!(equal_to_results[0]);
assert!(!equal_to_results[1]);
}
#[test]
fn test_nullable_primitive_vectorized_operation_special_case() {
let mut builder =
PrimitiveGroupValueBuilder::<Int64Type, true>::new(DataType::Int64);
let all_nulls_input_array = Arc::new(Int64Array::from(vec![
Option::<i64>::None,
None,
None,
None,
None,
])) as _;
builder
.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4])
.unwrap();
let mut equal_to_results = vec![true; all_nulls_input_array.len()];
builder.vectorized_equal_to(
&[0, 1, 2, 3, 4],
&all_nulls_input_array,
&[0, 1, 2, 3, 4],
&mut equal_to_results,
);
assert!(equal_to_results[0]);
assert!(equal_to_results[1]);
assert!(equal_to_results[2]);
assert!(equal_to_results[3]);
assert!(equal_to_results[4]);
let all_not_nulls_input_array = Arc::new(Int64Array::from(vec![
Some(1),
Some(2),
Some(3),
Some(4),
Some(5),
])) as _;
builder
.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4])
.unwrap();
let mut equal_to_results = vec![true; all_not_nulls_input_array.len()];
builder.vectorized_equal_to(
&[5, 6, 7, 8, 9],
&all_not_nulls_input_array,
&[0, 1, 2, 3, 4],
&mut equal_to_results,
);
assert!(equal_to_results[0]);
assert!(equal_to_results[1]);
assert!(equal_to_results[2]);
assert!(equal_to_results[3]);
assert!(equal_to_results[4]);
}
}