datafusion_functions_aggregate/
sum.rs1use ahash::RandomState;
21use datafusion_expr::utils::AggregateOrderSensitivity;
22use std::any::Any;
23use std::collections::HashSet;
24use std::mem::{size_of, size_of_val};
25
26use arrow::array::Array;
27use arrow::array::ArrowNativeTypeOp;
28use arrow::array::{ArrowNumericType, AsArray};
29use arrow::datatypes::ArrowPrimitiveType;
30use arrow::datatypes::{ArrowNativeType, FieldRef};
31use arrow::datatypes::{
32 DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type,
33 DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
34};
35use arrow::{array::ArrayRef, datatypes::Field};
36use datafusion_common::{
37 exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
38};
39use datafusion_expr::function::AccumulatorArgs;
40use datafusion_expr::function::StateFieldsArgs;
41use datafusion_expr::utils::format_state_name;
42use datafusion_expr::{
43 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
44 SetMonotonicity, Signature, Volatility,
45};
46use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
47use datafusion_functions_aggregate_common::utils::Hashable;
48use datafusion_macros::user_doc;
49
50make_udaf_expr_and_func!(
51 Sum,
52 sum,
53 expression,
54 "Returns the sum of a group of values.",
55 sum_udaf
56);
57
58macro_rules! downcast_sum {
65 ($args:ident, $helper:ident) => {
66 match $args.return_field.data_type().clone() {
67 DataType::UInt64 => {
68 $helper!(UInt64Type, $args.return_field.data_type().clone())
69 }
70 DataType::Int64 => {
71 $helper!(Int64Type, $args.return_field.data_type().clone())
72 }
73 DataType::Float64 => {
74 $helper!(Float64Type, $args.return_field.data_type().clone())
75 }
76 DataType::Decimal128(_, _) => {
77 $helper!(Decimal128Type, $args.return_field.data_type().clone())
78 }
79 DataType::Decimal256(_, _) => {
80 $helper!(Decimal256Type, $args.return_field.data_type().clone())
81 }
82 _ => {
83 not_impl_err!(
84 "Sum not supported for {}: {}",
85 $args.name,
86 $args.return_field.data_type()
87 )
88 }
89 }
90 };
91}
92
93#[user_doc(
94 doc_section(label = "General Functions"),
95 description = "Returns the sum of all values in the specified column.",
96 syntax_example = "sum(expression)",
97 sql_example = r#"```sql
98> SELECT sum(column_name) FROM table_name;
99+-----------------------+
100| sum(column_name) |
101+-----------------------+
102| 12345 |
103+-----------------------+
104```"#,
105 standard_argument(name = "expression",)
106)]
107#[derive(Debug)]
108pub struct Sum {
109 signature: Signature,
110}
111
112impl Sum {
113 pub fn new() -> Self {
114 Self {
115 signature: Signature::user_defined(Volatility::Immutable),
116 }
117 }
118}
119
120impl Default for Sum {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl AggregateUDFImpl for Sum {
127 fn as_any(&self) -> &dyn Any {
128 self
129 }
130
131 fn name(&self) -> &str {
132 "sum"
133 }
134
135 fn signature(&self) -> &Signature {
136 &self.signature
137 }
138
139 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
140 let [args] = take_function_args(self.name(), arg_types)?;
141
142 fn coerced_type(data_type: &DataType) -> Result<DataType> {
146 match data_type {
147 DataType::Dictionary(_, v) => coerced_type(v),
148 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
151 Ok(data_type.clone())
152 }
153 dt if dt.is_signed_integer() => Ok(DataType::Int64),
154 dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
155 dt if dt.is_floating() => Ok(DataType::Float64),
156 _ => exec_err!("Sum not supported for {}", data_type),
157 }
158 }
159
160 Ok(vec![coerced_type(args)?])
161 }
162
163 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
164 match &arg_types[0] {
165 DataType::Int64 => Ok(DataType::Int64),
166 DataType::UInt64 => Ok(DataType::UInt64),
167 DataType::Float64 => Ok(DataType::Float64),
168 DataType::Decimal128(precision, scale) => {
169 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
172 Ok(DataType::Decimal128(new_precision, *scale))
173 }
174 DataType::Decimal256(precision, scale) => {
175 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
178 Ok(DataType::Decimal256(new_precision, *scale))
179 }
180 other => {
181 exec_err!("[return_type] SUM not supported for {}", other)
182 }
183 }
184 }
185
186 fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
187 if args.is_distinct {
188 macro_rules! helper {
189 ($t:ty, $dt:expr) => {
190 Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?))
191 };
192 }
193 downcast_sum!(args, helper)
194 } else {
195 macro_rules! helper {
196 ($t:ty, $dt:expr) => {
197 Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
198 };
199 }
200 downcast_sum!(args, helper)
201 }
202 }
203
204 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
205 if args.is_distinct {
206 Ok(vec![Field::new_list(
207 format_state_name(args.name, "sum distinct"),
208 Field::new_list_field(args.return_type().clone(), true),
210 false,
211 )
212 .into()])
213 } else {
214 Ok(vec![Field::new(
215 format_state_name(args.name, "sum"),
216 args.return_type().clone(),
217 true,
218 )
219 .into()])
220 }
221 }
222
223 fn aliases(&self) -> &[String] {
224 &[]
225 }
226
227 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
228 !args.is_distinct
229 }
230
231 fn create_groups_accumulator(
232 &self,
233 args: AccumulatorArgs,
234 ) -> Result<Box<dyn GroupsAccumulator>> {
235 macro_rules! helper {
236 ($t:ty, $dt:expr) => {
237 Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
238 &$dt,
239 |x, y| *x = x.add_wrapping(y),
240 )))
241 };
242 }
243 downcast_sum!(args, helper)
244 }
245
246 fn create_sliding_accumulator(
247 &self,
248 args: AccumulatorArgs,
249 ) -> Result<Box<dyn Accumulator>> {
250 macro_rules! helper {
251 ($t:ty, $dt:expr) => {
252 Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
253 };
254 }
255 downcast_sum!(args, helper)
256 }
257
258 fn reverse_expr(&self) -> ReversedUDAF {
259 ReversedUDAF::Identical
260 }
261
262 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
263 AggregateOrderSensitivity::Insensitive
264 }
265
266 fn documentation(&self) -> Option<&Documentation> {
267 self.doc()
268 }
269
270 fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
271 match data_type {
274 DataType::UInt8 => SetMonotonicity::Increasing,
275 DataType::UInt16 => SetMonotonicity::Increasing,
276 DataType::UInt32 => SetMonotonicity::Increasing,
277 DataType::UInt64 => SetMonotonicity::Increasing,
278 _ => SetMonotonicity::NotMonotonic,
279 }
280 }
281}
282
283struct SumAccumulator<T: ArrowNumericType> {
285 sum: Option<T::Native>,
286 data_type: DataType,
287}
288
289impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 write!(f, "SumAccumulator({})", self.data_type)
292 }
293}
294
295impl<T: ArrowNumericType> SumAccumulator<T> {
296 fn new(data_type: DataType) -> Self {
297 Self {
298 sum: None,
299 data_type,
300 }
301 }
302}
303
304impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
305 fn state(&mut self) -> Result<Vec<ScalarValue>> {
306 Ok(vec![self.evaluate()?])
307 }
308
309 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
310 let values = values[0].as_primitive::<T>();
311 if let Some(x) = arrow::compute::sum(values) {
312 let v = self.sum.get_or_insert(T::Native::usize_as(0));
313 *v = v.add_wrapping(x);
314 }
315 Ok(())
316 }
317
318 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
319 self.update_batch(states)
320 }
321
322 fn evaluate(&mut self) -> Result<ScalarValue> {
323 ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
324 }
325
326 fn size(&self) -> usize {
327 size_of_val(self)
328 }
329}
330
331struct SlidingSumAccumulator<T: ArrowNumericType> {
335 sum: T::Native,
336 count: u64,
337 data_type: DataType,
338}
339
340impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
341 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342 write!(f, "SlidingSumAccumulator({})", self.data_type)
343 }
344}
345
346impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
347 fn new(data_type: DataType) -> Self {
348 Self {
349 sum: T::Native::usize_as(0),
350 count: 0,
351 data_type,
352 }
353 }
354}
355
356impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
357 fn state(&mut self) -> Result<Vec<ScalarValue>> {
358 Ok(vec![self.evaluate()?, self.count.into()])
359 }
360
361 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
362 let values = values[0].as_primitive::<T>();
363 self.count += (values.len() - values.null_count()) as u64;
364 if let Some(x) = arrow::compute::sum(values) {
365 self.sum = self.sum.add_wrapping(x)
366 }
367 Ok(())
368 }
369
370 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
371 let values = states[0].as_primitive::<T>();
372 if let Some(x) = arrow::compute::sum(values) {
373 self.sum = self.sum.add_wrapping(x)
374 }
375 if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
376 self.count += x;
377 }
378 Ok(())
379 }
380
381 fn evaluate(&mut self) -> Result<ScalarValue> {
382 let v = (self.count != 0).then_some(self.sum);
383 ScalarValue::new_primitive::<T>(v, &self.data_type)
384 }
385
386 fn size(&self) -> usize {
387 size_of_val(self)
388 }
389
390 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
391 let values = values[0].as_primitive::<T>();
392 if let Some(x) = arrow::compute::sum(values) {
393 self.sum = self.sum.sub_wrapping(x)
394 }
395 self.count -= (values.len() - values.null_count()) as u64;
396 Ok(())
397 }
398
399 fn supports_retract_batch(&self) -> bool {
400 true
401 }
402}
403
404struct DistinctSumAccumulator<T: ArrowPrimitiveType> {
405 values: HashSet<Hashable<T::Native>, RandomState>,
406 data_type: DataType,
407}
408
409impl<T: ArrowPrimitiveType> std::fmt::Debug for DistinctSumAccumulator<T> {
410 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411 write!(f, "DistinctSumAccumulator({})", self.data_type)
412 }
413}
414
415impl<T: ArrowPrimitiveType> DistinctSumAccumulator<T> {
416 pub fn try_new(data_type: &DataType) -> Result<Self> {
417 Ok(Self {
418 values: HashSet::default(),
419 data_type: data_type.clone(),
420 })
421 }
422}
423
424impl<T: ArrowPrimitiveType> Accumulator for DistinctSumAccumulator<T> {
425 fn state(&mut self) -> Result<Vec<ScalarValue>> {
426 let state_out = {
429 let distinct_values = self
430 .values
431 .iter()
432 .map(|value| {
433 ScalarValue::new_primitive::<T>(Some(value.0), &self.data_type)
434 })
435 .collect::<Result<Vec<_>>>()?;
436
437 vec![ScalarValue::List(ScalarValue::new_list_nullable(
438 &distinct_values,
439 &self.data_type,
440 ))]
441 };
442 Ok(state_out)
443 }
444
445 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
446 if values.is_empty() {
447 return Ok(());
448 }
449
450 let array = values[0].as_primitive::<T>();
451 match array.nulls().filter(|x| x.null_count() > 0) {
452 Some(n) => {
453 for idx in n.valid_indices() {
454 self.values.insert(Hashable(array.value(idx)));
455 }
456 }
457 None => array.values().iter().for_each(|x| {
458 self.values.insert(Hashable(*x));
459 }),
460 }
461 Ok(())
462 }
463
464 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
465 for x in states[0].as_list::<i32>().iter().flatten() {
466 self.update_batch(&[x])?
467 }
468 Ok(())
469 }
470
471 fn evaluate(&mut self) -> Result<ScalarValue> {
472 let mut acc = T::Native::usize_as(0);
473 for distinct_value in self.values.iter() {
474 acc = acc.add_wrapping(distinct_value.0)
475 }
476 let v = (!self.values.is_empty()).then_some(acc);
477 ScalarValue::new_primitive::<T>(v, &self.data_type)
478 }
479
480 fn size(&self) -> usize {
481 size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
482 }
483}