datafusion_spark/function/aggregate/
avg.rs1use arrow::array::ArrowNativeTypeOp;
19use arrow::array::{
20 builder::PrimitiveBuilder,
21 cast::AsArray,
22 types::{Float64Type, Int64Type},
23 Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray,
24};
25use arrow::compute::sum;
26use arrow::datatypes::{DataType, Field, FieldRef};
27use datafusion_common::utils::take_function_args;
28use datafusion_common::{not_impl_err, plan_err, Result, ScalarValue};
29use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
30use datafusion_expr::utils::format_state_name;
31use datafusion_expr::Volatility::Immutable;
32use datafusion_expr::{
33 Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature,
34};
35use std::{any::Any, sync::Arc};
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45pub struct SparkAvg {
46 signature: Signature,
47}
48
49impl Default for SparkAvg {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl SparkAvg {
56 pub fn new() -> Self {
58 Self {
59 signature: Signature::user_defined(Immutable),
60 }
61 }
62}
63
64impl AggregateUDFImpl for SparkAvg {
65 fn as_any(&self) -> &dyn Any {
66 self
67 }
68
69 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
70 let [args] = take_function_args(self.name(), arg_types)?;
71
72 fn coerced_type(data_type: &DataType) -> Result<DataType> {
73 match &data_type {
74 d if d.is_numeric() => Ok(DataType::Float64),
75 DataType::Dictionary(_, v) => coerced_type(v.as_ref()),
76 _ => {
77 plan_err!("Avg does not support inputs of type {data_type}.")
78 }
79 }
80 }
81 Ok(vec![coerced_type(args)?])
82 }
83
84 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
85 Ok(DataType::Float64)
86 }
87
88 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
89 if acc_args.is_distinct {
90 return not_impl_err!("DistinctAvgAccumulator");
91 }
92
93 let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
94
95 match (&data_type, &acc_args.return_type()) {
97 (DataType::Float64, DataType::Float64) => {
98 Ok(Box::<AvgAccumulator>::default())
99 }
100 (dt, return_type) => {
101 not_impl_err!("AvgAccumulator for ({dt} --> {return_type})")
102 }
103 }
104 }
105
106 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
107 Ok(vec![
108 Arc::new(Field::new(
109 format_state_name(self.name(), "sum"),
110 args.input_fields[0].data_type().clone(),
111 true,
112 )),
113 Arc::new(Field::new(
114 format_state_name(self.name(), "count"),
115 DataType::Int64,
116 true,
117 )),
118 ])
119 }
120
121 fn name(&self) -> &str {
122 "avg"
123 }
124
125 fn reverse_expr(&self) -> ReversedUDAF {
126 ReversedUDAF::Identical
127 }
128
129 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
130 !args.is_distinct
131 }
132
133 fn create_groups_accumulator(
134 &self,
135 args: AccumulatorArgs,
136 ) -> Result<Box<dyn GroupsAccumulator>> {
137 let data_type = args.exprs[0].data_type(args.schema)?;
138
139 match (&data_type, args.return_type()) {
141 (DataType::Float64, DataType::Float64) => {
142 Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
143 args.return_field.data_type(),
144 |sum: f64, count: i64| Ok(sum / count as f64),
145 )))
146 }
147 (dt, return_type) => {
148 not_impl_err!("AvgGroupsAccumulator for ({dt} --> {return_type})")
149 }
150 }
151 }
152
153 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
154 Ok(ScalarValue::Float64(None))
155 }
156
157 fn signature(&self) -> &Signature {
158 &self.signature
159 }
160}
161
162#[derive(Debug, Default)]
164pub struct AvgAccumulator {
165 sum: Option<f64>,
166 count: i64,
167}
168
169impl Accumulator for AvgAccumulator {
170 fn state(&mut self) -> Result<Vec<ScalarValue>> {
171 Ok(vec![
172 ScalarValue::Float64(self.sum),
173 ScalarValue::from(self.count),
174 ])
175 }
176
177 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
178 let values = values[0].as_primitive::<Float64Type>();
179 self.count += (values.len() - values.null_count()) as i64;
180 let v = self.sum.get_or_insert(0.);
181 if let Some(x) = sum(values) {
182 *v += x;
183 }
184 Ok(())
185 }
186
187 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
188 self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
190
191 if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {
193 let v = self.sum.get_or_insert(0.);
194 *v += x;
195 }
196 Ok(())
197 }
198
199 fn evaluate(&mut self) -> Result<ScalarValue> {
200 if self.count == 0 {
201 Ok(ScalarValue::Float64(None))
204 } else {
205 Ok(ScalarValue::Float64(
206 self.sum.map(|f| f / self.count as f64),
207 ))
208 }
209 }
210
211 fn size(&self) -> usize {
212 size_of_val(self)
213 }
214}
215
216#[derive(Debug)]
222struct AvgGroupsAccumulator<T, F>
223where
224 T: ArrowNumericType + Send,
225 F: Fn(T::Native, i64) -> Result<T::Native> + Send,
226{
227 return_data_type: DataType,
229
230 counts: Vec<i64>,
232
233 sums: Vec<T::Native>,
235
236 avg_fn: F,
238}
239
240impl<T, F> AvgGroupsAccumulator<T, F>
241where
242 T: ArrowNumericType + Send,
243 F: Fn(T::Native, i64) -> Result<T::Native> + Send,
244{
245 pub fn new(return_data_type: &DataType, avg_fn: F) -> Self {
246 Self {
247 return_data_type: return_data_type.clone(),
248 counts: vec![],
249 sums: vec![],
250 avg_fn,
251 }
252 }
253}
254
255impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
256where
257 T: ArrowNumericType + Send,
258 F: Fn(T::Native, i64) -> Result<T::Native> + Send,
259{
260 fn update_batch(
261 &mut self,
262 values: &[ArrayRef],
263 group_indices: &[usize],
264 _opt_filter: Option<&arrow::array::BooleanArray>,
265 total_num_groups: usize,
266 ) -> Result<()> {
267 assert_eq!(values.len(), 1, "single argument to update_batch");
268 let values = values[0].as_primitive::<T>();
269 let data = values.values();
270
271 self.counts.resize(total_num_groups, 0);
273 self.sums.resize(total_num_groups, T::default_value());
274
275 let iter = group_indices.iter().zip(data.iter());
276 if values.null_count() == 0 {
277 for (&group_index, &value) in iter {
278 let sum = &mut self.sums[group_index];
279 *sum = (*sum).add_wrapping(value);
280 self.counts[group_index] += 1;
281 }
282 } else {
283 for (idx, (&group_index, &value)) in iter.enumerate() {
284 if values.is_null(idx) {
285 continue;
286 }
287 let sum = &mut self.sums[group_index];
288 *sum = (*sum).add_wrapping(value);
289
290 self.counts[group_index] += 1;
291 }
292 }
293
294 Ok(())
295 }
296
297 fn merge_batch(
298 &mut self,
299 values: &[ArrayRef],
300 group_indices: &[usize],
301 _opt_filter: Option<&arrow::array::BooleanArray>,
302 total_num_groups: usize,
303 ) -> Result<()> {
304 assert_eq!(values.len(), 2, "two arguments to merge_batch");
305 let partial_sums = values[0].as_primitive::<T>();
307 let partial_counts = values[1].as_primitive::<Int64Type>();
308 self.counts.resize(total_num_groups, 0);
310 let iter1 = group_indices.iter().zip(partial_counts.values().iter());
311 for (&group_index, &partial_count) in iter1 {
312 self.counts[group_index] += partial_count;
313 }
314
315 self.sums.resize(total_num_groups, T::default_value());
317 let iter2 = group_indices.iter().zip(partial_sums.values().iter());
318 for (&group_index, &new_value) in iter2 {
319 let sum = &mut self.sums[group_index];
320 *sum = sum.add_wrapping(new_value);
321 }
322
323 Ok(())
324 }
325
326 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
327 let counts = emit_to.take_needed(&mut self.counts);
328 let sums = emit_to.take_needed(&mut self.sums);
329 let mut builder = PrimitiveBuilder::<T>::with_capacity(sums.len());
330 let iter = sums.into_iter().zip(counts);
331
332 for (sum, count) in iter {
333 if count != 0 {
334 builder.append_value((self.avg_fn)(sum, count)?)
335 } else {
336 builder.append_null();
337 }
338 }
339 let array: PrimitiveArray<T> = builder.finish();
340
341 Ok(Arc::new(array))
342 }
343
344 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
346 let counts = emit_to.take_needed(&mut self.counts);
347 let counts = Int64Array::new(counts.into(), None);
348
349 let sums = emit_to.take_needed(&mut self.sums);
350 let sums = PrimitiveArray::<T>::new(sums.into(), None)
351 .with_data_type(self.return_data_type.clone());
352
353 Ok(vec![
354 Arc::new(sums) as ArrayRef,
355 Arc::new(counts) as ArrayRef,
356 ])
357 }
358
359 fn size(&self) -> usize {
360 self.counts.capacity() * size_of::<i64>() + self.sums.capacity() * size_of::<T>()
361 }
362}