use ahash::RandomState;
use arrow::array::{Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, AsArray};
use arrow::datatypes::Field;
use arrow::datatypes::{
ArrowNativeType, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, Decimal32Type,
Decimal64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType,
DurationMillisecondType, DurationNanosecondType, DurationSecondType, FieldRef,
Float64Type, Int64Type, TimeUnit, UInt64Type,
};
use datafusion_common::types::{
NativeType, logical_float64, logical_int8, logical_int16, logical_int32,
logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64,
};
use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, GroupsAccumulator,
ReversedUDAF, SetMonotonicity, Signature, TypeSignature, TypeSignatureClass,
Volatility,
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
use datafusion_macros::user_doc;
use std::any::Any;
use std::mem::size_of_val;
make_udaf_expr_and_func!(
Sum,
sum,
expression,
"Returns the sum of a group of values.",
sum_udaf
);
pub fn sum_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
sum_udaf(),
vec![expr],
true,
None,
vec![],
None,
))
}
macro_rules! downcast_sum {
($args:ident, $helper:ident) => {
match $args.return_field.data_type().clone() {
DataType::UInt64 => {
$helper!(UInt64Type, $args.return_field.data_type().clone())
}
DataType::Int64 => {
$helper!(Int64Type, $args.return_field.data_type().clone())
}
DataType::Float64 => {
$helper!(Float64Type, $args.return_field.data_type().clone())
}
DataType::Decimal32(_, _) => {
$helper!(Decimal32Type, $args.return_field.data_type().clone())
}
DataType::Decimal64(_, _) => {
$helper!(Decimal64Type, $args.return_field.data_type().clone())
}
DataType::Decimal128(_, _) => {
$helper!(Decimal128Type, $args.return_field.data_type().clone())
}
DataType::Decimal256(_, _) => {
$helper!(Decimal256Type, $args.return_field.data_type().clone())
}
DataType::Duration(TimeUnit::Second) => {
$helper!(DurationSecondType, $args.return_field.data_type().clone())
}
DataType::Duration(TimeUnit::Millisecond) => {
$helper!(
DurationMillisecondType,
$args.return_field.data_type().clone()
)
}
DataType::Duration(TimeUnit::Microsecond) => {
$helper!(
DurationMicrosecondType,
$args.return_field.data_type().clone()
)
}
DataType::Duration(TimeUnit::Nanosecond) => {
$helper!(
DurationNanosecondType,
$args.return_field.data_type().clone()
)
}
_ => {
not_impl_err!(
"Sum not supported for {}: {}",
$args.name,
$args.return_field.data_type()
)
}
}
};
}
#[user_doc(
doc_section(label = "General Functions"),
description = "Returns the sum of all values in the specified column.",
syntax_example = "sum(expression)",
sql_example = r#"```sql
> SELECT sum(column_name) FROM table_name;
+-----------------------+
| sum(column_name) |
+-----------------------+
| 12345 |
+-----------------------+
```"#,
standard_argument(name = "expression",)
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct Sum {
signature: Signature,
}
impl Sum {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Coercible(vec![Coercion::new_exact(
TypeSignatureClass::Decimal,
)]),
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_uint64()),
vec![
TypeSignatureClass::Native(logical_uint8()),
TypeSignatureClass::Native(logical_uint16()),
TypeSignatureClass::Native(logical_uint32()),
],
NativeType::UInt64,
)]),
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_int64()),
vec![
TypeSignatureClass::Native(logical_int8()),
TypeSignatureClass::Native(logical_int16()),
TypeSignatureClass::Native(logical_int32()),
],
NativeType::Int64,
)]),
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_float64()),
vec![TypeSignatureClass::Float],
NativeType::Float64,
)]),
TypeSignature::Coercible(vec![Coercion::new_exact(
TypeSignatureClass::Duration,
)]),
],
Volatility::Immutable,
),
}
}
}
impl Default for Sum {
fn default() -> Self {
Self::new()
}
}
impl AggregateUDFImpl for Sum {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"sum"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match &arg_types[0] {
DataType::Int64 => Ok(DataType::Int64),
DataType::UInt64 => Ok(DataType::UInt64),
DataType::Float64 => Ok(DataType::Float64),
DataType::Decimal32(precision, scale) => {
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal32(new_precision, *scale))
}
DataType::Decimal64(precision, scale) => {
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal64(new_precision, *scale))
}
DataType::Decimal128(precision, scale) => {
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal128(new_precision, *scale))
}
DataType::Decimal256(precision, scale) => {
let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal256(new_precision, *scale))
}
DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
other => {
exec_err!("[return_type] SUM not supported for {}", other)
}
}
}
fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if args.is_distinct {
macro_rules! helper {
($t:ty, $dt:expr) => {
Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt)))
};
}
downcast_sum!(args, helper)
} else {
macro_rules! helper {
($t:ty, $dt:expr) => {
Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
};
}
downcast_sum!(args, helper)
}
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
if args.is_distinct {
Ok(vec![
Field::new_list(
format_state_name(args.name, "sum distinct"),
Field::new_list_field(args.return_type().clone(), true),
false,
)
.into(),
])
} else {
Ok(vec![
Field::new(
format_state_name(args.name, "sum"),
args.return_type().clone(),
true,
)
.into(),
])
}
}
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
!args.is_distinct
}
fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
macro_rules! helper {
($t:ty, $dt:expr) => {
Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
&$dt,
|x, y| *x = x.add_wrapping(y),
)))
};
}
downcast_sum!(args, helper)
}
fn create_sliding_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
if args.is_distinct {
macro_rules! helper_distinct {
($t:ty, $dt:expr) => {
Ok(Box::new(SlidingDistinctSumAccumulator::try_new(&$dt)?))
};
}
downcast_sum!(args, helper_distinct)
} else {
macro_rules! helper {
($t:ty, $dt:expr) => {
Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
};
}
downcast_sum!(args, helper)
}
}
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::Identical
}
fn order_sensitivity(&self) -> AggregateOrderSensitivity {
AggregateOrderSensitivity::Insensitive
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
match data_type {
DataType::UInt8 => SetMonotonicity::Increasing,
DataType::UInt16 => SetMonotonicity::Increasing,
DataType::UInt32 => SetMonotonicity::Increasing,
DataType::UInt64 => SetMonotonicity::Increasing,
_ => SetMonotonicity::NotMonotonic,
}
}
}
struct SumAccumulator<T: ArrowNumericType> {
sum: Option<T::Native>,
data_type: DataType,
}
impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SumAccumulator({})", self.data_type)
}
}
impl<T: ArrowNumericType> SumAccumulator<T> {
fn new(data_type: DataType) -> Self {
Self {
sum: None,
data_type,
}
}
}
impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = values[0].as_primitive::<T>();
if let Some(x) = arrow::compute::sum(values) {
let v = self.sum.get_or_insert_with(|| T::Native::usize_as(0));
*v = v.add_wrapping(x);
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.update_batch(states)
}
fn evaluate(&mut self) -> Result<ScalarValue> {
ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
}
fn size(&self) -> usize {
size_of_val(self)
}
}
struct SlidingSumAccumulator<T: ArrowNumericType> {
sum: T::Native,
count: u64,
data_type: DataType,
}
impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SlidingSumAccumulator({})", self.data_type)
}
}
impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
fn new(data_type: DataType) -> Self {
Self {
sum: T::Native::usize_as(0),
count: 0,
data_type,
}
}
}
impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?, self.count.into()])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = values[0].as_primitive::<T>();
self.count += (values.len() - values.null_count()) as u64;
if let Some(x) = arrow::compute::sum(values) {
self.sum = self.sum.add_wrapping(x)
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let values = states[0].as_primitive::<T>();
if let Some(x) = arrow::compute::sum(values) {
self.sum = self.sum.add_wrapping(x)
}
if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
self.count += x;
}
Ok(())
}
fn evaluate(&mut self) -> Result<ScalarValue> {
let v = (self.count != 0).then_some(self.sum);
ScalarValue::new_primitive::<T>(v, &self.data_type)
}
fn size(&self) -> usize {
size_of_val(self)
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = values[0].as_primitive::<T>();
if let Some(x) = arrow::compute::sum(values) {
self.sum = self.sum.sub_wrapping(x)
}
self.count -= (values.len() - values.null_count()) as u64;
Ok(())
}
fn supports_retract_batch(&self) -> bool {
true
}
}
#[derive(Debug)]
pub struct SlidingDistinctSumAccumulator {
counts: HashMap<i64, usize, RandomState>,
sum: i64,
data_type: DataType,
}
impl SlidingDistinctSumAccumulator {
pub fn try_new(data_type: &DataType) -> Result<Self> {
if *data_type != DataType::Int64 {
return exec_err!("SlidingDistinctSumAccumulator only supports Int64");
}
Ok(Self {
counts: HashMap::default(),
sum: 0,
data_type: data_type.clone(),
})
}
}
impl Accumulator for SlidingDistinctSumAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let arr = values[0].as_primitive::<Int64Type>();
for &v in arr.values() {
let cnt = self.counts.entry(v).or_insert(0);
if *cnt == 0 {
self.sum = self.sum.wrapping_add(v);
}
*cnt += 1;
}
Ok(())
}
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.sum)))
}
fn size(&self) -> usize {
size_of_val(self)
}
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let keys = self
.counts
.keys()
.cloned()
.map(Some)
.map(ScalarValue::Int64)
.collect::<Vec<_>>();
Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
&keys,
&self.data_type,
))])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let list_arr = states[0].as_list::<i32>();
for maybe_inner in list_arr.iter().flatten() {
for idx in 0..maybe_inner.len() {
if let ScalarValue::Int64(Some(v)) =
ScalarValue::try_from_array(&*maybe_inner, idx)?
{
let cnt = self.counts.entry(v).or_insert(0);
if *cnt == 0 {
self.sum = self.sum.wrapping_add(v);
}
*cnt += 1;
}
}
}
Ok(())
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let arr = values[0].as_primitive::<Int64Type>();
for &v in arr.values() {
if let Some(cnt) = self.counts.get_mut(&v) {
*cnt -= 1;
if *cnt == 0 {
self.sum = self.sum.wrapping_sub(v);
self.counts.remove(&v);
}
}
}
Ok(())
}
fn supports_retract_batch(&self) -> bool {
true
}
}