datafusion_functions_aggregate/
sum.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators
19
20use ahash::RandomState;
21use datafusion_expr::utils::AggregateOrderSensitivity;
22use std::any::Any;
23use std::collections::HashSet;
24use std::mem::{size_of, size_of_val};
25
26use arrow::array::Array;
27use arrow::array::ArrowNativeTypeOp;
28use arrow::array::{ArrowNumericType, AsArray};
29use arrow::datatypes::ArrowPrimitiveType;
30use arrow::datatypes::{ArrowNativeType, FieldRef};
31use arrow::datatypes::{
32    DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type,
33    DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
34};
35use arrow::{array::ArrayRef, datatypes::Field};
36use datafusion_common::{
37    exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
38};
39use datafusion_expr::function::AccumulatorArgs;
40use datafusion_expr::function::StateFieldsArgs;
41use datafusion_expr::utils::format_state_name;
42use datafusion_expr::{
43    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
44    SetMonotonicity, Signature, Volatility,
45};
46use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
47use datafusion_functions_aggregate_common::utils::Hashable;
48use datafusion_macros::user_doc;
49
50make_udaf_expr_and_func!(
51    Sum,
52    sum,
53    expression,
54    "Returns the sum of a group of values.",
55    sum_udaf
56);
57
58/// Sum only supports a subset of numeric types, instead relying on type coercion
59///
60/// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive)
61///
62/// `args` is [AccumulatorArgs]
63/// `helper` is a macro accepting (ArrowPrimitiveType, DataType)
64macro_rules! downcast_sum {
65    ($args:ident, $helper:ident) => {
66        match $args.return_field.data_type().clone() {
67            DataType::UInt64 => {
68                $helper!(UInt64Type, $args.return_field.data_type().clone())
69            }
70            DataType::Int64 => {
71                $helper!(Int64Type, $args.return_field.data_type().clone())
72            }
73            DataType::Float64 => {
74                $helper!(Float64Type, $args.return_field.data_type().clone())
75            }
76            DataType::Decimal128(_, _) => {
77                $helper!(Decimal128Type, $args.return_field.data_type().clone())
78            }
79            DataType::Decimal256(_, _) => {
80                $helper!(Decimal256Type, $args.return_field.data_type().clone())
81            }
82            _ => {
83                not_impl_err!(
84                    "Sum not supported for {}: {}",
85                    $args.name,
86                    $args.return_field.data_type()
87                )
88            }
89        }
90    };
91}
92
93#[user_doc(
94    doc_section(label = "General Functions"),
95    description = "Returns the sum of all values in the specified column.",
96    syntax_example = "sum(expression)",
97    sql_example = r#"```sql
98> SELECT sum(column_name) FROM table_name;
99+-----------------------+
100| sum(column_name)       |
101+-----------------------+
102| 12345                 |
103+-----------------------+
104```"#,
105    standard_argument(name = "expression",)
106)]
107#[derive(Debug)]
108pub struct Sum {
109    signature: Signature,
110}
111
112impl Sum {
113    pub fn new() -> Self {
114        Self {
115            signature: Signature::user_defined(Volatility::Immutable),
116        }
117    }
118}
119
120impl Default for Sum {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl AggregateUDFImpl for Sum {
127    fn as_any(&self) -> &dyn Any {
128        self
129    }
130
131    fn name(&self) -> &str {
132        "sum"
133    }
134
135    fn signature(&self) -> &Signature {
136        &self.signature
137    }
138
139    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
140        let [args] = take_function_args(self.name(), arg_types)?;
141
142        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
143        // smallint, int, bigint, real, double precision, decimal, or interval.
144
145        fn coerced_type(data_type: &DataType) -> Result<DataType> {
146            match data_type {
147                DataType::Dictionary(_, v) => coerced_type(v),
148                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
149                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
150                DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
151                    Ok(data_type.clone())
152                }
153                dt if dt.is_signed_integer() => Ok(DataType::Int64),
154                dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
155                dt if dt.is_floating() => Ok(DataType::Float64),
156                _ => exec_err!("Sum not supported for {}", data_type),
157            }
158        }
159
160        Ok(vec![coerced_type(args)?])
161    }
162
163    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
164        match &arg_types[0] {
165            DataType::Int64 => Ok(DataType::Int64),
166            DataType::UInt64 => Ok(DataType::UInt64),
167            DataType::Float64 => Ok(DataType::Float64),
168            DataType::Decimal128(precision, scale) => {
169                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
170                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
171                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
172                Ok(DataType::Decimal128(new_precision, *scale))
173            }
174            DataType::Decimal256(precision, scale) => {
175                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
176                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
177                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
178                Ok(DataType::Decimal256(new_precision, *scale))
179            }
180            other => {
181                exec_err!("[return_type] SUM not supported for {}", other)
182            }
183        }
184    }
185
186    fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
187        if args.is_distinct {
188            macro_rules! helper {
189                ($t:ty, $dt:expr) => {
190                    Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?))
191                };
192            }
193            downcast_sum!(args, helper)
194        } else {
195            macro_rules! helper {
196                ($t:ty, $dt:expr) => {
197                    Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
198                };
199            }
200            downcast_sum!(args, helper)
201        }
202    }
203
204    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
205        if args.is_distinct {
206            Ok(vec![Field::new_list(
207                format_state_name(args.name, "sum distinct"),
208                // See COMMENTS.md to understand why nullable is set to true
209                Field::new_list_field(args.return_type().clone(), true),
210                false,
211            )
212            .into()])
213        } else {
214            Ok(vec![Field::new(
215                format_state_name(args.name, "sum"),
216                args.return_type().clone(),
217                true,
218            )
219            .into()])
220        }
221    }
222
223    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
224        !args.is_distinct
225    }
226
227    fn create_groups_accumulator(
228        &self,
229        args: AccumulatorArgs,
230    ) -> Result<Box<dyn GroupsAccumulator>> {
231        macro_rules! helper {
232            ($t:ty, $dt:expr) => {
233                Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
234                    &$dt,
235                    |x, y| *x = x.add_wrapping(y),
236                )))
237            };
238        }
239        downcast_sum!(args, helper)
240    }
241
242    fn create_sliding_accumulator(
243        &self,
244        args: AccumulatorArgs,
245    ) -> Result<Box<dyn Accumulator>> {
246        macro_rules! helper {
247            ($t:ty, $dt:expr) => {
248                Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
249            };
250        }
251        downcast_sum!(args, helper)
252    }
253
254    fn reverse_expr(&self) -> ReversedUDAF {
255        ReversedUDAF::Identical
256    }
257
258    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
259        AggregateOrderSensitivity::Insensitive
260    }
261
262    fn documentation(&self) -> Option<&Documentation> {
263        self.doc()
264    }
265
266    fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
267        // `SUM` is only monotonically increasing when its input is unsigned.
268        // TODO: Expand these utilizing statistics.
269        match data_type {
270            DataType::UInt8 => SetMonotonicity::Increasing,
271            DataType::UInt16 => SetMonotonicity::Increasing,
272            DataType::UInt32 => SetMonotonicity::Increasing,
273            DataType::UInt64 => SetMonotonicity::Increasing,
274            _ => SetMonotonicity::NotMonotonic,
275        }
276    }
277}
278
279/// This accumulator computes SUM incrementally
280struct SumAccumulator<T: ArrowNumericType> {
281    sum: Option<T::Native>,
282    data_type: DataType,
283}
284
285impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
286    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287        write!(f, "SumAccumulator({})", self.data_type)
288    }
289}
290
291impl<T: ArrowNumericType> SumAccumulator<T> {
292    fn new(data_type: DataType) -> Self {
293        Self {
294            sum: None,
295            data_type,
296        }
297    }
298}
299
300impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
301    fn state(&mut self) -> Result<Vec<ScalarValue>> {
302        Ok(vec![self.evaluate()?])
303    }
304
305    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
306        let values = values[0].as_primitive::<T>();
307        if let Some(x) = arrow::compute::sum(values) {
308            let v = self.sum.get_or_insert(T::Native::usize_as(0));
309            *v = v.add_wrapping(x);
310        }
311        Ok(())
312    }
313
314    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
315        self.update_batch(states)
316    }
317
318    fn evaluate(&mut self) -> Result<ScalarValue> {
319        ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
320    }
321
322    fn size(&self) -> usize {
323        size_of_val(self)
324    }
325}
326
327/// This accumulator incrementally computes sums over a sliding window
328///
329/// This is separate from [`SumAccumulator`] as requires additional state
330struct SlidingSumAccumulator<T: ArrowNumericType> {
331    sum: T::Native,
332    count: u64,
333    data_type: DataType,
334}
335
336impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
337    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338        write!(f, "SlidingSumAccumulator({})", self.data_type)
339    }
340}
341
342impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
343    fn new(data_type: DataType) -> Self {
344        Self {
345            sum: T::Native::usize_as(0),
346            count: 0,
347            data_type,
348        }
349    }
350}
351
352impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
353    fn state(&mut self) -> Result<Vec<ScalarValue>> {
354        Ok(vec![self.evaluate()?, self.count.into()])
355    }
356
357    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
358        let values = values[0].as_primitive::<T>();
359        self.count += (values.len() - values.null_count()) as u64;
360        if let Some(x) = arrow::compute::sum(values) {
361            self.sum = self.sum.add_wrapping(x)
362        }
363        Ok(())
364    }
365
366    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
367        let values = states[0].as_primitive::<T>();
368        if let Some(x) = arrow::compute::sum(values) {
369            self.sum = self.sum.add_wrapping(x)
370        }
371        if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
372            self.count += x;
373        }
374        Ok(())
375    }
376
377    fn evaluate(&mut self) -> Result<ScalarValue> {
378        let v = (self.count != 0).then_some(self.sum);
379        ScalarValue::new_primitive::<T>(v, &self.data_type)
380    }
381
382    fn size(&self) -> usize {
383        size_of_val(self)
384    }
385
386    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
387        let values = values[0].as_primitive::<T>();
388        if let Some(x) = arrow::compute::sum(values) {
389            self.sum = self.sum.sub_wrapping(x)
390        }
391        self.count -= (values.len() - values.null_count()) as u64;
392        Ok(())
393    }
394
395    fn supports_retract_batch(&self) -> bool {
396        true
397    }
398}
399
400struct DistinctSumAccumulator<T: ArrowPrimitiveType> {
401    values: HashSet<Hashable<T::Native>, RandomState>,
402    data_type: DataType,
403}
404
405impl<T: ArrowPrimitiveType> std::fmt::Debug for DistinctSumAccumulator<T> {
406    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407        write!(f, "DistinctSumAccumulator({})", self.data_type)
408    }
409}
410
411impl<T: ArrowPrimitiveType> DistinctSumAccumulator<T> {
412    pub fn try_new(data_type: &DataType) -> Result<Self> {
413        Ok(Self {
414            values: HashSet::default(),
415            data_type: data_type.clone(),
416        })
417    }
418}
419
420impl<T: ArrowPrimitiveType> Accumulator for DistinctSumAccumulator<T> {
421    fn state(&mut self) -> Result<Vec<ScalarValue>> {
422        // 1. Stores aggregate state in `ScalarValue::List`
423        // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
424        let state_out = {
425            let distinct_values = self
426                .values
427                .iter()
428                .map(|value| {
429                    ScalarValue::new_primitive::<T>(Some(value.0), &self.data_type)
430                })
431                .collect::<Result<Vec<_>>>()?;
432
433            vec![ScalarValue::List(ScalarValue::new_list_nullable(
434                &distinct_values,
435                &self.data_type,
436            ))]
437        };
438        Ok(state_out)
439    }
440
441    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
442        if values.is_empty() {
443            return Ok(());
444        }
445
446        let array = values[0].as_primitive::<T>();
447        match array.nulls().filter(|x| x.null_count() > 0) {
448            Some(n) => {
449                for idx in n.valid_indices() {
450                    self.values.insert(Hashable(array.value(idx)));
451                }
452            }
453            None => array.values().iter().for_each(|x| {
454                self.values.insert(Hashable(*x));
455            }),
456        }
457        Ok(())
458    }
459
460    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
461        for x in states[0].as_list::<i32>().iter().flatten() {
462            self.update_batch(&[x])?
463        }
464        Ok(())
465    }
466
467    fn evaluate(&mut self) -> Result<ScalarValue> {
468        let mut acc = T::Native::usize_as(0);
469        for distinct_value in self.values.iter() {
470            acc = acc.add_wrapping(distinct_value.0)
471        }
472        let v = (!self.values.is_empty()).then_some(acc);
473        ScalarValue::new_primitive::<T>(v, &self.data_type)
474    }
475
476    fn size(&self) -> usize {
477        size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
478    }
479}