datafusion_functions_aggregate/
sum.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators
19
20use ahash::RandomState;
21use arrow::array::{Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, AsArray};
22use arrow::datatypes::Field;
23use arrow::datatypes::{
24    ArrowNativeType, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
25    DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, Decimal32Type,
26    Decimal64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType,
27    DurationMillisecondType, DurationNanosecondType, DurationSecondType, FieldRef,
28    Float64Type, Int64Type, TimeUnit, UInt64Type,
29};
30use datafusion_common::types::{
31    NativeType, logical_float64, logical_int8, logical_int16, logical_int32,
32    logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64,
33};
34use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err};
35use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
36use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
37use datafusion_expr::{
38    Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, GroupsAccumulator,
39    ReversedUDAF, SetMonotonicity, Signature, TypeSignature, TypeSignatureClass,
40    Volatility,
41};
42use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
43use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
44use datafusion_macros::user_doc;
45use std::any::Any;
46use std::mem::size_of_val;
47
48make_udaf_expr_and_func!(
49    Sum,
50    sum,
51    expression,
52    "Returns the sum of a group of values.",
53    sum_udaf
54);
55
56pub fn sum_distinct(expr: Expr) -> Expr {
57    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
58        sum_udaf(),
59        vec![expr],
60        true,
61        None,
62        vec![],
63        None,
64    ))
65}
66
67/// Sum only supports a subset of numeric types, instead relying on type coercion
68///
69/// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive)
70///
71/// `args` is [AccumulatorArgs]
72/// `helper` is a macro accepting (ArrowPrimitiveType, DataType)
73macro_rules! downcast_sum {
74    ($args:ident, $helper:ident) => {
75        match $args.return_field.data_type().clone() {
76            DataType::UInt64 => {
77                $helper!(UInt64Type, $args.return_field.data_type().clone())
78            }
79            DataType::Int64 => {
80                $helper!(Int64Type, $args.return_field.data_type().clone())
81            }
82            DataType::Float64 => {
83                $helper!(Float64Type, $args.return_field.data_type().clone())
84            }
85            DataType::Decimal32(_, _) => {
86                $helper!(Decimal32Type, $args.return_field.data_type().clone())
87            }
88            DataType::Decimal64(_, _) => {
89                $helper!(Decimal64Type, $args.return_field.data_type().clone())
90            }
91            DataType::Decimal128(_, _) => {
92                $helper!(Decimal128Type, $args.return_field.data_type().clone())
93            }
94            DataType::Decimal256(_, _) => {
95                $helper!(Decimal256Type, $args.return_field.data_type().clone())
96            }
97            DataType::Duration(TimeUnit::Second) => {
98                $helper!(DurationSecondType, $args.return_field.data_type().clone())
99            }
100            DataType::Duration(TimeUnit::Millisecond) => {
101                $helper!(
102                    DurationMillisecondType,
103                    $args.return_field.data_type().clone()
104                )
105            }
106            DataType::Duration(TimeUnit::Microsecond) => {
107                $helper!(
108                    DurationMicrosecondType,
109                    $args.return_field.data_type().clone()
110                )
111            }
112            DataType::Duration(TimeUnit::Nanosecond) => {
113                $helper!(
114                    DurationNanosecondType,
115                    $args.return_field.data_type().clone()
116                )
117            }
118            _ => {
119                not_impl_err!(
120                    "Sum not supported for {}: {}",
121                    $args.name,
122                    $args.return_field.data_type()
123                )
124            }
125        }
126    };
127}
128
129#[user_doc(
130    doc_section(label = "General Functions"),
131    description = "Returns the sum of all values in the specified column.",
132    syntax_example = "sum(expression)",
133    sql_example = r#"```sql
134> SELECT sum(column_name) FROM table_name;
135+-----------------------+
136| sum(column_name)       |
137+-----------------------+
138| 12345                 |
139+-----------------------+
140```"#,
141    standard_argument(name = "expression",)
142)]
143#[derive(Debug, PartialEq, Eq, Hash)]
144pub struct Sum {
145    signature: Signature,
146}
147
148impl Sum {
149    pub fn new() -> Self {
150        Self {
151            // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
152            // smallint, int, bigint, real, double precision, decimal, or interval.
153            signature: Signature::one_of(
154                vec![
155                    TypeSignature::Coercible(vec![Coercion::new_exact(
156                        TypeSignatureClass::Decimal,
157                    )]),
158                    // Unsigned to u64
159                    TypeSignature::Coercible(vec![Coercion::new_implicit(
160                        TypeSignatureClass::Native(logical_uint64()),
161                        vec![
162                            TypeSignatureClass::Native(logical_uint8()),
163                            TypeSignatureClass::Native(logical_uint16()),
164                            TypeSignatureClass::Native(logical_uint32()),
165                        ],
166                        NativeType::UInt64,
167                    )]),
168                    // Signed to i64
169                    TypeSignature::Coercible(vec![Coercion::new_implicit(
170                        TypeSignatureClass::Native(logical_int64()),
171                        vec![
172                            TypeSignatureClass::Native(logical_int8()),
173                            TypeSignatureClass::Native(logical_int16()),
174                            TypeSignatureClass::Native(logical_int32()),
175                        ],
176                        NativeType::Int64,
177                    )]),
178                    // Floats to f64
179                    TypeSignature::Coercible(vec![Coercion::new_implicit(
180                        TypeSignatureClass::Native(logical_float64()),
181                        vec![TypeSignatureClass::Float],
182                        NativeType::Float64,
183                    )]),
184                    TypeSignature::Coercible(vec![Coercion::new_exact(
185                        TypeSignatureClass::Duration,
186                    )]),
187                ],
188                Volatility::Immutable,
189            ),
190        }
191    }
192}
193
194impl Default for Sum {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200impl AggregateUDFImpl for Sum {
201    fn as_any(&self) -> &dyn Any {
202        self
203    }
204
205    fn name(&self) -> &str {
206        "sum"
207    }
208
209    fn signature(&self) -> &Signature {
210        &self.signature
211    }
212
213    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
214        match &arg_types[0] {
215            DataType::Int64 => Ok(DataType::Int64),
216            DataType::UInt64 => Ok(DataType::UInt64),
217            DataType::Float64 => Ok(DataType::Float64),
218            // In the spark, the result type is DECIMAL(min(38,precision+10), s)
219            // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
220            DataType::Decimal32(precision, scale) => {
221                let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
222                Ok(DataType::Decimal32(new_precision, *scale))
223            }
224            DataType::Decimal64(precision, scale) => {
225                let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
226                Ok(DataType::Decimal64(new_precision, *scale))
227            }
228            DataType::Decimal128(precision, scale) => {
229                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
230                Ok(DataType::Decimal128(new_precision, *scale))
231            }
232            DataType::Decimal256(precision, scale) => {
233                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
234                Ok(DataType::Decimal256(new_precision, *scale))
235            }
236            DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
237            other => {
238                exec_err!("[return_type] SUM not supported for {}", other)
239            }
240        }
241    }
242
243    fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
244        if args.is_distinct {
245            macro_rules! helper {
246                ($t:ty, $dt:expr) => {
247                    Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt)))
248                };
249            }
250            downcast_sum!(args, helper)
251        } else {
252            macro_rules! helper {
253                ($t:ty, $dt:expr) => {
254                    Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
255                };
256            }
257            downcast_sum!(args, helper)
258        }
259    }
260
261    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
262        if args.is_distinct {
263            Ok(vec![
264                Field::new_list(
265                    format_state_name(args.name, "sum distinct"),
266                    // See COMMENTS.md to understand why nullable is set to true
267                    Field::new_list_field(args.return_type().clone(), true),
268                    false,
269                )
270                .into(),
271            ])
272        } else {
273            Ok(vec![
274                Field::new(
275                    format_state_name(args.name, "sum"),
276                    args.return_type().clone(),
277                    true,
278                )
279                .into(),
280            ])
281        }
282    }
283
284    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
285        !args.is_distinct
286    }
287
288    fn create_groups_accumulator(
289        &self,
290        args: AccumulatorArgs,
291    ) -> Result<Box<dyn GroupsAccumulator>> {
292        macro_rules! helper {
293            ($t:ty, $dt:expr) => {
294                Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
295                    &$dt,
296                    |x, y| *x = x.add_wrapping(y),
297                )))
298            };
299        }
300        downcast_sum!(args, helper)
301    }
302
303    fn create_sliding_accumulator(
304        &self,
305        args: AccumulatorArgs,
306    ) -> Result<Box<dyn Accumulator>> {
307        if args.is_distinct {
308            // distinct path: use our sliding‐window distinct‐sum
309            macro_rules! helper_distinct {
310                ($t:ty, $dt:expr) => {
311                    Ok(Box::new(SlidingDistinctSumAccumulator::try_new(&$dt)?))
312                };
313            }
314            downcast_sum!(args, helper_distinct)
315        } else {
316            // non‐distinct path: existing sliding sum
317            macro_rules! helper {
318                ($t:ty, $dt:expr) => {
319                    Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
320                };
321            }
322            downcast_sum!(args, helper)
323        }
324    }
325
326    fn reverse_expr(&self) -> ReversedUDAF {
327        ReversedUDAF::Identical
328    }
329
330    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
331        AggregateOrderSensitivity::Insensitive
332    }
333
334    fn documentation(&self) -> Option<&Documentation> {
335        self.doc()
336    }
337
338    fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
339        // `SUM` is only monotonically increasing when its input is unsigned.
340        // TODO: Expand these utilizing statistics.
341        match data_type {
342            DataType::UInt8 => SetMonotonicity::Increasing,
343            DataType::UInt16 => SetMonotonicity::Increasing,
344            DataType::UInt32 => SetMonotonicity::Increasing,
345            DataType::UInt64 => SetMonotonicity::Increasing,
346            _ => SetMonotonicity::NotMonotonic,
347        }
348    }
349}
350
351/// This accumulator computes SUM incrementally
352struct SumAccumulator<T: ArrowNumericType> {
353    sum: Option<T::Native>,
354    data_type: DataType,
355}
356
357impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
358    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359        write!(f, "SumAccumulator({})", self.data_type)
360    }
361}
362
363impl<T: ArrowNumericType> SumAccumulator<T> {
364    fn new(data_type: DataType) -> Self {
365        Self {
366            sum: None,
367            data_type,
368        }
369    }
370}
371
372impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
373    fn state(&mut self) -> Result<Vec<ScalarValue>> {
374        Ok(vec![self.evaluate()?])
375    }
376
377    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
378        let values = values[0].as_primitive::<T>();
379        if let Some(x) = arrow::compute::sum(values) {
380            let v = self.sum.get_or_insert_with(|| T::Native::usize_as(0));
381            *v = v.add_wrapping(x);
382        }
383        Ok(())
384    }
385
386    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
387        self.update_batch(states)
388    }
389
390    fn evaluate(&mut self) -> Result<ScalarValue> {
391        ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
392    }
393
394    fn size(&self) -> usize {
395        size_of_val(self)
396    }
397}
398
399/// This accumulator incrementally computes sums over a sliding window
400///
401/// This is separate from [`SumAccumulator`] as requires additional state
402struct SlidingSumAccumulator<T: ArrowNumericType> {
403    sum: T::Native,
404    count: u64,
405    data_type: DataType,
406}
407
408impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
409    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410        write!(f, "SlidingSumAccumulator({})", self.data_type)
411    }
412}
413
414impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
415    fn new(data_type: DataType) -> Self {
416        Self {
417            sum: T::Native::usize_as(0),
418            count: 0,
419            data_type,
420        }
421    }
422}
423
424impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
425    fn state(&mut self) -> Result<Vec<ScalarValue>> {
426        Ok(vec![self.evaluate()?, self.count.into()])
427    }
428
429    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
430        let values = values[0].as_primitive::<T>();
431        self.count += (values.len() - values.null_count()) as u64;
432        if let Some(x) = arrow::compute::sum(values) {
433            self.sum = self.sum.add_wrapping(x)
434        }
435        Ok(())
436    }
437
438    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
439        let values = states[0].as_primitive::<T>();
440        if let Some(x) = arrow::compute::sum(values) {
441            self.sum = self.sum.add_wrapping(x)
442        }
443        if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
444            self.count += x;
445        }
446        Ok(())
447    }
448
449    fn evaluate(&mut self) -> Result<ScalarValue> {
450        let v = (self.count != 0).then_some(self.sum);
451        ScalarValue::new_primitive::<T>(v, &self.data_type)
452    }
453
454    fn size(&self) -> usize {
455        size_of_val(self)
456    }
457
458    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
459        let values = values[0].as_primitive::<T>();
460        if let Some(x) = arrow::compute::sum(values) {
461            self.sum = self.sum.sub_wrapping(x)
462        }
463        self.count -= (values.len() - values.null_count()) as u64;
464        Ok(())
465    }
466
467    fn supports_retract_batch(&self) -> bool {
468        true
469    }
470}
471
472/// A sliding‐window accumulator for `SUM(DISTINCT)` over Int64 columns.
473/// Maintains a running sum so that `evaluate()` is O(1).
474#[derive(Debug)]
475pub struct SlidingDistinctSumAccumulator {
476    /// Map each distinct value → its current count in the window
477    counts: HashMap<i64, usize, RandomState>,
478    /// Running sum of all distinct keys currently in the window
479    sum: i64,
480    /// Data type (must be Int64)
481    data_type: DataType,
482}
483
484impl SlidingDistinctSumAccumulator {
485    /// Create a new accumulator; only `DataType::Int64` is supported.
486    pub fn try_new(data_type: &DataType) -> Result<Self> {
487        // TODO support other numeric types
488        if *data_type != DataType::Int64 {
489            return exec_err!("SlidingDistinctSumAccumulator only supports Int64");
490        }
491        Ok(Self {
492            counts: HashMap::default(),
493            sum: 0,
494            data_type: data_type.clone(),
495        })
496    }
497}
498
499impl Accumulator for SlidingDistinctSumAccumulator {
500    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
501        let arr = values[0].as_primitive::<Int64Type>();
502        for &v in arr.values() {
503            let cnt = self.counts.entry(v).or_insert(0);
504            if *cnt == 0 {
505                // first occurrence in window
506                self.sum = self.sum.wrapping_add(v);
507            }
508            *cnt += 1;
509        }
510        Ok(())
511    }
512
513    fn evaluate(&mut self) -> Result<ScalarValue> {
514        // O(1) wrap of running sum
515        Ok(ScalarValue::Int64(Some(self.sum)))
516    }
517
518    fn size(&self) -> usize {
519        size_of_val(self)
520    }
521
522    fn state(&mut self) -> Result<Vec<ScalarValue>> {
523        // Serialize distinct keys for cross-partition merge if needed
524        let keys = self
525            .counts
526            .keys()
527            .cloned()
528            .map(Some)
529            .map(ScalarValue::Int64)
530            .collect::<Vec<_>>();
531        Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
532            &keys,
533            &self.data_type,
534        ))])
535    }
536
537    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
538        // Merge distinct keys from other partitions
539        let list_arr = states[0].as_list::<i32>();
540        for maybe_inner in list_arr.iter().flatten() {
541            for idx in 0..maybe_inner.len() {
542                if let ScalarValue::Int64(Some(v)) =
543                    ScalarValue::try_from_array(&*maybe_inner, idx)?
544                {
545                    let cnt = self.counts.entry(v).or_insert(0);
546                    if *cnt == 0 {
547                        self.sum = self.sum.wrapping_add(v);
548                    }
549                    *cnt += 1;
550                }
551            }
552        }
553        Ok(())
554    }
555
556    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
557        let arr = values[0].as_primitive::<Int64Type>();
558        for &v in arr.values() {
559            if let Some(cnt) = self.counts.get_mut(&v) {
560                *cnt -= 1;
561                if *cnt == 0 {
562                    // last copy leaving window
563                    self.sum = self.sum.wrapping_sub(v);
564                    self.counts.remove(&v);
565                }
566            }
567        }
568        Ok(())
569    }
570
571    fn supports_retract_batch(&self) -> bool {
572        true
573    }
574}