datafusion_functions_aggregate/
sum.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators
19
20use ahash::RandomState;
21use datafusion_expr::utils::AggregateOrderSensitivity;
22use std::any::Any;
23use std::collections::HashSet;
24use std::mem::{size_of, size_of_val};
25
26use arrow::array::Array;
27use arrow::array::ArrowNativeTypeOp;
28use arrow::array::{ArrowNumericType, AsArray};
29use arrow::datatypes::ArrowNativeType;
30use arrow::datatypes::ArrowPrimitiveType;
31use arrow::datatypes::{
32    DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type,
33    DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
34};
35use arrow::{array::ArrayRef, datatypes::Field};
36use datafusion_common::{
37    exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
38};
39use datafusion_expr::function::AccumulatorArgs;
40use datafusion_expr::function::StateFieldsArgs;
41use datafusion_expr::utils::format_state_name;
42use datafusion_expr::{
43    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
44    SetMonotonicity, Signature, Volatility,
45};
46use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
47use datafusion_functions_aggregate_common::utils::Hashable;
48use datafusion_macros::user_doc;
49
50make_udaf_expr_and_func!(
51    Sum,
52    sum,
53    expression,
54    "Returns the sum of a group of values.",
55    sum_udaf
56);
57
58/// Sum only supports a subset of numeric types, instead relying on type coercion
59///
60/// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive)
61///
62/// `args` is [AccumulatorArgs]
63/// `helper` is a macro accepting (ArrowPrimitiveType, DataType)
64macro_rules! downcast_sum {
65    ($args:ident, $helper:ident) => {
66        match $args.return_type {
67            DataType::UInt64 => $helper!(UInt64Type, $args.return_type),
68            DataType::Int64 => $helper!(Int64Type, $args.return_type),
69            DataType::Float64 => $helper!(Float64Type, $args.return_type),
70            DataType::Decimal128(_, _) => $helper!(Decimal128Type, $args.return_type),
71            DataType::Decimal256(_, _) => $helper!(Decimal256Type, $args.return_type),
72            _ => {
73                not_impl_err!(
74                    "Sum not supported for {}: {}",
75                    $args.name,
76                    $args.return_type
77                )
78            }
79        }
80    };
81}
82
83#[user_doc(
84    doc_section(label = "General Functions"),
85    description = "Returns the sum of all values in the specified column.",
86    syntax_example = "sum(expression)",
87    sql_example = r#"```sql
88> SELECT sum(column_name) FROM table_name;
89+-----------------------+
90| sum(column_name)       |
91+-----------------------+
92| 12345                 |
93+-----------------------+
94```"#,
95    standard_argument(name = "expression",)
96)]
97#[derive(Debug)]
98pub struct Sum {
99    signature: Signature,
100}
101
102impl Sum {
103    pub fn new() -> Self {
104        Self {
105            signature: Signature::user_defined(Volatility::Immutable),
106        }
107    }
108}
109
110impl Default for Sum {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116impl AggregateUDFImpl for Sum {
117    fn as_any(&self) -> &dyn Any {
118        self
119    }
120
121    fn name(&self) -> &str {
122        "sum"
123    }
124
125    fn signature(&self) -> &Signature {
126        &self.signature
127    }
128
129    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
130        let [args] = take_function_args(self.name(), arg_types)?;
131
132        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
133        // smallint, int, bigint, real, double precision, decimal, or interval.
134
135        fn coerced_type(data_type: &DataType) -> Result<DataType> {
136            match data_type {
137                DataType::Dictionary(_, v) => coerced_type(v),
138                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
139                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
140                DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
141                    Ok(data_type.clone())
142                }
143                dt if dt.is_signed_integer() => Ok(DataType::Int64),
144                dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
145                dt if dt.is_floating() => Ok(DataType::Float64),
146                _ => exec_err!("Sum not supported for {}", data_type),
147            }
148        }
149
150        Ok(vec![coerced_type(args)?])
151    }
152
153    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
154        match &arg_types[0] {
155            DataType::Int64 => Ok(DataType::Int64),
156            DataType::UInt64 => Ok(DataType::UInt64),
157            DataType::Float64 => Ok(DataType::Float64),
158            DataType::Decimal128(precision, scale) => {
159                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
160                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
161                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
162                Ok(DataType::Decimal128(new_precision, *scale))
163            }
164            DataType::Decimal256(precision, scale) => {
165                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
166                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
167                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
168                Ok(DataType::Decimal256(new_precision, *scale))
169            }
170            other => {
171                exec_err!("[return_type] SUM not supported for {}", other)
172            }
173        }
174    }
175
176    fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
177        if args.is_distinct {
178            macro_rules! helper {
179                ($t:ty, $dt:expr) => {
180                    Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?))
181                };
182            }
183            downcast_sum!(args, helper)
184        } else {
185            macro_rules! helper {
186                ($t:ty, $dt:expr) => {
187                    Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
188                };
189            }
190            downcast_sum!(args, helper)
191        }
192    }
193
194    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
195        if args.is_distinct {
196            Ok(vec![Field::new_list(
197                format_state_name(args.name, "sum distinct"),
198                // See COMMENTS.md to understand why nullable is set to true
199                Field::new_list_field(args.return_type.clone(), true),
200                false,
201            )])
202        } else {
203            Ok(vec![Field::new(
204                format_state_name(args.name, "sum"),
205                args.return_type.clone(),
206                true,
207            )])
208        }
209    }
210
211    fn aliases(&self) -> &[String] {
212        &[]
213    }
214
215    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
216        !args.is_distinct
217    }
218
219    fn create_groups_accumulator(
220        &self,
221        args: AccumulatorArgs,
222    ) -> Result<Box<dyn GroupsAccumulator>> {
223        macro_rules! helper {
224            ($t:ty, $dt:expr) => {
225                Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
226                    &$dt,
227                    |x, y| *x = x.add_wrapping(y),
228                )))
229            };
230        }
231        downcast_sum!(args, helper)
232    }
233
234    fn create_sliding_accumulator(
235        &self,
236        args: AccumulatorArgs,
237    ) -> Result<Box<dyn Accumulator>> {
238        macro_rules! helper {
239            ($t:ty, $dt:expr) => {
240                Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
241            };
242        }
243        downcast_sum!(args, helper)
244    }
245
246    fn reverse_expr(&self) -> ReversedUDAF {
247        ReversedUDAF::Identical
248    }
249
250    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
251        AggregateOrderSensitivity::Insensitive
252    }
253
254    fn documentation(&self) -> Option<&Documentation> {
255        self.doc()
256    }
257
258    fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
259        // `SUM` is only monotonically increasing when its input is unsigned.
260        // TODO: Expand these utilizing statistics.
261        match data_type {
262            DataType::UInt8 => SetMonotonicity::Increasing,
263            DataType::UInt16 => SetMonotonicity::Increasing,
264            DataType::UInt32 => SetMonotonicity::Increasing,
265            DataType::UInt64 => SetMonotonicity::Increasing,
266            _ => SetMonotonicity::NotMonotonic,
267        }
268    }
269}
270
271/// This accumulator computes SUM incrementally
272struct SumAccumulator<T: ArrowNumericType> {
273    sum: Option<T::Native>,
274    data_type: DataType,
275}
276
277impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
278    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279        write!(f, "SumAccumulator({})", self.data_type)
280    }
281}
282
283impl<T: ArrowNumericType> SumAccumulator<T> {
284    fn new(data_type: DataType) -> Self {
285        Self {
286            sum: None,
287            data_type,
288        }
289    }
290}
291
292impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
293    fn state(&mut self) -> Result<Vec<ScalarValue>> {
294        Ok(vec![self.evaluate()?])
295    }
296
297    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
298        let values = values[0].as_primitive::<T>();
299        if let Some(x) = arrow::compute::sum(values) {
300            let v = self.sum.get_or_insert(T::Native::usize_as(0));
301            *v = v.add_wrapping(x);
302        }
303        Ok(())
304    }
305
306    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
307        self.update_batch(states)
308    }
309
310    fn evaluate(&mut self) -> Result<ScalarValue> {
311        ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
312    }
313
314    fn size(&self) -> usize {
315        size_of_val(self)
316    }
317}
318
319/// This accumulator incrementally computes sums over a sliding window
320///
321/// This is separate from [`SumAccumulator`] as requires additional state
322struct SlidingSumAccumulator<T: ArrowNumericType> {
323    sum: T::Native,
324    count: u64,
325    data_type: DataType,
326}
327
328impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
329    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330        write!(f, "SlidingSumAccumulator({})", self.data_type)
331    }
332}
333
334impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
335    fn new(data_type: DataType) -> Self {
336        Self {
337            sum: T::Native::usize_as(0),
338            count: 0,
339            data_type,
340        }
341    }
342}
343
344impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
345    fn state(&mut self) -> Result<Vec<ScalarValue>> {
346        Ok(vec![self.evaluate()?, self.count.into()])
347    }
348
349    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
350        let values = values[0].as_primitive::<T>();
351        self.count += (values.len() - values.null_count()) as u64;
352        if let Some(x) = arrow::compute::sum(values) {
353            self.sum = self.sum.add_wrapping(x)
354        }
355        Ok(())
356    }
357
358    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
359        let values = states[0].as_primitive::<T>();
360        if let Some(x) = arrow::compute::sum(values) {
361            self.sum = self.sum.add_wrapping(x)
362        }
363        if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
364            self.count += x;
365        }
366        Ok(())
367    }
368
369    fn evaluate(&mut self) -> Result<ScalarValue> {
370        let v = (self.count != 0).then_some(self.sum);
371        ScalarValue::new_primitive::<T>(v, &self.data_type)
372    }
373
374    fn size(&self) -> usize {
375        size_of_val(self)
376    }
377
378    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
379        let values = values[0].as_primitive::<T>();
380        if let Some(x) = arrow::compute::sum(values) {
381            self.sum = self.sum.sub_wrapping(x)
382        }
383        self.count -= (values.len() - values.null_count()) as u64;
384        Ok(())
385    }
386
387    fn supports_retract_batch(&self) -> bool {
388        true
389    }
390}
391
392struct DistinctSumAccumulator<T: ArrowPrimitiveType> {
393    values: HashSet<Hashable<T::Native>, RandomState>,
394    data_type: DataType,
395}
396
397impl<T: ArrowPrimitiveType> std::fmt::Debug for DistinctSumAccumulator<T> {
398    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
399        write!(f, "DistinctSumAccumulator({})", self.data_type)
400    }
401}
402
403impl<T: ArrowPrimitiveType> DistinctSumAccumulator<T> {
404    pub fn try_new(data_type: &DataType) -> Result<Self> {
405        Ok(Self {
406            values: HashSet::default(),
407            data_type: data_type.clone(),
408        })
409    }
410}
411
412impl<T: ArrowPrimitiveType> Accumulator for DistinctSumAccumulator<T> {
413    fn state(&mut self) -> Result<Vec<ScalarValue>> {
414        // 1. Stores aggregate state in `ScalarValue::List`
415        // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
416        let state_out = {
417            let distinct_values = self
418                .values
419                .iter()
420                .map(|value| {
421                    ScalarValue::new_primitive::<T>(Some(value.0), &self.data_type)
422                })
423                .collect::<Result<Vec<_>>>()?;
424
425            vec![ScalarValue::List(ScalarValue::new_list_nullable(
426                &distinct_values,
427                &self.data_type,
428            ))]
429        };
430        Ok(state_out)
431    }
432
433    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
434        if values.is_empty() {
435            return Ok(());
436        }
437
438        let array = values[0].as_primitive::<T>();
439        match array.nulls().filter(|x| x.null_count() > 0) {
440            Some(n) => {
441                for idx in n.valid_indices() {
442                    self.values.insert(Hashable(array.value(idx)));
443                }
444            }
445            None => array.values().iter().for_each(|x| {
446                self.values.insert(Hashable(*x));
447            }),
448        }
449        Ok(())
450    }
451
452    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
453        for x in states[0].as_list::<i32>().iter().flatten() {
454            self.update_batch(&[x])?
455        }
456        Ok(())
457    }
458
459    fn evaluate(&mut self) -> Result<ScalarValue> {
460        let mut acc = T::Native::usize_as(0);
461        for distinct_value in self.values.iter() {
462            acc = acc.add_wrapping(distinct_value.0)
463        }
464        let v = (!self.values.is_empty()).then_some(acc);
465        ScalarValue::new_primitive::<T>(v, &self.data_type)
466    }
467
468    fn size(&self) -> usize {
469        size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
470    }
471}