datafusion_spark/function/aggregate/
avg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::ArrowNativeTypeOp;
19use arrow::array::{
20    builder::PrimitiveBuilder,
21    cast::AsArray,
22    types::{Float64Type, Int64Type},
23    Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray,
24};
25use arrow::compute::sum;
26use arrow::datatypes::{DataType, Field, FieldRef};
27use datafusion_common::utils::take_function_args;
28use datafusion_common::{not_impl_err, plan_err, Result, ScalarValue};
29use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
30use datafusion_expr::utils::format_state_name;
31use datafusion_expr::Volatility::Immutable;
32use datafusion_expr::{
33    Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature,
34};
35use std::{any::Any, sync::Arc};
36
37/// AVG aggregate expression
38/// Spark average aggregate expression. Differs from standard DataFusion average aggregate
39/// in that it uses an `i64` for the count (DataFusion version uses `u64`); also there is ANSI mode
40/// support planned in the future for Spark version.
41
42// TODO: see if can deduplicate with DF version
43//       https://github.com/apache/datafusion/issues/17964
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45pub struct SparkAvg {
46    signature: Signature,
47}
48
49impl Default for SparkAvg {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl SparkAvg {
56    /// Implement AVG aggregate function
57    pub fn new() -> Self {
58        Self {
59            signature: Signature::user_defined(Immutable),
60        }
61    }
62}
63
64impl AggregateUDFImpl for SparkAvg {
65    fn as_any(&self) -> &dyn Any {
66        self
67    }
68
69    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
70        let [args] = take_function_args(self.name(), arg_types)?;
71
72        fn coerced_type(data_type: &DataType) -> Result<DataType> {
73            match &data_type {
74                d if d.is_numeric() => Ok(DataType::Float64),
75                DataType::Dictionary(_, v) => coerced_type(v.as_ref()),
76                _ => {
77                    plan_err!("Avg does not support inputs of type {data_type}.")
78                }
79            }
80        }
81        Ok(vec![coerced_type(args)?])
82    }
83
84    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
85        Ok(DataType::Float64)
86    }
87
88    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
89        if acc_args.is_distinct {
90            return not_impl_err!("DistinctAvgAccumulator");
91        }
92
93        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
94
95        // instantiate specialized accumulator based for the type
96        match (&data_type, &acc_args.return_type()) {
97            (DataType::Float64, DataType::Float64) => {
98                Ok(Box::<AvgAccumulator>::default())
99            }
100            (dt, return_type) => {
101                not_impl_err!("AvgAccumulator for ({dt} --> {return_type})")
102            }
103        }
104    }
105
106    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
107        Ok(vec![
108            Arc::new(Field::new(
109                format_state_name(self.name(), "sum"),
110                args.input_fields[0].data_type().clone(),
111                true,
112            )),
113            Arc::new(Field::new(
114                format_state_name(self.name(), "count"),
115                DataType::Int64,
116                true,
117            )),
118        ])
119    }
120
121    fn name(&self) -> &str {
122        "avg"
123    }
124
125    fn reverse_expr(&self) -> ReversedUDAF {
126        ReversedUDAF::Identical
127    }
128
129    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
130        !args.is_distinct
131    }
132
133    fn create_groups_accumulator(
134        &self,
135        args: AccumulatorArgs,
136    ) -> Result<Box<dyn GroupsAccumulator>> {
137        let data_type = args.exprs[0].data_type(args.schema)?;
138
139        // instantiate specialized accumulator based for the type
140        match (&data_type, args.return_type()) {
141            (DataType::Float64, DataType::Float64) => {
142                Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
143                    args.return_field.data_type(),
144                    |sum: f64, count: i64| Ok(sum / count as f64),
145                )))
146            }
147            (dt, return_type) => {
148                not_impl_err!("AvgGroupsAccumulator for ({dt} --> {return_type})")
149            }
150        }
151    }
152
153    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
154        Ok(ScalarValue::Float64(None))
155    }
156
157    fn signature(&self) -> &Signature {
158        &self.signature
159    }
160}
161
162/// An accumulator to compute the average
163#[derive(Debug, Default)]
164pub struct AvgAccumulator {
165    sum: Option<f64>,
166    count: i64,
167}
168
169impl Accumulator for AvgAccumulator {
170    fn state(&mut self) -> Result<Vec<ScalarValue>> {
171        Ok(vec![
172            ScalarValue::Float64(self.sum),
173            ScalarValue::from(self.count),
174        ])
175    }
176
177    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
178        let values = values[0].as_primitive::<Float64Type>();
179        self.count += (values.len() - values.null_count()) as i64;
180        let v = self.sum.get_or_insert(0.);
181        if let Some(x) = sum(values) {
182            *v += x;
183        }
184        Ok(())
185    }
186
187    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
188        // counts are summed
189        self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
190
191        // sums are summed
192        if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {
193            let v = self.sum.get_or_insert(0.);
194            *v += x;
195        }
196        Ok(())
197    }
198
199    fn evaluate(&mut self) -> Result<ScalarValue> {
200        if self.count == 0 {
201            // If all input are nulls, count will be 0 and we will get null after the division.
202            // This is consistent with Spark Average implementation.
203            Ok(ScalarValue::Float64(None))
204        } else {
205            Ok(ScalarValue::Float64(
206                self.sum.map(|f| f / self.count as f64),
207            ))
208        }
209    }
210
211    fn size(&self) -> usize {
212        size_of_val(self)
213    }
214}
215
216/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
217/// Stores values as native types, and does overflow checking
218///
219/// F: Function that calculates the average value from a sum of
220/// T::Native and a total count
221#[derive(Debug)]
222struct AvgGroupsAccumulator<T, F>
223where
224    T: ArrowNumericType + Send,
225    F: Fn(T::Native, i64) -> Result<T::Native> + Send,
226{
227    /// The type of the returned average
228    return_data_type: DataType,
229
230    /// Count per group (use i64 to make Int64Array)
231    counts: Vec<i64>,
232
233    /// Sums per group, stored as the native type
234    sums: Vec<T::Native>,
235
236    /// Function that computes the final average (value / count)
237    avg_fn: F,
238}
239
240impl<T, F> AvgGroupsAccumulator<T, F>
241where
242    T: ArrowNumericType + Send,
243    F: Fn(T::Native, i64) -> Result<T::Native> + Send,
244{
245    pub fn new(return_data_type: &DataType, avg_fn: F) -> Self {
246        Self {
247            return_data_type: return_data_type.clone(),
248            counts: vec![],
249            sums: vec![],
250            avg_fn,
251        }
252    }
253}
254
255impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
256where
257    T: ArrowNumericType + Send,
258    F: Fn(T::Native, i64) -> Result<T::Native> + Send,
259{
260    fn update_batch(
261        &mut self,
262        values: &[ArrayRef],
263        group_indices: &[usize],
264        _opt_filter: Option<&arrow::array::BooleanArray>,
265        total_num_groups: usize,
266    ) -> Result<()> {
267        assert_eq!(values.len(), 1, "single argument to update_batch");
268        let values = values[0].as_primitive::<T>();
269        let data = values.values();
270
271        // increment counts, update sums
272        self.counts.resize(total_num_groups, 0);
273        self.sums.resize(total_num_groups, T::default_value());
274
275        let iter = group_indices.iter().zip(data.iter());
276        if values.null_count() == 0 {
277            for (&group_index, &value) in iter {
278                let sum = &mut self.sums[group_index];
279                *sum = (*sum).add_wrapping(value);
280                self.counts[group_index] += 1;
281            }
282        } else {
283            for (idx, (&group_index, &value)) in iter.enumerate() {
284                if values.is_null(idx) {
285                    continue;
286                }
287                let sum = &mut self.sums[group_index];
288                *sum = (*sum).add_wrapping(value);
289
290                self.counts[group_index] += 1;
291            }
292        }
293
294        Ok(())
295    }
296
297    fn merge_batch(
298        &mut self,
299        values: &[ArrayRef],
300        group_indices: &[usize],
301        _opt_filter: Option<&arrow::array::BooleanArray>,
302        total_num_groups: usize,
303    ) -> Result<()> {
304        assert_eq!(values.len(), 2, "two arguments to merge_batch");
305        // first batch is partial sums, second is counts
306        let partial_sums = values[0].as_primitive::<T>();
307        let partial_counts = values[1].as_primitive::<Int64Type>();
308        // update counts with partial counts
309        self.counts.resize(total_num_groups, 0);
310        let iter1 = group_indices.iter().zip(partial_counts.values().iter());
311        for (&group_index, &partial_count) in iter1 {
312            self.counts[group_index] += partial_count;
313        }
314
315        // update sums
316        self.sums.resize(total_num_groups, T::default_value());
317        let iter2 = group_indices.iter().zip(partial_sums.values().iter());
318        for (&group_index, &new_value) in iter2 {
319            let sum = &mut self.sums[group_index];
320            *sum = sum.add_wrapping(new_value);
321        }
322
323        Ok(())
324    }
325
326    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
327        let counts = emit_to.take_needed(&mut self.counts);
328        let sums = emit_to.take_needed(&mut self.sums);
329        let mut builder = PrimitiveBuilder::<T>::with_capacity(sums.len());
330        let iter = sums.into_iter().zip(counts);
331
332        for (sum, count) in iter {
333            if count != 0 {
334                builder.append_value((self.avg_fn)(sum, count)?)
335            } else {
336                builder.append_null();
337            }
338        }
339        let array: PrimitiveArray<T> = builder.finish();
340
341        Ok(Arc::new(array))
342    }
343
344    // return arrays for sums and counts
345    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
346        let counts = emit_to.take_needed(&mut self.counts);
347        let counts = Int64Array::new(counts.into(), None);
348
349        let sums = emit_to.take_needed(&mut self.sums);
350        let sums = PrimitiveArray::<T>::new(sums.into(), None)
351            .with_data_type(self.return_data_type.clone());
352
353        Ok(vec![
354            Arc::new(sums) as ArrayRef,
355            Arc::new(counts) as ArrayRef,
356        ])
357    }
358
359    fn size(&self) -> usize {
360        self.counts.capacity() * size_of::<i64>() + self.sums.capacity() * size_of::<T>()
361    }
362}