Skip to main content

datafusion_spark/function/aggregate/
avg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, Int64Array, PrimitiveArray,
20    builder::PrimitiveBuilder,
21    cast::AsArray,
22    types::{Float64Type, Int64Type},
23};
24use arrow::compute::sum;
25use arrow::datatypes::{DataType, Field, FieldRef};
26use datafusion_common::types::{NativeType, logical_float64};
27use datafusion_common::{Result, ScalarValue, not_impl_err};
28use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
29use datafusion_expr::utils::format_state_name;
30use datafusion_expr::{
31    Accumulator, AggregateUDFImpl, Coercion, EmitTo, GroupsAccumulator, ReversedUDAF,
32    Signature, TypeSignatureClass, Volatility,
33};
34use std::{any::Any, sync::Arc};
35
36/// AVG aggregate expression
37/// Spark average aggregate expression. Differs from standard DataFusion average aggregate
38/// in that it uses an `i64` for the count (DataFusion version uses `u64`); also there is ANSI mode
39/// support planned in the future for Spark version.
40
41// TODO: see if can deduplicate with DF version
42//       https://github.com/apache/datafusion/issues/17964
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct SparkAvg {
45    signature: Signature,
46}
47
48impl Default for SparkAvg {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl SparkAvg {
55    /// Implement AVG aggregate function
56    pub fn new() -> Self {
57        Self {
58            signature: Signature::coercible(
59                vec![Coercion::new_implicit(
60                    TypeSignatureClass::Native(logical_float64()),
61                    vec![TypeSignatureClass::Numeric],
62                    NativeType::Float64,
63                )],
64                Volatility::Immutable,
65            ),
66        }
67    }
68}
69
70impl AggregateUDFImpl for SparkAvg {
71    fn as_any(&self) -> &dyn Any {
72        self
73    }
74
75    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
76        Ok(DataType::Float64)
77    }
78
79    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
80        if acc_args.is_distinct {
81            return not_impl_err!("DistinctAvgAccumulator");
82        }
83
84        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
85
86        // instantiate specialized accumulator based for the type
87        match (&data_type, &acc_args.return_type()) {
88            (DataType::Float64, DataType::Float64) => {
89                Ok(Box::<AvgAccumulator>::default())
90            }
91            (dt, return_type) => {
92                not_impl_err!("AvgAccumulator for ({dt} --> {return_type})")
93            }
94        }
95    }
96
97    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
98        Ok(vec![
99            Arc::new(Field::new(
100                format_state_name(self.name(), "sum"),
101                args.input_fields[0].data_type().clone(),
102                true,
103            )),
104            Arc::new(Field::new(
105                format_state_name(self.name(), "count"),
106                DataType::Int64,
107                true,
108            )),
109        ])
110    }
111
112    fn name(&self) -> &str {
113        "avg"
114    }
115
116    fn reverse_expr(&self) -> ReversedUDAF {
117        ReversedUDAF::Identical
118    }
119
120    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
121        !args.is_distinct
122    }
123
124    fn create_groups_accumulator(
125        &self,
126        args: AccumulatorArgs,
127    ) -> Result<Box<dyn GroupsAccumulator>> {
128        let data_type = args.exprs[0].data_type(args.schema)?;
129
130        // instantiate specialized accumulator based for the type
131        match (&data_type, args.return_type()) {
132            (DataType::Float64, DataType::Float64) => {
133                Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
134                    args.return_field.data_type(),
135                    |sum: f64, count: i64| Ok(sum / count as f64),
136                )))
137            }
138            (dt, return_type) => {
139                not_impl_err!("AvgGroupsAccumulator for ({dt} --> {return_type})")
140            }
141        }
142    }
143
144    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
145        Ok(ScalarValue::Float64(None))
146    }
147
148    fn signature(&self) -> &Signature {
149        &self.signature
150    }
151}
152
153/// An accumulator to compute the average
154#[derive(Debug, Default)]
155pub struct AvgAccumulator {
156    sum: Option<f64>,
157    count: i64,
158}
159
160impl Accumulator for AvgAccumulator {
161    fn state(&mut self) -> Result<Vec<ScalarValue>> {
162        Ok(vec![
163            ScalarValue::Float64(self.sum),
164            ScalarValue::from(self.count),
165        ])
166    }
167
168    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
169        let values = values[0].as_primitive::<Float64Type>();
170        self.count += (values.len() - values.null_count()) as i64;
171        let v = self.sum.get_or_insert(0.);
172        if let Some(x) = sum(values) {
173            *v += x;
174        }
175        Ok(())
176    }
177
178    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
179        // counts are summed
180        self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
181
182        // sums are summed
183        if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {
184            let v = self.sum.get_or_insert(0.);
185            *v += x;
186        }
187        Ok(())
188    }
189
190    fn evaluate(&mut self) -> Result<ScalarValue> {
191        if self.count == 0 {
192            // If all input are nulls, count will be 0 and we will get null after the division.
193            // This is consistent with Spark Average implementation.
194            Ok(ScalarValue::Float64(None))
195        } else {
196            Ok(ScalarValue::Float64(
197                self.sum.map(|f| f / self.count as f64),
198            ))
199        }
200    }
201
202    fn size(&self) -> usize {
203        size_of_val(self)
204    }
205}
206
207/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
208/// Stores values as native types, and does overflow checking
209///
210/// F: Function that calculates the average value from a sum of
211/// T::Native and a total count
212#[derive(Debug)]
213struct AvgGroupsAccumulator<T, F>
214where
215    T: ArrowNumericType + Send,
216    F: Fn(T::Native, i64) -> Result<T::Native> + Send,
217{
218    /// The type of the returned average
219    return_data_type: DataType,
220
221    /// Count per group (use i64 to make Int64Array)
222    counts: Vec<i64>,
223
224    /// Sums per group, stored as the native type
225    sums: Vec<T::Native>,
226
227    /// Function that computes the final average (value / count)
228    avg_fn: F,
229}
230
231impl<T, F> AvgGroupsAccumulator<T, F>
232where
233    T: ArrowNumericType + Send,
234    F: Fn(T::Native, i64) -> Result<T::Native> + Send,
235{
236    pub fn new(return_data_type: &DataType, avg_fn: F) -> Self {
237        Self {
238            return_data_type: return_data_type.clone(),
239            counts: vec![],
240            sums: vec![],
241            avg_fn,
242        }
243    }
244}
245
246impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
247where
248    T: ArrowNumericType + Send,
249    F: Fn(T::Native, i64) -> Result<T::Native> + Send,
250{
251    fn update_batch(
252        &mut self,
253        values: &[ArrayRef],
254        group_indices: &[usize],
255        _opt_filter: Option<&arrow::array::BooleanArray>,
256        total_num_groups: usize,
257    ) -> Result<()> {
258        assert_eq!(values.len(), 1, "single argument to update_batch");
259        let values = values[0].as_primitive::<T>();
260        let data = values.values();
261
262        // increment counts, update sums
263        self.counts.resize(total_num_groups, 0);
264        self.sums.resize(total_num_groups, T::default_value());
265
266        let iter = group_indices.iter().zip(data.iter());
267        if values.null_count() == 0 {
268            for (&group_index, &value) in iter {
269                let sum = &mut self.sums[group_index];
270                *sum = (*sum).add_wrapping(value);
271                self.counts[group_index] += 1;
272            }
273        } else {
274            for (idx, (&group_index, &value)) in iter.enumerate() {
275                if values.is_null(idx) {
276                    continue;
277                }
278                let sum = &mut self.sums[group_index];
279                *sum = (*sum).add_wrapping(value);
280
281                self.counts[group_index] += 1;
282            }
283        }
284
285        Ok(())
286    }
287
288    fn merge_batch(
289        &mut self,
290        values: &[ArrayRef],
291        group_indices: &[usize],
292        _opt_filter: Option<&arrow::array::BooleanArray>,
293        total_num_groups: usize,
294    ) -> Result<()> {
295        assert_eq!(values.len(), 2, "two arguments to merge_batch");
296        // first batch is partial sums, second is counts
297        let partial_sums = values[0].as_primitive::<T>();
298        let partial_counts = values[1].as_primitive::<Int64Type>();
299        // update counts with partial counts
300        self.counts.resize(total_num_groups, 0);
301        let iter1 = group_indices.iter().zip(partial_counts.values().iter());
302        for (&group_index, &partial_count) in iter1 {
303            self.counts[group_index] += partial_count;
304        }
305
306        // update sums
307        self.sums.resize(total_num_groups, T::default_value());
308        let iter2 = group_indices.iter().zip(partial_sums.values().iter());
309        for (&group_index, &new_value) in iter2 {
310            let sum = &mut self.sums[group_index];
311            *sum = sum.add_wrapping(new_value);
312        }
313
314        Ok(())
315    }
316
317    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
318        let counts = emit_to.take_needed(&mut self.counts);
319        let sums = emit_to.take_needed(&mut self.sums);
320        let mut builder = PrimitiveBuilder::<T>::with_capacity(sums.len());
321        let iter = sums.into_iter().zip(counts);
322
323        for (sum, count) in iter {
324            if count != 0 {
325                builder.append_value((self.avg_fn)(sum, count)?)
326            } else {
327                builder.append_null();
328            }
329        }
330        let array: PrimitiveArray<T> = builder.finish();
331
332        Ok(Arc::new(array))
333    }
334
335    // return arrays for sums and counts
336    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
337        let counts = emit_to.take_needed(&mut self.counts);
338        let counts = Int64Array::new(counts.into(), None);
339
340        let sums = emit_to.take_needed(&mut self.sums);
341        let sums = PrimitiveArray::<T>::new(sums.into(), None)
342            .with_data_type(self.return_data_type.clone());
343
344        Ok(vec![
345            Arc::new(sums) as ArrayRef,
346            Arc::new(counts) as ArrayRef,
347        ])
348    }
349
350    fn size(&self) -> usize {
351        self.counts.capacity() * size_of::<i64>() + self.sums.capacity() * size_of::<T>()
352    }
353}