datafusion_functions_aggregate/
sum.rs1use ahash::RandomState;
21use arrow::datatypes::DECIMAL32_MAX_PRECISION;
22use arrow::datatypes::DECIMAL64_MAX_PRECISION;
23use datafusion_expr::utils::AggregateOrderSensitivity;
24use datafusion_expr::Expr;
25use std::any::Any;
26use std::mem::size_of_val;
27
28use arrow::array::Array;
29use arrow::array::ArrowNativeTypeOp;
30use arrow::array::{ArrowNumericType, AsArray};
31use arrow::datatypes::{ArrowNativeType, FieldRef};
32use arrow::datatypes::{
33 DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, Float64Type,
34 Int64Type, UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
35};
36use arrow::{array::ArrayRef, datatypes::Field};
37use datafusion_common::{
38 exec_err, not_impl_err, utils::take_function_args, HashMap, Result, ScalarValue,
39};
40use datafusion_expr::function::AccumulatorArgs;
41use datafusion_expr::function::StateFieldsArgs;
42use datafusion_expr::utils::format_state_name;
43use datafusion_expr::{
44 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
45 SetMonotonicity, Signature, Volatility,
46};
47use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
48use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
49use datafusion_macros::user_doc;
50
51make_udaf_expr_and_func!(
52 Sum,
53 sum,
54 expression,
55 "Returns the sum of a group of values.",
56 sum_udaf
57);
58
59pub fn sum_distinct(expr: Expr) -> Expr {
60 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
61 sum_udaf(),
62 vec![expr],
63 true,
64 None,
65 vec![],
66 None,
67 ))
68}
69
70macro_rules! downcast_sum {
77 ($args:ident, $helper:ident) => {
78 match $args.return_field.data_type().clone() {
79 DataType::UInt64 => {
80 $helper!(UInt64Type, $args.return_field.data_type().clone())
81 }
82 DataType::Int64 => {
83 $helper!(Int64Type, $args.return_field.data_type().clone())
84 }
85 DataType::Float64 => {
86 $helper!(Float64Type, $args.return_field.data_type().clone())
87 }
88 DataType::Decimal32(_, _) => {
89 $helper!(Decimal32Type, $args.return_field.data_type().clone())
90 }
91 DataType::Decimal64(_, _) => {
92 $helper!(Decimal64Type, $args.return_field.data_type().clone())
93 }
94 DataType::Decimal128(_, _) => {
95 $helper!(Decimal128Type, $args.return_field.data_type().clone())
96 }
97 DataType::Decimal256(_, _) => {
98 $helper!(Decimal256Type, $args.return_field.data_type().clone())
99 }
100 _ => {
101 not_impl_err!(
102 "Sum not supported for {}: {}",
103 $args.name,
104 $args.return_field.data_type()
105 )
106 }
107 }
108 };
109}
110
111#[user_doc(
112 doc_section(label = "General Functions"),
113 description = "Returns the sum of all values in the specified column.",
114 syntax_example = "sum(expression)",
115 sql_example = r#"```sql
116> SELECT sum(column_name) FROM table_name;
117+-----------------------+
118| sum(column_name) |
119+-----------------------+
120| 12345 |
121+-----------------------+
122```"#,
123 standard_argument(name = "expression",)
124)]
125#[derive(Debug, PartialEq, Eq, Hash)]
126pub struct Sum {
127 signature: Signature,
128}
129
130impl Sum {
131 pub fn new() -> Self {
132 Self {
133 signature: Signature::user_defined(Volatility::Immutable),
134 }
135 }
136}
137
138impl Default for Sum {
139 fn default() -> Self {
140 Self::new()
141 }
142}
143
144impl AggregateUDFImpl for Sum {
145 fn as_any(&self) -> &dyn Any {
146 self
147 }
148
149 fn name(&self) -> &str {
150 "sum"
151 }
152
153 fn signature(&self) -> &Signature {
154 &self.signature
155 }
156
157 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
158 let [args] = take_function_args(self.name(), arg_types)?;
159
160 fn coerced_type(data_type: &DataType) -> Result<DataType> {
164 match data_type {
165 DataType::Dictionary(_, v) => coerced_type(v),
166 DataType::Decimal32(_, _)
169 | DataType::Decimal64(_, _)
170 | DataType::Decimal128(_, _)
171 | DataType::Decimal256(_, _) => Ok(data_type.clone()),
172 dt if dt.is_signed_integer() => Ok(DataType::Int64),
173 dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
174 dt if dt.is_floating() => Ok(DataType::Float64),
175 _ => exec_err!("Sum not supported for {data_type}"),
176 }
177 }
178
179 Ok(vec![coerced_type(args)?])
180 }
181
182 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
183 match &arg_types[0] {
184 DataType::Int64 => Ok(DataType::Int64),
185 DataType::UInt64 => Ok(DataType::UInt64),
186 DataType::Float64 => Ok(DataType::Float64),
187 DataType::Decimal32(precision, scale) => {
188 let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
191 Ok(DataType::Decimal32(new_precision, *scale))
192 }
193 DataType::Decimal64(precision, scale) => {
194 let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
197 Ok(DataType::Decimal64(new_precision, *scale))
198 }
199 DataType::Decimal128(precision, scale) => {
200 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
203 Ok(DataType::Decimal128(new_precision, *scale))
204 }
205 DataType::Decimal256(precision, scale) => {
206 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
209 Ok(DataType::Decimal256(new_precision, *scale))
210 }
211 other => {
212 exec_err!("[return_type] SUM not supported for {}", other)
213 }
214 }
215 }
216
217 fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
218 if args.is_distinct {
219 macro_rules! helper {
220 ($t:ty, $dt:expr) => {
221 Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt)))
222 };
223 }
224 downcast_sum!(args, helper)
225 } else {
226 macro_rules! helper {
227 ($t:ty, $dt:expr) => {
228 Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
229 };
230 }
231 downcast_sum!(args, helper)
232 }
233 }
234
235 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
236 if args.is_distinct {
237 Ok(vec![Field::new_list(
238 format_state_name(args.name, "sum distinct"),
239 Field::new_list_field(args.return_type().clone(), true),
241 false,
242 )
243 .into()])
244 } else {
245 Ok(vec![Field::new(
246 format_state_name(args.name, "sum"),
247 args.return_type().clone(),
248 true,
249 )
250 .into()])
251 }
252 }
253
254 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
255 !args.is_distinct
256 }
257
258 fn create_groups_accumulator(
259 &self,
260 args: AccumulatorArgs,
261 ) -> Result<Box<dyn GroupsAccumulator>> {
262 macro_rules! helper {
263 ($t:ty, $dt:expr) => {
264 Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
265 &$dt,
266 |x, y| *x = x.add_wrapping(y),
267 )))
268 };
269 }
270 downcast_sum!(args, helper)
271 }
272
273 fn create_sliding_accumulator(
274 &self,
275 args: AccumulatorArgs,
276 ) -> Result<Box<dyn Accumulator>> {
277 if args.is_distinct {
278 macro_rules! helper_distinct {
280 ($t:ty, $dt:expr) => {
281 Ok(Box::new(SlidingDistinctSumAccumulator::try_new(&$dt)?))
282 };
283 }
284 downcast_sum!(args, helper_distinct)
285 } else {
286 macro_rules! helper {
288 ($t:ty, $dt:expr) => {
289 Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
290 };
291 }
292 downcast_sum!(args, helper)
293 }
294 }
295
296 fn reverse_expr(&self) -> ReversedUDAF {
297 ReversedUDAF::Identical
298 }
299
300 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
301 AggregateOrderSensitivity::Insensitive
302 }
303
304 fn documentation(&self) -> Option<&Documentation> {
305 self.doc()
306 }
307
308 fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
309 match data_type {
312 DataType::UInt8 => SetMonotonicity::Increasing,
313 DataType::UInt16 => SetMonotonicity::Increasing,
314 DataType::UInt32 => SetMonotonicity::Increasing,
315 DataType::UInt64 => SetMonotonicity::Increasing,
316 _ => SetMonotonicity::NotMonotonic,
317 }
318 }
319}
320
321struct SumAccumulator<T: ArrowNumericType> {
323 sum: Option<T::Native>,
324 data_type: DataType,
325}
326
327impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
328 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329 write!(f, "SumAccumulator({})", self.data_type)
330 }
331}
332
333impl<T: ArrowNumericType> SumAccumulator<T> {
334 fn new(data_type: DataType) -> Self {
335 Self {
336 sum: None,
337 data_type,
338 }
339 }
340}
341
342impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
343 fn state(&mut self) -> Result<Vec<ScalarValue>> {
344 Ok(vec![self.evaluate()?])
345 }
346
347 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
348 let values = values[0].as_primitive::<T>();
349 if let Some(x) = arrow::compute::sum(values) {
350 let v = self.sum.get_or_insert_with(|| T::Native::usize_as(0));
351 *v = v.add_wrapping(x);
352 }
353 Ok(())
354 }
355
356 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
357 self.update_batch(states)
358 }
359
360 fn evaluate(&mut self) -> Result<ScalarValue> {
361 ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
362 }
363
364 fn size(&self) -> usize {
365 size_of_val(self)
366 }
367}
368
369struct SlidingSumAccumulator<T: ArrowNumericType> {
373 sum: T::Native,
374 count: u64,
375 data_type: DataType,
376}
377
378impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
379 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 write!(f, "SlidingSumAccumulator({})", self.data_type)
381 }
382}
383
384impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
385 fn new(data_type: DataType) -> Self {
386 Self {
387 sum: T::Native::usize_as(0),
388 count: 0,
389 data_type,
390 }
391 }
392}
393
394impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
395 fn state(&mut self) -> Result<Vec<ScalarValue>> {
396 Ok(vec![self.evaluate()?, self.count.into()])
397 }
398
399 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
400 let values = values[0].as_primitive::<T>();
401 self.count += (values.len() - values.null_count()) as u64;
402 if let Some(x) = arrow::compute::sum(values) {
403 self.sum = self.sum.add_wrapping(x)
404 }
405 Ok(())
406 }
407
408 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
409 let values = states[0].as_primitive::<T>();
410 if let Some(x) = arrow::compute::sum(values) {
411 self.sum = self.sum.add_wrapping(x)
412 }
413 if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
414 self.count += x;
415 }
416 Ok(())
417 }
418
419 fn evaluate(&mut self) -> Result<ScalarValue> {
420 let v = (self.count != 0).then_some(self.sum);
421 ScalarValue::new_primitive::<T>(v, &self.data_type)
422 }
423
424 fn size(&self) -> usize {
425 size_of_val(self)
426 }
427
428 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
429 let values = values[0].as_primitive::<T>();
430 if let Some(x) = arrow::compute::sum(values) {
431 self.sum = self.sum.sub_wrapping(x)
432 }
433 self.count -= (values.len() - values.null_count()) as u64;
434 Ok(())
435 }
436
437 fn supports_retract_batch(&self) -> bool {
438 true
439 }
440}
441
442#[derive(Debug)]
445pub struct SlidingDistinctSumAccumulator {
446 counts: HashMap<i64, usize, RandomState>,
448 sum: i64,
450 data_type: DataType,
452}
453
454impl SlidingDistinctSumAccumulator {
455 pub fn try_new(data_type: &DataType) -> Result<Self> {
457 if *data_type != DataType::Int64 {
459 return exec_err!("SlidingDistinctSumAccumulator only supports Int64");
460 }
461 Ok(Self {
462 counts: HashMap::default(),
463 sum: 0,
464 data_type: data_type.clone(),
465 })
466 }
467}
468
469impl Accumulator for SlidingDistinctSumAccumulator {
470 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
471 let arr = values[0].as_primitive::<Int64Type>();
472 for &v in arr.values() {
473 let cnt = self.counts.entry(v).or_insert(0);
474 if *cnt == 0 {
475 self.sum = self.sum.wrapping_add(v);
477 }
478 *cnt += 1;
479 }
480 Ok(())
481 }
482
483 fn evaluate(&mut self) -> Result<ScalarValue> {
484 Ok(ScalarValue::Int64(Some(self.sum)))
486 }
487
488 fn size(&self) -> usize {
489 size_of_val(self)
490 }
491
492 fn state(&mut self) -> Result<Vec<ScalarValue>> {
493 let keys = self
495 .counts
496 .keys()
497 .cloned()
498 .map(Some)
499 .map(ScalarValue::Int64)
500 .collect::<Vec<_>>();
501 Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
502 &keys,
503 &self.data_type,
504 ))])
505 }
506
507 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
508 let list_arr = states[0].as_list::<i32>();
510 for maybe_inner in list_arr.iter().flatten() {
511 for idx in 0..maybe_inner.len() {
512 if let ScalarValue::Int64(Some(v)) =
513 ScalarValue::try_from_array(&*maybe_inner, idx)?
514 {
515 let cnt = self.counts.entry(v).or_insert(0);
516 if *cnt == 0 {
517 self.sum = self.sum.wrapping_add(v);
518 }
519 *cnt += 1;
520 }
521 }
522 }
523 Ok(())
524 }
525
526 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
527 let arr = values[0].as_primitive::<Int64Type>();
528 for &v in arr.values() {
529 if let Some(cnt) = self.counts.get_mut(&v) {
530 *cnt -= 1;
531 if *cnt == 0 {
532 self.sum = self.sum.wrapping_sub(v);
534 self.counts.remove(&v);
535 }
536 }
537 }
538 Ok(())
539 }
540
541 fn supports_retract_batch(&self) -> bool {
542 true
543 }
544}