datafusion_comet_spark_expr/agg_funcs/
avg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    builder::PrimitiveBuilder,
20    cast::AsArray,
21    types::{Float64Type, Int64Type},
22    Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray,
23};
24use arrow::compute::sum;
25use arrow::datatypes::{DataType, Field, FieldRef};
26use datafusion::common::{not_impl_err, Result, ScalarValue};
27use datafusion::logical_expr::{
28    type_coercion::aggregates::avg_return_type, Accumulator, AggregateUDFImpl, EmitTo,
29    GroupsAccumulator, ReversedUDAF, Signature,
30};
31use datafusion::physical_expr::expressions::format_state_name;
32use std::{any::Any, sync::Arc};
33
34use arrow::array::ArrowNativeTypeOp;
35use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
36use datafusion::logical_expr::Volatility::Immutable;
37use DataType::*;
38
39/// AVG aggregate expression
40#[derive(Debug, Clone)]
41pub struct Avg {
42    name: String,
43    signature: Signature,
44    // expr: Arc<dyn PhysicalExpr>,
45    input_data_type: DataType,
46    result_data_type: DataType,
47}
48
49impl Avg {
50    /// Create a new AVG aggregate function
51    pub fn new(name: impl Into<String>, data_type: DataType) -> Self {
52        let result_data_type = avg_return_type("avg", &data_type).unwrap();
53
54        Self {
55            name: name.into(),
56            signature: Signature::user_defined(Immutable),
57            input_data_type: data_type,
58            result_data_type,
59        }
60    }
61}
62
63impl AggregateUDFImpl for Avg {
64    /// Return a reference to Any that can be used for downcasting
65    fn as_any(&self) -> &dyn Any {
66        self
67    }
68
69    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
70        // instantiate specialized accumulator based for the type
71        match (&self.input_data_type, &self.result_data_type) {
72            (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
73            _ => not_impl_err!(
74                "AvgAccumulator for ({} --> {})",
75                self.input_data_type,
76                self.result_data_type
77            ),
78        }
79    }
80
81    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
82        Ok(vec![
83            Arc::new(Field::new(
84                format_state_name(&self.name, "sum"),
85                self.input_data_type.clone(),
86                true,
87            )),
88            Arc::new(Field::new(
89                format_state_name(&self.name, "count"),
90                DataType::Int64,
91                true,
92            )),
93        ])
94    }
95
96    fn name(&self) -> &str {
97        &self.name
98    }
99
100    fn reverse_expr(&self) -> ReversedUDAF {
101        ReversedUDAF::Identical
102    }
103
104    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
105        true
106    }
107
108    fn create_groups_accumulator(
109        &self,
110        _args: AccumulatorArgs,
111    ) -> Result<Box<dyn GroupsAccumulator>> {
112        // instantiate specialized accumulator based for the type
113        match (&self.input_data_type, &self.result_data_type) {
114            (Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
115                &self.input_data_type,
116                |sum: f64, count: i64| Ok(sum / count as f64),
117            ))),
118
119            _ => not_impl_err!(
120                "AvgGroupsAccumulator for ({} --> {})",
121                self.input_data_type,
122                self.result_data_type
123            ),
124        }
125    }
126
127    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
128        Ok(ScalarValue::Float64(None))
129    }
130
131    fn signature(&self) -> &Signature {
132        &self.signature
133    }
134
135    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
136        avg_return_type(self.name(), &arg_types[0])
137    }
138}
139
140/// An accumulator to compute the average
141#[derive(Debug, Default)]
142pub struct AvgAccumulator {
143    sum: Option<f64>,
144    count: i64,
145}
146
147impl Accumulator for AvgAccumulator {
148    fn state(&mut self) -> Result<Vec<ScalarValue>> {
149        Ok(vec![
150            ScalarValue::Float64(self.sum),
151            ScalarValue::from(self.count),
152        ])
153    }
154
155    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
156        let values = values[0].as_primitive::<Float64Type>();
157        self.count += (values.len() - values.null_count()) as i64;
158        let v = self.sum.get_or_insert(0.);
159        if let Some(x) = sum(values) {
160            *v += x;
161        }
162        Ok(())
163    }
164
165    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
166        // counts are summed
167        self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
168
169        // sums are summed
170        if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {
171            let v = self.sum.get_or_insert(0.);
172            *v += x;
173        }
174        Ok(())
175    }
176
177    fn evaluate(&mut self) -> Result<ScalarValue> {
178        if self.count == 0 {
179            // If all input are nulls, count will be 0 and we will get null after the division.
180            // This is consistent with Spark Average implementation.
181            Ok(ScalarValue::Float64(None))
182        } else {
183            Ok(ScalarValue::Float64(
184                self.sum.map(|f| f / self.count as f64),
185            ))
186        }
187    }
188
189    fn size(&self) -> usize {
190        std::mem::size_of_val(self)
191    }
192}
193
194/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
195/// Stores values as native types, and does overflow checking
196///
197/// F: Function that calculates the average value from a sum of
198/// T::Native and a total count
199#[derive(Debug)]
200struct AvgGroupsAccumulator<T, F>
201where
202    T: ArrowNumericType + Send,
203    F: Fn(T::Native, i64) -> Result<T::Native> + Send,
204{
205    /// The type of the returned average
206    return_data_type: DataType,
207
208    /// Count per group (use i64 to make Int64Array)
209    counts: Vec<i64>,
210
211    /// Sums per group, stored as the native type
212    sums: Vec<T::Native>,
213
214    /// Function that computes the final average (value / count)
215    avg_fn: F,
216}
217
218impl<T, F> AvgGroupsAccumulator<T, F>
219where
220    T: ArrowNumericType + Send,
221    F: Fn(T::Native, i64) -> Result<T::Native> + Send,
222{
223    pub fn new(return_data_type: &DataType, avg_fn: F) -> Self {
224        Self {
225            return_data_type: return_data_type.clone(),
226            counts: vec![],
227            sums: vec![],
228            avg_fn,
229        }
230    }
231}
232
233impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
234where
235    T: ArrowNumericType + Send,
236    F: Fn(T::Native, i64) -> Result<T::Native> + Send,
237{
238    fn update_batch(
239        &mut self,
240        values: &[ArrayRef],
241        group_indices: &[usize],
242        _opt_filter: Option<&arrow::array::BooleanArray>,
243        total_num_groups: usize,
244    ) -> Result<()> {
245        assert_eq!(values.len(), 1, "single argument to update_batch");
246        let values = values[0].as_primitive::<T>();
247        let data = values.values();
248
249        // increment counts, update sums
250        self.counts.resize(total_num_groups, 0);
251        self.sums.resize(total_num_groups, T::default_value());
252
253        let iter = group_indices.iter().zip(data.iter());
254        if values.null_count() == 0 {
255            for (&group_index, &value) in iter {
256                let sum = &mut self.sums[group_index];
257                *sum = (*sum).add_wrapping(value);
258                self.counts[group_index] += 1;
259            }
260        } else {
261            for (idx, (&group_index, &value)) in iter.enumerate() {
262                if values.is_null(idx) {
263                    continue;
264                }
265                let sum = &mut self.sums[group_index];
266                *sum = (*sum).add_wrapping(value);
267
268                self.counts[group_index] += 1;
269            }
270        }
271
272        Ok(())
273    }
274
275    fn merge_batch(
276        &mut self,
277        values: &[ArrayRef],
278        group_indices: &[usize],
279        _opt_filter: Option<&arrow::array::BooleanArray>,
280        total_num_groups: usize,
281    ) -> Result<()> {
282        assert_eq!(values.len(), 2, "two arguments to merge_batch");
283        // first batch is partial sums, second is counts
284        let partial_sums = values[0].as_primitive::<T>();
285        let partial_counts = values[1].as_primitive::<Int64Type>();
286        // update counts with partial counts
287        self.counts.resize(total_num_groups, 0);
288        let iter1 = group_indices.iter().zip(partial_counts.values().iter());
289        for (&group_index, &partial_count) in iter1 {
290            self.counts[group_index] += partial_count;
291        }
292
293        // update sums
294        self.sums.resize(total_num_groups, T::default_value());
295        let iter2 = group_indices.iter().zip(partial_sums.values().iter());
296        for (&group_index, &new_value) in iter2 {
297            let sum = &mut self.sums[group_index];
298            *sum = sum.add_wrapping(new_value);
299        }
300
301        Ok(())
302    }
303
304    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
305        let counts = emit_to.take_needed(&mut self.counts);
306        let sums = emit_to.take_needed(&mut self.sums);
307        let mut builder = PrimitiveBuilder::<T>::with_capacity(sums.len());
308        let iter = sums.into_iter().zip(counts);
309
310        for (sum, count) in iter {
311            if count != 0 {
312                builder.append_value((self.avg_fn)(sum, count)?)
313            } else {
314                builder.append_null();
315            }
316        }
317        let array: PrimitiveArray<T> = builder.finish();
318
319        Ok(Arc::new(array))
320    }
321
322    // return arrays for sums and counts
323    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
324        let counts = emit_to.take_needed(&mut self.counts);
325        let counts = Int64Array::new(counts.into(), None);
326
327        let sums = emit_to.take_needed(&mut self.sums);
328        let sums = PrimitiveArray::<T>::new(sums.into(), None)
329            .with_data_type(self.return_data_type.clone());
330
331        Ok(vec![
332            Arc::new(sums) as ArrayRef,
333            Arc::new(counts) as ArrayRef,
334        ])
335    }
336
337    fn size(&self) -> usize {
338        self.counts.capacity() * std::mem::size_of::<i64>()
339            + self.sums.capacity() * std::mem::size_of::<T>()
340    }
341}