use vortex_error::VortexResult;
use vortex_mask::AllOr;
use vortex_mask::Mask;
use super::Sum;
use super::primitive::sum_float_all;
use super::primitive::sum_signed_all;
use super::primitive::sum_unsigned_all;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::aggregate_fn::AggregateFnRef;
use crate::aggregate_fn::GroupRanges;
use crate::aggregate_fn::GroupedArray;
use crate::aggregate_fn::kernels::DynGroupedAggregateKernel;
use crate::arrays::Primitive;
use crate::arrays::PrimitiveArray;
use crate::dtype::NativePType;
use crate::match_each_native_ptype;
#[derive(Debug)]
pub(crate) struct PrimitiveGroupedSumEncodingKernel;
impl DynGroupedAggregateKernel for PrimitiveGroupedSumEncodingKernel {
fn grouped_aggregate(
&self,
aggregate_fn: &AggregateFnRef,
groups: &GroupedArray,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
if !aggregate_fn.is::<Sum>() {
return Ok(None);
}
try_grouped_sum(groups, ctx)
}
}
pub(super) fn try_grouped_sum(
groups: &GroupedArray,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
if !groups.elements().is::<Primitive>() {
return Ok(None);
}
let elements = groups.elements().clone().downcast::<Primitive>();
let group_ranges = groups.group_ranges(ctx)?;
let group_validity = groups.group_validity(ctx)?;
Ok(Some(grouped_sum(
&elements,
&group_ranges,
&group_validity,
ctx,
)?))
}
fn grouped_sum(
elements: &PrimitiveArray,
group_ranges: &GroupRanges,
group_validity: &Mask,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let elem_mask = elements
.as_ref()
.validity()?
.execute_mask(elements.as_ref().len(), ctx)?;
let all_valid = matches!(elem_mask.slices(), AllOr::All);
let result = match_each_native_ptype!(elements.ptype(),
unsigned: |T| {
let values = elements.as_slice::<T>();
collect_sums::<T, u64>(values, group_ranges, group_validity, &elem_mask, all_valid,
sum_unsigned_all)
},
signed: |T| {
let values = elements.as_slice::<T>();
collect_sums::<T, i64>(values, group_ranges, group_validity, &elem_mask, all_valid,
sum_signed_all)
},
floating: |T| {
let values = elements.as_slice::<T>();
collect_sums::<T, f64>(values, group_ranges, group_validity, &elem_mask, all_valid,
|acc, slice| { sum_float_all(acc, slice); false })
}
);
Ok(result.into_array())
}
fn collect_sums<T: NativePType, A: NativePType + Default>(
values: &[T],
group_ranges: &GroupRanges,
group_validity: &Mask,
elem_mask: &Mask,
all_valid: bool,
sum_run: impl Fn(&mut A, &[T]) -> bool,
) -> PrimitiveArray {
let sums = group_ranges.iter().enumerate().map(|(i, (offset, size))| {
if !group_validity.value(i) {
return None;
}
let mut acc = A::default();
let overflow = if all_valid {
sum_run(&mut acc, &values[offset..offset + size])
} else {
sum_masked_group(&mut acc, values, offset, size, elem_mask, &sum_run)
};
(!overflow).then_some(acc)
});
PrimitiveArray::from_option_iter(sums)
}
fn sum_masked_group<T: NativePType, A>(
acc: &mut A,
values: &[T],
offset: usize,
size: usize,
elem_mask: &Mask,
sum_run: &impl Fn(&mut A, &[T]) -> bool,
) -> bool {
match elem_mask.slice(offset..offset + size).slices() {
AllOr::All => sum_run(acc, &values[offset..offset + size]),
AllOr::None => false,
AllOr::Some(runs) => {
for &(start, end) in runs {
if sum_run(acc, &values[offset + start..offset + end]) {
return true;
}
}
false
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::cast_possible_truncation)]
use vortex_buffer::buffer;
use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::IntoArray;
use crate::LEGACY_SESSION;
use crate::VortexSessionExecute;
use crate::aggregate_fn::DynGroupedAccumulator;
use crate::aggregate_fn::EmptyOptions;
use crate::aggregate_fn::GroupedAccumulator;
use crate::aggregate_fn::fns::sum::Sum;
use crate::aggregate_fn::fns::sum::sum;
use crate::arrays::FixedSizeListArray;
use crate::arrays::ListViewArray;
use crate::arrays::PrimitiveArray;
use crate::assert_arrays_eq;
use crate::builders::builder_with_capacity;
use crate::dtype::DType;
use crate::dtype::Nullability::NonNullable;
use crate::dtype::Nullability::Nullable;
use crate::dtype::PType;
use crate::validity::Validity;
fn grouped_sum_actual(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult<ArrayRef> {
let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?;
acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?;
acc.finish()
}
fn grouped_sum_reference(
elements: &ArrayRef,
ranges: &[(usize, usize)],
group_valid: &[bool],
elem_dtype: &DType,
) -> VortexResult<ArrayRef> {
use crate::aggregate_fn::AggregateFnVTable;
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let sum_dtype = Sum
.partial_dtype(&EmptyOptions, elem_dtype)
.expect("sum partial dtype");
let mut builder = builder_with_capacity(&sum_dtype, ranges.len());
for (i, &(offset, size)) in ranges.iter().enumerate() {
if group_valid[i] {
let slice = elements.slice(offset..offset + size)?;
builder.append_scalar(&sum(&slice, &mut ctx)?)?;
} else {
builder.append_null();
}
}
Ok(builder.finish())
}
fn offsets_sizes(ranges: &[(usize, usize)]) -> (ArrayRef, ArrayRef) {
let offsets = PrimitiveArray::from_iter(ranges.iter().map(|&(o, _)| o as i32));
let sizes = PrimitiveArray::from_iter(ranges.iter().map(|&(_, s)| s as i32));
(offsets.into_array(), sizes.into_array())
}
fn listview(
elements: ArrayRef,
ranges: &[(usize, usize)],
group_valid: &[bool],
) -> VortexResult<ArrayRef> {
let (offsets, sizes) = offsets_sizes(ranges);
let validity = if group_valid.iter().all(|&v| v) {
Validity::NonNullable
} else {
Validity::from_iter(group_valid.iter().copied())
};
Ok(ListViewArray::try_new(elements, offsets, sizes, validity)?.into_array())
}
#[test]
fn listview_matches_reference_unsigned() -> VortexResult<()> {
let elements =
PrimitiveArray::new(buffer![1u32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array();
let elem_dtype = DType::Primitive(PType::U32, NonNullable);
let ranges = [(0, 2), (2, 1), (3, 3)];
let valid = [true, true, true];
let groups = listview(elements.clone(), &ranges, &valid)?;
let actual = grouped_sum_actual(&groups, &elem_dtype)?;
let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?;
let direct = PrimitiveArray::from_option_iter([Some(3u64), Some(3u64), Some(15u64)]);
assert_arrays_eq!(&actual, &direct.into_array());
assert_arrays_eq!(&actual, &expected);
Ok(())
}
#[test]
fn listview_out_of_order_offsets_with_null_group() -> VortexResult<()> {
let elements =
PrimitiveArray::new(buffer![10i32, 20, 30, 40, 50, 60], Validity::NonNullable)
.into_array();
let elem_dtype = DType::Primitive(PType::I32, NonNullable);
let ranges = [(4, 2), (0, 2), (2, 2)];
let valid = [true, false, true];
let groups = listview(elements.clone(), &ranges, &valid)?;
let actual = grouped_sum_actual(&groups, &elem_dtype)?;
let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?;
let direct = PrimitiveArray::from_option_iter([Some(110i64), None, Some(70i64)]);
assert_arrays_eq!(&actual, &direct.into_array());
assert_arrays_eq!(&actual, &expected);
Ok(())
}
#[test]
fn listview_interior_and_full_nulls() -> VortexResult<()> {
let elements =
PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, None, Some(9)])
.into_array();
let elem_dtype = DType::Primitive(PType::I32, Nullable);
let ranges = [(0, 3), (3, 2), (5, 0), (5, 1)];
let valid = [true, true, true, true];
let groups = listview(elements.clone(), &ranges, &valid)?;
let actual = grouped_sum_actual(&groups, &elem_dtype)?;
let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?;
let direct =
PrimitiveArray::from_option_iter([Some(4i64), Some(0i64), Some(0i64), Some(9i64)]);
assert_arrays_eq!(&actual, &direct.into_array());
assert_arrays_eq!(&actual, &expected);
Ok(())
}
#[test]
fn listview_overflow_group_is_null() -> VortexResult<()> {
let elements =
PrimitiveArray::new(buffer![i64::MAX, 1, 2, 3], Validity::NonNullable).into_array();
let elem_dtype = DType::Primitive(PType::I64, NonNullable);
let ranges = [(0, 2), (2, 2)];
let valid = [true, true];
let groups = listview(elements.clone(), &ranges, &valid)?;
let actual = grouped_sum_actual(&groups, &elem_dtype)?;
let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?;
let direct = PrimitiveArray::from_option_iter([None, Some(5i64)]);
assert_arrays_eq!(&actual, &direct.into_array());
assert_arrays_eq!(&actual, &expected);
Ok(())
}
#[test]
fn listview_float_nan_and_inf() -> VortexResult<()> {
let elements = PrimitiveArray::new(
buffer![1.0f64, f64::NAN, 2.0, f64::INFINITY, f64::NEG_INFINITY, 4.0],
Validity::NonNullable,
)
.into_array();
let elem_dtype = DType::Primitive(PType::F64, NonNullable);
let ranges = [(0, 3), (3, 3)];
let valid = [true, true];
let groups = listview(elements.clone(), &ranges, &valid)?;
let actual = grouped_sum_actual(&groups, &elem_dtype)?;
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?;
let g0 = actual.execute_scalar(0, &mut ctx)?;
assert_eq!(g0.as_primitive().typed_value::<f64>(), Some(3.0));
assert_eq!(
g0.as_primitive().typed_value::<f64>(),
expected
.execute_scalar(0, &mut ctx)?
.as_primitive()
.typed_value::<f64>()
);
let g1 = actual.execute_scalar(1, &mut ctx)?;
assert!(g1.as_primitive().typed_value::<f64>().unwrap().is_nan());
assert!(
expected
.execute_scalar(1, &mut ctx)?
.as_primitive()
.typed_value::<f64>()
.unwrap()
.is_nan()
);
Ok(())
}
#[test]
fn fixed_size_overflow_and_nan() -> VortexResult<()> {
let elements =
PrimitiveArray::new(buffer![i64::MAX, 1, 2, 3], Validity::NonNullable).into_array();
let elem_dtype = DType::Primitive(PType::I64, NonNullable);
let groups = FixedSizeListArray::try_new(elements.clone(), 2, Validity::NonNullable, 2)?
.into_array();
let actual = grouped_sum_actual(&groups, &elem_dtype)?;
let expected =
grouped_sum_reference(&elements, &[(0, 2), (2, 2)], &[true, true], &elem_dtype)?;
let direct = PrimitiveArray::from_option_iter([None, Some(5i64)]);
assert_arrays_eq!(&actual, &direct.into_array());
assert_arrays_eq!(&actual, &expected);
Ok(())
}
}