datafusion_functions_aggregate/
sum.rs1use ahash::RandomState;
21use datafusion_expr::utils::AggregateOrderSensitivity;
22use std::any::Any;
23use std::mem::size_of_val;
24
25use arrow::array::Array;
26use arrow::array::ArrowNativeTypeOp;
27use arrow::array::{ArrowNumericType, AsArray};
28use arrow::datatypes::{ArrowNativeType, FieldRef};
29use arrow::datatypes::{
30 DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type,
31 DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
32};
33use arrow::{array::ArrayRef, datatypes::Field};
34use datafusion_common::{
35 exec_err, not_impl_err, utils::take_function_args, HashMap, Result, ScalarValue,
36};
37use datafusion_expr::function::AccumulatorArgs;
38use datafusion_expr::function::StateFieldsArgs;
39use datafusion_expr::utils::format_state_name;
40use datafusion_expr::{
41 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
42 SetMonotonicity, Signature, Volatility,
43};
44use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
45use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
46use datafusion_macros::user_doc;
47
48make_udaf_expr_and_func!(
49 Sum,
50 sum,
51 expression,
52 "Returns the sum of a group of values.",
53 sum_udaf
54);
55
56macro_rules! downcast_sum {
63 ($args:ident, $helper:ident) => {
64 match $args.return_field.data_type().clone() {
65 DataType::UInt64 => {
66 $helper!(UInt64Type, $args.return_field.data_type().clone())
67 }
68 DataType::Int64 => {
69 $helper!(Int64Type, $args.return_field.data_type().clone())
70 }
71 DataType::Float64 => {
72 $helper!(Float64Type, $args.return_field.data_type().clone())
73 }
74 DataType::Decimal128(_, _) => {
75 $helper!(Decimal128Type, $args.return_field.data_type().clone())
76 }
77 DataType::Decimal256(_, _) => {
78 $helper!(Decimal256Type, $args.return_field.data_type().clone())
79 }
80 _ => {
81 not_impl_err!(
82 "Sum not supported for {}: {}",
83 $args.name,
84 $args.return_field.data_type()
85 )
86 }
87 }
88 };
89}
90
91#[user_doc(
92 doc_section(label = "General Functions"),
93 description = "Returns the sum of all values in the specified column.",
94 syntax_example = "sum(expression)",
95 sql_example = r#"```sql
96> SELECT sum(column_name) FROM table_name;
97+-----------------------+
98| sum(column_name) |
99+-----------------------+
100| 12345 |
101+-----------------------+
102```"#,
103 standard_argument(name = "expression",)
104)]
105#[derive(Debug, PartialEq, Eq, Hash)]
106pub struct Sum {
107 signature: Signature,
108}
109
110impl Sum {
111 pub fn new() -> Self {
112 Self {
113 signature: Signature::user_defined(Volatility::Immutable),
114 }
115 }
116}
117
118impl Default for Sum {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124impl AggregateUDFImpl for Sum {
125 fn as_any(&self) -> &dyn Any {
126 self
127 }
128
129 fn name(&self) -> &str {
130 "sum"
131 }
132
133 fn signature(&self) -> &Signature {
134 &self.signature
135 }
136
137 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
138 let [args] = take_function_args(self.name(), arg_types)?;
139
140 fn coerced_type(data_type: &DataType) -> Result<DataType> {
144 match data_type {
145 DataType::Dictionary(_, v) => coerced_type(v),
146 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
149 Ok(data_type.clone())
150 }
151 dt if dt.is_signed_integer() => Ok(DataType::Int64),
152 dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
153 dt if dt.is_floating() => Ok(DataType::Float64),
154 _ => exec_err!("Sum not supported for {}", data_type),
155 }
156 }
157
158 Ok(vec![coerced_type(args)?])
159 }
160
161 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
162 match &arg_types[0] {
163 DataType::Int64 => Ok(DataType::Int64),
164 DataType::UInt64 => Ok(DataType::UInt64),
165 DataType::Float64 => Ok(DataType::Float64),
166 DataType::Decimal128(precision, scale) => {
167 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
170 Ok(DataType::Decimal128(new_precision, *scale))
171 }
172 DataType::Decimal256(precision, scale) => {
173 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
176 Ok(DataType::Decimal256(new_precision, *scale))
177 }
178 other => {
179 exec_err!("[return_type] SUM not supported for {}", other)
180 }
181 }
182 }
183
184 fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
185 if args.is_distinct {
186 macro_rules! helper {
187 ($t:ty, $dt:expr) => {
188 Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt)))
189 };
190 }
191 downcast_sum!(args, helper)
192 } else {
193 macro_rules! helper {
194 ($t:ty, $dt:expr) => {
195 Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
196 };
197 }
198 downcast_sum!(args, helper)
199 }
200 }
201
202 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
203 if args.is_distinct {
204 Ok(vec![Field::new_list(
205 format_state_name(args.name, "sum distinct"),
206 Field::new_list_field(args.return_type().clone(), true),
208 false,
209 )
210 .into()])
211 } else {
212 Ok(vec![Field::new(
213 format_state_name(args.name, "sum"),
214 args.return_type().clone(),
215 true,
216 )
217 .into()])
218 }
219 }
220
221 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
222 !args.is_distinct
223 }
224
225 fn create_groups_accumulator(
226 &self,
227 args: AccumulatorArgs,
228 ) -> Result<Box<dyn GroupsAccumulator>> {
229 macro_rules! helper {
230 ($t:ty, $dt:expr) => {
231 Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
232 &$dt,
233 |x, y| *x = x.add_wrapping(y),
234 )))
235 };
236 }
237 downcast_sum!(args, helper)
238 }
239
240 fn create_sliding_accumulator(
241 &self,
242 args: AccumulatorArgs,
243 ) -> Result<Box<dyn Accumulator>> {
244 if args.is_distinct {
245 macro_rules! helper_distinct {
247 ($t:ty, $dt:expr) => {
248 Ok(Box::new(SlidingDistinctSumAccumulator::try_new(&$dt)?))
249 };
250 }
251 downcast_sum!(args, helper_distinct)
252 } else {
253 macro_rules! helper {
255 ($t:ty, $dt:expr) => {
256 Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
257 };
258 }
259 downcast_sum!(args, helper)
260 }
261 }
262
263 fn reverse_expr(&self) -> ReversedUDAF {
264 ReversedUDAF::Identical
265 }
266
267 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
268 AggregateOrderSensitivity::Insensitive
269 }
270
271 fn documentation(&self) -> Option<&Documentation> {
272 self.doc()
273 }
274
275 fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
276 match data_type {
279 DataType::UInt8 => SetMonotonicity::Increasing,
280 DataType::UInt16 => SetMonotonicity::Increasing,
281 DataType::UInt32 => SetMonotonicity::Increasing,
282 DataType::UInt64 => SetMonotonicity::Increasing,
283 _ => SetMonotonicity::NotMonotonic,
284 }
285 }
286}
287
288struct SumAccumulator<T: ArrowNumericType> {
290 sum: Option<T::Native>,
291 data_type: DataType,
292}
293
294impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
295 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296 write!(f, "SumAccumulator({})", self.data_type)
297 }
298}
299
300impl<T: ArrowNumericType> SumAccumulator<T> {
301 fn new(data_type: DataType) -> Self {
302 Self {
303 sum: None,
304 data_type,
305 }
306 }
307}
308
309impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
310 fn state(&mut self) -> Result<Vec<ScalarValue>> {
311 Ok(vec![self.evaluate()?])
312 }
313
314 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
315 let values = values[0].as_primitive::<T>();
316 if let Some(x) = arrow::compute::sum(values) {
317 let v = self.sum.get_or_insert(T::Native::usize_as(0));
318 *v = v.add_wrapping(x);
319 }
320 Ok(())
321 }
322
323 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
324 self.update_batch(states)
325 }
326
327 fn evaluate(&mut self) -> Result<ScalarValue> {
328 ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
329 }
330
331 fn size(&self) -> usize {
332 size_of_val(self)
333 }
334}
335
336struct SlidingSumAccumulator<T: ArrowNumericType> {
340 sum: T::Native,
341 count: u64,
342 data_type: DataType,
343}
344
345impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
346 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347 write!(f, "SlidingSumAccumulator({})", self.data_type)
348 }
349}
350
351impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
352 fn new(data_type: DataType) -> Self {
353 Self {
354 sum: T::Native::usize_as(0),
355 count: 0,
356 data_type,
357 }
358 }
359}
360
361impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
362 fn state(&mut self) -> Result<Vec<ScalarValue>> {
363 Ok(vec![self.evaluate()?, self.count.into()])
364 }
365
366 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
367 let values = values[0].as_primitive::<T>();
368 self.count += (values.len() - values.null_count()) as u64;
369 if let Some(x) = arrow::compute::sum(values) {
370 self.sum = self.sum.add_wrapping(x)
371 }
372 Ok(())
373 }
374
375 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
376 let values = states[0].as_primitive::<T>();
377 if let Some(x) = arrow::compute::sum(values) {
378 self.sum = self.sum.add_wrapping(x)
379 }
380 if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
381 self.count += x;
382 }
383 Ok(())
384 }
385
386 fn evaluate(&mut self) -> Result<ScalarValue> {
387 let v = (self.count != 0).then_some(self.sum);
388 ScalarValue::new_primitive::<T>(v, &self.data_type)
389 }
390
391 fn size(&self) -> usize {
392 size_of_val(self)
393 }
394
395 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
396 let values = values[0].as_primitive::<T>();
397 if let Some(x) = arrow::compute::sum(values) {
398 self.sum = self.sum.sub_wrapping(x)
399 }
400 self.count -= (values.len() - values.null_count()) as u64;
401 Ok(())
402 }
403
404 fn supports_retract_batch(&self) -> bool {
405 true
406 }
407}
408
409#[derive(Debug)]
412pub struct SlidingDistinctSumAccumulator {
413 counts: HashMap<i64, usize, RandomState>,
415 sum: i64,
417 data_type: DataType,
419}
420
421impl SlidingDistinctSumAccumulator {
422 pub fn try_new(data_type: &DataType) -> Result<Self> {
424 if *data_type != DataType::Int64 {
426 return exec_err!("SlidingDistinctSumAccumulator only supports Int64");
427 }
428 Ok(Self {
429 counts: HashMap::default(),
430 sum: 0,
431 data_type: data_type.clone(),
432 })
433 }
434}
435
436impl Accumulator for SlidingDistinctSumAccumulator {
437 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
438 let arr = values[0].as_primitive::<Int64Type>();
439 for &v in arr.values() {
440 let cnt = self.counts.entry(v).or_insert(0);
441 if *cnt == 0 {
442 self.sum = self.sum.wrapping_add(v);
444 }
445 *cnt += 1;
446 }
447 Ok(())
448 }
449
450 fn evaluate(&mut self) -> Result<ScalarValue> {
451 Ok(ScalarValue::Int64(Some(self.sum)))
453 }
454
455 fn size(&self) -> usize {
456 size_of_val(self)
457 }
458
459 fn state(&mut self) -> Result<Vec<ScalarValue>> {
460 let keys = self
462 .counts
463 .keys()
464 .cloned()
465 .map(Some)
466 .map(ScalarValue::Int64)
467 .collect::<Vec<_>>();
468 Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
469 &keys,
470 &self.data_type,
471 ))])
472 }
473
474 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
475 let list_arr = states[0].as_list::<i32>();
477 for maybe_inner in list_arr.iter().flatten() {
478 for idx in 0..maybe_inner.len() {
479 if let ScalarValue::Int64(Some(v)) =
480 ScalarValue::try_from_array(&*maybe_inner, idx)?
481 {
482 let cnt = self.counts.entry(v).or_insert(0);
483 if *cnt == 0 {
484 self.sum = self.sum.wrapping_add(v);
485 }
486 *cnt += 1;
487 }
488 }
489 }
490 Ok(())
491 }
492
493 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
494 let arr = values[0].as_primitive::<Int64Type>();
495 for &v in arr.values() {
496 if let Some(cnt) = self.counts.get_mut(&v) {
497 *cnt -= 1;
498 if *cnt == 0 {
499 self.sum = self.sum.wrapping_sub(v);
501 self.counts.remove(&v);
502 }
503 }
504 }
505 Ok(())
506 }
507
508 fn supports_retract_batch(&self) -> bool {
509 true
510 }
511}