datafusion_functions_aggregate/
sum.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators
19
20use ahash::RandomState;
21use datafusion_expr::utils::AggregateOrderSensitivity;
22use std::any::Any;
23use std::mem::size_of_val;
24
25use arrow::array::Array;
26use arrow::array::ArrowNativeTypeOp;
27use arrow::array::{ArrowNumericType, AsArray};
28use arrow::datatypes::{ArrowNativeType, FieldRef};
29use arrow::datatypes::{
30    DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type,
31    DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
32};
33use arrow::{array::ArrayRef, datatypes::Field};
34use datafusion_common::{
35    exec_err, not_impl_err, utils::take_function_args, HashMap, Result, ScalarValue,
36};
37use datafusion_expr::function::AccumulatorArgs;
38use datafusion_expr::function::StateFieldsArgs;
39use datafusion_expr::utils::format_state_name;
40use datafusion_expr::{
41    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
42    SetMonotonicity, Signature, Volatility,
43};
44use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
45use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
46use datafusion_macros::user_doc;
47
48make_udaf_expr_and_func!(
49    Sum,
50    sum,
51    expression,
52    "Returns the sum of a group of values.",
53    sum_udaf
54);
55
56/// Sum only supports a subset of numeric types, instead relying on type coercion
57///
58/// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive)
59///
60/// `args` is [AccumulatorArgs]
61/// `helper` is a macro accepting (ArrowPrimitiveType, DataType)
62macro_rules! downcast_sum {
63    ($args:ident, $helper:ident) => {
64        match $args.return_field.data_type().clone() {
65            DataType::UInt64 => {
66                $helper!(UInt64Type, $args.return_field.data_type().clone())
67            }
68            DataType::Int64 => {
69                $helper!(Int64Type, $args.return_field.data_type().clone())
70            }
71            DataType::Float64 => {
72                $helper!(Float64Type, $args.return_field.data_type().clone())
73            }
74            DataType::Decimal128(_, _) => {
75                $helper!(Decimal128Type, $args.return_field.data_type().clone())
76            }
77            DataType::Decimal256(_, _) => {
78                $helper!(Decimal256Type, $args.return_field.data_type().clone())
79            }
80            _ => {
81                not_impl_err!(
82                    "Sum not supported for {}: {}",
83                    $args.name,
84                    $args.return_field.data_type()
85                )
86            }
87        }
88    };
89}
90
91#[user_doc(
92    doc_section(label = "General Functions"),
93    description = "Returns the sum of all values in the specified column.",
94    syntax_example = "sum(expression)",
95    sql_example = r#"```sql
96> SELECT sum(column_name) FROM table_name;
97+-----------------------+
98| sum(column_name)       |
99+-----------------------+
100| 12345                 |
101+-----------------------+
102```"#,
103    standard_argument(name = "expression",)
104)]
105#[derive(Debug, PartialEq, Eq, Hash)]
106pub struct Sum {
107    signature: Signature,
108}
109
110impl Sum {
111    pub fn new() -> Self {
112        Self {
113            signature: Signature::user_defined(Volatility::Immutable),
114        }
115    }
116}
117
118impl Default for Sum {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124impl AggregateUDFImpl for Sum {
125    fn as_any(&self) -> &dyn Any {
126        self
127    }
128
129    fn name(&self) -> &str {
130        "sum"
131    }
132
133    fn signature(&self) -> &Signature {
134        &self.signature
135    }
136
137    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
138        let [args] = take_function_args(self.name(), arg_types)?;
139
140        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
141        // smallint, int, bigint, real, double precision, decimal, or interval.
142
143        fn coerced_type(data_type: &DataType) -> Result<DataType> {
144            match data_type {
145                DataType::Dictionary(_, v) => coerced_type(v),
146                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
147                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
148                DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
149                    Ok(data_type.clone())
150                }
151                dt if dt.is_signed_integer() => Ok(DataType::Int64),
152                dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
153                dt if dt.is_floating() => Ok(DataType::Float64),
154                _ => exec_err!("Sum not supported for {}", data_type),
155            }
156        }
157
158        Ok(vec![coerced_type(args)?])
159    }
160
161    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
162        match &arg_types[0] {
163            DataType::Int64 => Ok(DataType::Int64),
164            DataType::UInt64 => Ok(DataType::UInt64),
165            DataType::Float64 => Ok(DataType::Float64),
166            DataType::Decimal128(precision, scale) => {
167                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
168                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
169                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
170                Ok(DataType::Decimal128(new_precision, *scale))
171            }
172            DataType::Decimal256(precision, scale) => {
173                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
174                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
175                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
176                Ok(DataType::Decimal256(new_precision, *scale))
177            }
178            other => {
179                exec_err!("[return_type] SUM not supported for {}", other)
180            }
181        }
182    }
183
184    fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
185        if args.is_distinct {
186            macro_rules! helper {
187                ($t:ty, $dt:expr) => {
188                    Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt)))
189                };
190            }
191            downcast_sum!(args, helper)
192        } else {
193            macro_rules! helper {
194                ($t:ty, $dt:expr) => {
195                    Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
196                };
197            }
198            downcast_sum!(args, helper)
199        }
200    }
201
202    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
203        if args.is_distinct {
204            Ok(vec![Field::new_list(
205                format_state_name(args.name, "sum distinct"),
206                // See COMMENTS.md to understand why nullable is set to true
207                Field::new_list_field(args.return_type().clone(), true),
208                false,
209            )
210            .into()])
211        } else {
212            Ok(vec![Field::new(
213                format_state_name(args.name, "sum"),
214                args.return_type().clone(),
215                true,
216            )
217            .into()])
218        }
219    }
220
221    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
222        !args.is_distinct
223    }
224
225    fn create_groups_accumulator(
226        &self,
227        args: AccumulatorArgs,
228    ) -> Result<Box<dyn GroupsAccumulator>> {
229        macro_rules! helper {
230            ($t:ty, $dt:expr) => {
231                Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
232                    &$dt,
233                    |x, y| *x = x.add_wrapping(y),
234                )))
235            };
236        }
237        downcast_sum!(args, helper)
238    }
239
240    fn create_sliding_accumulator(
241        &self,
242        args: AccumulatorArgs,
243    ) -> Result<Box<dyn Accumulator>> {
244        if args.is_distinct {
245            // distinct path: use our sliding‐window distinct‐sum
246            macro_rules! helper_distinct {
247                ($t:ty, $dt:expr) => {
248                    Ok(Box::new(SlidingDistinctSumAccumulator::try_new(&$dt)?))
249                };
250            }
251            downcast_sum!(args, helper_distinct)
252        } else {
253            // non‐distinct path: existing sliding sum
254            macro_rules! helper {
255                ($t:ty, $dt:expr) => {
256                    Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
257                };
258            }
259            downcast_sum!(args, helper)
260        }
261    }
262
263    fn reverse_expr(&self) -> ReversedUDAF {
264        ReversedUDAF::Identical
265    }
266
267    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
268        AggregateOrderSensitivity::Insensitive
269    }
270
271    fn documentation(&self) -> Option<&Documentation> {
272        self.doc()
273    }
274
275    fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
276        // `SUM` is only monotonically increasing when its input is unsigned.
277        // TODO: Expand these utilizing statistics.
278        match data_type {
279            DataType::UInt8 => SetMonotonicity::Increasing,
280            DataType::UInt16 => SetMonotonicity::Increasing,
281            DataType::UInt32 => SetMonotonicity::Increasing,
282            DataType::UInt64 => SetMonotonicity::Increasing,
283            _ => SetMonotonicity::NotMonotonic,
284        }
285    }
286}
287
288/// This accumulator computes SUM incrementally
289struct SumAccumulator<T: ArrowNumericType> {
290    sum: Option<T::Native>,
291    data_type: DataType,
292}
293
294impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
295    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296        write!(f, "SumAccumulator({})", self.data_type)
297    }
298}
299
300impl<T: ArrowNumericType> SumAccumulator<T> {
301    fn new(data_type: DataType) -> Self {
302        Self {
303            sum: None,
304            data_type,
305        }
306    }
307}
308
309impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
310    fn state(&mut self) -> Result<Vec<ScalarValue>> {
311        Ok(vec![self.evaluate()?])
312    }
313
314    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
315        let values = values[0].as_primitive::<T>();
316        if let Some(x) = arrow::compute::sum(values) {
317            let v = self.sum.get_or_insert(T::Native::usize_as(0));
318            *v = v.add_wrapping(x);
319        }
320        Ok(())
321    }
322
323    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
324        self.update_batch(states)
325    }
326
327    fn evaluate(&mut self) -> Result<ScalarValue> {
328        ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
329    }
330
331    fn size(&self) -> usize {
332        size_of_val(self)
333    }
334}
335
336/// This accumulator incrementally computes sums over a sliding window
337///
338/// This is separate from [`SumAccumulator`] as requires additional state
339struct SlidingSumAccumulator<T: ArrowNumericType> {
340    sum: T::Native,
341    count: u64,
342    data_type: DataType,
343}
344
345impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
346    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347        write!(f, "SlidingSumAccumulator({})", self.data_type)
348    }
349}
350
351impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
352    fn new(data_type: DataType) -> Self {
353        Self {
354            sum: T::Native::usize_as(0),
355            count: 0,
356            data_type,
357        }
358    }
359}
360
361impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
362    fn state(&mut self) -> Result<Vec<ScalarValue>> {
363        Ok(vec![self.evaluate()?, self.count.into()])
364    }
365
366    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
367        let values = values[0].as_primitive::<T>();
368        self.count += (values.len() - values.null_count()) as u64;
369        if let Some(x) = arrow::compute::sum(values) {
370            self.sum = self.sum.add_wrapping(x)
371        }
372        Ok(())
373    }
374
375    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
376        let values = states[0].as_primitive::<T>();
377        if let Some(x) = arrow::compute::sum(values) {
378            self.sum = self.sum.add_wrapping(x)
379        }
380        if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
381            self.count += x;
382        }
383        Ok(())
384    }
385
386    fn evaluate(&mut self) -> Result<ScalarValue> {
387        let v = (self.count != 0).then_some(self.sum);
388        ScalarValue::new_primitive::<T>(v, &self.data_type)
389    }
390
391    fn size(&self) -> usize {
392        size_of_val(self)
393    }
394
395    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
396        let values = values[0].as_primitive::<T>();
397        if let Some(x) = arrow::compute::sum(values) {
398            self.sum = self.sum.sub_wrapping(x)
399        }
400        self.count -= (values.len() - values.null_count()) as u64;
401        Ok(())
402    }
403
404    fn supports_retract_batch(&self) -> bool {
405        true
406    }
407}
408
409/// A sliding‐window accumulator for `SUM(DISTINCT)` over Int64 columns.
410/// Maintains a running sum so that `evaluate()` is O(1).
411#[derive(Debug)]
412pub struct SlidingDistinctSumAccumulator {
413    /// Map each distinct value → its current count in the window
414    counts: HashMap<i64, usize, RandomState>,
415    /// Running sum of all distinct keys currently in the window
416    sum: i64,
417    /// Data type (must be Int64)
418    data_type: DataType,
419}
420
421impl SlidingDistinctSumAccumulator {
422    /// Create a new accumulator; only `DataType::Int64` is supported.
423    pub fn try_new(data_type: &DataType) -> Result<Self> {
424        // TODO support other numeric types
425        if *data_type != DataType::Int64 {
426            return exec_err!("SlidingDistinctSumAccumulator only supports Int64");
427        }
428        Ok(Self {
429            counts: HashMap::default(),
430            sum: 0,
431            data_type: data_type.clone(),
432        })
433    }
434}
435
436impl Accumulator for SlidingDistinctSumAccumulator {
437    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
438        let arr = values[0].as_primitive::<Int64Type>();
439        for &v in arr.values() {
440            let cnt = self.counts.entry(v).or_insert(0);
441            if *cnt == 0 {
442                // first occurrence in window
443                self.sum = self.sum.wrapping_add(v);
444            }
445            *cnt += 1;
446        }
447        Ok(())
448    }
449
450    fn evaluate(&mut self) -> Result<ScalarValue> {
451        // O(1) wrap of running sum
452        Ok(ScalarValue::Int64(Some(self.sum)))
453    }
454
455    fn size(&self) -> usize {
456        size_of_val(self)
457    }
458
459    fn state(&mut self) -> Result<Vec<ScalarValue>> {
460        // Serialize distinct keys for cross-partition merge if needed
461        let keys = self
462            .counts
463            .keys()
464            .cloned()
465            .map(Some)
466            .map(ScalarValue::Int64)
467            .collect::<Vec<_>>();
468        Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
469            &keys,
470            &self.data_type,
471        ))])
472    }
473
474    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
475        // Merge distinct keys from other partitions
476        let list_arr = states[0].as_list::<i32>();
477        for maybe_inner in list_arr.iter().flatten() {
478            for idx in 0..maybe_inner.len() {
479                if let ScalarValue::Int64(Some(v)) =
480                    ScalarValue::try_from_array(&*maybe_inner, idx)?
481                {
482                    let cnt = self.counts.entry(v).or_insert(0);
483                    if *cnt == 0 {
484                        self.sum = self.sum.wrapping_add(v);
485                    }
486                    *cnt += 1;
487                }
488            }
489        }
490        Ok(())
491    }
492
493    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
494        let arr = values[0].as_primitive::<Int64Type>();
495        for &v in arr.values() {
496            if let Some(cnt) = self.counts.get_mut(&v) {
497                *cnt -= 1;
498                if *cnt == 0 {
499                    // last copy leaving window
500                    self.sum = self.sum.wrapping_sub(v);
501                    self.counts.remove(&v);
502                }
503            }
504        }
505        Ok(())
506    }
507
508    fn supports_retract_batch(&self) -> bool {
509        true
510    }
511}