datafusion_functions_aggregate/
average.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `Avg` & `Mean` aggregate & accumulators
19
20use arrow::array::{
21    Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray,
22    BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
23};
24
25use arrow::compute::sum;
26use arrow::datatypes::{
27    i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType,
28    DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
29    DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type,
30    DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
31};
32use datafusion_common::{
33    exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
34};
35use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
36use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type};
37use datafusion_expr::utils::format_state_name;
38use datafusion_expr::Volatility::Immutable;
39use datafusion_expr::{
40    Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator,
41    ReversedUDAF, Signature,
42};
43
44use datafusion_functions_aggregate_common::aggregate::avg_distinct::{
45    DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator,
46};
47use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
48use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
49    filtered_null_mask, set_nulls,
50};
51
52use datafusion_functions_aggregate_common::utils::DecimalAverager;
53use datafusion_macros::user_doc;
54use log::debug;
55use std::any::Any;
56use std::fmt::Debug;
57use std::mem::{size_of, size_of_val};
58use std::sync::Arc;
59
60make_udaf_expr_and_func!(
61    Avg,
62    avg,
63    expression,
64    "Returns the avg of a group of values.",
65    avg_udaf
66);
67
68#[user_doc(
69    doc_section(label = "General Functions"),
70    description = "Returns the average of numeric values in the specified column.",
71    syntax_example = "avg(expression)",
72    sql_example = r#"```sql
73> SELECT avg(column_name) FROM table_name;
74+---------------------------+
75| avg(column_name)           |
76+---------------------------+
77| 42.75                      |
78+---------------------------+
79```"#,
80    standard_argument(name = "expression",)
81)]
82#[derive(Debug, PartialEq, Eq, Hash)]
83pub struct Avg {
84    signature: Signature,
85    aliases: Vec<String>,
86}
87
88impl Avg {
89    pub fn new() -> Self {
90        Self {
91            signature: Signature::user_defined(Immutable),
92            aliases: vec![String::from("mean")],
93        }
94    }
95}
96
97impl Default for Avg {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl AggregateUDFImpl for Avg {
104    fn as_any(&self) -> &dyn Any {
105        self
106    }
107
108    fn name(&self) -> &str {
109        "avg"
110    }
111
112    fn signature(&self) -> &Signature {
113        &self.signature
114    }
115
116    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
117        avg_return_type(self.name(), &arg_types[0])
118    }
119
120    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
121        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
122        use DataType::*;
123
124        // instantiate specialized accumulator based for the type
125        if acc_args.is_distinct {
126            match (&data_type, acc_args.return_type()) {
127                // Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation
128                (Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())),
129
130                (
131                    Decimal128(_, scale),
132                    Decimal128(target_precision, target_scale),
133                ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
134                    *scale,
135                    *target_precision,
136                    *target_scale,
137                ))),
138
139                (
140                    Decimal256(_, scale),
141                    Decimal256(target_precision, target_scale),
142                ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
143                    *scale,
144                    *target_precision,
145                    *target_scale,
146                ))),
147
148                (dt, return_type) => exec_err!(
149                    "AVG(DISTINCT) for ({} --> {}) not supported",
150                    dt,
151                    return_type
152                ),
153            }
154        } else {
155            match (&data_type, acc_args.return_type()) {
156                (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
157                (
158                    Decimal128(sum_precision, sum_scale),
159                    Decimal128(target_precision, target_scale),
160                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
161                    sum: None,
162                    count: 0,
163                    sum_scale: *sum_scale,
164                    sum_precision: *sum_precision,
165                    target_precision: *target_precision,
166                    target_scale: *target_scale,
167                })),
168
169                (
170                    Decimal256(sum_precision, sum_scale),
171                    Decimal256(target_precision, target_scale),
172                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
173                    sum: None,
174                    count: 0,
175                    sum_scale: *sum_scale,
176                    sum_precision: *sum_precision,
177                    target_precision: *target_precision,
178                    target_scale: *target_scale,
179                })),
180
181                (Duration(time_unit), Duration(result_unit)) => {
182                    Ok(Box::new(DurationAvgAccumulator {
183                        sum: None,
184                        count: 0,
185                        time_unit: *time_unit,
186                        result_unit: *result_unit,
187                    }))
188                }
189
190                (dt, return_type) => {
191                    exec_err!("AvgAccumulator for ({} --> {})", dt, return_type)
192                }
193            }
194        }
195    }
196
197    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
198        if args.is_distinct {
199            // Decimal accumulator actually uses a different precision during accumulation,
200            // see DecimalDistinctAvgAccumulator::with_decimal_params
201            let dt = match args.input_fields[0].data_type() {
202                DataType::Decimal128(_, scale) => {
203                    DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale)
204                }
205                DataType::Decimal256(_, scale) => {
206                    DataType::Decimal256(DECIMAL256_MAX_PRECISION, *scale)
207                }
208                _ => args.return_type().clone(),
209            };
210            // Similar to datafusion_functions_aggregate::sum::Sum::state_fields
211            // since the accumulator uses DistinctSumAccumulator internally.
212            Ok(vec![Field::new_list(
213                format_state_name(args.name, "avg distinct"),
214                Field::new_list_field(dt, true),
215                false,
216            )
217            .into()])
218        } else {
219            Ok(vec![
220                Field::new(
221                    format_state_name(args.name, "count"),
222                    DataType::UInt64,
223                    true,
224                ),
225                Field::new(
226                    format_state_name(args.name, "sum"),
227                    args.input_fields[0].data_type().clone(),
228                    true,
229                ),
230            ]
231            .into_iter()
232            .map(Arc::new)
233            .collect())
234        }
235    }
236
237    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
238        matches!(
239            args.return_field.data_type(),
240            DataType::Float64 | DataType::Decimal128(_, _) | DataType::Duration(_)
241        ) && !args.is_distinct
242    }
243
244    fn create_groups_accumulator(
245        &self,
246        args: AccumulatorArgs,
247    ) -> Result<Box<dyn GroupsAccumulator>> {
248        use DataType::*;
249
250        let data_type = args.exprs[0].data_type(args.schema)?;
251        // instantiate specialized accumulator based for the type
252        match (&data_type, args.return_field.data_type()) {
253            (Float64, Float64) => {
254                Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
255                    &data_type,
256                    args.return_field.data_type(),
257                    |sum: f64, count: u64| Ok(sum / count as f64),
258                )))
259            }
260            (
261                Decimal128(_sum_precision, sum_scale),
262                Decimal128(target_precision, target_scale),
263            ) => {
264                let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
265                    *sum_scale,
266                    *target_precision,
267                    *target_scale,
268                )?;
269
270                let avg_fn =
271                    move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
272
273                Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
274                    &data_type,
275                    args.return_field.data_type(),
276                    avg_fn,
277                )))
278            }
279
280            (
281                Decimal256(_sum_precision, sum_scale),
282                Decimal256(target_precision, target_scale),
283            ) => {
284                let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
285                    *sum_scale,
286                    *target_precision,
287                    *target_scale,
288                )?;
289
290                let avg_fn = move |sum: i256, count: u64| {
291                    decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
292                };
293
294                Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
295                    &data_type,
296                    args.return_field.data_type(),
297                    avg_fn,
298                )))
299            }
300
301            (Duration(time_unit), Duration(_result_unit)) => {
302                let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64);
303
304                match time_unit {
305                    TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::<
306                        DurationSecondType,
307                        _,
308                    >::new(
309                        &data_type,
310                        args.return_type(),
311                        avg_fn,
312                    ))),
313                    TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::<
314                        DurationMillisecondType,
315                        _,
316                    >::new(
317                        &data_type,
318                        args.return_type(),
319                        avg_fn,
320                    ))),
321                    TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::<
322                        DurationMicrosecondType,
323                        _,
324                    >::new(
325                        &data_type,
326                        args.return_type(),
327                        avg_fn,
328                    ))),
329                    TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
330                        DurationNanosecondType,
331                        _,
332                    >::new(
333                        &data_type,
334                        args.return_type(),
335                        avg_fn,
336                    ))),
337                }
338            }
339
340            _ => not_impl_err!(
341                "AvgGroupsAccumulator for ({} --> {})",
342                &data_type,
343                args.return_field.data_type()
344            ),
345        }
346    }
347
348    fn aliases(&self) -> &[String] {
349        &self.aliases
350    }
351
352    fn reverse_expr(&self) -> ReversedUDAF {
353        ReversedUDAF::Identical
354    }
355
356    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
357        let [args] = take_function_args(self.name(), arg_types)?;
358        coerce_avg_type(self.name(), std::slice::from_ref(args))
359    }
360
361    fn documentation(&self) -> Option<&Documentation> {
362        self.doc()
363    }
364}
365
366/// An accumulator to compute the average
367#[derive(Debug, Default)]
368pub struct AvgAccumulator {
369    sum: Option<f64>,
370    count: u64,
371}
372
373impl Accumulator for AvgAccumulator {
374    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
375        let values = values[0].as_primitive::<Float64Type>();
376        self.count += (values.len() - values.null_count()) as u64;
377        if let Some(x) = sum(values) {
378            let v = self.sum.get_or_insert(0.);
379            *v += x;
380        }
381        Ok(())
382    }
383
384    fn evaluate(&mut self) -> Result<ScalarValue> {
385        Ok(ScalarValue::Float64(
386            self.sum.map(|f| f / self.count as f64),
387        ))
388    }
389
390    fn size(&self) -> usize {
391        size_of_val(self)
392    }
393
394    fn state(&mut self) -> Result<Vec<ScalarValue>> {
395        Ok(vec![
396            ScalarValue::from(self.count),
397            ScalarValue::Float64(self.sum),
398        ])
399    }
400
401    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
402        // counts are summed
403        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
404
405        // sums are summed
406        if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
407            let v = self.sum.get_or_insert(0.);
408            *v += x;
409        }
410        Ok(())
411    }
412    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
413        let values = values[0].as_primitive::<Float64Type>();
414        self.count -= (values.len() - values.null_count()) as u64;
415        if let Some(x) = sum(values) {
416            self.sum = Some(self.sum.unwrap() - x);
417        }
418        Ok(())
419    }
420
421    fn supports_retract_batch(&self) -> bool {
422        true
423    }
424}
425
426/// An accumulator to compute the average for decimals
427#[derive(Debug)]
428struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> {
429    sum: Option<T::Native>,
430    count: u64,
431    sum_scale: i8,
432    sum_precision: u8,
433    target_precision: u8,
434    target_scale: i8,
435}
436
437impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> {
438    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
439        let values = values[0].as_primitive::<T>();
440        self.count += (values.len() - values.null_count()) as u64;
441
442        if let Some(x) = sum(values) {
443            let v = self.sum.get_or_insert(T::Native::default());
444            self.sum = Some(v.add_wrapping(x));
445        }
446        Ok(())
447    }
448
449    fn evaluate(&mut self) -> Result<ScalarValue> {
450        let v = self
451            .sum
452            .map(|v| {
453                DecimalAverager::<T>::try_new(
454                    self.sum_scale,
455                    self.target_precision,
456                    self.target_scale,
457                )?
458                .avg(v, T::Native::from_usize(self.count as usize).unwrap())
459            })
460            .transpose()?;
461
462        ScalarValue::new_primitive::<T>(
463            v,
464            &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
465        )
466    }
467
468    fn size(&self) -> usize {
469        size_of_val(self)
470    }
471
472    fn state(&mut self) -> Result<Vec<ScalarValue>> {
473        Ok(vec![
474            ScalarValue::from(self.count),
475            ScalarValue::new_primitive::<T>(
476                self.sum,
477                &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
478            )?,
479        ])
480    }
481
482    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
483        // counts are summed
484        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
485
486        // sums are summed
487        if let Some(x) = sum(states[1].as_primitive::<T>()) {
488            let v = self.sum.get_or_insert(T::Native::default());
489            self.sum = Some(v.add_wrapping(x));
490        }
491        Ok(())
492    }
493    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
494        let values = values[0].as_primitive::<T>();
495        self.count -= (values.len() - values.null_count()) as u64;
496        if let Some(x) = sum(values) {
497            self.sum = Some(self.sum.unwrap().sub_wrapping(x));
498        }
499        Ok(())
500    }
501
502    fn supports_retract_batch(&self) -> bool {
503        true
504    }
505}
506
507/// An accumulator to compute the average for duration values
508#[derive(Debug)]
509struct DurationAvgAccumulator {
510    sum: Option<i64>,
511    count: u64,
512    time_unit: TimeUnit,
513    result_unit: TimeUnit,
514}
515
516impl Accumulator for DurationAvgAccumulator {
517    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
518        let array = &values[0];
519        self.count += (array.len() - array.null_count()) as u64;
520
521        let sum_value = match self.time_unit {
522            TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
523            TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
524            TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
525            TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
526        };
527
528        if let Some(x) = sum_value {
529            let v = self.sum.get_or_insert(0);
530            *v += x;
531        }
532        Ok(())
533    }
534
535    fn evaluate(&mut self) -> Result<ScalarValue> {
536        let avg = self.sum.map(|sum| sum / self.count as i64);
537
538        match self.result_unit {
539            TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
540            TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)),
541            TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)),
542            TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)),
543        }
544    }
545
546    fn size(&self) -> usize {
547        size_of_val(self)
548    }
549
550    fn state(&mut self) -> Result<Vec<ScalarValue>> {
551        let duration_value = match self.time_unit {
552            TimeUnit::Second => ScalarValue::DurationSecond(self.sum),
553            TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum),
554            TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum),
555            TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum),
556        };
557
558        Ok(vec![ScalarValue::from(self.count), duration_value])
559    }
560
561    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
562        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
563
564        let sum_value = match self.time_unit {
565            TimeUnit::Second => sum(states[1].as_primitive::<DurationSecondType>()),
566            TimeUnit::Millisecond => {
567                sum(states[1].as_primitive::<DurationMillisecondType>())
568            }
569            TimeUnit::Microsecond => {
570                sum(states[1].as_primitive::<DurationMicrosecondType>())
571            }
572            TimeUnit::Nanosecond => {
573                sum(states[1].as_primitive::<DurationNanosecondType>())
574            }
575        };
576
577        if let Some(x) = sum_value {
578            let v = self.sum.get_or_insert(0);
579            *v += x;
580        }
581        Ok(())
582    }
583
584    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
585        let array = &values[0];
586        self.count -= (array.len() - array.null_count()) as u64;
587
588        let sum_value = match self.time_unit {
589            TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
590            TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
591            TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
592            TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
593        };
594
595        if let Some(x) = sum_value {
596            self.sum = Some(self.sum.unwrap() - x);
597        }
598        Ok(())
599    }
600
601    fn supports_retract_batch(&self) -> bool {
602        true
603    }
604}
605
606/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
607/// Stores values as native types, and does overflow checking
608///
609/// F: Function that calculates the average value from a sum of
610/// T::Native and a total count
611#[derive(Debug)]
612struct AvgGroupsAccumulator<T, F>
613where
614    T: ArrowNumericType + Send,
615    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
616{
617    /// The type of the internal sum
618    sum_data_type: DataType,
619
620    /// The type of the returned sum
621    return_data_type: DataType,
622
623    /// Count per group (use u64 to make UInt64Array)
624    counts: Vec<u64>,
625
626    /// Sums per group, stored as the native type
627    sums: Vec<T::Native>,
628
629    /// Track nulls in the input / filters
630    null_state: NullState,
631
632    /// Function that computes the final average (value / count)
633    avg_fn: F,
634}
635
636impl<T, F> AvgGroupsAccumulator<T, F>
637where
638    T: ArrowNumericType + Send,
639    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
640{
641    pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
642        debug!(
643            "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}",
644            std::any::type_name::<T>()
645        );
646
647        Self {
648            return_data_type: return_data_type.clone(),
649            sum_data_type: sum_data_type.clone(),
650            counts: vec![],
651            sums: vec![],
652            null_state: NullState::new(),
653            avg_fn,
654        }
655    }
656}
657
658impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
659where
660    T: ArrowNumericType + Send,
661    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
662{
663    fn update_batch(
664        &mut self,
665        values: &[ArrayRef],
666        group_indices: &[usize],
667        opt_filter: Option<&BooleanArray>,
668        total_num_groups: usize,
669    ) -> Result<()> {
670        assert_eq!(values.len(), 1, "single argument to update_batch");
671        let values = values[0].as_primitive::<T>();
672
673        // increment counts, update sums
674        self.counts.resize(total_num_groups, 0);
675        self.sums.resize(total_num_groups, T::default_value());
676        self.null_state.accumulate(
677            group_indices,
678            values,
679            opt_filter,
680            total_num_groups,
681            |group_index, new_value| {
682                let sum = &mut self.sums[group_index];
683                *sum = sum.add_wrapping(new_value);
684
685                self.counts[group_index] += 1;
686            },
687        );
688
689        Ok(())
690    }
691
692    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
693        let counts = emit_to.take_needed(&mut self.counts);
694        let sums = emit_to.take_needed(&mut self.sums);
695        let nulls = self.null_state.build(emit_to);
696
697        assert_eq!(nulls.len(), sums.len());
698        assert_eq!(counts.len(), sums.len());
699
700        // don't evaluate averages with null inputs to avoid errors on null values
701
702        let array: PrimitiveArray<T> = if nulls.null_count() > 0 {
703            let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
704                .with_data_type(self.return_data_type.clone());
705            let iter = sums.into_iter().zip(counts).zip(nulls.iter());
706
707            for ((sum, count), is_valid) in iter {
708                if is_valid {
709                    builder.append_value((self.avg_fn)(sum, count)?)
710                } else {
711                    builder.append_null();
712                }
713            }
714            builder.finish()
715        } else {
716            let averages: Vec<T::Native> = sums
717                .into_iter()
718                .zip(counts.into_iter())
719                .map(|(sum, count)| (self.avg_fn)(sum, count))
720                .collect::<Result<Vec<_>>>()?;
721            PrimitiveArray::new(averages.into(), Some(nulls)) // no copy
722                .with_data_type(self.return_data_type.clone())
723        };
724
725        Ok(Arc::new(array))
726    }
727
728    // return arrays for sums and counts
729    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
730        let nulls = self.null_state.build(emit_to);
731        let nulls = Some(nulls);
732
733        let counts = emit_to.take_needed(&mut self.counts);
734        let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy
735
736        let sums = emit_to.take_needed(&mut self.sums);
737        let sums = PrimitiveArray::<T>::new(sums.into(), nulls) // zero copy
738            .with_data_type(self.sum_data_type.clone());
739
740        Ok(vec![
741            Arc::new(counts) as ArrayRef,
742            Arc::new(sums) as ArrayRef,
743        ])
744    }
745
746    fn merge_batch(
747        &mut self,
748        values: &[ArrayRef],
749        group_indices: &[usize],
750        opt_filter: Option<&BooleanArray>,
751        total_num_groups: usize,
752    ) -> Result<()> {
753        assert_eq!(values.len(), 2, "two arguments to merge_batch");
754        // first batch is counts, second is partial sums
755        let partial_counts = values[0].as_primitive::<UInt64Type>();
756        let partial_sums = values[1].as_primitive::<T>();
757        // update counts with partial counts
758        self.counts.resize(total_num_groups, 0);
759        self.null_state.accumulate(
760            group_indices,
761            partial_counts,
762            opt_filter,
763            total_num_groups,
764            |group_index, partial_count| {
765                self.counts[group_index] += partial_count;
766            },
767        );
768
769        // update sums
770        self.sums.resize(total_num_groups, T::default_value());
771        self.null_state.accumulate(
772            group_indices,
773            partial_sums,
774            opt_filter,
775            total_num_groups,
776            |group_index, new_value: <T as ArrowPrimitiveType>::Native| {
777                let sum = &mut self.sums[group_index];
778                *sum = sum.add_wrapping(new_value);
779            },
780        );
781
782        Ok(())
783    }
784
785    fn convert_to_state(
786        &self,
787        values: &[ArrayRef],
788        opt_filter: Option<&BooleanArray>,
789    ) -> Result<Vec<ArrayRef>> {
790        let sums = values[0]
791            .as_primitive::<T>()
792            .clone()
793            .with_data_type(self.sum_data_type.clone());
794        let counts = UInt64Array::from_value(1, sums.len());
795
796        let nulls = filtered_null_mask(opt_filter, &sums);
797
798        // set nulls on the arrays
799        let counts = set_nulls(counts, nulls.clone());
800        let sums = set_nulls(sums, nulls);
801
802        Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
803    }
804
805    fn supports_convert_to_state(&self) -> bool {
806        true
807    }
808
809    fn size(&self) -> usize {
810        self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
811    }
812}