datafusion_functions_aggregate/
sum.rs1use ahash::RandomState;
21use datafusion_expr::utils::AggregateOrderSensitivity;
22use std::any::Any;
23use std::collections::HashSet;
24use std::mem::{size_of, size_of_val};
25
26use arrow::array::Array;
27use arrow::array::ArrowNativeTypeOp;
28use arrow::array::{ArrowNumericType, AsArray};
29use arrow::datatypes::ArrowPrimitiveType;
30use arrow::datatypes::{ArrowNativeType, FieldRef};
31use arrow::datatypes::{
32 DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type,
33 DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
34};
35use arrow::{array::ArrayRef, datatypes::Field};
36use datafusion_common::{
37 exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
38};
39use datafusion_expr::function::AccumulatorArgs;
40use datafusion_expr::function::StateFieldsArgs;
41use datafusion_expr::utils::format_state_name;
42use datafusion_expr::{
43 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
44 SetMonotonicity, Signature, Volatility,
45};
46use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
47use datafusion_functions_aggregate_common::utils::Hashable;
48use datafusion_macros::user_doc;
49
50make_udaf_expr_and_func!(
51 Sum,
52 sum,
53 expression,
54 "Returns the sum of a group of values.",
55 sum_udaf
56);
57
58macro_rules! downcast_sum {
65 ($args:ident, $helper:ident) => {
66 match $args.return_field.data_type().clone() {
67 DataType::UInt64 => {
68 $helper!(UInt64Type, $args.return_field.data_type().clone())
69 }
70 DataType::Int64 => {
71 $helper!(Int64Type, $args.return_field.data_type().clone())
72 }
73 DataType::Float64 => {
74 $helper!(Float64Type, $args.return_field.data_type().clone())
75 }
76 DataType::Decimal128(_, _) => {
77 $helper!(Decimal128Type, $args.return_field.data_type().clone())
78 }
79 DataType::Decimal256(_, _) => {
80 $helper!(Decimal256Type, $args.return_field.data_type().clone())
81 }
82 _ => {
83 not_impl_err!(
84 "Sum not supported for {}: {}",
85 $args.name,
86 $args.return_field.data_type()
87 )
88 }
89 }
90 };
91}
92
93#[user_doc(
94 doc_section(label = "General Functions"),
95 description = "Returns the sum of all values in the specified column.",
96 syntax_example = "sum(expression)",
97 sql_example = r#"```sql
98> SELECT sum(column_name) FROM table_name;
99+-----------------------+
100| sum(column_name) |
101+-----------------------+
102| 12345 |
103+-----------------------+
104```"#,
105 standard_argument(name = "expression",)
106)]
107#[derive(Debug)]
108pub struct Sum {
109 signature: Signature,
110}
111
112impl Sum {
113 pub fn new() -> Self {
114 Self {
115 signature: Signature::user_defined(Volatility::Immutable),
116 }
117 }
118}
119
120impl Default for Sum {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl AggregateUDFImpl for Sum {
127 fn as_any(&self) -> &dyn Any {
128 self
129 }
130
131 fn name(&self) -> &str {
132 "sum"
133 }
134
135 fn signature(&self) -> &Signature {
136 &self.signature
137 }
138
139 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
140 let [args] = take_function_args(self.name(), arg_types)?;
141
142 fn coerced_type(data_type: &DataType) -> Result<DataType> {
146 match data_type {
147 DataType::Dictionary(_, v) => coerced_type(v),
148 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
151 Ok(data_type.clone())
152 }
153 dt if dt.is_signed_integer() => Ok(DataType::Int64),
154 dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
155 dt if dt.is_floating() => Ok(DataType::Float64),
156 _ => exec_err!("Sum not supported for {}", data_type),
157 }
158 }
159
160 Ok(vec![coerced_type(args)?])
161 }
162
163 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
164 match &arg_types[0] {
165 DataType::Int64 => Ok(DataType::Int64),
166 DataType::UInt64 => Ok(DataType::UInt64),
167 DataType::Float64 => Ok(DataType::Float64),
168 DataType::Decimal128(precision, scale) => {
169 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
172 Ok(DataType::Decimal128(new_precision, *scale))
173 }
174 DataType::Decimal256(precision, scale) => {
175 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
178 Ok(DataType::Decimal256(new_precision, *scale))
179 }
180 other => {
181 exec_err!("[return_type] SUM not supported for {}", other)
182 }
183 }
184 }
185
186 fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
187 if args.is_distinct {
188 macro_rules! helper {
189 ($t:ty, $dt:expr) => {
190 Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?))
191 };
192 }
193 downcast_sum!(args, helper)
194 } else {
195 macro_rules! helper {
196 ($t:ty, $dt:expr) => {
197 Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
198 };
199 }
200 downcast_sum!(args, helper)
201 }
202 }
203
204 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
205 if args.is_distinct {
206 Ok(vec![Field::new_list(
207 format_state_name(args.name, "sum distinct"),
208 Field::new_list_field(args.return_type().clone(), true),
210 false,
211 )
212 .into()])
213 } else {
214 Ok(vec![Field::new(
215 format_state_name(args.name, "sum"),
216 args.return_type().clone(),
217 true,
218 )
219 .into()])
220 }
221 }
222
223 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
224 !args.is_distinct
225 }
226
227 fn create_groups_accumulator(
228 &self,
229 args: AccumulatorArgs,
230 ) -> Result<Box<dyn GroupsAccumulator>> {
231 macro_rules! helper {
232 ($t:ty, $dt:expr) => {
233 Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
234 &$dt,
235 |x, y| *x = x.add_wrapping(y),
236 )))
237 };
238 }
239 downcast_sum!(args, helper)
240 }
241
242 fn create_sliding_accumulator(
243 &self,
244 args: AccumulatorArgs,
245 ) -> Result<Box<dyn Accumulator>> {
246 macro_rules! helper {
247 ($t:ty, $dt:expr) => {
248 Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
249 };
250 }
251 downcast_sum!(args, helper)
252 }
253
254 fn reverse_expr(&self) -> ReversedUDAF {
255 ReversedUDAF::Identical
256 }
257
258 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
259 AggregateOrderSensitivity::Insensitive
260 }
261
262 fn documentation(&self) -> Option<&Documentation> {
263 self.doc()
264 }
265
266 fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
267 match data_type {
270 DataType::UInt8 => SetMonotonicity::Increasing,
271 DataType::UInt16 => SetMonotonicity::Increasing,
272 DataType::UInt32 => SetMonotonicity::Increasing,
273 DataType::UInt64 => SetMonotonicity::Increasing,
274 _ => SetMonotonicity::NotMonotonic,
275 }
276 }
277}
278
279struct SumAccumulator<T: ArrowNumericType> {
281 sum: Option<T::Native>,
282 data_type: DataType,
283}
284
285impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
286 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287 write!(f, "SumAccumulator({})", self.data_type)
288 }
289}
290
291impl<T: ArrowNumericType> SumAccumulator<T> {
292 fn new(data_type: DataType) -> Self {
293 Self {
294 sum: None,
295 data_type,
296 }
297 }
298}
299
300impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
301 fn state(&mut self) -> Result<Vec<ScalarValue>> {
302 Ok(vec![self.evaluate()?])
303 }
304
305 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
306 let values = values[0].as_primitive::<T>();
307 if let Some(x) = arrow::compute::sum(values) {
308 let v = self.sum.get_or_insert(T::Native::usize_as(0));
309 *v = v.add_wrapping(x);
310 }
311 Ok(())
312 }
313
314 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
315 self.update_batch(states)
316 }
317
318 fn evaluate(&mut self) -> Result<ScalarValue> {
319 ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
320 }
321
322 fn size(&self) -> usize {
323 size_of_val(self)
324 }
325}
326
327struct SlidingSumAccumulator<T: ArrowNumericType> {
331 sum: T::Native,
332 count: u64,
333 data_type: DataType,
334}
335
336impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
337 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338 write!(f, "SlidingSumAccumulator({})", self.data_type)
339 }
340}
341
342impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
343 fn new(data_type: DataType) -> Self {
344 Self {
345 sum: T::Native::usize_as(0),
346 count: 0,
347 data_type,
348 }
349 }
350}
351
352impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
353 fn state(&mut self) -> Result<Vec<ScalarValue>> {
354 Ok(vec![self.evaluate()?, self.count.into()])
355 }
356
357 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
358 let values = values[0].as_primitive::<T>();
359 self.count += (values.len() - values.null_count()) as u64;
360 if let Some(x) = arrow::compute::sum(values) {
361 self.sum = self.sum.add_wrapping(x)
362 }
363 Ok(())
364 }
365
366 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
367 let values = states[0].as_primitive::<T>();
368 if let Some(x) = arrow::compute::sum(values) {
369 self.sum = self.sum.add_wrapping(x)
370 }
371 if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
372 self.count += x;
373 }
374 Ok(())
375 }
376
377 fn evaluate(&mut self) -> Result<ScalarValue> {
378 let v = (self.count != 0).then_some(self.sum);
379 ScalarValue::new_primitive::<T>(v, &self.data_type)
380 }
381
382 fn size(&self) -> usize {
383 size_of_val(self)
384 }
385
386 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
387 let values = values[0].as_primitive::<T>();
388 if let Some(x) = arrow::compute::sum(values) {
389 self.sum = self.sum.sub_wrapping(x)
390 }
391 self.count -= (values.len() - values.null_count()) as u64;
392 Ok(())
393 }
394
395 fn supports_retract_batch(&self) -> bool {
396 true
397 }
398}
399
400struct DistinctSumAccumulator<T: ArrowPrimitiveType> {
401 values: HashSet<Hashable<T::Native>, RandomState>,
402 data_type: DataType,
403}
404
405impl<T: ArrowPrimitiveType> std::fmt::Debug for DistinctSumAccumulator<T> {
406 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407 write!(f, "DistinctSumAccumulator({})", self.data_type)
408 }
409}
410
411impl<T: ArrowPrimitiveType> DistinctSumAccumulator<T> {
412 pub fn try_new(data_type: &DataType) -> Result<Self> {
413 Ok(Self {
414 values: HashSet::default(),
415 data_type: data_type.clone(),
416 })
417 }
418}
419
420impl<T: ArrowPrimitiveType> Accumulator for DistinctSumAccumulator<T> {
421 fn state(&mut self) -> Result<Vec<ScalarValue>> {
422 let state_out = {
425 let distinct_values = self
426 .values
427 .iter()
428 .map(|value| {
429 ScalarValue::new_primitive::<T>(Some(value.0), &self.data_type)
430 })
431 .collect::<Result<Vec<_>>>()?;
432
433 vec![ScalarValue::List(ScalarValue::new_list_nullable(
434 &distinct_values,
435 &self.data_type,
436 ))]
437 };
438 Ok(state_out)
439 }
440
441 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
442 if values.is_empty() {
443 return Ok(());
444 }
445
446 let array = values[0].as_primitive::<T>();
447 match array.nulls().filter(|x| x.null_count() > 0) {
448 Some(n) => {
449 for idx in n.valid_indices() {
450 self.values.insert(Hashable(array.value(idx)));
451 }
452 }
453 None => array.values().iter().for_each(|x| {
454 self.values.insert(Hashable(*x));
455 }),
456 }
457 Ok(())
458 }
459
460 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
461 for x in states[0].as_list::<i32>().iter().flatten() {
462 self.update_batch(&[x])?
463 }
464 Ok(())
465 }
466
467 fn evaluate(&mut self) -> Result<ScalarValue> {
468 let mut acc = T::Native::usize_as(0);
469 for distinct_value in self.values.iter() {
470 acc = acc.add_wrapping(distinct_value.0)
471 }
472 let v = (!self.values.is_empty()).then_some(acc);
473 ScalarValue::new_primitive::<T>(v, &self.data_type)
474 }
475
476 fn size(&self) -> usize {
477 size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
478 }
479}