Skip to main content

datafusion_functions_aggregate/
sum.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators
19
20use arrow::array::{Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, AsArray};
21use arrow::datatypes::Field;
22use arrow::datatypes::{
23    ArrowNativeType, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
24    DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, Decimal32Type,
25    Decimal64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType,
26    DurationMillisecondType, DurationNanosecondType, DurationSecondType, FieldRef,
27    Float64Type, Int64Type, TimeUnit, UInt64Type,
28};
29use datafusion_common::hash_utils::RandomState;
30use datafusion_common::internal_err;
31use datafusion_common::types::{
32    NativeType, logical_float64, logical_int8, logical_int16, logical_int32,
33    logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64,
34};
35use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err};
36use datafusion_expr::expr::AggregateFunction;
37use datafusion_expr::expr_fn::cast;
38use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
39use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
40use datafusion_expr::{
41    Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, GroupsAccumulator,
42    Operator, ReversedUDAF, SetMonotonicity, Signature, TypeSignature,
43    TypeSignatureClass, Volatility,
44};
45use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
46use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
47use datafusion_macros::user_doc;
48use std::mem::size_of_val;
49
50make_udaf_expr_and_func!(
51    Sum,
52    sum,
53    expression,
54    "Returns the sum of a group of values.",
55    sum_udaf
56);
57
58pub fn sum_distinct(expr: Expr) -> Expr {
59    Expr::AggregateFunction(AggregateFunction::new_udf(
60        sum_udaf(),
61        vec![expr],
62        true,
63        None,
64        vec![],
65        None,
66    ))
67}
68
69/// Sum only supports a subset of numeric types, instead relying on type coercion
70///
71/// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive)
72///
73/// `args` is [AccumulatorArgs]
74/// `helper` is a macro accepting (ArrowPrimitiveType, DataType)
75macro_rules! downcast_sum {
76    ($args:ident, $helper:ident) => {
77        match $args.return_field.data_type().clone() {
78            DataType::UInt64 => {
79                $helper!(UInt64Type, $args.return_field.data_type().clone())
80            }
81            DataType::Int64 => {
82                $helper!(Int64Type, $args.return_field.data_type().clone())
83            }
84            DataType::Float64 => {
85                $helper!(Float64Type, $args.return_field.data_type().clone())
86            }
87            DataType::Decimal32(_, _) => {
88                $helper!(Decimal32Type, $args.return_field.data_type().clone())
89            }
90            DataType::Decimal64(_, _) => {
91                $helper!(Decimal64Type, $args.return_field.data_type().clone())
92            }
93            DataType::Decimal128(_, _) => {
94                $helper!(Decimal128Type, $args.return_field.data_type().clone())
95            }
96            DataType::Decimal256(_, _) => {
97                $helper!(Decimal256Type, $args.return_field.data_type().clone())
98            }
99            DataType::Duration(TimeUnit::Second) => {
100                $helper!(DurationSecondType, $args.return_field.data_type().clone())
101            }
102            DataType::Duration(TimeUnit::Millisecond) => {
103                $helper!(
104                    DurationMillisecondType,
105                    $args.return_field.data_type().clone()
106                )
107            }
108            DataType::Duration(TimeUnit::Microsecond) => {
109                $helper!(
110                    DurationMicrosecondType,
111                    $args.return_field.data_type().clone()
112                )
113            }
114            DataType::Duration(TimeUnit::Nanosecond) => {
115                $helper!(
116                    DurationNanosecondType,
117                    $args.return_field.data_type().clone()
118                )
119            }
120            _ => {
121                not_impl_err!(
122                    "Sum not supported for {}: {}",
123                    $args.name,
124                    $args.return_field.data_type()
125                )
126            }
127        }
128    };
129}
130
131#[user_doc(
132    doc_section(label = "General Functions"),
133    description = "Returns the sum of all values in the specified column.",
134    syntax_example = "sum(expression)",
135    sql_example = r#"```sql
136> SELECT sum(column_name) FROM table_name;
137+-----------------------+
138| sum(column_name)       |
139+-----------------------+
140| 12345                 |
141+-----------------------+
142```"#,
143    standard_argument(name = "expression",)
144)]
145#[derive(Debug, PartialEq, Eq, Hash)]
146pub struct Sum {
147    signature: Signature,
148}
149
150impl Sum {
151    pub fn new() -> Self {
152        Self {
153            // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
154            // smallint, int, bigint, real, double precision, decimal, or interval.
155            signature: Signature::one_of(
156                vec![
157                    TypeSignature::Coercible(vec![Coercion::new_exact(
158                        TypeSignatureClass::Decimal,
159                    )]),
160                    // Unsigned to u64
161                    TypeSignature::Coercible(vec![Coercion::new_implicit(
162                        TypeSignatureClass::Native(logical_uint64()),
163                        vec![
164                            TypeSignatureClass::Native(logical_uint8()),
165                            TypeSignatureClass::Native(logical_uint16()),
166                            TypeSignatureClass::Native(logical_uint32()),
167                        ],
168                        NativeType::UInt64,
169                    )]),
170                    // Signed to i64
171                    TypeSignature::Coercible(vec![Coercion::new_implicit(
172                        TypeSignatureClass::Native(logical_int64()),
173                        vec![
174                            TypeSignatureClass::Native(logical_int8()),
175                            TypeSignatureClass::Native(logical_int16()),
176                            TypeSignatureClass::Native(logical_int32()),
177                        ],
178                        NativeType::Int64,
179                    )]),
180                    // Floats to f64
181                    TypeSignature::Coercible(vec![Coercion::new_implicit(
182                        TypeSignatureClass::Native(logical_float64()),
183                        vec![TypeSignatureClass::Float],
184                        NativeType::Float64,
185                    )]),
186                    TypeSignature::Coercible(vec![Coercion::new_exact(
187                        TypeSignatureClass::Duration,
188                    )]),
189                ],
190                Volatility::Immutable,
191            ),
192        }
193    }
194}
195
196impl Default for Sum {
197    fn default() -> Self {
198        Self::new()
199    }
200}
201
202impl AggregateUDFImpl for Sum {
203    fn name(&self) -> &str {
204        "sum"
205    }
206
207    fn signature(&self) -> &Signature {
208        &self.signature
209    }
210
211    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
212        match &arg_types[0] {
213            DataType::Int64 => Ok(DataType::Int64),
214            DataType::UInt64 => Ok(DataType::UInt64),
215            DataType::Float64 => Ok(DataType::Float64),
216            // In the spark, the result type is DECIMAL(min(38,precision+10), s)
217            // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
218            DataType::Decimal32(precision, scale) => {
219                let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
220                Ok(DataType::Decimal32(new_precision, *scale))
221            }
222            DataType::Decimal64(precision, scale) => {
223                let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
224                Ok(DataType::Decimal64(new_precision, *scale))
225            }
226            DataType::Decimal128(precision, scale) => {
227                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
228                Ok(DataType::Decimal128(new_precision, *scale))
229            }
230            DataType::Decimal256(precision, scale) => {
231                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
232                Ok(DataType::Decimal256(new_precision, *scale))
233            }
234            DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
235            other => {
236                exec_err!("[return_type] SUM not supported for {}", other)
237            }
238        }
239    }
240
241    fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
242        if args.is_distinct {
243            macro_rules! helper {
244                ($t:ty, $dt:expr) => {
245                    Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt)))
246                };
247            }
248            downcast_sum!(args, helper)
249        } else {
250            macro_rules! helper {
251                ($t:ty, $dt:expr) => {
252                    Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
253                };
254            }
255            downcast_sum!(args, helper)
256        }
257    }
258
259    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
260        if args.is_distinct {
261            Ok(vec![
262                Field::new_list(
263                    format_state_name(args.name, "sum distinct"),
264                    // See COMMENTS.md to understand why nullable is set to true
265                    Field::new_list_field(args.return_type().clone(), true),
266                    false,
267                )
268                .into(),
269            ])
270        } else {
271            Ok(vec![
272                Field::new(
273                    format_state_name(args.name, "sum"),
274                    args.return_type().clone(),
275                    true,
276                )
277                .into(),
278            ])
279        }
280    }
281
282    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
283        !args.is_distinct
284    }
285
286    fn create_groups_accumulator(
287        &self,
288        args: AccumulatorArgs,
289    ) -> Result<Box<dyn GroupsAccumulator>> {
290        macro_rules! helper {
291            ($t:ty, $dt:expr) => {
292                Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
293                    &$dt,
294                    |x, y| *x = x.add_wrapping(y),
295                )))
296            };
297        }
298        downcast_sum!(args, helper)
299    }
300
301    fn create_sliding_accumulator(
302        &self,
303        args: AccumulatorArgs,
304    ) -> Result<Box<dyn Accumulator>> {
305        if args.is_distinct {
306            // distinct path: use our sliding‐window distinct‐sum
307            macro_rules! helper_distinct {
308                ($t:ty, $dt:expr) => {
309                    Ok(Box::new(SlidingDistinctSumAccumulator::try_new(&$dt)?))
310                };
311            }
312            downcast_sum!(args, helper_distinct)
313        } else {
314            // non‐distinct path: existing sliding sum
315            macro_rules! helper {
316                ($t:ty, $dt:expr) => {
317                    Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
318                };
319            }
320            downcast_sum!(args, helper)
321        }
322    }
323
324    fn reverse_expr(&self) -> ReversedUDAF {
325        ReversedUDAF::Identical
326    }
327
328    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
329        AggregateOrderSensitivity::Insensitive
330    }
331
332    fn documentation(&self) -> Option<&Documentation> {
333        self.doc()
334    }
335
336    fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
337        // `SUM` is only monotonically increasing when its input is unsigned.
338        // TODO: Expand these utilizing statistics.
339        match data_type {
340            DataType::UInt8 => SetMonotonicity::Increasing,
341            DataType::UInt16 => SetMonotonicity::Increasing,
342            DataType::UInt32 => SetMonotonicity::Increasing,
343            DataType::UInt64 => SetMonotonicity::Increasing,
344            _ => SetMonotonicity::NotMonotonic,
345        }
346    }
347
348    /// Implement ClickBench Q29 specific optimization:
349    /// `SUM(arg + constant)` --> `SUM(arg) + constant * COUNT(arg)`
350    ///
351    /// See background on [`AggregateUDFImpl::simplify_expr_op_literal`]
352    fn simplify_expr_op_literal(
353        &self,
354        agg_function: &AggregateFunction,
355        arg: &Expr,
356        op: Operator,
357        lit: &Expr,
358        // Only support '+' so the order of the args doesn't matter
359        _arg_is_left: bool,
360    ) -> Result<Option<Expr>> {
361        if op != Operator::Plus {
362            return Ok(None);
363        }
364
365        let lit_type = match &lit {
366            Expr::Literal(value, _) => value.data_type(),
367            _ => {
368                return internal_err!(
369                    "Sum::simplify_expr_op_literal got a non literal argument"
370                );
371            }
372        };
373        if lit_type == DataType::Null {
374            return Ok(None);
375        }
376
377        // Build up SUM(arg)
378        let mut sum_agg = agg_function.clone();
379        sum_agg.params.args = vec![arg.clone()];
380        let sum_agg = Expr::AggregateFunction(sum_agg);
381
382        // COUNT(arg) - cast to the correct type
383        let count_agg = cast(crate::count::count(arg.clone()), lit_type);
384
385        // SUM(arg) + lit * COUNT(arg)
386        Ok(Some(sum_agg + (lit.clone() * count_agg)))
387    }
388}
389
390/// This accumulator computes SUM incrementally
391struct SumAccumulator<T: ArrowNumericType> {
392    sum: Option<T::Native>,
393    data_type: DataType,
394}
395
396impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
397    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
398        write!(f, "SumAccumulator({})", self.data_type)
399    }
400}
401
402impl<T: ArrowNumericType> SumAccumulator<T> {
403    fn new(data_type: DataType) -> Self {
404        Self {
405            sum: None,
406            data_type,
407        }
408    }
409}
410
411impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
412    fn state(&mut self) -> Result<Vec<ScalarValue>> {
413        Ok(vec![self.evaluate()?])
414    }
415
416    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
417        let values = values[0].as_primitive::<T>();
418        if let Some(x) = arrow::compute::sum(values) {
419            let v = self.sum.get_or_insert_with(|| T::Native::usize_as(0));
420            *v = v.add_wrapping(x);
421        }
422        Ok(())
423    }
424
425    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
426        self.update_batch(states)
427    }
428
429    fn evaluate(&mut self) -> Result<ScalarValue> {
430        ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
431    }
432
433    fn size(&self) -> usize {
434        size_of_val(self)
435    }
436}
437
438/// This accumulator incrementally computes sums over a sliding window
439///
440/// This is separate from [`SumAccumulator`] as requires additional state
441struct SlidingSumAccumulator<T: ArrowNumericType> {
442    sum: T::Native,
443    count: u64,
444    data_type: DataType,
445}
446
447impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
448    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
449        write!(f, "SlidingSumAccumulator({})", self.data_type)
450    }
451}
452
453impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
454    fn new(data_type: DataType) -> Self {
455        Self {
456            sum: T::Native::usize_as(0),
457            count: 0,
458            data_type,
459        }
460    }
461}
462
463impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
464    fn state(&mut self) -> Result<Vec<ScalarValue>> {
465        Ok(vec![self.evaluate()?, self.count.into()])
466    }
467
468    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
469        let values = values[0].as_primitive::<T>();
470        self.count += (values.len() - values.null_count()) as u64;
471        if let Some(x) = arrow::compute::sum(values) {
472            self.sum = self.sum.add_wrapping(x)
473        }
474        Ok(())
475    }
476
477    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
478        let values = states[0].as_primitive::<T>();
479        if let Some(x) = arrow::compute::sum(values) {
480            self.sum = self.sum.add_wrapping(x)
481        }
482        if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
483            self.count += x;
484        }
485        Ok(())
486    }
487
488    fn evaluate(&mut self) -> Result<ScalarValue> {
489        let v = (self.count != 0).then_some(self.sum);
490        ScalarValue::new_primitive::<T>(v, &self.data_type)
491    }
492
493    fn size(&self) -> usize {
494        size_of_val(self)
495    }
496
497    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
498        let values = values[0].as_primitive::<T>();
499        if let Some(x) = arrow::compute::sum(values) {
500            self.sum = self.sum.sub_wrapping(x)
501        }
502        self.count -= (values.len() - values.null_count()) as u64;
503        Ok(())
504    }
505
506    fn supports_retract_batch(&self) -> bool {
507        true
508    }
509}
510
511/// A sliding‐window accumulator for `SUM(DISTINCT)` over Int64 columns.
512/// Maintains a running sum so that `evaluate()` is O(1).
513#[derive(Debug)]
514pub struct SlidingDistinctSumAccumulator {
515    /// Map each distinct value → its current count in the window
516    counts: HashMap<i64, usize, RandomState>,
517    /// Running sum of all distinct keys currently in the window
518    sum: i64,
519    /// Data type (must be Int64)
520    data_type: DataType,
521}
522
523impl SlidingDistinctSumAccumulator {
524    /// Create a new accumulator; only `DataType::Int64` is supported.
525    pub fn try_new(data_type: &DataType) -> Result<Self> {
526        // TODO support other numeric types
527        if *data_type != DataType::Int64 {
528            return exec_err!("SlidingDistinctSumAccumulator only supports Int64");
529        }
530        Ok(Self {
531            counts: HashMap::default(),
532            sum: 0,
533            data_type: data_type.clone(),
534        })
535    }
536}
537
538impl Accumulator for SlidingDistinctSumAccumulator {
539    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
540        let arr = values[0].as_primitive::<Int64Type>();
541        for &v in arr.values() {
542            let cnt = self.counts.entry(v).or_insert(0);
543            if *cnt == 0 {
544                // first occurrence in window
545                self.sum = self.sum.wrapping_add(v);
546            }
547            *cnt += 1;
548        }
549        Ok(())
550    }
551
552    fn evaluate(&mut self) -> Result<ScalarValue> {
553        // O(1) wrap of running sum
554        Ok(ScalarValue::Int64(Some(self.sum)))
555    }
556
557    fn size(&self) -> usize {
558        size_of_val(self)
559    }
560
561    fn state(&mut self) -> Result<Vec<ScalarValue>> {
562        // Serialize distinct keys for cross-partition merge if needed
563        let keys = self
564            .counts
565            .keys()
566            .cloned()
567            .map(Some)
568            .map(ScalarValue::Int64)
569            .collect::<Vec<_>>();
570        Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
571            &keys,
572            &self.data_type,
573        ))])
574    }
575
576    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
577        // Merge distinct keys from other partitions
578        let list_arr = states[0].as_list::<i32>();
579        for maybe_inner in list_arr.iter().flatten() {
580            for idx in 0..maybe_inner.len() {
581                if let ScalarValue::Int64(Some(v)) =
582                    ScalarValue::try_from_array(&*maybe_inner, idx)?
583                {
584                    let cnt = self.counts.entry(v).or_insert(0);
585                    if *cnt == 0 {
586                        self.sum = self.sum.wrapping_add(v);
587                    }
588                    *cnt += 1;
589                }
590            }
591        }
592        Ok(())
593    }
594
595    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
596        let arr = values[0].as_primitive::<Int64Type>();
597        for &v in arr.values() {
598            if let Some(cnt) = self.counts.get_mut(&v) {
599                *cnt -= 1;
600                if *cnt == 0 {
601                    // last copy leaving window
602                    self.sum = self.sum.wrapping_sub(v);
603                    self.counts.remove(&v);
604                }
605            }
606        }
607        Ok(())
608    }
609
610    fn supports_retract_batch(&self) -> bool {
611        true
612    }
613}