datafusion_functions_aggregate/
average.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `Avg` & `Mean` aggregate & accumulators
19
20use arrow::array::{
21    Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray,
22    BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
23};
24
25use arrow::compute::sum;
26use arrow::datatypes::{
27    i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
28    Float64Type, UInt64Type,
29};
30use datafusion_common::{
31    exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
32};
33use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
34use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type};
35use datafusion_expr::utils::format_state_name;
36use datafusion_expr::Volatility::Immutable;
37use datafusion_expr::{
38    Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator,
39    ReversedUDAF, Signature,
40};
41
42use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
43use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
44    filtered_null_mask, set_nulls,
45};
46
47use datafusion_functions_aggregate_common::utils::DecimalAverager;
48use datafusion_macros::user_doc;
49use log::debug;
50use std::any::Any;
51use std::fmt::Debug;
52use std::mem::{size_of, size_of_val};
53use std::sync::Arc;
54
55make_udaf_expr_and_func!(
56    Avg,
57    avg,
58    expression,
59    "Returns the avg of a group of values.",
60    avg_udaf
61);
62
63#[user_doc(
64    doc_section(label = "General Functions"),
65    description = "Returns the average of numeric values in the specified column.",
66    syntax_example = "avg(expression)",
67    sql_example = r#"```sql
68> SELECT avg(column_name) FROM table_name;
69+---------------------------+
70| avg(column_name)           |
71+---------------------------+
72| 42.75                      |
73+---------------------------+
74```"#,
75    standard_argument(name = "expression",)
76)]
77#[derive(Debug)]
78pub struct Avg {
79    signature: Signature,
80    aliases: Vec<String>,
81}
82
83impl Avg {
84    pub fn new() -> Self {
85        Self {
86            signature: Signature::user_defined(Immutable),
87            aliases: vec![String::from("mean")],
88        }
89    }
90}
91
92impl Default for Avg {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98impl AggregateUDFImpl for Avg {
99    fn as_any(&self) -> &dyn Any {
100        self
101    }
102
103    fn name(&self) -> &str {
104        "avg"
105    }
106
107    fn signature(&self) -> &Signature {
108        &self.signature
109    }
110
111    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
112        avg_return_type(self.name(), &arg_types[0])
113    }
114
115    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
116        if acc_args.is_distinct {
117            return exec_err!("avg(DISTINCT) aggregations are not available");
118        }
119        use DataType::*;
120
121        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
122        // instantiate specialized accumulator based for the type
123        match (&data_type, acc_args.return_type) {
124            (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
125            (
126                Decimal128(sum_precision, sum_scale),
127                Decimal128(target_precision, target_scale),
128            ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
129                sum: None,
130                count: 0,
131                sum_scale: *sum_scale,
132                sum_precision: *sum_precision,
133                target_precision: *target_precision,
134                target_scale: *target_scale,
135            })),
136
137            (
138                Decimal256(sum_precision, sum_scale),
139                Decimal256(target_precision, target_scale),
140            ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
141                sum: None,
142                count: 0,
143                sum_scale: *sum_scale,
144                sum_precision: *sum_precision,
145                target_precision: *target_precision,
146                target_scale: *target_scale,
147            })),
148            _ => exec_err!(
149                "AvgAccumulator for ({} --> {})",
150                &data_type,
151                acc_args.return_type
152            ),
153        }
154    }
155
156    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
157        Ok(vec![
158            Field::new(
159                format_state_name(args.name, "count"),
160                DataType::UInt64,
161                true,
162            ),
163            Field::new(
164                format_state_name(args.name, "sum"),
165                args.input_types[0].clone(),
166                true,
167            ),
168        ])
169    }
170
171    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
172        matches!(
173            args.return_type,
174            DataType::Float64 | DataType::Decimal128(_, _)
175        )
176    }
177
178    fn create_groups_accumulator(
179        &self,
180        args: AccumulatorArgs,
181    ) -> Result<Box<dyn GroupsAccumulator>> {
182        use DataType::*;
183
184        let data_type = args.exprs[0].data_type(args.schema)?;
185        // instantiate specialized accumulator based for the type
186        match (&data_type, args.return_type) {
187            (Float64, Float64) => {
188                Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
189                    &data_type,
190                    args.return_type,
191                    |sum: f64, count: u64| Ok(sum / count as f64),
192                )))
193            }
194            (
195                Decimal128(_sum_precision, sum_scale),
196                Decimal128(target_precision, target_scale),
197            ) => {
198                let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
199                    *sum_scale,
200                    *target_precision,
201                    *target_scale,
202                )?;
203
204                let avg_fn =
205                    move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
206
207                Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
208                    &data_type,
209                    args.return_type,
210                    avg_fn,
211                )))
212            }
213
214            (
215                Decimal256(_sum_precision, sum_scale),
216                Decimal256(target_precision, target_scale),
217            ) => {
218                let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
219                    *sum_scale,
220                    *target_precision,
221                    *target_scale,
222                )?;
223
224                let avg_fn = move |sum: i256, count: u64| {
225                    decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
226                };
227
228                Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
229                    &data_type,
230                    args.return_type,
231                    avg_fn,
232                )))
233            }
234
235            _ => not_impl_err!(
236                "AvgGroupsAccumulator for ({} --> {})",
237                &data_type,
238                args.return_type
239            ),
240        }
241    }
242
243    fn aliases(&self) -> &[String] {
244        &self.aliases
245    }
246
247    fn reverse_expr(&self) -> ReversedUDAF {
248        ReversedUDAF::Identical
249    }
250
251    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
252        let [args] = take_function_args(self.name(), arg_types)?;
253        coerce_avg_type(self.name(), std::slice::from_ref(args))
254    }
255
256    fn documentation(&self) -> Option<&Documentation> {
257        self.doc()
258    }
259}
260
261/// An accumulator to compute the average
262#[derive(Debug, Default)]
263pub struct AvgAccumulator {
264    sum: Option<f64>,
265    count: u64,
266}
267
268impl Accumulator for AvgAccumulator {
269    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
270        let values = values[0].as_primitive::<Float64Type>();
271        self.count += (values.len() - values.null_count()) as u64;
272        if let Some(x) = sum(values) {
273            let v = self.sum.get_or_insert(0.);
274            *v += x;
275        }
276        Ok(())
277    }
278
279    fn evaluate(&mut self) -> Result<ScalarValue> {
280        Ok(ScalarValue::Float64(
281            self.sum.map(|f| f / self.count as f64),
282        ))
283    }
284
285    fn size(&self) -> usize {
286        size_of_val(self)
287    }
288
289    fn state(&mut self) -> Result<Vec<ScalarValue>> {
290        Ok(vec![
291            ScalarValue::from(self.count),
292            ScalarValue::Float64(self.sum),
293        ])
294    }
295
296    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
297        // counts are summed
298        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
299
300        // sums are summed
301        if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
302            let v = self.sum.get_or_insert(0.);
303            *v += x;
304        }
305        Ok(())
306    }
307    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
308        let values = values[0].as_primitive::<Float64Type>();
309        self.count -= (values.len() - values.null_count()) as u64;
310        if let Some(x) = sum(values) {
311            self.sum = Some(self.sum.unwrap() - x);
312        }
313        Ok(())
314    }
315
316    fn supports_retract_batch(&self) -> bool {
317        true
318    }
319}
320
321/// An accumulator to compute the average for decimals
322#[derive(Debug)]
323struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> {
324    sum: Option<T::Native>,
325    count: u64,
326    sum_scale: i8,
327    sum_precision: u8,
328    target_precision: u8,
329    target_scale: i8,
330}
331
332impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> {
333    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
334        let values = values[0].as_primitive::<T>();
335        self.count += (values.len() - values.null_count()) as u64;
336
337        if let Some(x) = sum(values) {
338            let v = self.sum.get_or_insert(T::Native::default());
339            self.sum = Some(v.add_wrapping(x));
340        }
341        Ok(())
342    }
343
344    fn evaluate(&mut self) -> Result<ScalarValue> {
345        let v = self
346            .sum
347            .map(|v| {
348                DecimalAverager::<T>::try_new(
349                    self.sum_scale,
350                    self.target_precision,
351                    self.target_scale,
352                )?
353                .avg(v, T::Native::from_usize(self.count as usize).unwrap())
354            })
355            .transpose()?;
356
357        ScalarValue::new_primitive::<T>(
358            v,
359            &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
360        )
361    }
362
363    fn size(&self) -> usize {
364        size_of_val(self)
365    }
366
367    fn state(&mut self) -> Result<Vec<ScalarValue>> {
368        Ok(vec![
369            ScalarValue::from(self.count),
370            ScalarValue::new_primitive::<T>(
371                self.sum,
372                &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
373            )?,
374        ])
375    }
376
377    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
378        // counts are summed
379        self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
380
381        // sums are summed
382        if let Some(x) = sum(states[1].as_primitive::<T>()) {
383            let v = self.sum.get_or_insert(T::Native::default());
384            self.sum = Some(v.add_wrapping(x));
385        }
386        Ok(())
387    }
388    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
389        let values = values[0].as_primitive::<T>();
390        self.count -= (values.len() - values.null_count()) as u64;
391        if let Some(x) = sum(values) {
392            self.sum = Some(self.sum.unwrap().sub_wrapping(x));
393        }
394        Ok(())
395    }
396
397    fn supports_retract_batch(&self) -> bool {
398        true
399    }
400}
401
402/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
403/// Stores values as native types, and does overflow checking
404///
405/// F: Function that calculates the average value from a sum of
406/// T::Native and a total count
407#[derive(Debug)]
408struct AvgGroupsAccumulator<T, F>
409where
410    T: ArrowNumericType + Send,
411    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
412{
413    /// The type of the internal sum
414    sum_data_type: DataType,
415
416    /// The type of the returned sum
417    return_data_type: DataType,
418
419    /// Count per group (use u64 to make UInt64Array)
420    counts: Vec<u64>,
421
422    /// Sums per group, stored as the native type
423    sums: Vec<T::Native>,
424
425    /// Track nulls in the input / filters
426    null_state: NullState,
427
428    /// Function that computes the final average (value / count)
429    avg_fn: F,
430}
431
432impl<T, F> AvgGroupsAccumulator<T, F>
433where
434    T: ArrowNumericType + Send,
435    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
436{
437    pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
438        debug!(
439            "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}",
440            std::any::type_name::<T>()
441        );
442
443        Self {
444            return_data_type: return_data_type.clone(),
445            sum_data_type: sum_data_type.clone(),
446            counts: vec![],
447            sums: vec![],
448            null_state: NullState::new(),
449            avg_fn,
450        }
451    }
452}
453
454impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
455where
456    T: ArrowNumericType + Send,
457    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
458{
459    fn update_batch(
460        &mut self,
461        values: &[ArrayRef],
462        group_indices: &[usize],
463        opt_filter: Option<&BooleanArray>,
464        total_num_groups: usize,
465    ) -> Result<()> {
466        assert_eq!(values.len(), 1, "single argument to update_batch");
467        let values = values[0].as_primitive::<T>();
468
469        // increment counts, update sums
470        self.counts.resize(total_num_groups, 0);
471        self.sums.resize(total_num_groups, T::default_value());
472        self.null_state.accumulate(
473            group_indices,
474            values,
475            opt_filter,
476            total_num_groups,
477            |group_index, new_value| {
478                let sum = &mut self.sums[group_index];
479                *sum = sum.add_wrapping(new_value);
480
481                self.counts[group_index] += 1;
482            },
483        );
484
485        Ok(())
486    }
487
488    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
489        let counts = emit_to.take_needed(&mut self.counts);
490        let sums = emit_to.take_needed(&mut self.sums);
491        let nulls = self.null_state.build(emit_to);
492
493        assert_eq!(nulls.len(), sums.len());
494        assert_eq!(counts.len(), sums.len());
495
496        // don't evaluate averages with null inputs to avoid errors on null values
497
498        let array: PrimitiveArray<T> = if nulls.null_count() > 0 {
499            let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
500                .with_data_type(self.return_data_type.clone());
501            let iter = sums.into_iter().zip(counts).zip(nulls.iter());
502
503            for ((sum, count), is_valid) in iter {
504                if is_valid {
505                    builder.append_value((self.avg_fn)(sum, count)?)
506                } else {
507                    builder.append_null();
508                }
509            }
510            builder.finish()
511        } else {
512            let averages: Vec<T::Native> = sums
513                .into_iter()
514                .zip(counts.into_iter())
515                .map(|(sum, count)| (self.avg_fn)(sum, count))
516                .collect::<Result<Vec<_>>>()?;
517            PrimitiveArray::new(averages.into(), Some(nulls)) // no copy
518                .with_data_type(self.return_data_type.clone())
519        };
520
521        Ok(Arc::new(array))
522    }
523
524    // return arrays for sums and counts
525    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
526        let nulls = self.null_state.build(emit_to);
527        let nulls = Some(nulls);
528
529        let counts = emit_to.take_needed(&mut self.counts);
530        let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy
531
532        let sums = emit_to.take_needed(&mut self.sums);
533        let sums = PrimitiveArray::<T>::new(sums.into(), nulls) // zero copy
534            .with_data_type(self.sum_data_type.clone());
535
536        Ok(vec![
537            Arc::new(counts) as ArrayRef,
538            Arc::new(sums) as ArrayRef,
539        ])
540    }
541
542    fn merge_batch(
543        &mut self,
544        values: &[ArrayRef],
545        group_indices: &[usize],
546        opt_filter: Option<&BooleanArray>,
547        total_num_groups: usize,
548    ) -> Result<()> {
549        assert_eq!(values.len(), 2, "two arguments to merge_batch");
550        // first batch is counts, second is partial sums
551        let partial_counts = values[0].as_primitive::<UInt64Type>();
552        let partial_sums = values[1].as_primitive::<T>();
553        // update counts with partial counts
554        self.counts.resize(total_num_groups, 0);
555        self.null_state.accumulate(
556            group_indices,
557            partial_counts,
558            opt_filter,
559            total_num_groups,
560            |group_index, partial_count| {
561                self.counts[group_index] += partial_count;
562            },
563        );
564
565        // update sums
566        self.sums.resize(total_num_groups, T::default_value());
567        self.null_state.accumulate(
568            group_indices,
569            partial_sums,
570            opt_filter,
571            total_num_groups,
572            |group_index, new_value: <T as ArrowPrimitiveType>::Native| {
573                let sum = &mut self.sums[group_index];
574                *sum = sum.add_wrapping(new_value);
575            },
576        );
577
578        Ok(())
579    }
580
581    fn convert_to_state(
582        &self,
583        values: &[ArrayRef],
584        opt_filter: Option<&BooleanArray>,
585    ) -> Result<Vec<ArrayRef>> {
586        let sums = values[0]
587            .as_primitive::<T>()
588            .clone()
589            .with_data_type(self.sum_data_type.clone());
590        let counts = UInt64Array::from_value(1, sums.len());
591
592        let nulls = filtered_null_mask(opt_filter, &sums);
593
594        // set nulls on the arrays
595        let counts = set_nulls(counts, nulls.clone());
596        let sums = set_nulls(sums, nulls);
597
598        Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
599    }
600
601    fn supports_convert_to_state(&self) -> bool {
602        true
603    }
604
605    fn size(&self) -> usize {
606        self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
607    }
608}