1use arrow::array::{Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, AsArray};
21use arrow::datatypes::Field;
22use arrow::datatypes::{
23 ArrowNativeType, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
24 DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, Decimal32Type,
25 Decimal64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType,
26 DurationMillisecondType, DurationNanosecondType, DurationSecondType, FieldRef,
27 Float64Type, Int64Type, TimeUnit, UInt64Type,
28};
29use datafusion_common::hash_utils::RandomState;
30use datafusion_common::internal_err;
31use datafusion_common::types::{
32 NativeType, logical_float64, logical_int8, logical_int16, logical_int32,
33 logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64,
34};
35use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err};
36use datafusion_expr::expr::AggregateFunction;
37use datafusion_expr::expr_fn::cast;
38use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
39use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
40use datafusion_expr::{
41 Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, GroupsAccumulator,
42 Operator, ReversedUDAF, SetMonotonicity, Signature, TypeSignature,
43 TypeSignatureClass, Volatility,
44};
45use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
46use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
47use datafusion_macros::user_doc;
48use std::mem::size_of_val;
49
50make_udaf_expr_and_func!(
51 Sum,
52 sum,
53 expression,
54 "Returns the sum of a group of values.",
55 sum_udaf
56);
57
58pub fn sum_distinct(expr: Expr) -> Expr {
59 Expr::AggregateFunction(AggregateFunction::new_udf(
60 sum_udaf(),
61 vec![expr],
62 true,
63 None,
64 vec![],
65 None,
66 ))
67}
68
69macro_rules! downcast_sum {
76 ($args:ident, $helper:ident) => {
77 match $args.return_field.data_type().clone() {
78 DataType::UInt64 => {
79 $helper!(UInt64Type, $args.return_field.data_type().clone())
80 }
81 DataType::Int64 => {
82 $helper!(Int64Type, $args.return_field.data_type().clone())
83 }
84 DataType::Float64 => {
85 $helper!(Float64Type, $args.return_field.data_type().clone())
86 }
87 DataType::Decimal32(_, _) => {
88 $helper!(Decimal32Type, $args.return_field.data_type().clone())
89 }
90 DataType::Decimal64(_, _) => {
91 $helper!(Decimal64Type, $args.return_field.data_type().clone())
92 }
93 DataType::Decimal128(_, _) => {
94 $helper!(Decimal128Type, $args.return_field.data_type().clone())
95 }
96 DataType::Decimal256(_, _) => {
97 $helper!(Decimal256Type, $args.return_field.data_type().clone())
98 }
99 DataType::Duration(TimeUnit::Second) => {
100 $helper!(DurationSecondType, $args.return_field.data_type().clone())
101 }
102 DataType::Duration(TimeUnit::Millisecond) => {
103 $helper!(
104 DurationMillisecondType,
105 $args.return_field.data_type().clone()
106 )
107 }
108 DataType::Duration(TimeUnit::Microsecond) => {
109 $helper!(
110 DurationMicrosecondType,
111 $args.return_field.data_type().clone()
112 )
113 }
114 DataType::Duration(TimeUnit::Nanosecond) => {
115 $helper!(
116 DurationNanosecondType,
117 $args.return_field.data_type().clone()
118 )
119 }
120 _ => {
121 not_impl_err!(
122 "Sum not supported for {}: {}",
123 $args.name,
124 $args.return_field.data_type()
125 )
126 }
127 }
128 };
129}
130
131#[user_doc(
132 doc_section(label = "General Functions"),
133 description = "Returns the sum of all values in the specified column.",
134 syntax_example = "sum(expression)",
135 sql_example = r#"```sql
136> SELECT sum(column_name) FROM table_name;
137+-----------------------+
138| sum(column_name) |
139+-----------------------+
140| 12345 |
141+-----------------------+
142```"#,
143 standard_argument(name = "expression",)
144)]
145#[derive(Debug, PartialEq, Eq, Hash)]
146pub struct Sum {
147 signature: Signature,
148}
149
150impl Sum {
151 pub fn new() -> Self {
152 Self {
153 signature: Signature::one_of(
156 vec![
157 TypeSignature::Coercible(vec![Coercion::new_exact(
158 TypeSignatureClass::Decimal,
159 )]),
160 TypeSignature::Coercible(vec![Coercion::new_implicit(
162 TypeSignatureClass::Native(logical_uint64()),
163 vec![
164 TypeSignatureClass::Native(logical_uint8()),
165 TypeSignatureClass::Native(logical_uint16()),
166 TypeSignatureClass::Native(logical_uint32()),
167 ],
168 NativeType::UInt64,
169 )]),
170 TypeSignature::Coercible(vec![Coercion::new_implicit(
172 TypeSignatureClass::Native(logical_int64()),
173 vec![
174 TypeSignatureClass::Native(logical_int8()),
175 TypeSignatureClass::Native(logical_int16()),
176 TypeSignatureClass::Native(logical_int32()),
177 ],
178 NativeType::Int64,
179 )]),
180 TypeSignature::Coercible(vec![Coercion::new_implicit(
182 TypeSignatureClass::Native(logical_float64()),
183 vec![TypeSignatureClass::Float],
184 NativeType::Float64,
185 )]),
186 TypeSignature::Coercible(vec![Coercion::new_exact(
187 TypeSignatureClass::Duration,
188 )]),
189 ],
190 Volatility::Immutable,
191 ),
192 }
193 }
194}
195
196impl Default for Sum {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202impl AggregateUDFImpl for Sum {
203 fn name(&self) -> &str {
204 "sum"
205 }
206
207 fn signature(&self) -> &Signature {
208 &self.signature
209 }
210
211 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
212 match &arg_types[0] {
213 DataType::Int64 => Ok(DataType::Int64),
214 DataType::UInt64 => Ok(DataType::UInt64),
215 DataType::Float64 => Ok(DataType::Float64),
216 DataType::Decimal32(precision, scale) => {
219 let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
220 Ok(DataType::Decimal32(new_precision, *scale))
221 }
222 DataType::Decimal64(precision, scale) => {
223 let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
224 Ok(DataType::Decimal64(new_precision, *scale))
225 }
226 DataType::Decimal128(precision, scale) => {
227 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
228 Ok(DataType::Decimal128(new_precision, *scale))
229 }
230 DataType::Decimal256(precision, scale) => {
231 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
232 Ok(DataType::Decimal256(new_precision, *scale))
233 }
234 DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
235 other => {
236 exec_err!("[return_type] SUM not supported for {}", other)
237 }
238 }
239 }
240
241 fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
242 if args.is_distinct {
243 macro_rules! helper {
244 ($t:ty, $dt:expr) => {
245 Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt)))
246 };
247 }
248 downcast_sum!(args, helper)
249 } else {
250 macro_rules! helper {
251 ($t:ty, $dt:expr) => {
252 Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
253 };
254 }
255 downcast_sum!(args, helper)
256 }
257 }
258
259 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
260 if args.is_distinct {
261 Ok(vec![
262 Field::new_list(
263 format_state_name(args.name, "sum distinct"),
264 Field::new_list_field(args.return_type().clone(), true),
266 false,
267 )
268 .into(),
269 ])
270 } else {
271 Ok(vec![
272 Field::new(
273 format_state_name(args.name, "sum"),
274 args.return_type().clone(),
275 true,
276 )
277 .into(),
278 ])
279 }
280 }
281
282 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
283 !args.is_distinct
284 }
285
286 fn create_groups_accumulator(
287 &self,
288 args: AccumulatorArgs,
289 ) -> Result<Box<dyn GroupsAccumulator>> {
290 macro_rules! helper {
291 ($t:ty, $dt:expr) => {
292 Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
293 &$dt,
294 |x, y| *x = x.add_wrapping(y),
295 )))
296 };
297 }
298 downcast_sum!(args, helper)
299 }
300
301 fn create_sliding_accumulator(
302 &self,
303 args: AccumulatorArgs,
304 ) -> Result<Box<dyn Accumulator>> {
305 if args.is_distinct {
306 macro_rules! helper_distinct {
308 ($t:ty, $dt:expr) => {
309 Ok(Box::new(SlidingDistinctSumAccumulator::try_new(&$dt)?))
310 };
311 }
312 downcast_sum!(args, helper_distinct)
313 } else {
314 macro_rules! helper {
316 ($t:ty, $dt:expr) => {
317 Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
318 };
319 }
320 downcast_sum!(args, helper)
321 }
322 }
323
324 fn reverse_expr(&self) -> ReversedUDAF {
325 ReversedUDAF::Identical
326 }
327
328 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
329 AggregateOrderSensitivity::Insensitive
330 }
331
332 fn documentation(&self) -> Option<&Documentation> {
333 self.doc()
334 }
335
336 fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
337 match data_type {
340 DataType::UInt8 => SetMonotonicity::Increasing,
341 DataType::UInt16 => SetMonotonicity::Increasing,
342 DataType::UInt32 => SetMonotonicity::Increasing,
343 DataType::UInt64 => SetMonotonicity::Increasing,
344 _ => SetMonotonicity::NotMonotonic,
345 }
346 }
347
348 fn simplify_expr_op_literal(
353 &self,
354 agg_function: &AggregateFunction,
355 arg: &Expr,
356 op: Operator,
357 lit: &Expr,
358 _arg_is_left: bool,
360 ) -> Result<Option<Expr>> {
361 if op != Operator::Plus {
362 return Ok(None);
363 }
364
365 let lit_type = match &lit {
366 Expr::Literal(value, _) => value.data_type(),
367 _ => {
368 return internal_err!(
369 "Sum::simplify_expr_op_literal got a non literal argument"
370 );
371 }
372 };
373 if lit_type == DataType::Null {
374 return Ok(None);
375 }
376
377 let mut sum_agg = agg_function.clone();
379 sum_agg.params.args = vec![arg.clone()];
380 let sum_agg = Expr::AggregateFunction(sum_agg);
381
382 let count_agg = cast(crate::count::count(arg.clone()), lit_type);
384
385 Ok(Some(sum_agg + (lit.clone() * count_agg)))
387 }
388}
389
390struct SumAccumulator<T: ArrowNumericType> {
392 sum: Option<T::Native>,
393 data_type: DataType,
394}
395
396impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
397 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
398 write!(f, "SumAccumulator({})", self.data_type)
399 }
400}
401
402impl<T: ArrowNumericType> SumAccumulator<T> {
403 fn new(data_type: DataType) -> Self {
404 Self {
405 sum: None,
406 data_type,
407 }
408 }
409}
410
411impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
412 fn state(&mut self) -> Result<Vec<ScalarValue>> {
413 Ok(vec![self.evaluate()?])
414 }
415
416 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
417 let values = values[0].as_primitive::<T>();
418 if let Some(x) = arrow::compute::sum(values) {
419 let v = self.sum.get_or_insert_with(|| T::Native::usize_as(0));
420 *v = v.add_wrapping(x);
421 }
422 Ok(())
423 }
424
425 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
426 self.update_batch(states)
427 }
428
429 fn evaluate(&mut self) -> Result<ScalarValue> {
430 ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
431 }
432
433 fn size(&self) -> usize {
434 size_of_val(self)
435 }
436}
437
438struct SlidingSumAccumulator<T: ArrowNumericType> {
442 sum: T::Native,
443 count: u64,
444 data_type: DataType,
445}
446
447impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
448 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
449 write!(f, "SlidingSumAccumulator({})", self.data_type)
450 }
451}
452
453impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
454 fn new(data_type: DataType) -> Self {
455 Self {
456 sum: T::Native::usize_as(0),
457 count: 0,
458 data_type,
459 }
460 }
461}
462
463impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
464 fn state(&mut self) -> Result<Vec<ScalarValue>> {
465 Ok(vec![self.evaluate()?, self.count.into()])
466 }
467
468 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
469 let values = values[0].as_primitive::<T>();
470 self.count += (values.len() - values.null_count()) as u64;
471 if let Some(x) = arrow::compute::sum(values) {
472 self.sum = self.sum.add_wrapping(x)
473 }
474 Ok(())
475 }
476
477 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
478 let values = states[0].as_primitive::<T>();
479 if let Some(x) = arrow::compute::sum(values) {
480 self.sum = self.sum.add_wrapping(x)
481 }
482 if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
483 self.count += x;
484 }
485 Ok(())
486 }
487
488 fn evaluate(&mut self) -> Result<ScalarValue> {
489 let v = (self.count != 0).then_some(self.sum);
490 ScalarValue::new_primitive::<T>(v, &self.data_type)
491 }
492
493 fn size(&self) -> usize {
494 size_of_val(self)
495 }
496
497 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
498 let values = values[0].as_primitive::<T>();
499 if let Some(x) = arrow::compute::sum(values) {
500 self.sum = self.sum.sub_wrapping(x)
501 }
502 self.count -= (values.len() - values.null_count()) as u64;
503 Ok(())
504 }
505
506 fn supports_retract_batch(&self) -> bool {
507 true
508 }
509}
510
511#[derive(Debug)]
514pub struct SlidingDistinctSumAccumulator {
515 counts: HashMap<i64, usize, RandomState>,
517 sum: i64,
519 data_type: DataType,
521}
522
523impl SlidingDistinctSumAccumulator {
524 pub fn try_new(data_type: &DataType) -> Result<Self> {
526 if *data_type != DataType::Int64 {
528 return exec_err!("SlidingDistinctSumAccumulator only supports Int64");
529 }
530 Ok(Self {
531 counts: HashMap::default(),
532 sum: 0,
533 data_type: data_type.clone(),
534 })
535 }
536}
537
538impl Accumulator for SlidingDistinctSumAccumulator {
539 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
540 let arr = values[0].as_primitive::<Int64Type>();
541 for &v in arr.values() {
542 let cnt = self.counts.entry(v).or_insert(0);
543 if *cnt == 0 {
544 self.sum = self.sum.wrapping_add(v);
546 }
547 *cnt += 1;
548 }
549 Ok(())
550 }
551
552 fn evaluate(&mut self) -> Result<ScalarValue> {
553 Ok(ScalarValue::Int64(Some(self.sum)))
555 }
556
557 fn size(&self) -> usize {
558 size_of_val(self)
559 }
560
561 fn state(&mut self) -> Result<Vec<ScalarValue>> {
562 let keys = self
564 .counts
565 .keys()
566 .cloned()
567 .map(Some)
568 .map(ScalarValue::Int64)
569 .collect::<Vec<_>>();
570 Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
571 &keys,
572 &self.data_type,
573 ))])
574 }
575
576 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
577 let list_arr = states[0].as_list::<i32>();
579 for maybe_inner in list_arr.iter().flatten() {
580 for idx in 0..maybe_inner.len() {
581 if let ScalarValue::Int64(Some(v)) =
582 ScalarValue::try_from_array(&*maybe_inner, idx)?
583 {
584 let cnt = self.counts.entry(v).or_insert(0);
585 if *cnt == 0 {
586 self.sum = self.sum.wrapping_add(v);
587 }
588 *cnt += 1;
589 }
590 }
591 }
592 Ok(())
593 }
594
595 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
596 let arr = values[0].as_primitive::<Int64Type>();
597 for &v in arr.values() {
598 if let Some(cnt) = self.counts.get_mut(&v) {
599 *cnt -= 1;
600 if *cnt == 0 {
601 self.sum = self.sum.wrapping_sub(v);
603 self.counts.remove(&v);
604 }
605 }
606 }
607 Ok(())
608 }
609
610 fn supports_retract_batch(&self) -> bool {
611 true
612 }
613}