1use arrow::array::{
21 Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray,
22 BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
23};
24
25use arrow::compute::sum;
26use arrow::datatypes::{
27 i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
28 Float64Type, UInt64Type,
29};
30use datafusion_common::{
31 exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
32};
33use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
34use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type};
35use datafusion_expr::utils::format_state_name;
36use datafusion_expr::Volatility::Immutable;
37use datafusion_expr::{
38 Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator,
39 ReversedUDAF, Signature,
40};
41
42use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
43use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
44 filtered_null_mask, set_nulls,
45};
46
47use datafusion_functions_aggregate_common::utils::DecimalAverager;
48use datafusion_macros::user_doc;
49use log::debug;
50use std::any::Any;
51use std::fmt::Debug;
52use std::mem::{size_of, size_of_val};
53use std::sync::Arc;
54
55make_udaf_expr_and_func!(
56 Avg,
57 avg,
58 expression,
59 "Returns the avg of a group of values.",
60 avg_udaf
61);
62
63#[user_doc(
64 doc_section(label = "General Functions"),
65 description = "Returns the average of numeric values in the specified column.",
66 syntax_example = "avg(expression)",
67 sql_example = r#"```sql
68> SELECT avg(column_name) FROM table_name;
69+---------------------------+
70| avg(column_name) |
71+---------------------------+
72| 42.75 |
73+---------------------------+
74```"#,
75 standard_argument(name = "expression",)
76)]
77#[derive(Debug)]
78pub struct Avg {
79 signature: Signature,
80 aliases: Vec<String>,
81}
82
83impl Avg {
84 pub fn new() -> Self {
85 Self {
86 signature: Signature::user_defined(Immutable),
87 aliases: vec![String::from("mean")],
88 }
89 }
90}
91
92impl Default for Avg {
93 fn default() -> Self {
94 Self::new()
95 }
96}
97
98impl AggregateUDFImpl for Avg {
99 fn as_any(&self) -> &dyn Any {
100 self
101 }
102
103 fn name(&self) -> &str {
104 "avg"
105 }
106
107 fn signature(&self) -> &Signature {
108 &self.signature
109 }
110
111 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
112 avg_return_type(self.name(), &arg_types[0])
113 }
114
115 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
116 if acc_args.is_distinct {
117 return exec_err!("avg(DISTINCT) aggregations are not available");
118 }
119 use DataType::*;
120
121 let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
122 match (&data_type, acc_args.return_type) {
124 (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
125 (
126 Decimal128(sum_precision, sum_scale),
127 Decimal128(target_precision, target_scale),
128 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
129 sum: None,
130 count: 0,
131 sum_scale: *sum_scale,
132 sum_precision: *sum_precision,
133 target_precision: *target_precision,
134 target_scale: *target_scale,
135 })),
136
137 (
138 Decimal256(sum_precision, sum_scale),
139 Decimal256(target_precision, target_scale),
140 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
141 sum: None,
142 count: 0,
143 sum_scale: *sum_scale,
144 sum_precision: *sum_precision,
145 target_precision: *target_precision,
146 target_scale: *target_scale,
147 })),
148 _ => exec_err!(
149 "AvgAccumulator for ({} --> {})",
150 &data_type,
151 acc_args.return_type
152 ),
153 }
154 }
155
156 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
157 Ok(vec![
158 Field::new(
159 format_state_name(args.name, "count"),
160 DataType::UInt64,
161 true,
162 ),
163 Field::new(
164 format_state_name(args.name, "sum"),
165 args.input_types[0].clone(),
166 true,
167 ),
168 ])
169 }
170
171 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
172 matches!(
173 args.return_type,
174 DataType::Float64 | DataType::Decimal128(_, _)
175 )
176 }
177
178 fn create_groups_accumulator(
179 &self,
180 args: AccumulatorArgs,
181 ) -> Result<Box<dyn GroupsAccumulator>> {
182 use DataType::*;
183
184 let data_type = args.exprs[0].data_type(args.schema)?;
185 match (&data_type, args.return_type) {
187 (Float64, Float64) => {
188 Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
189 &data_type,
190 args.return_type,
191 |sum: f64, count: u64| Ok(sum / count as f64),
192 )))
193 }
194 (
195 Decimal128(_sum_precision, sum_scale),
196 Decimal128(target_precision, target_scale),
197 ) => {
198 let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
199 *sum_scale,
200 *target_precision,
201 *target_scale,
202 )?;
203
204 let avg_fn =
205 move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
206
207 Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
208 &data_type,
209 args.return_type,
210 avg_fn,
211 )))
212 }
213
214 (
215 Decimal256(_sum_precision, sum_scale),
216 Decimal256(target_precision, target_scale),
217 ) => {
218 let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
219 *sum_scale,
220 *target_precision,
221 *target_scale,
222 )?;
223
224 let avg_fn = move |sum: i256, count: u64| {
225 decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
226 };
227
228 Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
229 &data_type,
230 args.return_type,
231 avg_fn,
232 )))
233 }
234
235 _ => not_impl_err!(
236 "AvgGroupsAccumulator for ({} --> {})",
237 &data_type,
238 args.return_type
239 ),
240 }
241 }
242
243 fn aliases(&self) -> &[String] {
244 &self.aliases
245 }
246
247 fn reverse_expr(&self) -> ReversedUDAF {
248 ReversedUDAF::Identical
249 }
250
251 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
252 let [args] = take_function_args(self.name(), arg_types)?;
253 coerce_avg_type(self.name(), std::slice::from_ref(args))
254 }
255
256 fn documentation(&self) -> Option<&Documentation> {
257 self.doc()
258 }
259}
260
261#[derive(Debug, Default)]
263pub struct AvgAccumulator {
264 sum: Option<f64>,
265 count: u64,
266}
267
268impl Accumulator for AvgAccumulator {
269 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
270 let values = values[0].as_primitive::<Float64Type>();
271 self.count += (values.len() - values.null_count()) as u64;
272 if let Some(x) = sum(values) {
273 let v = self.sum.get_or_insert(0.);
274 *v += x;
275 }
276 Ok(())
277 }
278
279 fn evaluate(&mut self) -> Result<ScalarValue> {
280 Ok(ScalarValue::Float64(
281 self.sum.map(|f| f / self.count as f64),
282 ))
283 }
284
285 fn size(&self) -> usize {
286 size_of_val(self)
287 }
288
289 fn state(&mut self) -> Result<Vec<ScalarValue>> {
290 Ok(vec![
291 ScalarValue::from(self.count),
292 ScalarValue::Float64(self.sum),
293 ])
294 }
295
296 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
297 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
299
300 if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
302 let v = self.sum.get_or_insert(0.);
303 *v += x;
304 }
305 Ok(())
306 }
307 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
308 let values = values[0].as_primitive::<Float64Type>();
309 self.count -= (values.len() - values.null_count()) as u64;
310 if let Some(x) = sum(values) {
311 self.sum = Some(self.sum.unwrap() - x);
312 }
313 Ok(())
314 }
315
316 fn supports_retract_batch(&self) -> bool {
317 true
318 }
319}
320
321#[derive(Debug)]
323struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> {
324 sum: Option<T::Native>,
325 count: u64,
326 sum_scale: i8,
327 sum_precision: u8,
328 target_precision: u8,
329 target_scale: i8,
330}
331
332impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> {
333 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
334 let values = values[0].as_primitive::<T>();
335 self.count += (values.len() - values.null_count()) as u64;
336
337 if let Some(x) = sum(values) {
338 let v = self.sum.get_or_insert(T::Native::default());
339 self.sum = Some(v.add_wrapping(x));
340 }
341 Ok(())
342 }
343
344 fn evaluate(&mut self) -> Result<ScalarValue> {
345 let v = self
346 .sum
347 .map(|v| {
348 DecimalAverager::<T>::try_new(
349 self.sum_scale,
350 self.target_precision,
351 self.target_scale,
352 )?
353 .avg(v, T::Native::from_usize(self.count as usize).unwrap())
354 })
355 .transpose()?;
356
357 ScalarValue::new_primitive::<T>(
358 v,
359 &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
360 )
361 }
362
363 fn size(&self) -> usize {
364 size_of_val(self)
365 }
366
367 fn state(&mut self) -> Result<Vec<ScalarValue>> {
368 Ok(vec![
369 ScalarValue::from(self.count),
370 ScalarValue::new_primitive::<T>(
371 self.sum,
372 &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
373 )?,
374 ])
375 }
376
377 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
378 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
380
381 if let Some(x) = sum(states[1].as_primitive::<T>()) {
383 let v = self.sum.get_or_insert(T::Native::default());
384 self.sum = Some(v.add_wrapping(x));
385 }
386 Ok(())
387 }
388 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
389 let values = values[0].as_primitive::<T>();
390 self.count -= (values.len() - values.null_count()) as u64;
391 if let Some(x) = sum(values) {
392 self.sum = Some(self.sum.unwrap().sub_wrapping(x));
393 }
394 Ok(())
395 }
396
397 fn supports_retract_batch(&self) -> bool {
398 true
399 }
400}
401
402#[derive(Debug)]
408struct AvgGroupsAccumulator<T, F>
409where
410 T: ArrowNumericType + Send,
411 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
412{
413 sum_data_type: DataType,
415
416 return_data_type: DataType,
418
419 counts: Vec<u64>,
421
422 sums: Vec<T::Native>,
424
425 null_state: NullState,
427
428 avg_fn: F,
430}
431
432impl<T, F> AvgGroupsAccumulator<T, F>
433where
434 T: ArrowNumericType + Send,
435 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
436{
437 pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
438 debug!(
439 "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}",
440 std::any::type_name::<T>()
441 );
442
443 Self {
444 return_data_type: return_data_type.clone(),
445 sum_data_type: sum_data_type.clone(),
446 counts: vec![],
447 sums: vec![],
448 null_state: NullState::new(),
449 avg_fn,
450 }
451 }
452}
453
454impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
455where
456 T: ArrowNumericType + Send,
457 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
458{
459 fn update_batch(
460 &mut self,
461 values: &[ArrayRef],
462 group_indices: &[usize],
463 opt_filter: Option<&BooleanArray>,
464 total_num_groups: usize,
465 ) -> Result<()> {
466 assert_eq!(values.len(), 1, "single argument to update_batch");
467 let values = values[0].as_primitive::<T>();
468
469 self.counts.resize(total_num_groups, 0);
471 self.sums.resize(total_num_groups, T::default_value());
472 self.null_state.accumulate(
473 group_indices,
474 values,
475 opt_filter,
476 total_num_groups,
477 |group_index, new_value| {
478 let sum = &mut self.sums[group_index];
479 *sum = sum.add_wrapping(new_value);
480
481 self.counts[group_index] += 1;
482 },
483 );
484
485 Ok(())
486 }
487
488 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
489 let counts = emit_to.take_needed(&mut self.counts);
490 let sums = emit_to.take_needed(&mut self.sums);
491 let nulls = self.null_state.build(emit_to);
492
493 assert_eq!(nulls.len(), sums.len());
494 assert_eq!(counts.len(), sums.len());
495
496 let array: PrimitiveArray<T> = if nulls.null_count() > 0 {
499 let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
500 .with_data_type(self.return_data_type.clone());
501 let iter = sums.into_iter().zip(counts).zip(nulls.iter());
502
503 for ((sum, count), is_valid) in iter {
504 if is_valid {
505 builder.append_value((self.avg_fn)(sum, count)?)
506 } else {
507 builder.append_null();
508 }
509 }
510 builder.finish()
511 } else {
512 let averages: Vec<T::Native> = sums
513 .into_iter()
514 .zip(counts.into_iter())
515 .map(|(sum, count)| (self.avg_fn)(sum, count))
516 .collect::<Result<Vec<_>>>()?;
517 PrimitiveArray::new(averages.into(), Some(nulls)) .with_data_type(self.return_data_type.clone())
519 };
520
521 Ok(Arc::new(array))
522 }
523
524 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
526 let nulls = self.null_state.build(emit_to);
527 let nulls = Some(nulls);
528
529 let counts = emit_to.take_needed(&mut self.counts);
530 let counts = UInt64Array::new(counts.into(), nulls.clone()); let sums = emit_to.take_needed(&mut self.sums);
533 let sums = PrimitiveArray::<T>::new(sums.into(), nulls) .with_data_type(self.sum_data_type.clone());
535
536 Ok(vec![
537 Arc::new(counts) as ArrayRef,
538 Arc::new(sums) as ArrayRef,
539 ])
540 }
541
542 fn merge_batch(
543 &mut self,
544 values: &[ArrayRef],
545 group_indices: &[usize],
546 opt_filter: Option<&BooleanArray>,
547 total_num_groups: usize,
548 ) -> Result<()> {
549 assert_eq!(values.len(), 2, "two arguments to merge_batch");
550 let partial_counts = values[0].as_primitive::<UInt64Type>();
552 let partial_sums = values[1].as_primitive::<T>();
553 self.counts.resize(total_num_groups, 0);
555 self.null_state.accumulate(
556 group_indices,
557 partial_counts,
558 opt_filter,
559 total_num_groups,
560 |group_index, partial_count| {
561 self.counts[group_index] += partial_count;
562 },
563 );
564
565 self.sums.resize(total_num_groups, T::default_value());
567 self.null_state.accumulate(
568 group_indices,
569 partial_sums,
570 opt_filter,
571 total_num_groups,
572 |group_index, new_value: <T as ArrowPrimitiveType>::Native| {
573 let sum = &mut self.sums[group_index];
574 *sum = sum.add_wrapping(new_value);
575 },
576 );
577
578 Ok(())
579 }
580
581 fn convert_to_state(
582 &self,
583 values: &[ArrayRef],
584 opt_filter: Option<&BooleanArray>,
585 ) -> Result<Vec<ArrayRef>> {
586 let sums = values[0]
587 .as_primitive::<T>()
588 .clone()
589 .with_data_type(self.sum_data_type.clone());
590 let counts = UInt64Array::from_value(1, sums.len());
591
592 let nulls = filtered_null_mask(opt_filter, &sums);
593
594 let counts = set_nulls(counts, nulls.clone());
596 let sums = set_nulls(sums, nulls);
597
598 Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
599 }
600
601 fn supports_convert_to_state(&self) -> bool {
602 true
603 }
604
605 fn size(&self) -> usize {
606 self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
607 }
608}