Skip to main content

datafusion_functions_aggregate/
average.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `Avg` & `Mean` aggregate & accumulators
19
20use arrow::array::{
21    Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray,
22    BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
23};
24
25use arrow::compute::sum;
26use arrow::datatypes::{
27    ArrowNativeType, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE,
28    DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION,
29    DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType,
30    Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType,
31    DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
32    DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type, i256,
33};
34use datafusion_common::types::{NativeType, logical_float64};
35use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err};
36use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
37use datafusion_expr::utils::format_state_name;
38use datafusion_expr::{
39    Accumulator, AggregateUDFImpl, Coercion, Documentation, EmitTo, Expr,
40    GroupsAccumulator, ReversedUDAF, Signature, TypeSignature, TypeSignatureClass,
41    Volatility,
42};
43use datafusion_functions_aggregate_common::aggregate::avg_distinct::{
44    DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator,
45};
46use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
47use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
48    filtered_null_mask, set_nulls,
49};
50use datafusion_functions_aggregate_common::utils::DecimalAverager;
51use datafusion_macros::user_doc;
52use log::debug;
53use std::fmt::Debug;
54use std::mem::{size_of, size_of_val};
55use std::sync::Arc;
56
57make_udaf_expr_and_func!(
58    Avg,
59    avg,
60    expression,
61    "Returns the avg of a group of values.",
62    avg_udaf
63);
64
65pub fn avg_distinct(expr: Expr) -> Expr {
66    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
67        avg_udaf(),
68        vec![expr],
69        true,
70        None,
71        vec![],
72        None,
73    ))
74}
75
76#[user_doc(
77    doc_section(label = "General Functions"),
78    description = "Returns the average of numeric values in the specified column.",
79    syntax_example = "avg(expression)",
80    sql_example = r#"```sql
81> SELECT avg(column_name) FROM table_name;
82+---------------------------+
83| avg(column_name)           |
84+---------------------------+
85| 42.75                      |
86+---------------------------+
87```"#,
88    standard_argument(name = "expression",)
89)]
90#[derive(Debug, PartialEq, Eq, Hash)]
91pub struct Avg {
92    signature: Signature,
93    aliases: Vec<String>,
94}
95
96impl Avg {
97    pub fn new() -> Self {
98        Self {
99            // Supported types smallint, int, bigint, real, double precision, decimal, or interval
100            // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
101            signature: Signature::one_of(
102                vec![
103                    TypeSignature::Coercible(vec![Coercion::new_exact(
104                        TypeSignatureClass::Decimal,
105                    )]),
106                    TypeSignature::Coercible(vec![Coercion::new_exact(
107                        TypeSignatureClass::Duration,
108                    )]),
109                    TypeSignature::Coercible(vec![Coercion::new_implicit(
110                        TypeSignatureClass::Native(logical_float64()),
111                        vec![TypeSignatureClass::Integer, TypeSignatureClass::Float],
112                        NativeType::Float64,
113                    )]),
114                ],
115                Volatility::Immutable,
116            ),
117            aliases: vec![String::from("mean")],
118        }
119    }
120}
121
122impl Default for Avg {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128impl AggregateUDFImpl for Avg {
129    fn name(&self) -> &str {
130        "avg"
131    }
132
133    fn signature(&self) -> &Signature {
134        &self.signature
135    }
136
137    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
138        match &arg_types[0] {
139            DataType::Decimal32(precision, scale) => {
140                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
141                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
142                let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
143                let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
144                Ok(DataType::Decimal32(new_precision, new_scale))
145            }
146            DataType::Decimal64(precision, scale) => {
147                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
148                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
149                let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
150                let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
151                Ok(DataType::Decimal64(new_precision, new_scale))
152            }
153            DataType::Decimal128(precision, scale) => {
154                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
155                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
156                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
157                let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
158                Ok(DataType::Decimal128(new_precision, new_scale))
159            }
160            DataType::Decimal256(precision, scale) => {
161                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
162                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
163                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
164                let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
165                Ok(DataType::Decimal256(new_precision, new_scale))
166            }
167            DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
168            _ => Ok(DataType::Float64),
169        }
170    }
171
172    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
173        let data_type = acc_args.expr_fields[0].data_type();
174        use DataType::*;
175
176        // instantiate specialized accumulator based for the type
177        if acc_args.is_distinct {
178            match (data_type, acc_args.return_type()) {
179                // Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation
180                (Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())),
181
182                (
183                    Decimal32(_, scale),
184                    Decimal32(target_precision, target_scale),
185                ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal32Type>::with_decimal_params(
186                    *scale,
187                    *target_precision,
188                    *target_scale,
189                ))),
190                                (
191                    Decimal64(_, scale),
192                    Decimal64(target_precision, target_scale),
193                ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal64Type>::with_decimal_params(
194                    *scale,
195                    *target_precision,
196                    *target_scale,
197                ))),
198                (
199                    Decimal128(_, scale),
200                    Decimal128(target_precision, target_scale),
201                ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
202                    *scale,
203                    *target_precision,
204                    *target_scale,
205                ))),
206
207                (
208                    Decimal256(_, scale),
209                    Decimal256(target_precision, target_scale),
210                ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
211                    *scale,
212                    *target_precision,
213                    *target_scale,
214                ))),
215
216                (dt, return_type) => exec_err!(
217                    "AVG(DISTINCT) for ({} --> {}) not supported",
218                    dt,
219                    return_type
220                ),
221            }
222        } else {
223            match (&data_type, acc_args.return_type()) {
224                (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
225                (
226                    Decimal32(sum_precision, sum_scale),
227                    Decimal32(target_precision, target_scale),
228                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal32Type> {
229                    sum: None,
230                    count: 0,
231                    sum_scale: *sum_scale,
232                    sum_precision: *sum_precision,
233                    target_precision: *target_precision,
234                    target_scale: *target_scale,
235                })),
236                (
237                    Decimal64(sum_precision, sum_scale),
238                    Decimal64(target_precision, target_scale),
239                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal64Type> {
240                    sum: None,
241                    count: 0,
242                    sum_scale: *sum_scale,
243                    sum_precision: *sum_precision,
244                    target_precision: *target_precision,
245                    target_scale: *target_scale,
246                })),
247                (
248                    Decimal128(sum_precision, sum_scale),
249                    Decimal128(target_precision, target_scale),
250                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
251                    sum: None,
252                    count: 0,
253                    sum_scale: *sum_scale,
254                    sum_precision: *sum_precision,
255                    target_precision: *target_precision,
256                    target_scale: *target_scale,
257                })),
258
259                (
260                    Decimal256(sum_precision, sum_scale),
261                    Decimal256(target_precision, target_scale),
262                ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
263                    sum: None,
264                    count: 0,
265                    sum_scale: *sum_scale,
266                    sum_precision: *sum_precision,
267                    target_precision: *target_precision,
268                    target_scale: *target_scale,
269                })),
270
271                (Duration(time_unit), Duration(result_unit)) => {
272                    Ok(Box::new(DurationAvgAccumulator {
273                        sum: None,
274                        count: 0,
275                        time_unit: *time_unit,
276                        result_unit: *result_unit,
277                    }))
278                }
279
280                (dt, return_type) => {
281                    exec_err!("AvgAccumulator for ({} --> {})", dt, return_type)
282                }
283            }
284        }
285    }
286
287    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
288        if args.is_distinct {
289            // Decimal accumulator actually uses a different precision during accumulation,
290            // see DecimalDistinctAvgAccumulator::with_decimal_params
291            let dt = match args.input_fields[0].data_type() {
292                DataType::Decimal32(_, scale) => {
293                    DataType::Decimal32(DECIMAL32_MAX_PRECISION, *scale)
294                }
295                DataType::Decimal64(_, scale) => {
296                    DataType::Decimal64(DECIMAL64_MAX_PRECISION, *scale)
297                }
298                DataType::Decimal128(_, scale) => {
299                    DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale)
300                }
301                DataType::Decimal256(_, scale) => {
302                    DataType::Decimal256(DECIMAL256_MAX_PRECISION, *scale)
303                }
304                _ => args.return_type().clone(),
305            };
306            // Similar to datafusion_functions_aggregate::sum::Sum::state_fields
307            // since the accumulator uses DistinctSumAccumulator internally.
308            Ok(vec![
309                Field::new_list(
310                    format_state_name(args.name, "avg distinct"),
311                    Field::new_list_field(dt, true),
312                    false,
313                )
314                .into(),
315            ])
316        } else {
317            Ok(vec![
318                Field::new(
319                    format_state_name(args.name, "count"),
320                    DataType::UInt64,
321                    true,
322                ),
323                Field::new(
324                    format_state_name(args.name, "sum"),
325                    args.input_fields[0].data_type().clone(),
326                    true,
327                ),
328            ]
329            .into_iter()
330            .map(Arc::new)
331            .collect())
332        }
333    }
334
335    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
336        matches!(
337            args.return_field.data_type(),
338            DataType::Float64
339                | DataType::Decimal32(_, _)
340                | DataType::Decimal64(_, _)
341                | DataType::Decimal128(_, _)
342                | DataType::Decimal256(_, _)
343                | DataType::Duration(_)
344        ) && !args.is_distinct
345    }
346
347    fn create_groups_accumulator(
348        &self,
349        args: AccumulatorArgs,
350    ) -> Result<Box<dyn GroupsAccumulator>> {
351        use DataType::*;
352
353        let data_type = args.expr_fields[0].data_type();
354
355        // instantiate specialized accumulator based for the type
356        match (data_type, args.return_field.data_type()) {
357            (Float64, Float64) => {
358                Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
359                    data_type,
360                    args.return_field.data_type(),
361                    |sum: f64, count: u64| Ok(sum / count as f64),
362                )))
363            }
364            (
365                Decimal32(_sum_precision, sum_scale),
366                Decimal32(target_precision, target_scale),
367            ) => {
368                let decimal_averager = DecimalAverager::<Decimal32Type>::try_new(
369                    *sum_scale,
370                    *target_precision,
371                    *target_scale,
372                )?;
373
374                let avg_fn =
375                    move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32);
376
377                Ok(Box::new(AvgGroupsAccumulator::<Decimal32Type, _>::new(
378                    data_type,
379                    args.return_field.data_type(),
380                    avg_fn,
381                )))
382            }
383            (
384                Decimal64(_sum_precision, sum_scale),
385                Decimal64(target_precision, target_scale),
386            ) => {
387                let decimal_averager = DecimalAverager::<Decimal64Type>::try_new(
388                    *sum_scale,
389                    *target_precision,
390                    *target_scale,
391                )?;
392
393                let avg_fn =
394                    move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64);
395
396                Ok(Box::new(AvgGroupsAccumulator::<Decimal64Type, _>::new(
397                    data_type,
398                    args.return_field.data_type(),
399                    avg_fn,
400                )))
401            }
402            (
403                Decimal128(_sum_precision, sum_scale),
404                Decimal128(target_precision, target_scale),
405            ) => {
406                let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
407                    *sum_scale,
408                    *target_precision,
409                    *target_scale,
410                )?;
411
412                let avg_fn =
413                    move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
414
415                Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
416                    data_type,
417                    args.return_field.data_type(),
418                    avg_fn,
419                )))
420            }
421
422            (
423                Decimal256(_sum_precision, sum_scale),
424                Decimal256(target_precision, target_scale),
425            ) => {
426                let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
427                    *sum_scale,
428                    *target_precision,
429                    *target_scale,
430                )?;
431
432                let avg_fn = move |sum: i256, count: u64| {
433                    decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
434                };
435
436                Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
437                    data_type,
438                    args.return_field.data_type(),
439                    avg_fn,
440                )))
441            }
442
443            (Duration(time_unit), Duration(_result_unit)) => {
444                let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64);
445
446                match time_unit {
447                    TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::<
448                        DurationSecondType,
449                        _,
450                    >::new(
451                        data_type,
452                        args.return_type(),
453                        avg_fn,
454                    ))),
455                    TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::<
456                        DurationMillisecondType,
457                        _,
458                    >::new(
459                        data_type,
460                        args.return_type(),
461                        avg_fn,
462                    ))),
463                    TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::<
464                        DurationMicrosecondType,
465                        _,
466                    >::new(
467                        data_type,
468                        args.return_type(),
469                        avg_fn,
470                    ))),
471                    TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
472                        DurationNanosecondType,
473                        _,
474                    >::new(
475                        data_type,
476                        args.return_type(),
477                        avg_fn,
478                    ))),
479                }
480            }
481
482            _ => not_impl_err!(
483                "AvgGroupsAccumulator for ({} --> {})",
484                &data_type,
485                args.return_field.data_type()
486            ),
487        }
488    }
489
490    fn aliases(&self) -> &[String] {
491        &self.aliases
492    }
493
494    fn reverse_expr(&self) -> ReversedUDAF {
495        ReversedUDAF::Identical
496    }
497
498    fn documentation(&self) -> Option<&Documentation> {
499        self.doc()
500    }
501}
502
503/// An accumulator to compute the average
504#[derive(Debug, Default)]
505pub struct AvgAccumulator {
506    sum: Option<f64>,
507    count: u64,
508}
509
510impl Accumulator for AvgAccumulator {
511    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
512        let values = values[0].as_primitive::<Float64Type>();
513        self.count += (values.len() - values.null_count()) as u64;
514        if let Some(x) = sum(values) {
515            let v = self.sum.get_or_insert(0.);
516            *v += x;
517        }
518        Ok(())
519    }
520
521    fn evaluate(&mut self) -> Result<ScalarValue> {
522        // In sliding-window mode `retract_batch` can bring `count` back to 0
523        // while `sum` remains `Some(..)` (possibly zero or a floating-point
524        // residual). Guard against that so the frame with no non-NULL values
525        // yields NULL rather than NaN / ±Inf.
526        let avg = if self.count == 0 {
527            None
528        } else {
529            self.sum.map(|f| f / self.count as f64)
530        };
531        Ok(ScalarValue::Float64(avg))
532    }
533
534    fn size(&self) -> usize {
535        size_of_val(self)
536    }
537
538    fn state(&mut self) -> Result<Vec<ScalarValue>> {
539        Ok(vec![
540            ScalarValue::from(self.count),
541            ScalarValue::Float64(self.sum),
542        ])
543    }
544
545    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
546        // counts are summed
547        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
548
549        // sums are summed
550        if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
551            let v = self.sum.get_or_insert(0.);
552            *v += x;
553        }
554        Ok(())
555    }
556    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
557        let values = values[0].as_primitive::<Float64Type>();
558        self.count -= (values.len() - values.null_count()) as u64;
559        if let Some(x) = sum(values) {
560            self.sum = Some(self.sum.unwrap() - x);
561        }
562        Ok(())
563    }
564
565    fn supports_retract_batch(&self) -> bool {
566        true
567    }
568}
569
570/// An accumulator to compute the average for decimals
571#[derive(Debug)]
572struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> {
573    sum: Option<T::Native>,
574    count: u64,
575    sum_scale: i8,
576    sum_precision: u8,
577    target_precision: u8,
578    target_scale: i8,
579}
580
581impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> {
582    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
583        let values = values[0].as_primitive::<T>();
584        self.count += (values.len() - values.null_count()) as u64;
585
586        if let Some(x) = sum(values) {
587            let v = self.sum.get_or_insert_with(T::Native::default);
588            self.sum = Some(v.add_wrapping(x));
589        }
590        Ok(())
591    }
592
593    fn evaluate(&mut self) -> Result<ScalarValue> {
594        // `count == 0` can occur in sliding-window mode after `retract_batch`
595        // removes every contributing value. Return NULL rather than dividing
596        // by zero (which would panic for integer decimal types).
597        let v = if self.count == 0 {
598            None
599        } else {
600            self.sum
601                .map(|v| {
602                    DecimalAverager::<T>::try_new(
603                        self.sum_scale,
604                        self.target_precision,
605                        self.target_scale,
606                    )?
607                    .avg(v, T::Native::from_usize(self.count as usize).unwrap())
608                })
609                .transpose()?
610        };
611
612        ScalarValue::new_primitive::<T>(
613            v,
614            &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
615        )
616    }
617
618    fn size(&self) -> usize {
619        size_of_val(self)
620    }
621
622    fn state(&mut self) -> Result<Vec<ScalarValue>> {
623        Ok(vec![
624            ScalarValue::from(self.count),
625            ScalarValue::new_primitive::<T>(
626                self.sum,
627                &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
628            )?,
629        ])
630    }
631
632    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
633        // counts are summed
634        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
635
636        // sums are summed
637        if let Some(x) = sum(states[1].as_primitive::<T>()) {
638            let v = self.sum.get_or_insert_with(T::Native::default);
639            self.sum = Some(v.add_wrapping(x));
640        }
641        Ok(())
642    }
643    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
644        let values = values[0].as_primitive::<T>();
645        self.count -= (values.len() - values.null_count()) as u64;
646        if let Some(x) = sum(values) {
647            self.sum = Some(self.sum.unwrap().sub_wrapping(x));
648        }
649        Ok(())
650    }
651
652    fn supports_retract_batch(&self) -> bool {
653        true
654    }
655}
656
657/// An accumulator to compute the average for duration values
658#[derive(Debug)]
659struct DurationAvgAccumulator {
660    sum: Option<i64>,
661    count: u64,
662    time_unit: TimeUnit,
663    result_unit: TimeUnit,
664}
665
666impl Accumulator for DurationAvgAccumulator {
667    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
668        let array = &values[0];
669        self.count += (array.len() - array.null_count()) as u64;
670
671        let sum_value = match self.time_unit {
672            TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
673            TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
674            TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
675            TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
676        };
677
678        if let Some(x) = sum_value {
679            let v = self.sum.get_or_insert(0);
680            *v += x;
681        }
682        Ok(())
683    }
684
685    fn evaluate(&mut self) -> Result<ScalarValue> {
686        // Guard against `count == 0` which can happen in sliding-window mode
687        // after every contributing value has been retracted. Without this
688        // check we would integer-divide by zero.
689        let avg = if self.count == 0 {
690            None
691        } else {
692            self.sum.map(|sum| sum / self.count as i64)
693        };
694
695        match self.result_unit {
696            TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
697            TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)),
698            TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)),
699            TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)),
700        }
701    }
702
703    fn size(&self) -> usize {
704        size_of_val(self)
705    }
706
707    fn state(&mut self) -> Result<Vec<ScalarValue>> {
708        let duration_value = match self.time_unit {
709            TimeUnit::Second => ScalarValue::DurationSecond(self.sum),
710            TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum),
711            TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum),
712            TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum),
713        };
714
715        Ok(vec![ScalarValue::from(self.count), duration_value])
716    }
717
718    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
719        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
720
721        let sum_value = match self.time_unit {
722            TimeUnit::Second => sum(states[1].as_primitive::<DurationSecondType>()),
723            TimeUnit::Millisecond => {
724                sum(states[1].as_primitive::<DurationMillisecondType>())
725            }
726            TimeUnit::Microsecond => {
727                sum(states[1].as_primitive::<DurationMicrosecondType>())
728            }
729            TimeUnit::Nanosecond => {
730                sum(states[1].as_primitive::<DurationNanosecondType>())
731            }
732        };
733
734        if let Some(x) = sum_value {
735            let v = self.sum.get_or_insert(0);
736            *v += x;
737        }
738        Ok(())
739    }
740
741    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
742        let array = &values[0];
743        self.count -= (array.len() - array.null_count()) as u64;
744
745        let sum_value = match self.time_unit {
746            TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
747            TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
748            TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
749            TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
750        };
751
752        if let Some(x) = sum_value {
753            self.sum = Some(self.sum.unwrap() - x);
754        }
755        Ok(())
756    }
757
758    fn supports_retract_batch(&self) -> bool {
759        true
760    }
761}
762
763/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
764/// Stores values as native types, and does overflow checking
765///
766/// F: Function that calculates the average value from a sum of
767/// T::Native and a total count
768#[derive(Debug)]
769struct AvgGroupsAccumulator<T, F>
770where
771    T: ArrowNumericType + Send,
772    F: Fn(T::Native, u64) -> Result<T::Native> + Send + 'static,
773{
774    /// The type of the internal sum
775    sum_data_type: DataType,
776
777    /// The type of the returned sum
778    return_data_type: DataType,
779
780    /// Count per group (use u64 to make UInt64Array)
781    counts: Vec<u64>,
782
783    /// Sums per group, stored as the native type
784    sums: Vec<T::Native>,
785
786    /// Track nulls in the input / filters
787    null_state: NullState,
788
789    /// Function that computes the final average (value / count)
790    avg_fn: F,
791}
792
793impl<T, F> AvgGroupsAccumulator<T, F>
794where
795    T: ArrowNumericType + Send,
796    F: Fn(T::Native, u64) -> Result<T::Native> + Send + 'static,
797{
798    pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
799        debug!(
800            "AvgGroupsAccumulator ({}, sum type: {sum_data_type}) --> {return_data_type}",
801            std::any::type_name::<T>()
802        );
803
804        Self {
805            return_data_type: return_data_type.clone(),
806            sum_data_type: sum_data_type.clone(),
807            counts: vec![],
808            sums: vec![],
809            null_state: NullState::new(),
810            avg_fn,
811        }
812    }
813}
814
815impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
816where
817    T: ArrowNumericType + Send,
818    F: Fn(T::Native, u64) -> Result<T::Native> + Send + 'static,
819{
820    fn update_batch(
821        &mut self,
822        values: &[ArrayRef],
823        group_indices: &[usize],
824        opt_filter: Option<&BooleanArray>,
825        total_num_groups: usize,
826    ) -> Result<()> {
827        assert_eq!(values.len(), 1, "single argument to update_batch");
828        let values = values[0].as_primitive::<T>();
829
830        // increment counts, update sums
831        self.counts.resize(total_num_groups, 0);
832        self.sums.resize(total_num_groups, T::default_value());
833        self.null_state.accumulate(
834            group_indices,
835            values,
836            opt_filter,
837            total_num_groups,
838            |group_index, new_value| {
839                // SAFETY: group_index is guaranteed to be in bounds
840                let sum = unsafe { self.sums.get_unchecked_mut(group_index) };
841                *sum = sum.add_wrapping(new_value);
842
843                self.counts[group_index] += 1;
844            },
845        );
846
847        Ok(())
848    }
849
850    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
851        let counts = emit_to.take_needed(&mut self.counts);
852        let sums = emit_to.take_needed(&mut self.sums);
853        let nulls = self.null_state.build(emit_to);
854
855        if let Some(nulls) = &nulls {
856            assert_eq!(nulls.len(), sums.len());
857        }
858        assert_eq!(counts.len(), sums.len());
859
860        // don't evaluate averages with null inputs to avoid errors on null values
861
862        let array: PrimitiveArray<T> = if let Some(nulls) = &nulls
863            && nulls.null_count() > 0
864        {
865            let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
866                .with_data_type(self.return_data_type.clone());
867            let iter = sums.into_iter().zip(counts).zip(nulls.iter());
868
869            for ((sum, count), is_valid) in iter {
870                if is_valid {
871                    builder.append_value((self.avg_fn)(sum, count)?)
872                } else {
873                    builder.append_null();
874                }
875            }
876            builder.finish()
877        } else {
878            let averages: Vec<T::Native> = sums
879                .into_iter()
880                .zip(counts)
881                .map(|(sum, count)| (self.avg_fn)(sum, count))
882                .collect::<Result<Vec<_>>>()?;
883            PrimitiveArray::new(averages.into(), nulls) // no copy
884                .with_data_type(self.return_data_type.clone())
885        };
886
887        Ok(Arc::new(array))
888    }
889
890    // return arrays for sums and counts
891    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
892        let nulls = self.null_state.build(emit_to);
893
894        let counts = emit_to.take_needed(&mut self.counts);
895        let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy
896
897        let sums = emit_to.take_needed(&mut self.sums);
898        let sums = PrimitiveArray::<T>::new(sums.into(), nulls) // zero copy
899            .with_data_type(self.sum_data_type.clone());
900
901        Ok(vec![
902            Arc::new(counts) as ArrayRef,
903            Arc::new(sums) as ArrayRef,
904        ])
905    }
906
907    fn merge_batch(
908        &mut self,
909        values: &[ArrayRef],
910        group_indices: &[usize],
911        opt_filter: Option<&BooleanArray>,
912        total_num_groups: usize,
913    ) -> Result<()> {
914        assert_eq!(values.len(), 2, "two arguments to merge_batch");
915        // first batch is counts, second is partial sums
916        let partial_counts = values[0].as_primitive::<UInt64Type>();
917        let partial_sums = values[1].as_primitive::<T>();
918        // update counts with partial counts
919        self.counts.resize(total_num_groups, 0);
920        self.null_state.accumulate(
921            group_indices,
922            partial_counts,
923            opt_filter,
924            total_num_groups,
925            |group_index, partial_count| {
926                // SAFETY: group_index is guaranteed to be in bounds
927                let count = unsafe { self.counts.get_unchecked_mut(group_index) };
928                *count += partial_count;
929            },
930        );
931
932        // update sums
933        self.sums.resize(total_num_groups, T::default_value());
934        self.null_state.accumulate(
935            group_indices,
936            partial_sums,
937            opt_filter,
938            total_num_groups,
939            |group_index, new_value: <T as ArrowPrimitiveType>::Native| {
940                // SAFETY: group_index is guaranteed to be in bounds
941                let sum = unsafe { self.sums.get_unchecked_mut(group_index) };
942                *sum = sum.add_wrapping(new_value);
943            },
944        );
945
946        Ok(())
947    }
948
949    fn convert_to_state(
950        &self,
951        values: &[ArrayRef],
952        opt_filter: Option<&BooleanArray>,
953    ) -> Result<Vec<ArrayRef>> {
954        let sums = values[0]
955            .as_primitive::<T>()
956            .clone()
957            .with_data_type(self.sum_data_type.clone());
958        let counts = UInt64Array::from_value(1, sums.len());
959
960        let nulls = filtered_null_mask(opt_filter, &sums);
961
962        // set nulls on the arrays
963        let counts = set_nulls(counts, nulls.clone());
964        let sums = set_nulls(sums, nulls);
965
966        Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
967    }
968
969    fn supports_convert_to_state(&self) -> bool {
970        true
971    }
972
973    fn size(&self) -> usize {
974        self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
975    }
976}