datafusion_functions_aggregate/
sum.rs1use ahash::RandomState;
21use datafusion_expr::utils::AggregateOrderSensitivity;
22use std::any::Any;
23use std::collections::HashSet;
24use std::mem::{size_of, size_of_val};
25
26use arrow::array::Array;
27use arrow::array::ArrowNativeTypeOp;
28use arrow::array::{ArrowNumericType, AsArray};
29use arrow::datatypes::ArrowNativeType;
30use arrow::datatypes::ArrowPrimitiveType;
31use arrow::datatypes::{
32 DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type,
33 DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
34};
35use arrow::{array::ArrayRef, datatypes::Field};
36use datafusion_common::{
37 exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
38};
39use datafusion_expr::function::AccumulatorArgs;
40use datafusion_expr::function::StateFieldsArgs;
41use datafusion_expr::utils::format_state_name;
42use datafusion_expr::{
43 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
44 SetMonotonicity, Signature, Volatility,
45};
46use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
47use datafusion_functions_aggregate_common::utils::Hashable;
48use datafusion_macros::user_doc;
49
50make_udaf_expr_and_func!(
51 Sum,
52 sum,
53 expression,
54 "Returns the sum of a group of values.",
55 sum_udaf
56);
57
58macro_rules! downcast_sum {
65 ($args:ident, $helper:ident) => {
66 match $args.return_type {
67 DataType::UInt64 => $helper!(UInt64Type, $args.return_type),
68 DataType::Int64 => $helper!(Int64Type, $args.return_type),
69 DataType::Float64 => $helper!(Float64Type, $args.return_type),
70 DataType::Decimal128(_, _) => $helper!(Decimal128Type, $args.return_type),
71 DataType::Decimal256(_, _) => $helper!(Decimal256Type, $args.return_type),
72 _ => {
73 not_impl_err!(
74 "Sum not supported for {}: {}",
75 $args.name,
76 $args.return_type
77 )
78 }
79 }
80 };
81}
82
83#[user_doc(
84 doc_section(label = "General Functions"),
85 description = "Returns the sum of all values in the specified column.",
86 syntax_example = "sum(expression)",
87 sql_example = r#"```sql
88> SELECT sum(column_name) FROM table_name;
89+-----------------------+
90| sum(column_name) |
91+-----------------------+
92| 12345 |
93+-----------------------+
94```"#,
95 standard_argument(name = "expression",)
96)]
97#[derive(Debug)]
98pub struct Sum {
99 signature: Signature,
100}
101
102impl Sum {
103 pub fn new() -> Self {
104 Self {
105 signature: Signature::user_defined(Volatility::Immutable),
106 }
107 }
108}
109
110impl Default for Sum {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116impl AggregateUDFImpl for Sum {
117 fn as_any(&self) -> &dyn Any {
118 self
119 }
120
121 fn name(&self) -> &str {
122 "sum"
123 }
124
125 fn signature(&self) -> &Signature {
126 &self.signature
127 }
128
129 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
130 let [args] = take_function_args(self.name(), arg_types)?;
131
132 fn coerced_type(data_type: &DataType) -> Result<DataType> {
136 match data_type {
137 DataType::Dictionary(_, v) => coerced_type(v),
138 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
141 Ok(data_type.clone())
142 }
143 dt if dt.is_signed_integer() => Ok(DataType::Int64),
144 dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
145 dt if dt.is_floating() => Ok(DataType::Float64),
146 _ => exec_err!("Sum not supported for {}", data_type),
147 }
148 }
149
150 Ok(vec![coerced_type(args)?])
151 }
152
153 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
154 match &arg_types[0] {
155 DataType::Int64 => Ok(DataType::Int64),
156 DataType::UInt64 => Ok(DataType::UInt64),
157 DataType::Float64 => Ok(DataType::Float64),
158 DataType::Decimal128(precision, scale) => {
159 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
162 Ok(DataType::Decimal128(new_precision, *scale))
163 }
164 DataType::Decimal256(precision, scale) => {
165 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
168 Ok(DataType::Decimal256(new_precision, *scale))
169 }
170 other => {
171 exec_err!("[return_type] SUM not supported for {}", other)
172 }
173 }
174 }
175
176 fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
177 if args.is_distinct {
178 macro_rules! helper {
179 ($t:ty, $dt:expr) => {
180 Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?))
181 };
182 }
183 downcast_sum!(args, helper)
184 } else {
185 macro_rules! helper {
186 ($t:ty, $dt:expr) => {
187 Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
188 };
189 }
190 downcast_sum!(args, helper)
191 }
192 }
193
194 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
195 if args.is_distinct {
196 Ok(vec![Field::new_list(
197 format_state_name(args.name, "sum distinct"),
198 Field::new_list_field(args.return_type.clone(), true),
200 false,
201 )])
202 } else {
203 Ok(vec![Field::new(
204 format_state_name(args.name, "sum"),
205 args.return_type.clone(),
206 true,
207 )])
208 }
209 }
210
211 fn aliases(&self) -> &[String] {
212 &[]
213 }
214
215 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
216 !args.is_distinct
217 }
218
219 fn create_groups_accumulator(
220 &self,
221 args: AccumulatorArgs,
222 ) -> Result<Box<dyn GroupsAccumulator>> {
223 macro_rules! helper {
224 ($t:ty, $dt:expr) => {
225 Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
226 &$dt,
227 |x, y| *x = x.add_wrapping(y),
228 )))
229 };
230 }
231 downcast_sum!(args, helper)
232 }
233
234 fn create_sliding_accumulator(
235 &self,
236 args: AccumulatorArgs,
237 ) -> Result<Box<dyn Accumulator>> {
238 macro_rules! helper {
239 ($t:ty, $dt:expr) => {
240 Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
241 };
242 }
243 downcast_sum!(args, helper)
244 }
245
246 fn reverse_expr(&self) -> ReversedUDAF {
247 ReversedUDAF::Identical
248 }
249
250 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
251 AggregateOrderSensitivity::Insensitive
252 }
253
254 fn documentation(&self) -> Option<&Documentation> {
255 self.doc()
256 }
257
258 fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
259 match data_type {
262 DataType::UInt8 => SetMonotonicity::Increasing,
263 DataType::UInt16 => SetMonotonicity::Increasing,
264 DataType::UInt32 => SetMonotonicity::Increasing,
265 DataType::UInt64 => SetMonotonicity::Increasing,
266 _ => SetMonotonicity::NotMonotonic,
267 }
268 }
269}
270
271struct SumAccumulator<T: ArrowNumericType> {
273 sum: Option<T::Native>,
274 data_type: DataType,
275}
276
277impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
278 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279 write!(f, "SumAccumulator({})", self.data_type)
280 }
281}
282
283impl<T: ArrowNumericType> SumAccumulator<T> {
284 fn new(data_type: DataType) -> Self {
285 Self {
286 sum: None,
287 data_type,
288 }
289 }
290}
291
292impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
293 fn state(&mut self) -> Result<Vec<ScalarValue>> {
294 Ok(vec![self.evaluate()?])
295 }
296
297 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
298 let values = values[0].as_primitive::<T>();
299 if let Some(x) = arrow::compute::sum(values) {
300 let v = self.sum.get_or_insert(T::Native::usize_as(0));
301 *v = v.add_wrapping(x);
302 }
303 Ok(())
304 }
305
306 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
307 self.update_batch(states)
308 }
309
310 fn evaluate(&mut self) -> Result<ScalarValue> {
311 ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
312 }
313
314 fn size(&self) -> usize {
315 size_of_val(self)
316 }
317}
318
319struct SlidingSumAccumulator<T: ArrowNumericType> {
323 sum: T::Native,
324 count: u64,
325 data_type: DataType,
326}
327
328impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 write!(f, "SlidingSumAccumulator({})", self.data_type)
331 }
332}
333
334impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
335 fn new(data_type: DataType) -> Self {
336 Self {
337 sum: T::Native::usize_as(0),
338 count: 0,
339 data_type,
340 }
341 }
342}
343
344impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
345 fn state(&mut self) -> Result<Vec<ScalarValue>> {
346 Ok(vec![self.evaluate()?, self.count.into()])
347 }
348
349 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
350 let values = values[0].as_primitive::<T>();
351 self.count += (values.len() - values.null_count()) as u64;
352 if let Some(x) = arrow::compute::sum(values) {
353 self.sum = self.sum.add_wrapping(x)
354 }
355 Ok(())
356 }
357
358 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
359 let values = states[0].as_primitive::<T>();
360 if let Some(x) = arrow::compute::sum(values) {
361 self.sum = self.sum.add_wrapping(x)
362 }
363 if let Some(x) = arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
364 self.count += x;
365 }
366 Ok(())
367 }
368
369 fn evaluate(&mut self) -> Result<ScalarValue> {
370 let v = (self.count != 0).then_some(self.sum);
371 ScalarValue::new_primitive::<T>(v, &self.data_type)
372 }
373
374 fn size(&self) -> usize {
375 size_of_val(self)
376 }
377
378 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
379 let values = values[0].as_primitive::<T>();
380 if let Some(x) = arrow::compute::sum(values) {
381 self.sum = self.sum.sub_wrapping(x)
382 }
383 self.count -= (values.len() - values.null_count()) as u64;
384 Ok(())
385 }
386
387 fn supports_retract_batch(&self) -> bool {
388 true
389 }
390}
391
392struct DistinctSumAccumulator<T: ArrowPrimitiveType> {
393 values: HashSet<Hashable<T::Native>, RandomState>,
394 data_type: DataType,
395}
396
397impl<T: ArrowPrimitiveType> std::fmt::Debug for DistinctSumAccumulator<T> {
398 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
399 write!(f, "DistinctSumAccumulator({})", self.data_type)
400 }
401}
402
403impl<T: ArrowPrimitiveType> DistinctSumAccumulator<T> {
404 pub fn try_new(data_type: &DataType) -> Result<Self> {
405 Ok(Self {
406 values: HashSet::default(),
407 data_type: data_type.clone(),
408 })
409 }
410}
411
412impl<T: ArrowPrimitiveType> Accumulator for DistinctSumAccumulator<T> {
413 fn state(&mut self) -> Result<Vec<ScalarValue>> {
414 let state_out = {
417 let distinct_values = self
418 .values
419 .iter()
420 .map(|value| {
421 ScalarValue::new_primitive::<T>(Some(value.0), &self.data_type)
422 })
423 .collect::<Result<Vec<_>>>()?;
424
425 vec![ScalarValue::List(ScalarValue::new_list_nullable(
426 &distinct_values,
427 &self.data_type,
428 ))]
429 };
430 Ok(state_out)
431 }
432
433 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
434 if values.is_empty() {
435 return Ok(());
436 }
437
438 let array = values[0].as_primitive::<T>();
439 match array.nulls().filter(|x| x.null_count() > 0) {
440 Some(n) => {
441 for idx in n.valid_indices() {
442 self.values.insert(Hashable(array.value(idx)));
443 }
444 }
445 None => array.values().iter().for_each(|x| {
446 self.values.insert(Hashable(*x));
447 }),
448 }
449 Ok(())
450 }
451
452 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
453 for x in states[0].as_list::<i32>().iter().flatten() {
454 self.update_batch(&[x])?
455 }
456 Ok(())
457 }
458
459 fn evaluate(&mut self) -> Result<ScalarValue> {
460 let mut acc = T::Native::usize_as(0);
461 for distinct_value in self.values.iter() {
462 acc = acc.add_wrapping(distinct_value.0)
463 }
464 let v = (!self.values.is_empty()).then_some(acc);
465 ScalarValue::new_primitive::<T>(v, &self.data_type)
466 }
467
468 fn size(&self) -> usize {
469 size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
470 }
471}