datafusion_functions_aggregate/
sum.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators
19
20use ahash::RandomState;
21use arrow::datatypes::DECIMAL32_MAX_PRECISION;
22use arrow::datatypes::DECIMAL64_MAX_PRECISION;
23use datafusion_expr::utils::AggregateOrderSensitivity;
24use datafusion_expr::Expr;
25use std::any::Any;
26use std::mem::size_of_val;
27
28use arrow::array::Array;
29use arrow::array::ArrowNativeTypeOp;
30use arrow::array::{ArrowNumericType, AsArray};
31use arrow::datatypes::{ArrowNativeType, FieldRef};
32use arrow::datatypes::{
33    DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, Float64Type,
34    Int64Type, UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
35};
36use arrow::{array::ArrayRef, datatypes::Field};
37use datafusion_common::{
38    exec_err, not_impl_err, utils::take_function_args, HashMap, Result, ScalarValue,
39};
40use datafusion_expr::function::AccumulatorArgs;
41use datafusion_expr::function::StateFieldsArgs;
42use datafusion_expr::utils::format_state_name;
43use datafusion_expr::{
44    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
45    SetMonotonicity, Signature, Volatility,
46};
47use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
48use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
49use datafusion_macros::user_doc;
50
51make_udaf_expr_and_func!(
52    Sum,
53    sum,
54    expression,
55    "Returns the sum of a group of values.",
56    sum_udaf
57);
58
59pub fn sum_distinct(expr: Expr) -> Expr {
60    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
61        sum_udaf(),
62        vec![expr],
63        true,
64        None,
65        vec![],
66        None,
67    ))
68}
69
70/// Sum only supports a subset of numeric types, instead relying on type coercion
71///
72/// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive)
73///
74/// `args` is [AccumulatorArgs]
75/// `helper` is a macro accepting (ArrowPrimitiveType, DataType)
76macro_rules! downcast_sum {
77    ($args:ident, $helper:ident) => {
78        match $args.return_field.data_type().clone() {
79            DataType::UInt64 => {
80                $helper!(UInt64Type, $args.return_field.data_type().clone())
81            }
82            DataType::Int64 => {
83                $helper!(Int64Type, $args.return_field.data_type().clone())
84            }
85            DataType::Float64 => {
86                $helper!(Float64Type, $args.return_field.data_type().clone())
87            }
88            DataType::Decimal32(_, _) => {
89                $helper!(Decimal32Type, $args.return_field.data_type().clone())
90            }
91            DataType::Decimal64(_, _) => {
92                $helper!(Decimal64Type, $args.return_field.data_type().clone())
93            }
94            DataType::Decimal128(_, _) => {
95                $helper!(Decimal128Type, $args.return_field.data_type().clone())
96            }
97            DataType::Decimal256(_, _) => {
98                $helper!(Decimal256Type, $args.return_field.data_type().clone())
99            }
100            _ => {
101                not_impl_err!(
102                    "Sum not supported for {}: {}",
103                    $args.name,
104                    $args.return_field.data_type()
105                )
106            }
107        }
108    };
109}
110
111#[user_doc(
112    doc_section(label = "General Functions"),
113    description = "Returns the sum of all values in the specified column.",
114    syntax_example = "sum(expression)",
115    sql_example = r#"```sql
116> SELECT sum(column_name) FROM table_name;
117+-----------------------+
118| sum(column_name)       |
119+-----------------------+
120| 12345                 |
121+-----------------------+
122```"#,
123    standard_argument(name = "expression",)
124)]
125#[derive(Debug, PartialEq, Eq, Hash)]
126pub struct Sum {
127    signature: Signature,
128}
129
130impl Sum {
131    pub fn new() -> Self {
132        Self {
133            signature: Signature::user_defined(Volatility::Immutable),
134        }
135    }
136}
137
138impl Default for Sum {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144impl AggregateUDFImpl for Sum {
145    fn as_any(&self) -> &dyn Any {
146        self
147    }
148
149    fn name(&self) -> &str {
150        "sum"
151    }
152
153    fn signature(&self) -> &Signature {
154        &self.signature
155    }
156
157    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
158        let [args] = take_function_args(self.name(), arg_types)?;
159
160        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
161        // smallint, int, bigint, real, double precision, decimal, or interval.
162
163        fn coerced_type(data_type: &DataType) -> Result<DataType> {
164            match data_type {
165                DataType::Dictionary(_, v) => coerced_type(v),
166                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
167                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
168                DataType::Decimal32(_, _)
169                | DataType::Decimal64(_, _)
170                | DataType::Decimal128(_, _)
171                | DataType::Decimal256(_, _) => Ok(data_type.clone()),
172                dt if dt.is_signed_integer() => Ok(DataType::Int64),
173                dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
174                dt if dt.is_floating() => Ok(DataType::Float64),
175                _ => exec_err!("Sum not supported for {data_type}"),
176            }
177        }
178
179        Ok(vec![coerced_type(args)?])
180    }
181
182    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
183        match &arg_types[0] {
184            DataType::Int64 => Ok(DataType::Int64),
185            DataType::UInt64 => Ok(DataType::UInt64),
186            DataType::Float64 => Ok(DataType::Float64),
187            DataType::Decimal32(precision, scale) => {
188                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
189                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
190                let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
191                Ok(DataType::Decimal32(new_precision, *scale))
192            }
193            DataType::Decimal64(precision, scale) => {
194                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
195                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
196                let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
197                Ok(DataType::Decimal64(new_precision, *scale))
198            }
199            DataType::Decimal128(precision, scale) => {
200                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
201                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
202                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
203                Ok(DataType::Decimal128(new_precision, *scale))
204            }
205            DataType::Decimal256(precision, scale) => {
206                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
207                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
208                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
209                Ok(DataType::Decimal256(new_precision, *scale))
210            }
211            other => {
212                exec_err!("[return_type] SUM not supported for {}", other)
213            }
214        }
215    }
216
217    fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
218        if args.is_distinct {
219            macro_rules! helper {
220                ($t:ty, $dt:expr) => {
221                    Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt)))
222                };
223            }
224            downcast_sum!(args, helper)
225        } else {
226            macro_rules! helper {
227                ($t:ty, $dt:expr) => {
228                    Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
229                };
230            }
231            downcast_sum!(args, helper)
232        }
233    }
234
235    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
236        if args.is_distinct {
237            Ok(vec![Field::new_list(
238                format_state_name(args.name, "sum distinct"),
239                // See COMMENTS.md to understand why nullable is set to true
240                Field::new_list_field(args.return_type().clone(), true),
241                false,
242            )
243            .into()])
244        } else {
245            Ok(vec![Field::new(
246                format_state_name(args.name, "sum"),
247                args.return_type().clone(),
248                true,
249            )
250            .into()])
251        }
252    }
253
254    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
255        !args.is_distinct
256    }
257
258    fn create_groups_accumulator(
259        &self,
260        args: AccumulatorArgs,
261    ) -> Result<Box<dyn GroupsAccumulator>> {
262        macro_rules! helper {
263            ($t:ty, $dt:expr) => {
264                Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
265                    &$dt,
266                    |x, y| *x = x.add_wrapping(y),
267                )))
268            };
269        }
270        downcast_sum!(args, helper)
271    }
272
273    fn create_sliding_accumulator(
274        &self,
275        args: AccumulatorArgs,
276    ) -> Result<Box<dyn Accumulator>> {
277        if args.is_distinct {
278            // distinct path: use our sliding‐window distinct‐sum
279            macro_rules! helper_distinct {
280                ($t:ty, $dt:expr) => {
281                    Ok(Box::new(SlidingDistinctSumAccumulator::try_new(&$dt)?))
282                };
283            }
284            downcast_sum!(args, helper_distinct)
285        } else {
286            // non‐distinct path: existing sliding sum
287            macro_rules! helper {
288                ($t:ty, $dt:expr) => {
289                    Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
290                };
291            }
292            downcast_sum!(args, helper)
293        }
294    }
295
296    fn reverse_expr(&self) -> ReversedUDAF {
297        ReversedUDAF::Identical
298    }
299
300    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
301        AggregateOrderSensitivity::Insensitive
302    }
303
304    fn documentation(&self) -> Option<&Documentation> {
305        self.doc()
306    }
307
308    fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
309        // `SUM` is only monotonically increasing when its input is unsigned.
310        // TODO: Expand these utilizing statistics.
311        match data_type {
312            DataType::UInt8 => SetMonotonicity::Increasing,
313            DataType::UInt16 => SetMonotonicity::Increasing,
314            DataType::UInt32 => SetMonotonicity::Increasing,
315            DataType::UInt64 => SetMonotonicity::Increasing,
316            _ => SetMonotonicity::NotMonotonic,
317        }
318    }
319}
320
321/// This accumulator computes SUM incrementally
322struct SumAccumulator<T: ArrowNumericType> {
323    sum: Option<T::Native>,
324    data_type: DataType,
325}
326
327impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
328    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329        write!(f, "SumAccumulator({})", self.data_type)
330    }
331}
332
333impl<T: ArrowNumericType> SumAccumulator<T> {
334    fn new(data_type: DataType) -> Self {
335        Self {
336            sum: None,
337            data_type,
338        }
339    }
340}
341
342impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
343    fn state(&mut self) -> Result<Vec<ScalarValue>> {
344        Ok(vec![self.evaluate()?])
345    }
346
347    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
348        let values = values[0].as_primitive::<T>();
349        if let Some(x) = arrow::compute::sum(values) {
350            let v = self.sum.get_or_insert_with(|| T::Native::usize_as(0));
351            *v = v.add_wrapping(x);
352        }
353        Ok(())
354    }
355
356    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
357        self.update_batch(states)
358    }
359
360    fn evaluate(&mut self) -> Result<ScalarValue> {
361        ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
362    }
363
364    fn size(&self) -> usize {
365        size_of_val(self)
366    }
367}
368
369/// This accumulator incrementally computes sums over a sliding window
370///
371/// This is separate from [`SumAccumulator`] as requires additional state
372struct SlidingSumAccumulator<T: ArrowNumericType> {
373    sum: T::Native,
374    count: u64,
375    data_type: DataType,
376}
377
378impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
379    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380        write!(f, "SlidingSumAccumulator({})", self.data_type)
381    }
382}
383
384impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
385    fn new(data_type: DataType) -> Self {
386        Self {
387            sum: T::Native::usize_as(0),
388            count: 0,
389            data_type,
390        }
391    }
392}
393
394impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
395    fn state(&mut self) -> Result<Vec<ScalarValue>> {
396        Ok(vec![self.evaluate()?, self.count.into()])
397    }
398
399    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
400        let values = values[0].as_primitive::<T>();
401        self.count += (values.len() - values.null_count()) as u64;
402        if let Some(x) = arrow::compute::sum(values) {
403            self.sum = self.sum.add_wrapping(x)
404        }
405        Ok(())
406    }
407
408    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
409        let values = states[0].as_primitive::<T>();
410        if let Some(x) = arrow::compute::sum(values) {
411            self.sum = self.sum.add_wrapping(x)
412        }
413        if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
414            self.count += x;
415        }
416        Ok(())
417    }
418
419    fn evaluate(&mut self) -> Result<ScalarValue> {
420        let v = (self.count != 0).then_some(self.sum);
421        ScalarValue::new_primitive::<T>(v, &self.data_type)
422    }
423
424    fn size(&self) -> usize {
425        size_of_val(self)
426    }
427
428    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
429        let values = values[0].as_primitive::<T>();
430        if let Some(x) = arrow::compute::sum(values) {
431            self.sum = self.sum.sub_wrapping(x)
432        }
433        self.count -= (values.len() - values.null_count()) as u64;
434        Ok(())
435    }
436
437    fn supports_retract_batch(&self) -> bool {
438        true
439    }
440}
441
442/// A sliding‐window accumulator for `SUM(DISTINCT)` over Int64 columns.
443/// Maintains a running sum so that `evaluate()` is O(1).
444#[derive(Debug)]
445pub struct SlidingDistinctSumAccumulator {
446    /// Map each distinct value → its current count in the window
447    counts: HashMap<i64, usize, RandomState>,
448    /// Running sum of all distinct keys currently in the window
449    sum: i64,
450    /// Data type (must be Int64)
451    data_type: DataType,
452}
453
454impl SlidingDistinctSumAccumulator {
455    /// Create a new accumulator; only `DataType::Int64` is supported.
456    pub fn try_new(data_type: &DataType) -> Result<Self> {
457        // TODO support other numeric types
458        if *data_type != DataType::Int64 {
459            return exec_err!("SlidingDistinctSumAccumulator only supports Int64");
460        }
461        Ok(Self {
462            counts: HashMap::default(),
463            sum: 0,
464            data_type: data_type.clone(),
465        })
466    }
467}
468
469impl Accumulator for SlidingDistinctSumAccumulator {
470    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
471        let arr = values[0].as_primitive::<Int64Type>();
472        for &v in arr.values() {
473            let cnt = self.counts.entry(v).or_insert(0);
474            if *cnt == 0 {
475                // first occurrence in window
476                self.sum = self.sum.wrapping_add(v);
477            }
478            *cnt += 1;
479        }
480        Ok(())
481    }
482
483    fn evaluate(&mut self) -> Result<ScalarValue> {
484        // O(1) wrap of running sum
485        Ok(ScalarValue::Int64(Some(self.sum)))
486    }
487
488    fn size(&self) -> usize {
489        size_of_val(self)
490    }
491
492    fn state(&mut self) -> Result<Vec<ScalarValue>> {
493        // Serialize distinct keys for cross-partition merge if needed
494        let keys = self
495            .counts
496            .keys()
497            .cloned()
498            .map(Some)
499            .map(ScalarValue::Int64)
500            .collect::<Vec<_>>();
501        Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
502            &keys,
503            &self.data_type,
504        ))])
505    }
506
507    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
508        // Merge distinct keys from other partitions
509        let list_arr = states[0].as_list::<i32>();
510        for maybe_inner in list_arr.iter().flatten() {
511            for idx in 0..maybe_inner.len() {
512                if let ScalarValue::Int64(Some(v)) =
513                    ScalarValue::try_from_array(&*maybe_inner, idx)?
514                {
515                    let cnt = self.counts.entry(v).or_insert(0);
516                    if *cnt == 0 {
517                        self.sum = self.sum.wrapping_add(v);
518                    }
519                    *cnt += 1;
520                }
521            }
522        }
523        Ok(())
524    }
525
526    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
527        let arr = values[0].as_primitive::<Int64Type>();
528        for &v in arr.values() {
529            if let Some(cnt) = self.counts.get_mut(&v) {
530                *cnt -= 1;
531                if *cnt == 0 {
532                    // last copy leaving window
533                    self.sum = self.sum.wrapping_sub(v);
534                    self.counts.remove(&v);
535                }
536            }
537        }
538        Ok(())
539    }
540
541    fn supports_retract_batch(&self) -> bool {
542        true
543    }
544}