1use ahash::RandomState;
21use arrow::array::{Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, AsArray};
22use arrow::datatypes::Field;
23use arrow::datatypes::{
24 ArrowNativeType, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
25 DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, Decimal32Type,
26 Decimal64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType,
27 DurationMillisecondType, DurationNanosecondType, DurationSecondType, FieldRef,
28 Float64Type, Int64Type, TimeUnit, UInt64Type,
29};
30use datafusion_common::types::{
31 NativeType, logical_float64, logical_int8, logical_int16, logical_int32,
32 logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64,
33};
34use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err};
35use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
36use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
37use datafusion_expr::{
38 Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, GroupsAccumulator,
39 ReversedUDAF, SetMonotonicity, Signature, TypeSignature, TypeSignatureClass,
40 Volatility,
41};
42use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
43use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
44use datafusion_macros::user_doc;
45use std::any::Any;
46use std::mem::size_of_val;
47
48make_udaf_expr_and_func!(
49 Sum,
50 sum,
51 expression,
52 "Returns the sum of a group of values.",
53 sum_udaf
54);
55
56pub fn sum_distinct(expr: Expr) -> Expr {
57 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
58 sum_udaf(),
59 vec![expr],
60 true,
61 None,
62 vec![],
63 None,
64 ))
65}
66
67macro_rules! downcast_sum {
74 ($args:ident, $helper:ident) => {
75 match $args.return_field.data_type().clone() {
76 DataType::UInt64 => {
77 $helper!(UInt64Type, $args.return_field.data_type().clone())
78 }
79 DataType::Int64 => {
80 $helper!(Int64Type, $args.return_field.data_type().clone())
81 }
82 DataType::Float64 => {
83 $helper!(Float64Type, $args.return_field.data_type().clone())
84 }
85 DataType::Decimal32(_, _) => {
86 $helper!(Decimal32Type, $args.return_field.data_type().clone())
87 }
88 DataType::Decimal64(_, _) => {
89 $helper!(Decimal64Type, $args.return_field.data_type().clone())
90 }
91 DataType::Decimal128(_, _) => {
92 $helper!(Decimal128Type, $args.return_field.data_type().clone())
93 }
94 DataType::Decimal256(_, _) => {
95 $helper!(Decimal256Type, $args.return_field.data_type().clone())
96 }
97 DataType::Duration(TimeUnit::Second) => {
98 $helper!(DurationSecondType, $args.return_field.data_type().clone())
99 }
100 DataType::Duration(TimeUnit::Millisecond) => {
101 $helper!(
102 DurationMillisecondType,
103 $args.return_field.data_type().clone()
104 )
105 }
106 DataType::Duration(TimeUnit::Microsecond) => {
107 $helper!(
108 DurationMicrosecondType,
109 $args.return_field.data_type().clone()
110 )
111 }
112 DataType::Duration(TimeUnit::Nanosecond) => {
113 $helper!(
114 DurationNanosecondType,
115 $args.return_field.data_type().clone()
116 )
117 }
118 _ => {
119 not_impl_err!(
120 "Sum not supported for {}: {}",
121 $args.name,
122 $args.return_field.data_type()
123 )
124 }
125 }
126 };
127}
128
129#[user_doc(
130 doc_section(label = "General Functions"),
131 description = "Returns the sum of all values in the specified column.",
132 syntax_example = "sum(expression)",
133 sql_example = r#"```sql
134> SELECT sum(column_name) FROM table_name;
135+-----------------------+
136| sum(column_name) |
137+-----------------------+
138| 12345 |
139+-----------------------+
140```"#,
141 standard_argument(name = "expression",)
142)]
143#[derive(Debug, PartialEq, Eq, Hash)]
144pub struct Sum {
145 signature: Signature,
146}
147
148impl Sum {
149 pub fn new() -> Self {
150 Self {
151 signature: Signature::one_of(
154 vec![
155 TypeSignature::Coercible(vec![Coercion::new_exact(
156 TypeSignatureClass::Decimal,
157 )]),
158 TypeSignature::Coercible(vec![Coercion::new_implicit(
160 TypeSignatureClass::Native(logical_uint64()),
161 vec![
162 TypeSignatureClass::Native(logical_uint8()),
163 TypeSignatureClass::Native(logical_uint16()),
164 TypeSignatureClass::Native(logical_uint32()),
165 ],
166 NativeType::UInt64,
167 )]),
168 TypeSignature::Coercible(vec![Coercion::new_implicit(
170 TypeSignatureClass::Native(logical_int64()),
171 vec![
172 TypeSignatureClass::Native(logical_int8()),
173 TypeSignatureClass::Native(logical_int16()),
174 TypeSignatureClass::Native(logical_int32()),
175 ],
176 NativeType::Int64,
177 )]),
178 TypeSignature::Coercible(vec![Coercion::new_implicit(
180 TypeSignatureClass::Native(logical_float64()),
181 vec![TypeSignatureClass::Float],
182 NativeType::Float64,
183 )]),
184 TypeSignature::Coercible(vec![Coercion::new_exact(
185 TypeSignatureClass::Duration,
186 )]),
187 ],
188 Volatility::Immutable,
189 ),
190 }
191 }
192}
193
194impl Default for Sum {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200impl AggregateUDFImpl for Sum {
201 fn as_any(&self) -> &dyn Any {
202 self
203 }
204
205 fn name(&self) -> &str {
206 "sum"
207 }
208
209 fn signature(&self) -> &Signature {
210 &self.signature
211 }
212
213 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
214 match &arg_types[0] {
215 DataType::Int64 => Ok(DataType::Int64),
216 DataType::UInt64 => Ok(DataType::UInt64),
217 DataType::Float64 => Ok(DataType::Float64),
218 DataType::Decimal32(precision, scale) => {
221 let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
222 Ok(DataType::Decimal32(new_precision, *scale))
223 }
224 DataType::Decimal64(precision, scale) => {
225 let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
226 Ok(DataType::Decimal64(new_precision, *scale))
227 }
228 DataType::Decimal128(precision, scale) => {
229 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
230 Ok(DataType::Decimal128(new_precision, *scale))
231 }
232 DataType::Decimal256(precision, scale) => {
233 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
234 Ok(DataType::Decimal256(new_precision, *scale))
235 }
236 DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
237 other => {
238 exec_err!("[return_type] SUM not supported for {}", other)
239 }
240 }
241 }
242
243 fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
244 if args.is_distinct {
245 macro_rules! helper {
246 ($t:ty, $dt:expr) => {
247 Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt)))
248 };
249 }
250 downcast_sum!(args, helper)
251 } else {
252 macro_rules! helper {
253 ($t:ty, $dt:expr) => {
254 Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
255 };
256 }
257 downcast_sum!(args, helper)
258 }
259 }
260
261 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
262 if args.is_distinct {
263 Ok(vec![
264 Field::new_list(
265 format_state_name(args.name, "sum distinct"),
266 Field::new_list_field(args.return_type().clone(), true),
268 false,
269 )
270 .into(),
271 ])
272 } else {
273 Ok(vec![
274 Field::new(
275 format_state_name(args.name, "sum"),
276 args.return_type().clone(),
277 true,
278 )
279 .into(),
280 ])
281 }
282 }
283
284 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
285 !args.is_distinct
286 }
287
288 fn create_groups_accumulator(
289 &self,
290 args: AccumulatorArgs,
291 ) -> Result<Box<dyn GroupsAccumulator>> {
292 macro_rules! helper {
293 ($t:ty, $dt:expr) => {
294 Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
295 &$dt,
296 |x, y| *x = x.add_wrapping(y),
297 )))
298 };
299 }
300 downcast_sum!(args, helper)
301 }
302
303 fn create_sliding_accumulator(
304 &self,
305 args: AccumulatorArgs,
306 ) -> Result<Box<dyn Accumulator>> {
307 if args.is_distinct {
308 macro_rules! helper_distinct {
310 ($t:ty, $dt:expr) => {
311 Ok(Box::new(SlidingDistinctSumAccumulator::try_new(&$dt)?))
312 };
313 }
314 downcast_sum!(args, helper_distinct)
315 } else {
316 macro_rules! helper {
318 ($t:ty, $dt:expr) => {
319 Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
320 };
321 }
322 downcast_sum!(args, helper)
323 }
324 }
325
326 fn reverse_expr(&self) -> ReversedUDAF {
327 ReversedUDAF::Identical
328 }
329
330 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
331 AggregateOrderSensitivity::Insensitive
332 }
333
334 fn documentation(&self) -> Option<&Documentation> {
335 self.doc()
336 }
337
338 fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
339 match data_type {
342 DataType::UInt8 => SetMonotonicity::Increasing,
343 DataType::UInt16 => SetMonotonicity::Increasing,
344 DataType::UInt32 => SetMonotonicity::Increasing,
345 DataType::UInt64 => SetMonotonicity::Increasing,
346 _ => SetMonotonicity::NotMonotonic,
347 }
348 }
349}
350
351struct SumAccumulator<T: ArrowNumericType> {
353 sum: Option<T::Native>,
354 data_type: DataType,
355}
356
357impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
358 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359 write!(f, "SumAccumulator({})", self.data_type)
360 }
361}
362
363impl<T: ArrowNumericType> SumAccumulator<T> {
364 fn new(data_type: DataType) -> Self {
365 Self {
366 sum: None,
367 data_type,
368 }
369 }
370}
371
372impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
373 fn state(&mut self) -> Result<Vec<ScalarValue>> {
374 Ok(vec![self.evaluate()?])
375 }
376
377 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
378 let values = values[0].as_primitive::<T>();
379 if let Some(x) = arrow::compute::sum(values) {
380 let v = self.sum.get_or_insert_with(|| T::Native::usize_as(0));
381 *v = v.add_wrapping(x);
382 }
383 Ok(())
384 }
385
386 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
387 self.update_batch(states)
388 }
389
390 fn evaluate(&mut self) -> Result<ScalarValue> {
391 ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
392 }
393
394 fn size(&self) -> usize {
395 size_of_val(self)
396 }
397}
398
399struct SlidingSumAccumulator<T: ArrowNumericType> {
403 sum: T::Native,
404 count: u64,
405 data_type: DataType,
406}
407
408impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
409 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410 write!(f, "SlidingSumAccumulator({})", self.data_type)
411 }
412}
413
414impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
415 fn new(data_type: DataType) -> Self {
416 Self {
417 sum: T::Native::usize_as(0),
418 count: 0,
419 data_type,
420 }
421 }
422}
423
424impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
425 fn state(&mut self) -> Result<Vec<ScalarValue>> {
426 Ok(vec![self.evaluate()?, self.count.into()])
427 }
428
429 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
430 let values = values[0].as_primitive::<T>();
431 self.count += (values.len() - values.null_count()) as u64;
432 if let Some(x) = arrow::compute::sum(values) {
433 self.sum = self.sum.add_wrapping(x)
434 }
435 Ok(())
436 }
437
438 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
439 let values = states[0].as_primitive::<T>();
440 if let Some(x) = arrow::compute::sum(values) {
441 self.sum = self.sum.add_wrapping(x)
442 }
443 if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
444 self.count += x;
445 }
446 Ok(())
447 }
448
449 fn evaluate(&mut self) -> Result<ScalarValue> {
450 let v = (self.count != 0).then_some(self.sum);
451 ScalarValue::new_primitive::<T>(v, &self.data_type)
452 }
453
454 fn size(&self) -> usize {
455 size_of_val(self)
456 }
457
458 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
459 let values = values[0].as_primitive::<T>();
460 if let Some(x) = arrow::compute::sum(values) {
461 self.sum = self.sum.sub_wrapping(x)
462 }
463 self.count -= (values.len() - values.null_count()) as u64;
464 Ok(())
465 }
466
467 fn supports_retract_batch(&self) -> bool {
468 true
469 }
470}
471
472#[derive(Debug)]
475pub struct SlidingDistinctSumAccumulator {
476 counts: HashMap<i64, usize, RandomState>,
478 sum: i64,
480 data_type: DataType,
482}
483
484impl SlidingDistinctSumAccumulator {
485 pub fn try_new(data_type: &DataType) -> Result<Self> {
487 if *data_type != DataType::Int64 {
489 return exec_err!("SlidingDistinctSumAccumulator only supports Int64");
490 }
491 Ok(Self {
492 counts: HashMap::default(),
493 sum: 0,
494 data_type: data_type.clone(),
495 })
496 }
497}
498
499impl Accumulator for SlidingDistinctSumAccumulator {
500 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
501 let arr = values[0].as_primitive::<Int64Type>();
502 for &v in arr.values() {
503 let cnt = self.counts.entry(v).or_insert(0);
504 if *cnt == 0 {
505 self.sum = self.sum.wrapping_add(v);
507 }
508 *cnt += 1;
509 }
510 Ok(())
511 }
512
513 fn evaluate(&mut self) -> Result<ScalarValue> {
514 Ok(ScalarValue::Int64(Some(self.sum)))
516 }
517
518 fn size(&self) -> usize {
519 size_of_val(self)
520 }
521
522 fn state(&mut self) -> Result<Vec<ScalarValue>> {
523 let keys = self
525 .counts
526 .keys()
527 .cloned()
528 .map(Some)
529 .map(ScalarValue::Int64)
530 .collect::<Vec<_>>();
531 Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
532 &keys,
533 &self.data_type,
534 ))])
535 }
536
537 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
538 let list_arr = states[0].as_list::<i32>();
540 for maybe_inner in list_arr.iter().flatten() {
541 for idx in 0..maybe_inner.len() {
542 if let ScalarValue::Int64(Some(v)) =
543 ScalarValue::try_from_array(&*maybe_inner, idx)?
544 {
545 let cnt = self.counts.entry(v).or_insert(0);
546 if *cnt == 0 {
547 self.sum = self.sum.wrapping_add(v);
548 }
549 *cnt += 1;
550 }
551 }
552 }
553 Ok(())
554 }
555
556 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
557 let arr = values[0].as_primitive::<Int64Type>();
558 for &v in arr.values() {
559 if let Some(cnt) = self.counts.get_mut(&v) {
560 *cnt -= 1;
561 if *cnt == 0 {
562 self.sum = self.sum.wrapping_sub(v);
564 self.counts.remove(&v);
565 }
566 }
567 }
568 Ok(())
569 }
570
571 fn supports_retract_batch(&self) -> bool {
572 true
573 }
574}