datafusion_functions_aggregate/
average.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `Avg` & `Mean` aggregate & accumulators
19
20use arrow::array::{
21    Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray,
22    BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
23};
24
25use arrow::compute::sum;
26use arrow::datatypes::{
27    i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType,
28    DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
29    DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type,
30};
31use datafusion_common::{
32    exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
33};
34use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
35use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type};
36use datafusion_expr::utils::format_state_name;
37use datafusion_expr::Volatility::Immutable;
38use datafusion_expr::{
39    Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator,
40    ReversedUDAF, Signature,
41};
42
43use datafusion_functions_aggregate_common::aggregate::avg_distinct::Float64DistinctAvgAccumulator;
44use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
45use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
46    filtered_null_mask, set_nulls,
47};
48
49use datafusion_functions_aggregate_common::utils::DecimalAverager;
50use datafusion_macros::user_doc;
51use log::debug;
52use std::any::Any;
53use std::fmt::Debug;
54use std::mem::{size_of, size_of_val};
55use std::sync::Arc;
56
57make_udaf_expr_and_func!(
58    Avg,
59    avg,
60    expression,
61    "Returns the avg of a group of values.",
62    avg_udaf
63);
64
65#[user_doc(
66    doc_section(label = "General Functions"),
67    description = "Returns the average of numeric values in the specified column.",
68    syntax_example = "avg(expression)",
69    sql_example = r#"```sql
70> SELECT avg(column_name) FROM table_name;
71+---------------------------+
72| avg(column_name)           |
73+---------------------------+
74| 42.75                      |
75+---------------------------+
76```"#,
77    standard_argument(name = "expression",)
78)]
79#[derive(Debug, PartialEq, Eq, Hash)]
80pub struct Avg {
81    signature: Signature,
82    aliases: Vec<String>,
83}
84
85impl Avg {
86    pub fn new() -> Self {
87        Self {
88            signature: Signature::user_defined(Immutable),
89            aliases: vec![String::from("mean")],
90        }
91    }
92}
93
94impl Default for Avg {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100impl AggregateUDFImpl for Avg {
101    fn as_any(&self) -> &dyn Any {
102        self
103    }
104
105    fn name(&self) -> &str {
106        "avg"
107    }
108
109    fn signature(&self) -> &Signature {
110        &self.signature
111    }
112
113    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
114        avg_return_type(self.name(), &arg_types[0])
115    }
116
117    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
118        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
119        use DataType::*;
120
121        // instantiate specialized accumulator based for the type
122        if acc_args.is_distinct {
123            match &data_type {
124                // Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation
125                Float64 => Ok(Box::new(Float64DistinctAvgAccumulator::default())),
126                _ => exec_err!("AVG(DISTINCT) for {} not supported", data_type),
127            }
128        } else {
129            match (&data_type, acc_args.return_field.data_type()) {
130                (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
131                (
132                    Decimal128(sum_precision, sum_scale),
133                    Decimal128(target_precision, target_scale),
134                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
135                    sum: None,
136                    count: 0,
137                    sum_scale: *sum_scale,
138                    sum_precision: *sum_precision,
139                    target_precision: *target_precision,
140                    target_scale: *target_scale,
141                })),
142
143                (
144                    Decimal256(sum_precision, sum_scale),
145                    Decimal256(target_precision, target_scale),
146                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
147                    sum: None,
148                    count: 0,
149                    sum_scale: *sum_scale,
150                    sum_precision: *sum_precision,
151                    target_precision: *target_precision,
152                    target_scale: *target_scale,
153                })),
154
155                (Duration(time_unit), Duration(result_unit)) => {
156                    Ok(Box::new(DurationAvgAccumulator {
157                        sum: None,
158                        count: 0,
159                        time_unit: *time_unit,
160                        result_unit: *result_unit,
161                    }))
162                }
163
164                _ => exec_err!(
165                    "AvgAccumulator for ({} --> {})",
166                    &data_type,
167                    acc_args.return_field.data_type()
168                ),
169            }
170        }
171    }
172
173    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
174        if args.is_distinct {
175            // Copied from datafusion_functions_aggregate::sum::Sum::state_fields
176            // since the accumulator uses DistinctSumAccumulator internally.
177            Ok(vec![Field::new_list(
178                format_state_name(args.name, "avg distinct"),
179                Field::new_list_field(args.return_type().clone(), true),
180                false,
181            )
182            .into()])
183        } else {
184            Ok(vec![
185                Field::new(
186                    format_state_name(args.name, "count"),
187                    DataType::UInt64,
188                    true,
189                ),
190                Field::new(
191                    format_state_name(args.name, "sum"),
192                    args.input_fields[0].data_type().clone(),
193                    true,
194                ),
195            ]
196            .into_iter()
197            .map(Arc::new)
198            .collect())
199        }
200    }
201
202    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
203        matches!(
204            args.return_field.data_type(),
205            DataType::Float64 | DataType::Decimal128(_, _) | DataType::Duration(_)
206        ) && !args.is_distinct
207    }
208
209    fn create_groups_accumulator(
210        &self,
211        args: AccumulatorArgs,
212    ) -> Result<Box<dyn GroupsAccumulator>> {
213        use DataType::*;
214
215        let data_type = args.exprs[0].data_type(args.schema)?;
216        // instantiate specialized accumulator based for the type
217        match (&data_type, args.return_field.data_type()) {
218            (Float64, Float64) => {
219                Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
220                    &data_type,
221                    args.return_field.data_type(),
222                    |sum: f64, count: u64| Ok(sum / count as f64),
223                )))
224            }
225            (
226                Decimal128(_sum_precision, sum_scale),
227                Decimal128(target_precision, target_scale),
228            ) => {
229                let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
230                    *sum_scale,
231                    *target_precision,
232                    *target_scale,
233                )?;
234
235                let avg_fn =
236                    move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
237
238                Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
239                    &data_type,
240                    args.return_field.data_type(),
241                    avg_fn,
242                )))
243            }
244
245            (
246                Decimal256(_sum_precision, sum_scale),
247                Decimal256(target_precision, target_scale),
248            ) => {
249                let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
250                    *sum_scale,
251                    *target_precision,
252                    *target_scale,
253                )?;
254
255                let avg_fn = move |sum: i256, count: u64| {
256                    decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
257                };
258
259                Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
260                    &data_type,
261                    args.return_field.data_type(),
262                    avg_fn,
263                )))
264            }
265
266            (Duration(time_unit), Duration(_result_unit)) => {
267                let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64);
268
269                match time_unit {
270                    TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::<
271                        DurationSecondType,
272                        _,
273                    >::new(
274                        &data_type,
275                        args.return_type(),
276                        avg_fn,
277                    ))),
278                    TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::<
279                        DurationMillisecondType,
280                        _,
281                    >::new(
282                        &data_type,
283                        args.return_type(),
284                        avg_fn,
285                    ))),
286                    TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::<
287                        DurationMicrosecondType,
288                        _,
289                    >::new(
290                        &data_type,
291                        args.return_type(),
292                        avg_fn,
293                    ))),
294                    TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
295                        DurationNanosecondType,
296                        _,
297                    >::new(
298                        &data_type,
299                        args.return_type(),
300                        avg_fn,
301                    ))),
302                }
303            }
304
305            _ => not_impl_err!(
306                "AvgGroupsAccumulator for ({} --> {})",
307                &data_type,
308                args.return_field.data_type()
309            ),
310        }
311    }
312
313    fn aliases(&self) -> &[String] {
314        &self.aliases
315    }
316
317    fn reverse_expr(&self) -> ReversedUDAF {
318        ReversedUDAF::Identical
319    }
320
321    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
322        let [args] = take_function_args(self.name(), arg_types)?;
323        coerce_avg_type(self.name(), std::slice::from_ref(args))
324    }
325
326    fn documentation(&self) -> Option<&Documentation> {
327        self.doc()
328    }
329}
330
331/// An accumulator to compute the average
332#[derive(Debug, Default)]
333pub struct AvgAccumulator {
334    sum: Option<f64>,
335    count: u64,
336}
337
338impl Accumulator for AvgAccumulator {
339    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
340        let values = values[0].as_primitive::<Float64Type>();
341        self.count += (values.len() - values.null_count()) as u64;
342        if let Some(x) = sum(values) {
343            let v = self.sum.get_or_insert(0.);
344            *v += x;
345        }
346        Ok(())
347    }
348
349    fn evaluate(&mut self) -> Result<ScalarValue> {
350        Ok(ScalarValue::Float64(
351            self.sum.map(|f| f / self.count as f64),
352        ))
353    }
354
355    fn size(&self) -> usize {
356        size_of_val(self)
357    }
358
359    fn state(&mut self) -> Result<Vec<ScalarValue>> {
360        Ok(vec![
361            ScalarValue::from(self.count),
362            ScalarValue::Float64(self.sum),
363        ])
364    }
365
366    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
367        // counts are summed
368        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
369
370        // sums are summed
371        if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
372            let v = self.sum.get_or_insert(0.);
373            *v += x;
374        }
375        Ok(())
376    }
377    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
378        let values = values[0].as_primitive::<Float64Type>();
379        self.count -= (values.len() - values.null_count()) as u64;
380        if let Some(x) = sum(values) {
381            self.sum = Some(self.sum.unwrap() - x);
382        }
383        Ok(())
384    }
385
386    fn supports_retract_batch(&self) -> bool {
387        true
388    }
389}
390
391/// An accumulator to compute the average for decimals
392#[derive(Debug)]
393struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> {
394    sum: Option<T::Native>,
395    count: u64,
396    sum_scale: i8,
397    sum_precision: u8,
398    target_precision: u8,
399    target_scale: i8,
400}
401
402impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> {
403    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
404        let values = values[0].as_primitive::<T>();
405        self.count += (values.len() - values.null_count()) as u64;
406
407        if let Some(x) = sum(values) {
408            let v = self.sum.get_or_insert(T::Native::default());
409            self.sum = Some(v.add_wrapping(x));
410        }
411        Ok(())
412    }
413
414    fn evaluate(&mut self) -> Result<ScalarValue> {
415        let v = self
416            .sum
417            .map(|v| {
418                DecimalAverager::<T>::try_new(
419                    self.sum_scale,
420                    self.target_precision,
421                    self.target_scale,
422                )?
423                .avg(v, T::Native::from_usize(self.count as usize).unwrap())
424            })
425            .transpose()?;
426
427        ScalarValue::new_primitive::<T>(
428            v,
429            &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
430        )
431    }
432
433    fn size(&self) -> usize {
434        size_of_val(self)
435    }
436
437    fn state(&mut self) -> Result<Vec<ScalarValue>> {
438        Ok(vec![
439            ScalarValue::from(self.count),
440            ScalarValue::new_primitive::<T>(
441                self.sum,
442                &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
443            )?,
444        ])
445    }
446
447    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
448        // counts are summed
449        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
450
451        // sums are summed
452        if let Some(x) = sum(states[1].as_primitive::<T>()) {
453            let v = self.sum.get_or_insert(T::Native::default());
454            self.sum = Some(v.add_wrapping(x));
455        }
456        Ok(())
457    }
458    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
459        let values = values[0].as_primitive::<T>();
460        self.count -= (values.len() - values.null_count()) as u64;
461        if let Some(x) = sum(values) {
462            self.sum = Some(self.sum.unwrap().sub_wrapping(x));
463        }
464        Ok(())
465    }
466
467    fn supports_retract_batch(&self) -> bool {
468        true
469    }
470}
471
472/// An accumulator to compute the average for duration values
473#[derive(Debug)]
474struct DurationAvgAccumulator {
475    sum: Option<i64>,
476    count: u64,
477    time_unit: TimeUnit,
478    result_unit: TimeUnit,
479}
480
481impl Accumulator for DurationAvgAccumulator {
482    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
483        let array = &values[0];
484        self.count += (array.len() - array.null_count()) as u64;
485
486        let sum_value = match self.time_unit {
487            TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
488            TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
489            TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
490            TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
491        };
492
493        if let Some(x) = sum_value {
494            let v = self.sum.get_or_insert(0);
495            *v += x;
496        }
497        Ok(())
498    }
499
500    fn evaluate(&mut self) -> Result<ScalarValue> {
501        let avg = self.sum.map(|sum| sum / self.count as i64);
502
503        match self.result_unit {
504            TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
505            TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)),
506            TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)),
507            TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)),
508        }
509    }
510
511    fn size(&self) -> usize {
512        size_of_val(self)
513    }
514
515    fn state(&mut self) -> Result<Vec<ScalarValue>> {
516        let duration_value = match self.time_unit {
517            TimeUnit::Second => ScalarValue::DurationSecond(self.sum),
518            TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum),
519            TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum),
520            TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum),
521        };
522
523        Ok(vec![ScalarValue::from(self.count), duration_value])
524    }
525
526    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
527        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
528
529        let sum_value = match self.time_unit {
530            TimeUnit::Second => sum(states[1].as_primitive::<DurationSecondType>()),
531            TimeUnit::Millisecond => {
532                sum(states[1].as_primitive::<DurationMillisecondType>())
533            }
534            TimeUnit::Microsecond => {
535                sum(states[1].as_primitive::<DurationMicrosecondType>())
536            }
537            TimeUnit::Nanosecond => {
538                sum(states[1].as_primitive::<DurationNanosecondType>())
539            }
540        };
541
542        if let Some(x) = sum_value {
543            let v = self.sum.get_or_insert(0);
544            *v += x;
545        }
546        Ok(())
547    }
548
549    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
550        let array = &values[0];
551        self.count -= (array.len() - array.null_count()) as u64;
552
553        let sum_value = match self.time_unit {
554            TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
555            TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
556            TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
557            TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
558        };
559
560        if let Some(x) = sum_value {
561            self.sum = Some(self.sum.unwrap() - x);
562        }
563        Ok(())
564    }
565
566    fn supports_retract_batch(&self) -> bool {
567        true
568    }
569}
570
571/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
572/// Stores values as native types, and does overflow checking
573///
574/// F: Function that calculates the average value from a sum of
575/// T::Native and a total count
576#[derive(Debug)]
577struct AvgGroupsAccumulator<T, F>
578where
579    T: ArrowNumericType + Send,
580    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
581{
582    /// The type of the internal sum
583    sum_data_type: DataType,
584
585    /// The type of the returned sum
586    return_data_type: DataType,
587
588    /// Count per group (use u64 to make UInt64Array)
589    counts: Vec<u64>,
590
591    /// Sums per group, stored as the native type
592    sums: Vec<T::Native>,
593
594    /// Track nulls in the input / filters
595    null_state: NullState,
596
597    /// Function that computes the final average (value / count)
598    avg_fn: F,
599}
600
601impl<T, F> AvgGroupsAccumulator<T, F>
602where
603    T: ArrowNumericType + Send,
604    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
605{
606    pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
607        debug!(
608            "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}",
609            std::any::type_name::<T>()
610        );
611
612        Self {
613            return_data_type: return_data_type.clone(),
614            sum_data_type: sum_data_type.clone(),
615            counts: vec![],
616            sums: vec![],
617            null_state: NullState::new(),
618            avg_fn,
619        }
620    }
621}
622
623impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
624where
625    T: ArrowNumericType + Send,
626    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
627{
628    fn update_batch(
629        &mut self,
630        values: &[ArrayRef],
631        group_indices: &[usize],
632        opt_filter: Option<&BooleanArray>,
633        total_num_groups: usize,
634    ) -> Result<()> {
635        assert_eq!(values.len(), 1, "single argument to update_batch");
636        let values = values[0].as_primitive::<T>();
637
638        // increment counts, update sums
639        self.counts.resize(total_num_groups, 0);
640        self.sums.resize(total_num_groups, T::default_value());
641        self.null_state.accumulate(
642            group_indices,
643            values,
644            opt_filter,
645            total_num_groups,
646            |group_index, new_value| {
647                let sum = &mut self.sums[group_index];
648                *sum = sum.add_wrapping(new_value);
649
650                self.counts[group_index] += 1;
651            },
652        );
653
654        Ok(())
655    }
656
657    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
658        let counts = emit_to.take_needed(&mut self.counts);
659        let sums = emit_to.take_needed(&mut self.sums);
660        let nulls = self.null_state.build(emit_to);
661
662        assert_eq!(nulls.len(), sums.len());
663        assert_eq!(counts.len(), sums.len());
664
665        // don't evaluate averages with null inputs to avoid errors on null values
666
667        let array: PrimitiveArray<T> = if nulls.null_count() > 0 {
668            let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
669                .with_data_type(self.return_data_type.clone());
670            let iter = sums.into_iter().zip(counts).zip(nulls.iter());
671
672            for ((sum, count), is_valid) in iter {
673                if is_valid {
674                    builder.append_value((self.avg_fn)(sum, count)?)
675                } else {
676                    builder.append_null();
677                }
678            }
679            builder.finish()
680        } else {
681            let averages: Vec<T::Native> = sums
682                .into_iter()
683                .zip(counts.into_iter())
684                .map(|(sum, count)| (self.avg_fn)(sum, count))
685                .collect::<Result<Vec<_>>>()?;
686            PrimitiveArray::new(averages.into(), Some(nulls)) // no copy
687                .with_data_type(self.return_data_type.clone())
688        };
689
690        Ok(Arc::new(array))
691    }
692
693    // return arrays for sums and counts
694    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
695        let nulls = self.null_state.build(emit_to);
696        let nulls = Some(nulls);
697
698        let counts = emit_to.take_needed(&mut self.counts);
699        let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy
700
701        let sums = emit_to.take_needed(&mut self.sums);
702        let sums = PrimitiveArray::<T>::new(sums.into(), nulls) // zero copy
703            .with_data_type(self.sum_data_type.clone());
704
705        Ok(vec![
706            Arc::new(counts) as ArrayRef,
707            Arc::new(sums) as ArrayRef,
708        ])
709    }
710
711    fn merge_batch(
712        &mut self,
713        values: &[ArrayRef],
714        group_indices: &[usize],
715        opt_filter: Option<&BooleanArray>,
716        total_num_groups: usize,
717    ) -> Result<()> {
718        assert_eq!(values.len(), 2, "two arguments to merge_batch");
719        // first batch is counts, second is partial sums
720        let partial_counts = values[0].as_primitive::<UInt64Type>();
721        let partial_sums = values[1].as_primitive::<T>();
722        // update counts with partial counts
723        self.counts.resize(total_num_groups, 0);
724        self.null_state.accumulate(
725            group_indices,
726            partial_counts,
727            opt_filter,
728            total_num_groups,
729            |group_index, partial_count| {
730                self.counts[group_index] += partial_count;
731            },
732        );
733
734        // update sums
735        self.sums.resize(total_num_groups, T::default_value());
736        self.null_state.accumulate(
737            group_indices,
738            partial_sums,
739            opt_filter,
740            total_num_groups,
741            |group_index, new_value: <T as ArrowPrimitiveType>::Native| {
742                let sum = &mut self.sums[group_index];
743                *sum = sum.add_wrapping(new_value);
744            },
745        );
746
747        Ok(())
748    }
749
750    fn convert_to_state(
751        &self,
752        values: &[ArrayRef],
753        opt_filter: Option<&BooleanArray>,
754    ) -> Result<Vec<ArrayRef>> {
755        let sums = values[0]
756            .as_primitive::<T>()
757            .clone()
758            .with_data_type(self.sum_data_type.clone());
759        let counts = UInt64Array::from_value(1, sums.len());
760
761        let nulls = filtered_null_mask(opt_filter, &sums);
762
763        // set nulls on the arrays
764        let counts = set_nulls(counts, nulls.clone());
765        let sums = set_nulls(sums, nulls);
766
767        Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
768    }
769
770    fn supports_convert_to_state(&self) -> bool {
771        true
772    }
773
774    fn size(&self) -> usize {
775        self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
776    }
777}