use crate::utils::{build_bool_state, is_valid_decimal_precision};
use arrow::array::{
cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array,
};
use arrow::datatypes::{DataType, Field, FieldRef};
use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer};
use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue};
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::Volatility::Immutable;
use datafusion::logical_expr::{
Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature,
};
use std::{any::Any, ops::BitAnd, sync::Arc};
#[derive(Debug)]
pub struct SumDecimal {
signature: Signature,
result_type: DataType,
precision: u8,
scale: i8,
}
impl SumDecimal {
pub fn try_new(data_type: DataType) -> DFResult<Self> {
let (precision, scale) = match data_type {
DataType::Decimal128(p, s) => (p, s),
_ => {
return Err(DataFusionError::Internal(
"Invalid data type for SumDecimal".into(),
))
}
};
Ok(Self {
signature: Signature::user_defined(Immutable),
result_type: data_type,
precision,
scale,
})
}
}
impl AggregateUDFImpl for SumDecimal {
fn as_any(&self) -> &dyn Any {
self
}
fn accumulator(&self, _args: AccumulatorArgs) -> DFResult<Box<dyn Accumulator>> {
Ok(Box::new(SumDecimalAccumulator::new(
self.precision,
self.scale,
)))
}
fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<FieldRef>> {
let fields = vec![
Arc::new(Field::new(
self.name(),
self.result_type.clone(),
self.is_nullable(),
)),
Arc::new(Field::new("is_empty", DataType::Boolean, false)),
];
Ok(fields)
}
fn name(&self) -> &str {
"sum"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(self.result_type.clone())
}
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> DFResult<Box<dyn GroupsAccumulator>> {
Ok(Box::new(SumDecimalGroupsAccumulator::new(
self.result_type.clone(),
self.precision,
)))
}
fn default_value(&self, _data_type: &DataType) -> DFResult<ScalarValue> {
ScalarValue::new_primitive::<Decimal128Type>(
None,
&DataType::Decimal128(self.precision, self.scale),
)
}
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::Identical
}
fn is_nullable(&self) -> bool {
true
}
}
#[derive(Debug)]
struct SumDecimalAccumulator {
sum: i128,
is_empty: bool,
is_not_null: bool,
precision: u8,
scale: i8,
}
impl SumDecimalAccumulator {
fn new(precision: u8, scale: i8) -> Self {
Self {
sum: 0,
is_empty: true,
is_not_null: true,
precision,
scale,
}
}
fn update_single(&mut self, values: &Decimal128Array, idx: usize) {
let v = unsafe { values.value_unchecked(idx) };
let (new_sum, is_overflow) = self.sum.overflowing_add(v);
if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
self.is_not_null = false;
return;
}
self.sum = new_sum;
self.is_not_null = true;
}
}
impl Accumulator for SumDecimalAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
assert_eq!(
values.len(),
1,
"Expect only one element in 'values' but found {}",
values.len()
);
if !self.is_empty && !self.is_not_null {
return Ok(());
}
let values = &values[0];
let data = values.as_primitive::<Decimal128Type>();
self.is_empty = self.is_empty && values.len() == values.null_count();
if values.null_count() == 0 {
for i in 0..data.len() {
self.update_single(data, i);
}
} else {
for i in 0..data.len() {
if data.is_null(i) {
continue;
}
self.update_single(data, i);
}
}
Ok(())
}
fn evaluate(&mut self) -> DFResult<ScalarValue> {
if self.is_empty
|| !self.is_not_null
|| !is_valid_decimal_precision(self.sum, self.precision)
{
ScalarValue::new_primitive::<Decimal128Type>(
None,
&DataType::Decimal128(self.precision, self.scale),
)
} else {
ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)
}
}
fn size(&self) -> usize {
std::mem::size_of_val(self)
}
fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
let sum = if self.is_not_null {
ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)?
} else {
ScalarValue::new_primitive::<Decimal128Type>(
None,
&DataType::Decimal128(self.precision, self.scale),
)?
};
Ok(vec![sum, ScalarValue::from(self.is_empty)])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
assert_eq!(
states.len(),
2,
"Expect two element in 'states' but found {}",
states.len()
);
assert_eq!(states[0].len(), 1);
assert_eq!(states[1].len(), 1);
let that_sum = states[0].as_primitive::<Decimal128Type>();
let that_is_empty = states[1].as_any().downcast_ref::<BooleanArray>().unwrap();
let this_overflow = !self.is_empty && !self.is_not_null;
let that_overflow = !that_is_empty.value(0) && that_sum.is_null(0);
self.is_not_null = !this_overflow && !that_overflow;
self.is_empty = self.is_empty && that_is_empty.value(0);
if self.is_not_null {
self.sum += that_sum.value(0);
}
Ok(())
}
}
struct SumDecimalGroupsAccumulator {
is_not_null: BooleanBufferBuilder,
is_empty: BooleanBufferBuilder,
sum: Vec<i128>,
result_type: DataType,
precision: u8,
}
impl SumDecimalGroupsAccumulator {
fn new(result_type: DataType, precision: u8) -> Self {
Self {
is_not_null: BooleanBufferBuilder::new(0),
is_empty: BooleanBufferBuilder::new(0),
sum: Vec::new(),
result_type,
precision,
}
}
fn is_overflow(&self, index: usize) -> bool {
!self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index)
}
#[inline]
fn update_single(&mut self, group_index: usize, value: i128) {
self.is_empty.set_bit(group_index, false);
let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value);
self.sum[group_index] = new_sum;
if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
self.is_not_null.set_bit(group_index, false);
}
}
}
fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) {
if builder.len() < capacity {
let additional = capacity - builder.len();
builder.append_n(additional, true);
}
}
impl GroupsAccumulator for SumDecimalGroupsAccumulator {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> DFResult<()> {
assert!(opt_filter.is_none(), "opt_filter is not supported yet");
assert_eq!(values.len(), 1);
let values = values[0].as_primitive::<Decimal128Type>();
let data = values.values();
self.sum.resize(total_num_groups, 0);
ensure_bit_capacity(&mut self.is_empty, total_num_groups);
ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
let iter = group_indices.iter().zip(data.iter());
if values.null_count() == 0 {
for (&group_index, &value) in iter {
self.update_single(group_index, value);
}
} else {
for (idx, (&group_index, &value)) in iter.enumerate() {
if values.is_null(idx) {
continue;
}
self.update_single(group_index, value);
}
}
Ok(())
}
fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
let result = emit_to.take_needed(&mut self.sum);
result.iter().enumerate().for_each(|(i, &v)| {
if !is_valid_decimal_precision(v, self.precision) {
self.is_not_null.set_bit(i, false);
}
});
let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
let x = (!&is_empty).bitand(&nulls);
let result = Decimal128Array::new(result.into(), Some(NullBuffer::new(x)))
.with_data_type(self.result_type.clone());
Ok(Arc::new(result))
}
fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
let nulls = Some(NullBuffer::new(nulls));
let sum = emit_to.take_needed(&mut self.sum);
let sum = Decimal128Array::new(sum.into(), nulls.clone())
.with_data_type(self.result_type.clone());
let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
let is_empty = BooleanArray::new(is_empty, None);
Ok(vec![
Arc::new(sum) as ArrayRef,
Arc::new(is_empty) as ArrayRef,
])
}
fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> DFResult<()> {
assert_eq!(
values.len(),
2,
"Expected two arrays: 'sum' and 'is_empty', but found {}",
values.len()
);
assert!(opt_filter.is_none(), "opt_filter is not supported yet");
self.sum.resize(total_num_groups, 0);
ensure_bit_capacity(&mut self.is_empty, total_num_groups);
ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
let that_sum = &values[0];
let that_sum = that_sum.as_primitive::<Decimal128Type>();
let that_is_empty = &values[1];
let that_is_empty = that_is_empty
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap();
group_indices
.iter()
.enumerate()
.for_each(|(idx, &group_index)| unsafe {
let this_overflow = self.is_overflow(group_index);
let that_is_empty = that_is_empty.value_unchecked(idx);
let that_overflow = !that_is_empty && that_sum.is_null(idx);
let is_overflow = this_overflow || that_overflow;
self.is_not_null.set_bit(group_index, !is_overflow);
self.is_empty.set_bit(
group_index,
self.is_empty.get_bit(group_index) && that_is_empty,
);
if !is_overflow {
self.sum[group_index] += that_sum.value_unchecked(idx);
}
});
Ok(())
}
fn size(&self) -> usize {
self.sum.capacity() * std::mem::size_of::<i128>()
+ self.is_empty.capacity() / 8
+ self.is_not_null.capacity() / 8
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::builder::{Decimal128Builder, StringBuilder};
use arrow::array::RecordBatch;
use arrow::datatypes::*;
use datafusion::common::Result;
use datafusion::datasource::memory::MemorySourceConfig;
use datafusion::datasource::source::DataSourceExec;
use datafusion::execution::TaskContext;
use datafusion::logical_expr::AggregateUDF;
use datafusion::physical_expr::aggregate::AggregateExprBuilder;
use datafusion::physical_expr::expressions::Column;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use datafusion::physical_plan::ExecutionPlan;
use futures::StreamExt;
#[test]
fn invalid_data_type() {
assert!(SumDecimal::try_new(DataType::Int32).is_err());
}
#[tokio::test]
async fn sum_no_overflow() -> Result<()> {
let num_rows = 8192;
let batch = create_record_batch(num_rows);
let mut batches = Vec::new();
for _ in 0..10 {
batches.push(batch.clone());
}
let partitions = &[batches];
let c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
let c1: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c1", 1));
let data_type = DataType::Decimal128(8, 2);
let schema = Arc::clone(&partitions[0][0].schema());
let scan: Arc<dyn ExecutionPlan> = Arc::new(DataSourceExec::new(Arc::new(
MemorySourceConfig::try_new(partitions, Arc::clone(&schema), None).unwrap(),
)));
let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new(
data_type.clone(),
)?));
let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1])
.schema(Arc::clone(&schema))
.alias("sum")
.with_ignore_nulls(false)
.with_distinct(false)
.build()?;
let aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]),
vec![aggr_expr.into()],
vec![None], scan,
Arc::clone(&schema),
)?);
let mut stream = aggregate
.execute(0, Arc::new(TaskContext::default()))
.unwrap();
while let Some(batch) = stream.next().await {
let _batch = batch?;
}
Ok(())
}
fn create_record_batch(num_rows: usize) -> RecordBatch {
let mut decimal_builder = Decimal128Builder::with_capacity(num_rows);
let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
for i in 0..num_rows {
decimal_builder.append_value(i as i128);
string_builder.append_value(format!("this is string #{}", i % 1024));
}
let decimal_array = Arc::new(decimal_builder.finish());
let string_array = Arc::new(string_builder.finish());
let mut fields = vec![];
let mut columns: Vec<ArrayRef> = vec![];
fields.push(Field::new("c0", DataType::Utf8, false));
columns.push(string_array);
fields.push(Field::new("c1", DataType::Decimal128(38, 10), false));
columns.push(decimal_array);
let schema = Schema::new(fields);
RecordBatch::try_new(Arc::new(schema), columns).unwrap()
}
}