datafusion_functions_aggregate/
average.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `Avg` & `Mean` aggregate & accumulators
19
20use arrow::array::{
21    Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray,
22    BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
23};
24
25use arrow::compute::sum;
26use arrow::datatypes::{
27    i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type,
28    Decimal64Type, DecimalType, DurationMicrosecondType, DurationMillisecondType,
29    DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type, TimeUnit,
30    UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION,
31    DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE,
32    DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
33};
34use datafusion_common::plan_err;
35use datafusion_common::{
36    exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
37};
38use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
39use datafusion_expr::utils::format_state_name;
40use datafusion_expr::Volatility::Immutable;
41use datafusion_expr::{
42    Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator,
43    ReversedUDAF, Signature,
44};
45
46use datafusion_functions_aggregate_common::aggregate::avg_distinct::{
47    DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator,
48};
49use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
50use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
51    filtered_null_mask, set_nulls,
52};
53
54use datafusion_functions_aggregate_common::utils::DecimalAverager;
55use datafusion_macros::user_doc;
56use log::debug;
57use std::any::Any;
58use std::fmt::Debug;
59use std::mem::{size_of, size_of_val};
60use std::sync::Arc;
61
62make_udaf_expr_and_func!(
63    Avg,
64    avg,
65    expression,
66    "Returns the avg of a group of values.",
67    avg_udaf
68);
69
70pub fn avg_distinct(expr: Expr) -> Expr {
71    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
72        avg_udaf(),
73        vec![expr],
74        true,
75        None,
76        vec![],
77        None,
78    ))
79}
80
81#[user_doc(
82    doc_section(label = "General Functions"),
83    description = "Returns the average of numeric values in the specified column.",
84    syntax_example = "avg(expression)",
85    sql_example = r#"```sql
86> SELECT avg(column_name) FROM table_name;
87+---------------------------+
88| avg(column_name)           |
89+---------------------------+
90| 42.75                      |
91+---------------------------+
92```"#,
93    standard_argument(name = "expression",)
94)]
95#[derive(Debug, PartialEq, Eq, Hash)]
96pub struct Avg {
97    signature: Signature,
98    aliases: Vec<String>,
99}
100
101impl Avg {
102    pub fn new() -> Self {
103        Self {
104            signature: Signature::user_defined(Immutable),
105            aliases: vec![String::from("mean")],
106        }
107    }
108}
109
110impl Default for Avg {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116impl AggregateUDFImpl for Avg {
117    fn as_any(&self) -> &dyn Any {
118        self
119    }
120
121    fn name(&self) -> &str {
122        "avg"
123    }
124
125    fn signature(&self) -> &Signature {
126        &self.signature
127    }
128
129    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
130        let [args] = take_function_args(self.name(), arg_types)?;
131
132        // Supported types smallint, int, bigint, real, double precision, decimal, or interval
133        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
134        fn coerced_type(data_type: &DataType) -> Result<DataType> {
135            match &data_type {
136                DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
137                DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
138                DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
139                DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
140                d if d.is_numeric() => Ok(DataType::Float64),
141                DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
142                DataType::Dictionary(_, v) => coerced_type(v.as_ref()),
143                _ => {
144                    plan_err!("Avg does not support inputs of type {data_type}.")
145                }
146            }
147        }
148        Ok(vec![coerced_type(args)?])
149    }
150
151    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
152        match &arg_types[0] {
153            DataType::Decimal32(precision, scale) => {
154                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
155                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
156                let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
157                let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
158                Ok(DataType::Decimal32(new_precision, new_scale))
159            }
160            DataType::Decimal64(precision, scale) => {
161                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
162                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
163                let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
164                let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
165                Ok(DataType::Decimal64(new_precision, new_scale))
166            }
167            DataType::Decimal128(precision, scale) => {
168                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
169                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
170                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
171                let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
172                Ok(DataType::Decimal128(new_precision, new_scale))
173            }
174            DataType::Decimal256(precision, scale) => {
175                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
176                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
177                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
178                let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
179                Ok(DataType::Decimal256(new_precision, new_scale))
180            }
181            DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
182            _ => Ok(DataType::Float64),
183        }
184    }
185
186    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
187        let data_type = acc_args.expr_fields[0].data_type();
188        use DataType::*;
189
190        // instantiate specialized accumulator based for the type
191        if acc_args.is_distinct {
192            match (data_type, acc_args.return_type()) {
193                // Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation
194                (Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())),
195
196                (
197                    Decimal32(_, scale),
198                    Decimal32(target_precision, target_scale),
199                ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal32Type>::with_decimal_params(
200                    *scale,
201                    *target_precision,
202                    *target_scale,
203                ))),
204                                (
205                    Decimal64(_, scale),
206                    Decimal64(target_precision, target_scale),
207                ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal64Type>::with_decimal_params(
208                    *scale,
209                    *target_precision,
210                    *target_scale,
211                ))),
212                (
213                    Decimal128(_, scale),
214                    Decimal128(target_precision, target_scale),
215                ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
216                    *scale,
217                    *target_precision,
218                    *target_scale,
219                ))),
220
221                (
222                    Decimal256(_, scale),
223                    Decimal256(target_precision, target_scale),
224                ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
225                    *scale,
226                    *target_precision,
227                    *target_scale,
228                ))),
229
230                (dt, return_type) => exec_err!(
231                    "AVG(DISTINCT) for ({} --> {}) not supported",
232                    dt,
233                    return_type
234                ),
235            }
236        } else {
237            match (&data_type, acc_args.return_type()) {
238                (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
239                (
240                    Decimal32(sum_precision, sum_scale),
241                    Decimal32(target_precision, target_scale),
242                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal32Type> {
243                    sum: None,
244                    count: 0,
245                    sum_scale: *sum_scale,
246                    sum_precision: *sum_precision,
247                    target_precision: *target_precision,
248                    target_scale: *target_scale,
249                })),
250                (
251                    Decimal64(sum_precision, sum_scale),
252                    Decimal64(target_precision, target_scale),
253                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal64Type> {
254                    sum: None,
255                    count: 0,
256                    sum_scale: *sum_scale,
257                    sum_precision: *sum_precision,
258                    target_precision: *target_precision,
259                    target_scale: *target_scale,
260                })),
261                (
262                    Decimal128(sum_precision, sum_scale),
263                    Decimal128(target_precision, target_scale),
264                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
265                    sum: None,
266                    count: 0,
267                    sum_scale: *sum_scale,
268                    sum_precision: *sum_precision,
269                    target_precision: *target_precision,
270                    target_scale: *target_scale,
271                })),
272
273                (
274                    Decimal256(sum_precision, sum_scale),
275                    Decimal256(target_precision, target_scale),
276                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
277                    sum: None,
278                    count: 0,
279                    sum_scale: *sum_scale,
280                    sum_precision: *sum_precision,
281                    target_precision: *target_precision,
282                    target_scale: *target_scale,
283                })),
284
285                (Duration(time_unit), Duration(result_unit)) => {
286                    Ok(Box::new(DurationAvgAccumulator {
287                        sum: None,
288                        count: 0,
289                        time_unit: *time_unit,
290                        result_unit: *result_unit,
291                    }))
292                }
293
294                (dt, return_type) => {
295                    exec_err!("AvgAccumulator for ({} --> {})", dt, return_type)
296                }
297            }
298        }
299    }
300
301    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
302        if args.is_distinct {
303            // Decimal accumulator actually uses a different precision during accumulation,
304            // see DecimalDistinctAvgAccumulator::with_decimal_params
305            let dt = match args.input_fields[0].data_type() {
306                DataType::Decimal32(_, scale) => {
307                    DataType::Decimal32(DECIMAL32_MAX_PRECISION, *scale)
308                }
309                DataType::Decimal64(_, scale) => {
310                    DataType::Decimal64(DECIMAL64_MAX_PRECISION, *scale)
311                }
312                DataType::Decimal128(_, scale) => {
313                    DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale)
314                }
315                DataType::Decimal256(_, scale) => {
316                    DataType::Decimal256(DECIMAL256_MAX_PRECISION, *scale)
317                }
318                _ => args.return_type().clone(),
319            };
320            // Similar to datafusion_functions_aggregate::sum::Sum::state_fields
321            // since the accumulator uses DistinctSumAccumulator internally.
322            Ok(vec![Field::new_list(
323                format_state_name(args.name, "avg distinct"),
324                Field::new_list_field(dt, true),
325                false,
326            )
327            .into()])
328        } else {
329            Ok(vec![
330                Field::new(
331                    format_state_name(args.name, "count"),
332                    DataType::UInt64,
333                    true,
334                ),
335                Field::new(
336                    format_state_name(args.name, "sum"),
337                    args.input_fields[0].data_type().clone(),
338                    true,
339                ),
340            ]
341            .into_iter()
342            .map(Arc::new)
343            .collect())
344        }
345    }
346
347    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
348        matches!(
349            args.return_field.data_type(),
350            DataType::Float64
351                | DataType::Decimal32(_, _)
352                | DataType::Decimal64(_, _)
353                | DataType::Decimal128(_, _)
354                | DataType::Decimal256(_, _)
355                | DataType::Duration(_)
356        ) && !args.is_distinct
357    }
358
359    fn create_groups_accumulator(
360        &self,
361        args: AccumulatorArgs,
362    ) -> Result<Box<dyn GroupsAccumulator>> {
363        use DataType::*;
364
365        let data_type = args.expr_fields[0].data_type();
366
367        // instantiate specialized accumulator based for the type
368        match (data_type, args.return_field.data_type()) {
369            (Float64, Float64) => {
370                Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
371                    data_type,
372                    args.return_field.data_type(),
373                    |sum: f64, count: u64| Ok(sum / count as f64),
374                )))
375            }
376            (
377                Decimal32(_sum_precision, sum_scale),
378                Decimal32(target_precision, target_scale),
379            ) => {
380                let decimal_averager = DecimalAverager::<Decimal32Type>::try_new(
381                    *sum_scale,
382                    *target_precision,
383                    *target_scale,
384                )?;
385
386                let avg_fn =
387                    move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32);
388
389                Ok(Box::new(AvgGroupsAccumulator::<Decimal32Type, _>::new(
390                    data_type,
391                    args.return_field.data_type(),
392                    avg_fn,
393                )))
394            }
395            (
396                Decimal64(_sum_precision, sum_scale),
397                Decimal64(target_precision, target_scale),
398            ) => {
399                let decimal_averager = DecimalAverager::<Decimal64Type>::try_new(
400                    *sum_scale,
401                    *target_precision,
402                    *target_scale,
403                )?;
404
405                let avg_fn =
406                    move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64);
407
408                Ok(Box::new(AvgGroupsAccumulator::<Decimal64Type, _>::new(
409                    data_type,
410                    args.return_field.data_type(),
411                    avg_fn,
412                )))
413            }
414            (
415                Decimal128(_sum_precision, sum_scale),
416                Decimal128(target_precision, target_scale),
417            ) => {
418                let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
419                    *sum_scale,
420                    *target_precision,
421                    *target_scale,
422                )?;
423
424                let avg_fn =
425                    move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
426
427                Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
428                    data_type,
429                    args.return_field.data_type(),
430                    avg_fn,
431                )))
432            }
433
434            (
435                Decimal256(_sum_precision, sum_scale),
436                Decimal256(target_precision, target_scale),
437            ) => {
438                let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
439                    *sum_scale,
440                    *target_precision,
441                    *target_scale,
442                )?;
443
444                let avg_fn = move |sum: i256, count: u64| {
445                    decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
446                };
447
448                Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
449                    data_type,
450                    args.return_field.data_type(),
451                    avg_fn,
452                )))
453            }
454
455            (Duration(time_unit), Duration(_result_unit)) => {
456                let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64);
457
458                match time_unit {
459                    TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::<
460                        DurationSecondType,
461                        _,
462                    >::new(
463                        data_type,
464                        args.return_type(),
465                        avg_fn,
466                    ))),
467                    TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::<
468                        DurationMillisecondType,
469                        _,
470                    >::new(
471                        data_type,
472                        args.return_type(),
473                        avg_fn,
474                    ))),
475                    TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::<
476                        DurationMicrosecondType,
477                        _,
478                    >::new(
479                        data_type,
480                        args.return_type(),
481                        avg_fn,
482                    ))),
483                    TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
484                        DurationNanosecondType,
485                        _,
486                    >::new(
487                        data_type,
488                        args.return_type(),
489                        avg_fn,
490                    ))),
491                }
492            }
493
494            _ => not_impl_err!(
495                "AvgGroupsAccumulator for ({} --> {})",
496                &data_type,
497                args.return_field.data_type()
498            ),
499        }
500    }
501
502    fn aliases(&self) -> &[String] {
503        &self.aliases
504    }
505
506    fn reverse_expr(&self) -> ReversedUDAF {
507        ReversedUDAF::Identical
508    }
509
510    fn documentation(&self) -> Option<&Documentation> {
511        self.doc()
512    }
513}
514
515/// An accumulator to compute the average
516#[derive(Debug, Default)]
517pub struct AvgAccumulator {
518    sum: Option<f64>,
519    count: u64,
520}
521
522impl Accumulator for AvgAccumulator {
523    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
524        let values = values[0].as_primitive::<Float64Type>();
525        self.count += (values.len() - values.null_count()) as u64;
526        if let Some(x) = sum(values) {
527            let v = self.sum.get_or_insert(0.);
528            *v += x;
529        }
530        Ok(())
531    }
532
533    fn evaluate(&mut self) -> Result<ScalarValue> {
534        Ok(ScalarValue::Float64(
535            self.sum.map(|f| f / self.count as f64),
536        ))
537    }
538
539    fn size(&self) -> usize {
540        size_of_val(self)
541    }
542
543    fn state(&mut self) -> Result<Vec<ScalarValue>> {
544        Ok(vec![
545            ScalarValue::from(self.count),
546            ScalarValue::Float64(self.sum),
547        ])
548    }
549
550    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
551        // counts are summed
552        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
553
554        // sums are summed
555        if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
556            let v = self.sum.get_or_insert(0.);
557            *v += x;
558        }
559        Ok(())
560    }
561    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
562        let values = values[0].as_primitive::<Float64Type>();
563        self.count -= (values.len() - values.null_count()) as u64;
564        if let Some(x) = sum(values) {
565            self.sum = Some(self.sum.unwrap() - x);
566        }
567        Ok(())
568    }
569
570    fn supports_retract_batch(&self) -> bool {
571        true
572    }
573}
574
575/// An accumulator to compute the average for decimals
576#[derive(Debug)]
577struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> {
578    sum: Option<T::Native>,
579    count: u64,
580    sum_scale: i8,
581    sum_precision: u8,
582    target_precision: u8,
583    target_scale: i8,
584}
585
586impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> {
587    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
588        let values = values[0].as_primitive::<T>();
589        self.count += (values.len() - values.null_count()) as u64;
590
591        if let Some(x) = sum(values) {
592            let v = self.sum.get_or_insert_with(T::Native::default);
593            self.sum = Some(v.add_wrapping(x));
594        }
595        Ok(())
596    }
597
598    fn evaluate(&mut self) -> Result<ScalarValue> {
599        let v = self
600            .sum
601            .map(|v| {
602                DecimalAverager::<T>::try_new(
603                    self.sum_scale,
604                    self.target_precision,
605                    self.target_scale,
606                )?
607                .avg(v, T::Native::from_usize(self.count as usize).unwrap())
608            })
609            .transpose()?;
610
611        ScalarValue::new_primitive::<T>(
612            v,
613            &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
614        )
615    }
616
617    fn size(&self) -> usize {
618        size_of_val(self)
619    }
620
621    fn state(&mut self) -> Result<Vec<ScalarValue>> {
622        Ok(vec![
623            ScalarValue::from(self.count),
624            ScalarValue::new_primitive::<T>(
625                self.sum,
626                &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
627            )?,
628        ])
629    }
630
631    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
632        // counts are summed
633        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
634
635        // sums are summed
636        if let Some(x) = sum(states[1].as_primitive::<T>()) {
637            let v = self.sum.get_or_insert_with(T::Native::default);
638            self.sum = Some(v.add_wrapping(x));
639        }
640        Ok(())
641    }
642    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
643        let values = values[0].as_primitive::<T>();
644        self.count -= (values.len() - values.null_count()) as u64;
645        if let Some(x) = sum(values) {
646            self.sum = Some(self.sum.unwrap().sub_wrapping(x));
647        }
648        Ok(())
649    }
650
651    fn supports_retract_batch(&self) -> bool {
652        true
653    }
654}
655
656/// An accumulator to compute the average for duration values
657#[derive(Debug)]
658struct DurationAvgAccumulator {
659    sum: Option<i64>,
660    count: u64,
661    time_unit: TimeUnit,
662    result_unit: TimeUnit,
663}
664
665impl Accumulator for DurationAvgAccumulator {
666    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
667        let array = &values[0];
668        self.count += (array.len() - array.null_count()) as u64;
669
670        let sum_value = match self.time_unit {
671            TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
672            TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
673            TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
674            TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
675        };
676
677        if let Some(x) = sum_value {
678            let v = self.sum.get_or_insert(0);
679            *v += x;
680        }
681        Ok(())
682    }
683
684    fn evaluate(&mut self) -> Result<ScalarValue> {
685        let avg = self.sum.map(|sum| sum / self.count as i64);
686
687        match self.result_unit {
688            TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
689            TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)),
690            TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)),
691            TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)),
692        }
693    }
694
695    fn size(&self) -> usize {
696        size_of_val(self)
697    }
698
699    fn state(&mut self) -> Result<Vec<ScalarValue>> {
700        let duration_value = match self.time_unit {
701            TimeUnit::Second => ScalarValue::DurationSecond(self.sum),
702            TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum),
703            TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum),
704            TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum),
705        };
706
707        Ok(vec![ScalarValue::from(self.count), duration_value])
708    }
709
710    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
711        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
712
713        let sum_value = match self.time_unit {
714            TimeUnit::Second => sum(states[1].as_primitive::<DurationSecondType>()),
715            TimeUnit::Millisecond => {
716                sum(states[1].as_primitive::<DurationMillisecondType>())
717            }
718            TimeUnit::Microsecond => {
719                sum(states[1].as_primitive::<DurationMicrosecondType>())
720            }
721            TimeUnit::Nanosecond => {
722                sum(states[1].as_primitive::<DurationNanosecondType>())
723            }
724        };
725
726        if let Some(x) = sum_value {
727            let v = self.sum.get_or_insert(0);
728            *v += x;
729        }
730        Ok(())
731    }
732
733    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
734        let array = &values[0];
735        self.count -= (array.len() - array.null_count()) as u64;
736
737        let sum_value = match self.time_unit {
738            TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
739            TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
740            TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
741            TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
742        };
743
744        if let Some(x) = sum_value {
745            self.sum = Some(self.sum.unwrap() - x);
746        }
747        Ok(())
748    }
749
750    fn supports_retract_batch(&self) -> bool {
751        true
752    }
753}
754
755/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
756/// Stores values as native types, and does overflow checking
757///
758/// F: Function that calculates the average value from a sum of
759/// T::Native and a total count
760#[derive(Debug)]
761struct AvgGroupsAccumulator<T, F>
762where
763    T: ArrowNumericType + Send,
764    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
765{
766    /// The type of the internal sum
767    sum_data_type: DataType,
768
769    /// The type of the returned sum
770    return_data_type: DataType,
771
772    /// Count per group (use u64 to make UInt64Array)
773    counts: Vec<u64>,
774
775    /// Sums per group, stored as the native type
776    sums: Vec<T::Native>,
777
778    /// Track nulls in the input / filters
779    null_state: NullState,
780
781    /// Function that computes the final average (value / count)
782    avg_fn: F,
783}
784
785impl<T, F> AvgGroupsAccumulator<T, F>
786where
787    T: ArrowNumericType + Send,
788    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
789{
790    pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
791        debug!(
792            "AvgGroupsAccumulator ({}, sum type: {sum_data_type}) --> {return_data_type}",
793            std::any::type_name::<T>()
794        );
795
796        Self {
797            return_data_type: return_data_type.clone(),
798            sum_data_type: sum_data_type.clone(),
799            counts: vec![],
800            sums: vec![],
801            null_state: NullState::new(),
802            avg_fn,
803        }
804    }
805}
806
807impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
808where
809    T: ArrowNumericType + Send,
810    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
811{
812    fn update_batch(
813        &mut self,
814        values: &[ArrayRef],
815        group_indices: &[usize],
816        opt_filter: Option<&BooleanArray>,
817        total_num_groups: usize,
818    ) -> Result<()> {
819        assert_eq!(values.len(), 1, "single argument to update_batch");
820        let values = values[0].as_primitive::<T>();
821
822        // increment counts, update sums
823        self.counts.resize(total_num_groups, 0);
824        self.sums.resize(total_num_groups, T::default_value());
825        self.null_state.accumulate(
826            group_indices,
827            values,
828            opt_filter,
829            total_num_groups,
830            |group_index, new_value| {
831                let sum = &mut self.sums[group_index];
832                *sum = sum.add_wrapping(new_value);
833
834                self.counts[group_index] += 1;
835            },
836        );
837
838        Ok(())
839    }
840
841    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
842        let counts = emit_to.take_needed(&mut self.counts);
843        let sums = emit_to.take_needed(&mut self.sums);
844        let nulls = self.null_state.build(emit_to);
845
846        assert_eq!(nulls.len(), sums.len());
847        assert_eq!(counts.len(), sums.len());
848
849        // don't evaluate averages with null inputs to avoid errors on null values
850
851        let array: PrimitiveArray<T> = if nulls.null_count() > 0 {
852            let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
853                .with_data_type(self.return_data_type.clone());
854            let iter = sums.into_iter().zip(counts).zip(nulls.iter());
855
856            for ((sum, count), is_valid) in iter {
857                if is_valid {
858                    builder.append_value((self.avg_fn)(sum, count)?)
859                } else {
860                    builder.append_null();
861                }
862            }
863            builder.finish()
864        } else {
865            let averages: Vec<T::Native> = sums
866                .into_iter()
867                .zip(counts.into_iter())
868                .map(|(sum, count)| (self.avg_fn)(sum, count))
869                .collect::<Result<Vec<_>>>()?;
870            PrimitiveArray::new(averages.into(), Some(nulls)) // no copy
871                .with_data_type(self.return_data_type.clone())
872        };
873
874        Ok(Arc::new(array))
875    }
876
877    // return arrays for sums and counts
878    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
879        let nulls = self.null_state.build(emit_to);
880        let nulls = Some(nulls);
881
882        let counts = emit_to.take_needed(&mut self.counts);
883        let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy
884
885        let sums = emit_to.take_needed(&mut self.sums);
886        let sums = PrimitiveArray::<T>::new(sums.into(), nulls) // zero copy
887            .with_data_type(self.sum_data_type.clone());
888
889        Ok(vec![
890            Arc::new(counts) as ArrayRef,
891            Arc::new(sums) as ArrayRef,
892        ])
893    }
894
895    fn merge_batch(
896        &mut self,
897        values: &[ArrayRef],
898        group_indices: &[usize],
899        opt_filter: Option<&BooleanArray>,
900        total_num_groups: usize,
901    ) -> Result<()> {
902        assert_eq!(values.len(), 2, "two arguments to merge_batch");
903        // first batch is counts, second is partial sums
904        let partial_counts = values[0].as_primitive::<UInt64Type>();
905        let partial_sums = values[1].as_primitive::<T>();
906        // update counts with partial counts
907        self.counts.resize(total_num_groups, 0);
908        self.null_state.accumulate(
909            group_indices,
910            partial_counts,
911            opt_filter,
912            total_num_groups,
913            |group_index, partial_count| {
914                self.counts[group_index] += partial_count;
915            },
916        );
917
918        // update sums
919        self.sums.resize(total_num_groups, T::default_value());
920        self.null_state.accumulate(
921            group_indices,
922            partial_sums,
923            opt_filter,
924            total_num_groups,
925            |group_index, new_value: <T as ArrowPrimitiveType>::Native| {
926                let sum = &mut self.sums[group_index];
927                *sum = sum.add_wrapping(new_value);
928            },
929        );
930
931        Ok(())
932    }
933
934    fn convert_to_state(
935        &self,
936        values: &[ArrayRef],
937        opt_filter: Option<&BooleanArray>,
938    ) -> Result<Vec<ArrayRef>> {
939        let sums = values[0]
940            .as_primitive::<T>()
941            .clone()
942            .with_data_type(self.sum_data_type.clone());
943        let counts = UInt64Array::from_value(1, sums.len());
944
945        let nulls = filtered_null_mask(opt_filter, &sums);
946
947        // set nulls on the arrays
948        let counts = set_nulls(counts, nulls.clone());
949        let sums = set_nulls(sums, nulls);
950
951        Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
952    }
953
954    fn supports_convert_to_state(&self) -> bool {
955        true
956    }
957
958    fn size(&self) -> usize {
959        self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
960    }
961}