datafusion_functions_aggregate/
sum.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators
19
20use ahash::RandomState;
21use datafusion_expr::utils::AggregateOrderSensitivity;
22use std::any::Any;
23use std::collections::HashSet;
24use std::mem::{size_of, size_of_val};
25
26use arrow::array::Array;
27use arrow::array::ArrowNativeTypeOp;
28use arrow::array::{ArrowNumericType, AsArray};
29use arrow::datatypes::ArrowPrimitiveType;
30use arrow::datatypes::{ArrowNativeType, FieldRef};
31use arrow::datatypes::{
32    DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type,
33    DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
34};
35use arrow::{array::ArrayRef, datatypes::Field};
36use datafusion_common::{
37    exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
38};
39use datafusion_expr::function::AccumulatorArgs;
40use datafusion_expr::function::StateFieldsArgs;
41use datafusion_expr::utils::format_state_name;
42use datafusion_expr::{
43    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
44    SetMonotonicity, Signature, Volatility,
45};
46use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
47use datafusion_functions_aggregate_common::utils::Hashable;
48use datafusion_macros::user_doc;
49
50make_udaf_expr_and_func!(
51    Sum,
52    sum,
53    expression,
54    "Returns the sum of a group of values.",
55    sum_udaf
56);
57
58/// Sum only supports a subset of numeric types, instead relying on type coercion
59///
60/// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive)
61///
62/// `args` is [AccumulatorArgs]
63/// `helper` is a macro accepting (ArrowPrimitiveType, DataType)
64macro_rules! downcast_sum {
65    ($args:ident, $helper:ident) => {
66        match $args.return_field.data_type().clone() {
67            DataType::UInt64 => {
68                $helper!(UInt64Type, $args.return_field.data_type().clone())
69            }
70            DataType::Int64 => {
71                $helper!(Int64Type, $args.return_field.data_type().clone())
72            }
73            DataType::Float64 => {
74                $helper!(Float64Type, $args.return_field.data_type().clone())
75            }
76            DataType::Decimal128(_, _) => {
77                $helper!(Decimal128Type, $args.return_field.data_type().clone())
78            }
79            DataType::Decimal256(_, _) => {
80                $helper!(Decimal256Type, $args.return_field.data_type().clone())
81            }
82            _ => {
83                not_impl_err!(
84                    "Sum not supported for {}: {}",
85                    $args.name,
86                    $args.return_field.data_type()
87                )
88            }
89        }
90    };
91}
92
93#[user_doc(
94    doc_section(label = "General Functions"),
95    description = "Returns the sum of all values in the specified column.",
96    syntax_example = "sum(expression)",
97    sql_example = r#"```sql
98> SELECT sum(column_name) FROM table_name;
99+-----------------------+
100| sum(column_name)       |
101+-----------------------+
102| 12345                 |
103+-----------------------+
104```"#,
105    standard_argument(name = "expression",)
106)]
107#[derive(Debug)]
108pub struct Sum {
109    signature: Signature,
110}
111
112impl Sum {
113    pub fn new() -> Self {
114        Self {
115            signature: Signature::user_defined(Volatility::Immutable),
116        }
117    }
118}
119
120impl Default for Sum {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl AggregateUDFImpl for Sum {
127    fn as_any(&self) -> &dyn Any {
128        self
129    }
130
131    fn name(&self) -> &str {
132        "sum"
133    }
134
135    fn signature(&self) -> &Signature {
136        &self.signature
137    }
138
139    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
140        let [args] = take_function_args(self.name(), arg_types)?;
141
142        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
143        // smallint, int, bigint, real, double precision, decimal, or interval.
144
145        fn coerced_type(data_type: &DataType) -> Result<DataType> {
146            match data_type {
147                DataType::Dictionary(_, v) => coerced_type(v),
148                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
149                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
150                DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
151                    Ok(data_type.clone())
152                }
153                dt if dt.is_signed_integer() => Ok(DataType::Int64),
154                dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
155                dt if dt.is_floating() => Ok(DataType::Float64),
156                _ => exec_err!("Sum not supported for {}", data_type),
157            }
158        }
159
160        Ok(vec![coerced_type(args)?])
161    }
162
163    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
164        match &arg_types[0] {
165            DataType::Int64 => Ok(DataType::Int64),
166            DataType::UInt64 => Ok(DataType::UInt64),
167            DataType::Float64 => Ok(DataType::Float64),
168            DataType::Decimal128(precision, scale) => {
169                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
170                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
171                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
172                Ok(DataType::Decimal128(new_precision, *scale))
173            }
174            DataType::Decimal256(precision, scale) => {
175                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
176                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
177                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
178                Ok(DataType::Decimal256(new_precision, *scale))
179            }
180            other => {
181                exec_err!("[return_type] SUM not supported for {}", other)
182            }
183        }
184    }
185
186    fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
187        if args.is_distinct {
188            macro_rules! helper {
189                ($t:ty, $dt:expr) => {
190                    Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?))
191                };
192            }
193            downcast_sum!(args, helper)
194        } else {
195            macro_rules! helper {
196                ($t:ty, $dt:expr) => {
197                    Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
198                };
199            }
200            downcast_sum!(args, helper)
201        }
202    }
203
204    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
205        if args.is_distinct {
206            Ok(vec![Field::new_list(
207                format_state_name(args.name, "sum distinct"),
208                // See COMMENTS.md to understand why nullable is set to true
209                Field::new_list_field(args.return_type().clone(), true),
210                false,
211            )
212            .into()])
213        } else {
214            Ok(vec![Field::new(
215                format_state_name(args.name, "sum"),
216                args.return_type().clone(),
217                true,
218            )
219            .into()])
220        }
221    }
222
223    fn aliases(&self) -> &[String] {
224        &[]
225    }
226
227    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
228        !args.is_distinct
229    }
230
231    fn create_groups_accumulator(
232        &self,
233        args: AccumulatorArgs,
234    ) -> Result<Box<dyn GroupsAccumulator>> {
235        macro_rules! helper {
236            ($t:ty, $dt:expr) => {
237                Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
238                    &$dt,
239                    |x, y| *x = x.add_wrapping(y),
240                )))
241            };
242        }
243        downcast_sum!(args, helper)
244    }
245
246    fn create_sliding_accumulator(
247        &self,
248        args: AccumulatorArgs,
249    ) -> Result<Box<dyn Accumulator>> {
250        macro_rules! helper {
251            ($t:ty, $dt:expr) => {
252                Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
253            };
254        }
255        downcast_sum!(args, helper)
256    }
257
258    fn reverse_expr(&self) -> ReversedUDAF {
259        ReversedUDAF::Identical
260    }
261
262    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
263        AggregateOrderSensitivity::Insensitive
264    }
265
266    fn documentation(&self) -> Option<&Documentation> {
267        self.doc()
268    }
269
270    fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
271        // `SUM` is only monotonically increasing when its input is unsigned.
272        // TODO: Expand these utilizing statistics.
273        match data_type {
274            DataType::UInt8 => SetMonotonicity::Increasing,
275            DataType::UInt16 => SetMonotonicity::Increasing,
276            DataType::UInt32 => SetMonotonicity::Increasing,
277            DataType::UInt64 => SetMonotonicity::Increasing,
278            _ => SetMonotonicity::NotMonotonic,
279        }
280    }
281}
282
283/// This accumulator computes SUM incrementally
284struct SumAccumulator<T: ArrowNumericType> {
285    sum: Option<T::Native>,
286    data_type: DataType,
287}
288
289impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
290    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291        write!(f, "SumAccumulator({})", self.data_type)
292    }
293}
294
295impl<T: ArrowNumericType> SumAccumulator<T> {
296    fn new(data_type: DataType) -> Self {
297        Self {
298            sum: None,
299            data_type,
300        }
301    }
302}
303
304impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
305    fn state(&mut self) -> Result<Vec<ScalarValue>> {
306        Ok(vec![self.evaluate()?])
307    }
308
309    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
310        let values = values[0].as_primitive::<T>();
311        if let Some(x) = arrow::compute::sum(values) {
312            let v = self.sum.get_or_insert(T::Native::usize_as(0));
313            *v = v.add_wrapping(x);
314        }
315        Ok(())
316    }
317
318    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
319        self.update_batch(states)
320    }
321
322    fn evaluate(&mut self) -> Result<ScalarValue> {
323        ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
324    }
325
326    fn size(&self) -> usize {
327        size_of_val(self)
328    }
329}
330
331/// This accumulator incrementally computes sums over a sliding window
332///
333/// This is separate from [`SumAccumulator`] as requires additional state
334struct SlidingSumAccumulator<T: ArrowNumericType> {
335    sum: T::Native,
336    count: u64,
337    data_type: DataType,
338}
339
340impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
341    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342        write!(f, "SlidingSumAccumulator({})", self.data_type)
343    }
344}
345
346impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
347    fn new(data_type: DataType) -> Self {
348        Self {
349            sum: T::Native::usize_as(0),
350            count: 0,
351            data_type,
352        }
353    }
354}
355
356impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
357    fn state(&mut self) -> Result<Vec<ScalarValue>> {
358        Ok(vec![self.evaluate()?, self.count.into()])
359    }
360
361    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
362        let values = values[0].as_primitive::<T>();
363        self.count += (values.len() - values.null_count()) as u64;
364        if let Some(x) = arrow::compute::sum(values) {
365            self.sum = self.sum.add_wrapping(x)
366        }
367        Ok(())
368    }
369
370    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
371        let values = states[0].as_primitive::<T>();
372        if let Some(x) = arrow::compute::sum(values) {
373            self.sum = self.sum.add_wrapping(x)
374        }
375        if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
376            self.count += x;
377        }
378        Ok(())
379    }
380
381    fn evaluate(&mut self) -> Result<ScalarValue> {
382        let v = (self.count != 0).then_some(self.sum);
383        ScalarValue::new_primitive::<T>(v, &self.data_type)
384    }
385
386    fn size(&self) -> usize {
387        size_of_val(self)
388    }
389
390    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
391        let values = values[0].as_primitive::<T>();
392        if let Some(x) = arrow::compute::sum(values) {
393            self.sum = self.sum.sub_wrapping(x)
394        }
395        self.count -= (values.len() - values.null_count()) as u64;
396        Ok(())
397    }
398
399    fn supports_retract_batch(&self) -> bool {
400        true
401    }
402}
403
404struct DistinctSumAccumulator<T: ArrowPrimitiveType> {
405    values: HashSet<Hashable<T::Native>, RandomState>,
406    data_type: DataType,
407}
408
409impl<T: ArrowPrimitiveType> std::fmt::Debug for DistinctSumAccumulator<T> {
410    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411        write!(f, "DistinctSumAccumulator({})", self.data_type)
412    }
413}
414
415impl<T: ArrowPrimitiveType> DistinctSumAccumulator<T> {
416    pub fn try_new(data_type: &DataType) -> Result<Self> {
417        Ok(Self {
418            values: HashSet::default(),
419            data_type: data_type.clone(),
420        })
421    }
422}
423
424impl<T: ArrowPrimitiveType> Accumulator for DistinctSumAccumulator<T> {
425    fn state(&mut self) -> Result<Vec<ScalarValue>> {
426        // 1. Stores aggregate state in `ScalarValue::List`
427        // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
428        let state_out = {
429            let distinct_values = self
430                .values
431                .iter()
432                .map(|value| {
433                    ScalarValue::new_primitive::<T>(Some(value.0), &self.data_type)
434                })
435                .collect::<Result<Vec<_>>>()?;
436
437            vec![ScalarValue::List(ScalarValue::new_list_nullable(
438                &distinct_values,
439                &self.data_type,
440            ))]
441        };
442        Ok(state_out)
443    }
444
445    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
446        if values.is_empty() {
447            return Ok(());
448        }
449
450        let array = values[0].as_primitive::<T>();
451        match array.nulls().filter(|x| x.null_count() > 0) {
452            Some(n) => {
453                for idx in n.valid_indices() {
454                    self.values.insert(Hashable(array.value(idx)));
455                }
456            }
457            None => array.values().iter().for_each(|x| {
458                self.values.insert(Hashable(*x));
459            }),
460        }
461        Ok(())
462    }
463
464    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
465        for x in states[0].as_list::<i32>().iter().flatten() {
466            self.update_batch(&[x])?
467        }
468        Ok(())
469    }
470
471    fn evaluate(&mut self) -> Result<ScalarValue> {
472        let mut acc = T::Native::usize_as(0);
473        for distinct_value in self.values.iter() {
474            acc = acc.add_wrapping(distinct_value.0)
475        }
476        let v = (!self.values.is_empty()).then_some(acc);
477        ScalarValue::new_primitive::<T>(v, &self.data_type)
478    }
479
480    fn size(&self) -> usize {
481        size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
482    }
483}