datafusion_spark/function/aggregate/
avg.rs1use arrow::array::{
19 Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, Int64Array, PrimitiveArray,
20 builder::PrimitiveBuilder,
21 cast::AsArray,
22 types::{Float64Type, Int64Type},
23};
24use arrow::compute::sum;
25use arrow::datatypes::{DataType, Field, FieldRef};
26use datafusion_common::types::{NativeType, logical_float64};
27use datafusion_common::{Result, ScalarValue, not_impl_err};
28use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
29use datafusion_expr::utils::format_state_name;
30use datafusion_expr::{
31 Accumulator, AggregateUDFImpl, Coercion, EmitTo, GroupsAccumulator, ReversedUDAF,
32 Signature, TypeSignatureClass, Volatility,
33};
34use std::{any::Any, sync::Arc};
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct SparkAvg {
45 signature: Signature,
46}
47
48impl Default for SparkAvg {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl SparkAvg {
55 pub fn new() -> Self {
57 Self {
58 signature: Signature::coercible(
59 vec![Coercion::new_implicit(
60 TypeSignatureClass::Native(logical_float64()),
61 vec![TypeSignatureClass::Numeric],
62 NativeType::Float64,
63 )],
64 Volatility::Immutable,
65 ),
66 }
67 }
68}
69
70impl AggregateUDFImpl for SparkAvg {
71 fn as_any(&self) -> &dyn Any {
72 self
73 }
74
75 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
76 Ok(DataType::Float64)
77 }
78
79 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
80 if acc_args.is_distinct {
81 return not_impl_err!("DistinctAvgAccumulator");
82 }
83
84 let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
85
86 match (&data_type, &acc_args.return_type()) {
88 (DataType::Float64, DataType::Float64) => {
89 Ok(Box::<AvgAccumulator>::default())
90 }
91 (dt, return_type) => {
92 not_impl_err!("AvgAccumulator for ({dt} --> {return_type})")
93 }
94 }
95 }
96
97 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
98 Ok(vec![
99 Arc::new(Field::new(
100 format_state_name(self.name(), "sum"),
101 args.input_fields[0].data_type().clone(),
102 true,
103 )),
104 Arc::new(Field::new(
105 format_state_name(self.name(), "count"),
106 DataType::Int64,
107 true,
108 )),
109 ])
110 }
111
112 fn name(&self) -> &str {
113 "avg"
114 }
115
116 fn reverse_expr(&self) -> ReversedUDAF {
117 ReversedUDAF::Identical
118 }
119
120 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
121 !args.is_distinct
122 }
123
124 fn create_groups_accumulator(
125 &self,
126 args: AccumulatorArgs,
127 ) -> Result<Box<dyn GroupsAccumulator>> {
128 let data_type = args.exprs[0].data_type(args.schema)?;
129
130 match (&data_type, args.return_type()) {
132 (DataType::Float64, DataType::Float64) => {
133 Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
134 args.return_field.data_type(),
135 |sum: f64, count: i64| Ok(sum / count as f64),
136 )))
137 }
138 (dt, return_type) => {
139 not_impl_err!("AvgGroupsAccumulator for ({dt} --> {return_type})")
140 }
141 }
142 }
143
144 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
145 Ok(ScalarValue::Float64(None))
146 }
147
148 fn signature(&self) -> &Signature {
149 &self.signature
150 }
151}
152
153#[derive(Debug, Default)]
155pub struct AvgAccumulator {
156 sum: Option<f64>,
157 count: i64,
158}
159
160impl Accumulator for AvgAccumulator {
161 fn state(&mut self) -> Result<Vec<ScalarValue>> {
162 Ok(vec![
163 ScalarValue::Float64(self.sum),
164 ScalarValue::from(self.count),
165 ])
166 }
167
168 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
169 let values = values[0].as_primitive::<Float64Type>();
170 self.count += (values.len() - values.null_count()) as i64;
171 let v = self.sum.get_or_insert(0.);
172 if let Some(x) = sum(values) {
173 *v += x;
174 }
175 Ok(())
176 }
177
178 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
179 self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
181
182 if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {
184 let v = self.sum.get_or_insert(0.);
185 *v += x;
186 }
187 Ok(())
188 }
189
190 fn evaluate(&mut self) -> Result<ScalarValue> {
191 if self.count == 0 {
192 Ok(ScalarValue::Float64(None))
195 } else {
196 Ok(ScalarValue::Float64(
197 self.sum.map(|f| f / self.count as f64),
198 ))
199 }
200 }
201
202 fn size(&self) -> usize {
203 size_of_val(self)
204 }
205}
206
207#[derive(Debug)]
213struct AvgGroupsAccumulator<T, F>
214where
215 T: ArrowNumericType + Send,
216 F: Fn(T::Native, i64) -> Result<T::Native> + Send,
217{
218 return_data_type: DataType,
220
221 counts: Vec<i64>,
223
224 sums: Vec<T::Native>,
226
227 avg_fn: F,
229}
230
231impl<T, F> AvgGroupsAccumulator<T, F>
232where
233 T: ArrowNumericType + Send,
234 F: Fn(T::Native, i64) -> Result<T::Native> + Send,
235{
236 pub fn new(return_data_type: &DataType, avg_fn: F) -> Self {
237 Self {
238 return_data_type: return_data_type.clone(),
239 counts: vec![],
240 sums: vec![],
241 avg_fn,
242 }
243 }
244}
245
246impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
247where
248 T: ArrowNumericType + Send,
249 F: Fn(T::Native, i64) -> Result<T::Native> + Send,
250{
251 fn update_batch(
252 &mut self,
253 values: &[ArrayRef],
254 group_indices: &[usize],
255 _opt_filter: Option<&arrow::array::BooleanArray>,
256 total_num_groups: usize,
257 ) -> Result<()> {
258 assert_eq!(values.len(), 1, "single argument to update_batch");
259 let values = values[0].as_primitive::<T>();
260 let data = values.values();
261
262 self.counts.resize(total_num_groups, 0);
264 self.sums.resize(total_num_groups, T::default_value());
265
266 let iter = group_indices.iter().zip(data.iter());
267 if values.null_count() == 0 {
268 for (&group_index, &value) in iter {
269 let sum = &mut self.sums[group_index];
270 *sum = (*sum).add_wrapping(value);
271 self.counts[group_index] += 1;
272 }
273 } else {
274 for (idx, (&group_index, &value)) in iter.enumerate() {
275 if values.is_null(idx) {
276 continue;
277 }
278 let sum = &mut self.sums[group_index];
279 *sum = (*sum).add_wrapping(value);
280
281 self.counts[group_index] += 1;
282 }
283 }
284
285 Ok(())
286 }
287
288 fn merge_batch(
289 &mut self,
290 values: &[ArrayRef],
291 group_indices: &[usize],
292 _opt_filter: Option<&arrow::array::BooleanArray>,
293 total_num_groups: usize,
294 ) -> Result<()> {
295 assert_eq!(values.len(), 2, "two arguments to merge_batch");
296 let partial_sums = values[0].as_primitive::<T>();
298 let partial_counts = values[1].as_primitive::<Int64Type>();
299 self.counts.resize(total_num_groups, 0);
301 let iter1 = group_indices.iter().zip(partial_counts.values().iter());
302 for (&group_index, &partial_count) in iter1 {
303 self.counts[group_index] += partial_count;
304 }
305
306 self.sums.resize(total_num_groups, T::default_value());
308 let iter2 = group_indices.iter().zip(partial_sums.values().iter());
309 for (&group_index, &new_value) in iter2 {
310 let sum = &mut self.sums[group_index];
311 *sum = sum.add_wrapping(new_value);
312 }
313
314 Ok(())
315 }
316
317 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
318 let counts = emit_to.take_needed(&mut self.counts);
319 let sums = emit_to.take_needed(&mut self.sums);
320 let mut builder = PrimitiveBuilder::<T>::with_capacity(sums.len());
321 let iter = sums.into_iter().zip(counts);
322
323 for (sum, count) in iter {
324 if count != 0 {
325 builder.append_value((self.avg_fn)(sum, count)?)
326 } else {
327 builder.append_null();
328 }
329 }
330 let array: PrimitiveArray<T> = builder.finish();
331
332 Ok(Arc::new(array))
333 }
334
335 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
337 let counts = emit_to.take_needed(&mut self.counts);
338 let counts = Int64Array::new(counts.into(), None);
339
340 let sums = emit_to.take_needed(&mut self.sums);
341 let sums = PrimitiveArray::<T>::new(sums.into(), None)
342 .with_data_type(self.return_data_type.clone());
343
344 Ok(vec![
345 Arc::new(sums) as ArrayRef,
346 Arc::new(counts) as ArrayRef,
347 ])
348 }
349
350 fn size(&self) -> usize {
351 self.counts.capacity() * size_of::<i64>() + self.sums.capacity() * size_of::<T>()
352 }
353}