datafusion_functions_aggregate/
average.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `Avg` & `Mean` aggregate & accumulators
19
20use arrow::array::{
21    Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray,
22    BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
23};
24
25use arrow::compute::sum;
26use arrow::datatypes::{
27    ArrowNativeType, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE,
28    DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION,
29    DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType,
30    Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType,
31    DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
32    DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type, i256,
33};
34use datafusion_common::types::{NativeType, logical_float64};
35use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err};
36use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
37use datafusion_expr::utils::format_state_name;
38use datafusion_expr::{
39    Accumulator, AggregateUDFImpl, Coercion, Documentation, EmitTo, Expr,
40    GroupsAccumulator, ReversedUDAF, Signature, TypeSignature, TypeSignatureClass,
41    Volatility,
42};
43use datafusion_functions_aggregate_common::aggregate::avg_distinct::{
44    DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator,
45};
46use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
47use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
48    filtered_null_mask, set_nulls,
49};
50use datafusion_functions_aggregate_common::utils::DecimalAverager;
51use datafusion_macros::user_doc;
52use log::debug;
53use std::any::Any;
54use std::fmt::Debug;
55use std::mem::{size_of, size_of_val};
56use std::sync::Arc;
57
58make_udaf_expr_and_func!(
59    Avg,
60    avg,
61    expression,
62    "Returns the avg of a group of values.",
63    avg_udaf
64);
65
66pub fn avg_distinct(expr: Expr) -> Expr {
67    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
68        avg_udaf(),
69        vec![expr],
70        true,
71        None,
72        vec![],
73        None,
74    ))
75}
76
77#[user_doc(
78    doc_section(label = "General Functions"),
79    description = "Returns the average of numeric values in the specified column.",
80    syntax_example = "avg(expression)",
81    sql_example = r#"```sql
82> SELECT avg(column_name) FROM table_name;
83+---------------------------+
84| avg(column_name)           |
85+---------------------------+
86| 42.75                      |
87+---------------------------+
88```"#,
89    standard_argument(name = "expression",)
90)]
91#[derive(Debug, PartialEq, Eq, Hash)]
92pub struct Avg {
93    signature: Signature,
94    aliases: Vec<String>,
95}
96
97impl Avg {
98    pub fn new() -> Self {
99        Self {
100            // Supported types smallint, int, bigint, real, double precision, decimal, or interval
101            // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
102            signature: Signature::one_of(
103                vec![
104                    TypeSignature::Coercible(vec![Coercion::new_exact(
105                        TypeSignatureClass::Decimal,
106                    )]),
107                    TypeSignature::Coercible(vec![Coercion::new_exact(
108                        TypeSignatureClass::Duration,
109                    )]),
110                    TypeSignature::Coercible(vec![Coercion::new_implicit(
111                        TypeSignatureClass::Native(logical_float64()),
112                        vec![TypeSignatureClass::Integer, TypeSignatureClass::Float],
113                        NativeType::Float64,
114                    )]),
115                ],
116                Volatility::Immutable,
117            ),
118            aliases: vec![String::from("mean")],
119        }
120    }
121}
122
123impl Default for Avg {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129impl AggregateUDFImpl for Avg {
130    fn as_any(&self) -> &dyn Any {
131        self
132    }
133
134    fn name(&self) -> &str {
135        "avg"
136    }
137
138    fn signature(&self) -> &Signature {
139        &self.signature
140    }
141
142    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
143        match &arg_types[0] {
144            DataType::Decimal32(precision, scale) => {
145                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
146                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
147                let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
148                let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
149                Ok(DataType::Decimal32(new_precision, new_scale))
150            }
151            DataType::Decimal64(precision, scale) => {
152                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
153                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
154                let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
155                let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
156                Ok(DataType::Decimal64(new_precision, new_scale))
157            }
158            DataType::Decimal128(precision, scale) => {
159                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
160                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
161                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
162                let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
163                Ok(DataType::Decimal128(new_precision, new_scale))
164            }
165            DataType::Decimal256(precision, scale) => {
166                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
167                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
168                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
169                let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
170                Ok(DataType::Decimal256(new_precision, new_scale))
171            }
172            DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
173            _ => Ok(DataType::Float64),
174        }
175    }
176
177    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
178        let data_type = acc_args.expr_fields[0].data_type();
179        use DataType::*;
180
181        // instantiate specialized accumulator based for the type
182        if acc_args.is_distinct {
183            match (data_type, acc_args.return_type()) {
184                // Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation
185                (Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())),
186
187                (
188                    Decimal32(_, scale),
189                    Decimal32(target_precision, target_scale),
190                ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal32Type>::with_decimal_params(
191                    *scale,
192                    *target_precision,
193                    *target_scale,
194                ))),
195                                (
196                    Decimal64(_, scale),
197                    Decimal64(target_precision, target_scale),
198                ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal64Type>::with_decimal_params(
199                    *scale,
200                    *target_precision,
201                    *target_scale,
202                ))),
203                (
204                    Decimal128(_, scale),
205                    Decimal128(target_precision, target_scale),
206                ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
207                    *scale,
208                    *target_precision,
209                    *target_scale,
210                ))),
211
212                (
213                    Decimal256(_, scale),
214                    Decimal256(target_precision, target_scale),
215                ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
216                    *scale,
217                    *target_precision,
218                    *target_scale,
219                ))),
220
221                (dt, return_type) => exec_err!(
222                    "AVG(DISTINCT) for ({} --> {}) not supported",
223                    dt,
224                    return_type
225                ),
226            }
227        } else {
228            match (&data_type, acc_args.return_type()) {
229                (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
230                (
231                    Decimal32(sum_precision, sum_scale),
232                    Decimal32(target_precision, target_scale),
233                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal32Type> {
234                    sum: None,
235                    count: 0,
236                    sum_scale: *sum_scale,
237                    sum_precision: *sum_precision,
238                    target_precision: *target_precision,
239                    target_scale: *target_scale,
240                })),
241                (
242                    Decimal64(sum_precision, sum_scale),
243                    Decimal64(target_precision, target_scale),
244                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal64Type> {
245                    sum: None,
246                    count: 0,
247                    sum_scale: *sum_scale,
248                    sum_precision: *sum_precision,
249                    target_precision: *target_precision,
250                    target_scale: *target_scale,
251                })),
252                (
253                    Decimal128(sum_precision, sum_scale),
254                    Decimal128(target_precision, target_scale),
255                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
256                    sum: None,
257                    count: 0,
258                    sum_scale: *sum_scale,
259                    sum_precision: *sum_precision,
260                    target_precision: *target_precision,
261                    target_scale: *target_scale,
262                })),
263
264                (
265                    Decimal256(sum_precision, sum_scale),
266                    Decimal256(target_precision, target_scale),
267                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
268                    sum: None,
269                    count: 0,
270                    sum_scale: *sum_scale,
271                    sum_precision: *sum_precision,
272                    target_precision: *target_precision,
273                    target_scale: *target_scale,
274                })),
275
276                (Duration(time_unit), Duration(result_unit)) => {
277                    Ok(Box::new(DurationAvgAccumulator {
278                        sum: None,
279                        count: 0,
280                        time_unit: *time_unit,
281                        result_unit: *result_unit,
282                    }))
283                }
284
285                (dt, return_type) => {
286                    exec_err!("AvgAccumulator for ({} --> {})", dt, return_type)
287                }
288            }
289        }
290    }
291
292    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
293        if args.is_distinct {
294            // Decimal accumulator actually uses a different precision during accumulation,
295            // see DecimalDistinctAvgAccumulator::with_decimal_params
296            let dt = match args.input_fields[0].data_type() {
297                DataType::Decimal32(_, scale) => {
298                    DataType::Decimal32(DECIMAL32_MAX_PRECISION, *scale)
299                }
300                DataType::Decimal64(_, scale) => {
301                    DataType::Decimal64(DECIMAL64_MAX_PRECISION, *scale)
302                }
303                DataType::Decimal128(_, scale) => {
304                    DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale)
305                }
306                DataType::Decimal256(_, scale) => {
307                    DataType::Decimal256(DECIMAL256_MAX_PRECISION, *scale)
308                }
309                _ => args.return_type().clone(),
310            };
311            // Similar to datafusion_functions_aggregate::sum::Sum::state_fields
312            // since the accumulator uses DistinctSumAccumulator internally.
313            Ok(vec![
314                Field::new_list(
315                    format_state_name(args.name, "avg distinct"),
316                    Field::new_list_field(dt, true),
317                    false,
318                )
319                .into(),
320            ])
321        } else {
322            Ok(vec![
323                Field::new(
324                    format_state_name(args.name, "count"),
325                    DataType::UInt64,
326                    true,
327                ),
328                Field::new(
329                    format_state_name(args.name, "sum"),
330                    args.input_fields[0].data_type().clone(),
331                    true,
332                ),
333            ]
334            .into_iter()
335            .map(Arc::new)
336            .collect())
337        }
338    }
339
340    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
341        matches!(
342            args.return_field.data_type(),
343            DataType::Float64
344                | DataType::Decimal32(_, _)
345                | DataType::Decimal64(_, _)
346                | DataType::Decimal128(_, _)
347                | DataType::Decimal256(_, _)
348                | DataType::Duration(_)
349        ) && !args.is_distinct
350    }
351
352    fn create_groups_accumulator(
353        &self,
354        args: AccumulatorArgs,
355    ) -> Result<Box<dyn GroupsAccumulator>> {
356        use DataType::*;
357
358        let data_type = args.expr_fields[0].data_type();
359
360        // instantiate specialized accumulator based for the type
361        match (data_type, args.return_field.data_type()) {
362            (Float64, Float64) => {
363                Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
364                    data_type,
365                    args.return_field.data_type(),
366                    |sum: f64, count: u64| Ok(sum / count as f64),
367                )))
368            }
369            (
370                Decimal32(_sum_precision, sum_scale),
371                Decimal32(target_precision, target_scale),
372            ) => {
373                let decimal_averager = DecimalAverager::<Decimal32Type>::try_new(
374                    *sum_scale,
375                    *target_precision,
376                    *target_scale,
377                )?;
378
379                let avg_fn =
380                    move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32);
381
382                Ok(Box::new(AvgGroupsAccumulator::<Decimal32Type, _>::new(
383                    data_type,
384                    args.return_field.data_type(),
385                    avg_fn,
386                )))
387            }
388            (
389                Decimal64(_sum_precision, sum_scale),
390                Decimal64(target_precision, target_scale),
391            ) => {
392                let decimal_averager = DecimalAverager::<Decimal64Type>::try_new(
393                    *sum_scale,
394                    *target_precision,
395                    *target_scale,
396                )?;
397
398                let avg_fn =
399                    move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64);
400
401                Ok(Box::new(AvgGroupsAccumulator::<Decimal64Type, _>::new(
402                    data_type,
403                    args.return_field.data_type(),
404                    avg_fn,
405                )))
406            }
407            (
408                Decimal128(_sum_precision, sum_scale),
409                Decimal128(target_precision, target_scale),
410            ) => {
411                let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
412                    *sum_scale,
413                    *target_precision,
414                    *target_scale,
415                )?;
416
417                let avg_fn =
418                    move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
419
420                Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
421                    data_type,
422                    args.return_field.data_type(),
423                    avg_fn,
424                )))
425            }
426
427            (
428                Decimal256(_sum_precision, sum_scale),
429                Decimal256(target_precision, target_scale),
430            ) => {
431                let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
432                    *sum_scale,
433                    *target_precision,
434                    *target_scale,
435                )?;
436
437                let avg_fn = move |sum: i256, count: u64| {
438                    decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
439                };
440
441                Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
442                    data_type,
443                    args.return_field.data_type(),
444                    avg_fn,
445                )))
446            }
447
448            (Duration(time_unit), Duration(_result_unit)) => {
449                let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64);
450
451                match time_unit {
452                    TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::<
453                        DurationSecondType,
454                        _,
455                    >::new(
456                        data_type,
457                        args.return_type(),
458                        avg_fn,
459                    ))),
460                    TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::<
461                        DurationMillisecondType,
462                        _,
463                    >::new(
464                        data_type,
465                        args.return_type(),
466                        avg_fn,
467                    ))),
468                    TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::<
469                        DurationMicrosecondType,
470                        _,
471                    >::new(
472                        data_type,
473                        args.return_type(),
474                        avg_fn,
475                    ))),
476                    TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
477                        DurationNanosecondType,
478                        _,
479                    >::new(
480                        data_type,
481                        args.return_type(),
482                        avg_fn,
483                    ))),
484                }
485            }
486
487            _ => not_impl_err!(
488                "AvgGroupsAccumulator for ({} --> {})",
489                &data_type,
490                args.return_field.data_type()
491            ),
492        }
493    }
494
495    fn aliases(&self) -> &[String] {
496        &self.aliases
497    }
498
499    fn reverse_expr(&self) -> ReversedUDAF {
500        ReversedUDAF::Identical
501    }
502
503    fn documentation(&self) -> Option<&Documentation> {
504        self.doc()
505    }
506}
507
508/// An accumulator to compute the average
509#[derive(Debug, Default)]
510pub struct AvgAccumulator {
511    sum: Option<f64>,
512    count: u64,
513}
514
515impl Accumulator for AvgAccumulator {
516    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
517        let values = values[0].as_primitive::<Float64Type>();
518        self.count += (values.len() - values.null_count()) as u64;
519        if let Some(x) = sum(values) {
520            let v = self.sum.get_or_insert(0.);
521            *v += x;
522        }
523        Ok(())
524    }
525
526    fn evaluate(&mut self) -> Result<ScalarValue> {
527        Ok(ScalarValue::Float64(
528            self.sum.map(|f| f / self.count as f64),
529        ))
530    }
531
532    fn size(&self) -> usize {
533        size_of_val(self)
534    }
535
536    fn state(&mut self) -> Result<Vec<ScalarValue>> {
537        Ok(vec![
538            ScalarValue::from(self.count),
539            ScalarValue::Float64(self.sum),
540        ])
541    }
542
543    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
544        // counts are summed
545        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
546
547        // sums are summed
548        if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
549            let v = self.sum.get_or_insert(0.);
550            *v += x;
551        }
552        Ok(())
553    }
554    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
555        let values = values[0].as_primitive::<Float64Type>();
556        self.count -= (values.len() - values.null_count()) as u64;
557        if let Some(x) = sum(values) {
558            self.sum = Some(self.sum.unwrap() - x);
559        }
560        Ok(())
561    }
562
563    fn supports_retract_batch(&self) -> bool {
564        true
565    }
566}
567
568/// An accumulator to compute the average for decimals
569#[derive(Debug)]
570struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> {
571    sum: Option<T::Native>,
572    count: u64,
573    sum_scale: i8,
574    sum_precision: u8,
575    target_precision: u8,
576    target_scale: i8,
577}
578
579impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> {
580    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
581        let values = values[0].as_primitive::<T>();
582        self.count += (values.len() - values.null_count()) as u64;
583
584        if let Some(x) = sum(values) {
585            let v = self.sum.get_or_insert_with(T::Native::default);
586            self.sum = Some(v.add_wrapping(x));
587        }
588        Ok(())
589    }
590
591    fn evaluate(&mut self) -> Result<ScalarValue> {
592        let v = self
593            .sum
594            .map(|v| {
595                DecimalAverager::<T>::try_new(
596                    self.sum_scale,
597                    self.target_precision,
598                    self.target_scale,
599                )?
600                .avg(v, T::Native::from_usize(self.count as usize).unwrap())
601            })
602            .transpose()?;
603
604        ScalarValue::new_primitive::<T>(
605            v,
606            &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
607        )
608    }
609
610    fn size(&self) -> usize {
611        size_of_val(self)
612    }
613
614    fn state(&mut self) -> Result<Vec<ScalarValue>> {
615        Ok(vec![
616            ScalarValue::from(self.count),
617            ScalarValue::new_primitive::<T>(
618                self.sum,
619                &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
620            )?,
621        ])
622    }
623
624    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
625        // counts are summed
626        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
627
628        // sums are summed
629        if let Some(x) = sum(states[1].as_primitive::<T>()) {
630            let v = self.sum.get_or_insert_with(T::Native::default);
631            self.sum = Some(v.add_wrapping(x));
632        }
633        Ok(())
634    }
635    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
636        let values = values[0].as_primitive::<T>();
637        self.count -= (values.len() - values.null_count()) as u64;
638        if let Some(x) = sum(values) {
639            self.sum = Some(self.sum.unwrap().sub_wrapping(x));
640        }
641        Ok(())
642    }
643
644    fn supports_retract_batch(&self) -> bool {
645        true
646    }
647}
648
649/// An accumulator to compute the average for duration values
650#[derive(Debug)]
651struct DurationAvgAccumulator {
652    sum: Option<i64>,
653    count: u64,
654    time_unit: TimeUnit,
655    result_unit: TimeUnit,
656}
657
658impl Accumulator for DurationAvgAccumulator {
659    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
660        let array = &values[0];
661        self.count += (array.len() - array.null_count()) as u64;
662
663        let sum_value = match self.time_unit {
664            TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
665            TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
666            TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
667            TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
668        };
669
670        if let Some(x) = sum_value {
671            let v = self.sum.get_or_insert(0);
672            *v += x;
673        }
674        Ok(())
675    }
676
677    fn evaluate(&mut self) -> Result<ScalarValue> {
678        let avg = self.sum.map(|sum| sum / self.count as i64);
679
680        match self.result_unit {
681            TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
682            TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)),
683            TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)),
684            TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)),
685        }
686    }
687
688    fn size(&self) -> usize {
689        size_of_val(self)
690    }
691
692    fn state(&mut self) -> Result<Vec<ScalarValue>> {
693        let duration_value = match self.time_unit {
694            TimeUnit::Second => ScalarValue::DurationSecond(self.sum),
695            TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum),
696            TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum),
697            TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum),
698        };
699
700        Ok(vec![ScalarValue::from(self.count), duration_value])
701    }
702
703    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
704        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
705
706        let sum_value = match self.time_unit {
707            TimeUnit::Second => sum(states[1].as_primitive::<DurationSecondType>()),
708            TimeUnit::Millisecond => {
709                sum(states[1].as_primitive::<DurationMillisecondType>())
710            }
711            TimeUnit::Microsecond => {
712                sum(states[1].as_primitive::<DurationMicrosecondType>())
713            }
714            TimeUnit::Nanosecond => {
715                sum(states[1].as_primitive::<DurationNanosecondType>())
716            }
717        };
718
719        if let Some(x) = sum_value {
720            let v = self.sum.get_or_insert(0);
721            *v += x;
722        }
723        Ok(())
724    }
725
726    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
727        let array = &values[0];
728        self.count -= (array.len() - array.null_count()) as u64;
729
730        let sum_value = match self.time_unit {
731            TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
732            TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
733            TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
734            TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
735        };
736
737        if let Some(x) = sum_value {
738            self.sum = Some(self.sum.unwrap() - x);
739        }
740        Ok(())
741    }
742
743    fn supports_retract_batch(&self) -> bool {
744        true
745    }
746}
747
748/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
749/// Stores values as native types, and does overflow checking
750///
751/// F: Function that calculates the average value from a sum of
752/// T::Native and a total count
753#[derive(Debug)]
754struct AvgGroupsAccumulator<T, F>
755where
756    T: ArrowNumericType + Send,
757    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
758{
759    /// The type of the internal sum
760    sum_data_type: DataType,
761
762    /// The type of the returned sum
763    return_data_type: DataType,
764
765    /// Count per group (use u64 to make UInt64Array)
766    counts: Vec<u64>,
767
768    /// Sums per group, stored as the native type
769    sums: Vec<T::Native>,
770
771    /// Track nulls in the input / filters
772    null_state: NullState,
773
774    /// Function that computes the final average (value / count)
775    avg_fn: F,
776}
777
778impl<T, F> AvgGroupsAccumulator<T, F>
779where
780    T: ArrowNumericType + Send,
781    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
782{
783    pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
784        debug!(
785            "AvgGroupsAccumulator ({}, sum type: {sum_data_type}) --> {return_data_type}",
786            std::any::type_name::<T>()
787        );
788
789        Self {
790            return_data_type: return_data_type.clone(),
791            sum_data_type: sum_data_type.clone(),
792            counts: vec![],
793            sums: vec![],
794            null_state: NullState::new(),
795            avg_fn,
796        }
797    }
798}
799
800impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
801where
802    T: ArrowNumericType + Send,
803    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
804{
805    fn update_batch(
806        &mut self,
807        values: &[ArrayRef],
808        group_indices: &[usize],
809        opt_filter: Option<&BooleanArray>,
810        total_num_groups: usize,
811    ) -> Result<()> {
812        assert_eq!(values.len(), 1, "single argument to update_batch");
813        let values = values[0].as_primitive::<T>();
814
815        // increment counts, update sums
816        self.counts.resize(total_num_groups, 0);
817        self.sums.resize(total_num_groups, T::default_value());
818        self.null_state.accumulate(
819            group_indices,
820            values,
821            opt_filter,
822            total_num_groups,
823            |group_index, new_value| {
824                let sum = &mut self.sums[group_index];
825                *sum = sum.add_wrapping(new_value);
826
827                self.counts[group_index] += 1;
828            },
829        );
830
831        Ok(())
832    }
833
834    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
835        let counts = emit_to.take_needed(&mut self.counts);
836        let sums = emit_to.take_needed(&mut self.sums);
837        let nulls = self.null_state.build(emit_to);
838
839        assert_eq!(nulls.len(), sums.len());
840        assert_eq!(counts.len(), sums.len());
841
842        // don't evaluate averages with null inputs to avoid errors on null values
843
844        let array: PrimitiveArray<T> = if nulls.null_count() > 0 {
845            let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
846                .with_data_type(self.return_data_type.clone());
847            let iter = sums.into_iter().zip(counts).zip(nulls.iter());
848
849            for ((sum, count), is_valid) in iter {
850                if is_valid {
851                    builder.append_value((self.avg_fn)(sum, count)?)
852                } else {
853                    builder.append_null();
854                }
855            }
856            builder.finish()
857        } else {
858            let averages: Vec<T::Native> = sums
859                .into_iter()
860                .zip(counts.into_iter())
861                .map(|(sum, count)| (self.avg_fn)(sum, count))
862                .collect::<Result<Vec<_>>>()?;
863            PrimitiveArray::new(averages.into(), Some(nulls)) // no copy
864                .with_data_type(self.return_data_type.clone())
865        };
866
867        Ok(Arc::new(array))
868    }
869
870    // return arrays for sums and counts
871    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
872        let nulls = self.null_state.build(emit_to);
873        let nulls = Some(nulls);
874
875        let counts = emit_to.take_needed(&mut self.counts);
876        let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy
877
878        let sums = emit_to.take_needed(&mut self.sums);
879        let sums = PrimitiveArray::<T>::new(sums.into(), nulls) // zero copy
880            .with_data_type(self.sum_data_type.clone());
881
882        Ok(vec![
883            Arc::new(counts) as ArrayRef,
884            Arc::new(sums) as ArrayRef,
885        ])
886    }
887
888    fn merge_batch(
889        &mut self,
890        values: &[ArrayRef],
891        group_indices: &[usize],
892        opt_filter: Option<&BooleanArray>,
893        total_num_groups: usize,
894    ) -> Result<()> {
895        assert_eq!(values.len(), 2, "two arguments to merge_batch");
896        // first batch is counts, second is partial sums
897        let partial_counts = values[0].as_primitive::<UInt64Type>();
898        let partial_sums = values[1].as_primitive::<T>();
899        // update counts with partial counts
900        self.counts.resize(total_num_groups, 0);
901        self.null_state.accumulate(
902            group_indices,
903            partial_counts,
904            opt_filter,
905            total_num_groups,
906            |group_index, partial_count| {
907                self.counts[group_index] += partial_count;
908            },
909        );
910
911        // update sums
912        self.sums.resize(total_num_groups, T::default_value());
913        self.null_state.accumulate(
914            group_indices,
915            partial_sums,
916            opt_filter,
917            total_num_groups,
918            |group_index, new_value: <T as ArrowPrimitiveType>::Native| {
919                let sum = &mut self.sums[group_index];
920                *sum = sum.add_wrapping(new_value);
921            },
922        );
923
924        Ok(())
925    }
926
927    fn convert_to_state(
928        &self,
929        values: &[ArrayRef],
930        opt_filter: Option<&BooleanArray>,
931    ) -> Result<Vec<ArrayRef>> {
932        let sums = values[0]
933            .as_primitive::<T>()
934            .clone()
935            .with_data_type(self.sum_data_type.clone());
936        let counts = UInt64Array::from_value(1, sums.len());
937
938        let nulls = filtered_null_mask(opt_filter, &sums);
939
940        // set nulls on the arrays
941        let counts = set_nulls(counts, nulls.clone());
942        let sums = set_nulls(sums, nulls);
943
944        Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
945    }
946
947    fn supports_convert_to_state(&self) -> bool {
948        true
949    }
950
951    fn size(&self) -> usize {
952        self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
953    }
954}