1use arrow::array::{
21 Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray,
22 BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
23};
24
25use arrow::compute::sum;
26use arrow::datatypes::{
27 i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType,
28 DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
29 DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type,
30};
31use datafusion_common::{
32 exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
33};
34use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
35use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type};
36use datafusion_expr::utils::format_state_name;
37use datafusion_expr::Volatility::Immutable;
38use datafusion_expr::{
39 Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator,
40 ReversedUDAF, Signature,
41};
42
43use datafusion_functions_aggregate_common::aggregate::avg_distinct::Float64DistinctAvgAccumulator;
44use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
45use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
46 filtered_null_mask, set_nulls,
47};
48
49use datafusion_functions_aggregate_common::utils::DecimalAverager;
50use datafusion_macros::user_doc;
51use log::debug;
52use std::any::Any;
53use std::fmt::Debug;
54use std::mem::{size_of, size_of_val};
55use std::sync::Arc;
56
57make_udaf_expr_and_func!(
58 Avg,
59 avg,
60 expression,
61 "Returns the avg of a group of values.",
62 avg_udaf
63);
64
65#[user_doc(
66 doc_section(label = "General Functions"),
67 description = "Returns the average of numeric values in the specified column.",
68 syntax_example = "avg(expression)",
69 sql_example = r#"```sql
70> SELECT avg(column_name) FROM table_name;
71+---------------------------+
72| avg(column_name) |
73+---------------------------+
74| 42.75 |
75+---------------------------+
76```"#,
77 standard_argument(name = "expression",)
78)]
79#[derive(Debug, PartialEq, Eq, Hash)]
80pub struct Avg {
81 signature: Signature,
82 aliases: Vec<String>,
83}
84
85impl Avg {
86 pub fn new() -> Self {
87 Self {
88 signature: Signature::user_defined(Immutable),
89 aliases: vec![String::from("mean")],
90 }
91 }
92}
93
94impl Default for Avg {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100impl AggregateUDFImpl for Avg {
101 fn as_any(&self) -> &dyn Any {
102 self
103 }
104
105 fn name(&self) -> &str {
106 "avg"
107 }
108
109 fn signature(&self) -> &Signature {
110 &self.signature
111 }
112
113 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
114 avg_return_type(self.name(), &arg_types[0])
115 }
116
117 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
118 let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
119 use DataType::*;
120
121 if acc_args.is_distinct {
123 match &data_type {
124 Float64 => Ok(Box::new(Float64DistinctAvgAccumulator::default())),
126 _ => exec_err!("AVG(DISTINCT) for {} not supported", data_type),
127 }
128 } else {
129 match (&data_type, acc_args.return_field.data_type()) {
130 (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
131 (
132 Decimal128(sum_precision, sum_scale),
133 Decimal128(target_precision, target_scale),
134 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
135 sum: None,
136 count: 0,
137 sum_scale: *sum_scale,
138 sum_precision: *sum_precision,
139 target_precision: *target_precision,
140 target_scale: *target_scale,
141 })),
142
143 (
144 Decimal256(sum_precision, sum_scale),
145 Decimal256(target_precision, target_scale),
146 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
147 sum: None,
148 count: 0,
149 sum_scale: *sum_scale,
150 sum_precision: *sum_precision,
151 target_precision: *target_precision,
152 target_scale: *target_scale,
153 })),
154
155 (Duration(time_unit), Duration(result_unit)) => {
156 Ok(Box::new(DurationAvgAccumulator {
157 sum: None,
158 count: 0,
159 time_unit: *time_unit,
160 result_unit: *result_unit,
161 }))
162 }
163
164 _ => exec_err!(
165 "AvgAccumulator for ({} --> {})",
166 &data_type,
167 acc_args.return_field.data_type()
168 ),
169 }
170 }
171 }
172
173 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
174 if args.is_distinct {
175 Ok(vec![Field::new_list(
178 format_state_name(args.name, "avg distinct"),
179 Field::new_list_field(args.return_type().clone(), true),
180 false,
181 )
182 .into()])
183 } else {
184 Ok(vec![
185 Field::new(
186 format_state_name(args.name, "count"),
187 DataType::UInt64,
188 true,
189 ),
190 Field::new(
191 format_state_name(args.name, "sum"),
192 args.input_fields[0].data_type().clone(),
193 true,
194 ),
195 ]
196 .into_iter()
197 .map(Arc::new)
198 .collect())
199 }
200 }
201
202 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
203 matches!(
204 args.return_field.data_type(),
205 DataType::Float64 | DataType::Decimal128(_, _) | DataType::Duration(_)
206 ) && !args.is_distinct
207 }
208
209 fn create_groups_accumulator(
210 &self,
211 args: AccumulatorArgs,
212 ) -> Result<Box<dyn GroupsAccumulator>> {
213 use DataType::*;
214
215 let data_type = args.exprs[0].data_type(args.schema)?;
216 match (&data_type, args.return_field.data_type()) {
218 (Float64, Float64) => {
219 Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
220 &data_type,
221 args.return_field.data_type(),
222 |sum: f64, count: u64| Ok(sum / count as f64),
223 )))
224 }
225 (
226 Decimal128(_sum_precision, sum_scale),
227 Decimal128(target_precision, target_scale),
228 ) => {
229 let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
230 *sum_scale,
231 *target_precision,
232 *target_scale,
233 )?;
234
235 let avg_fn =
236 move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
237
238 Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
239 &data_type,
240 args.return_field.data_type(),
241 avg_fn,
242 )))
243 }
244
245 (
246 Decimal256(_sum_precision, sum_scale),
247 Decimal256(target_precision, target_scale),
248 ) => {
249 let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
250 *sum_scale,
251 *target_precision,
252 *target_scale,
253 )?;
254
255 let avg_fn = move |sum: i256, count: u64| {
256 decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
257 };
258
259 Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
260 &data_type,
261 args.return_field.data_type(),
262 avg_fn,
263 )))
264 }
265
266 (Duration(time_unit), Duration(_result_unit)) => {
267 let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64);
268
269 match time_unit {
270 TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::<
271 DurationSecondType,
272 _,
273 >::new(
274 &data_type,
275 args.return_type(),
276 avg_fn,
277 ))),
278 TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::<
279 DurationMillisecondType,
280 _,
281 >::new(
282 &data_type,
283 args.return_type(),
284 avg_fn,
285 ))),
286 TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::<
287 DurationMicrosecondType,
288 _,
289 >::new(
290 &data_type,
291 args.return_type(),
292 avg_fn,
293 ))),
294 TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
295 DurationNanosecondType,
296 _,
297 >::new(
298 &data_type,
299 args.return_type(),
300 avg_fn,
301 ))),
302 }
303 }
304
305 _ => not_impl_err!(
306 "AvgGroupsAccumulator for ({} --> {})",
307 &data_type,
308 args.return_field.data_type()
309 ),
310 }
311 }
312
313 fn aliases(&self) -> &[String] {
314 &self.aliases
315 }
316
317 fn reverse_expr(&self) -> ReversedUDAF {
318 ReversedUDAF::Identical
319 }
320
321 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
322 let [args] = take_function_args(self.name(), arg_types)?;
323 coerce_avg_type(self.name(), std::slice::from_ref(args))
324 }
325
326 fn documentation(&self) -> Option<&Documentation> {
327 self.doc()
328 }
329}
330
331#[derive(Debug, Default)]
333pub struct AvgAccumulator {
334 sum: Option<f64>,
335 count: u64,
336}
337
338impl Accumulator for AvgAccumulator {
339 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
340 let values = values[0].as_primitive::<Float64Type>();
341 self.count += (values.len() - values.null_count()) as u64;
342 if let Some(x) = sum(values) {
343 let v = self.sum.get_or_insert(0.);
344 *v += x;
345 }
346 Ok(())
347 }
348
349 fn evaluate(&mut self) -> Result<ScalarValue> {
350 Ok(ScalarValue::Float64(
351 self.sum.map(|f| f / self.count as f64),
352 ))
353 }
354
355 fn size(&self) -> usize {
356 size_of_val(self)
357 }
358
359 fn state(&mut self) -> Result<Vec<ScalarValue>> {
360 Ok(vec![
361 ScalarValue::from(self.count),
362 ScalarValue::Float64(self.sum),
363 ])
364 }
365
366 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
367 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
369
370 if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
372 let v = self.sum.get_or_insert(0.);
373 *v += x;
374 }
375 Ok(())
376 }
377 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
378 let values = values[0].as_primitive::<Float64Type>();
379 self.count -= (values.len() - values.null_count()) as u64;
380 if let Some(x) = sum(values) {
381 self.sum = Some(self.sum.unwrap() - x);
382 }
383 Ok(())
384 }
385
386 fn supports_retract_batch(&self) -> bool {
387 true
388 }
389}
390
391#[derive(Debug)]
393struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> {
394 sum: Option<T::Native>,
395 count: u64,
396 sum_scale: i8,
397 sum_precision: u8,
398 target_precision: u8,
399 target_scale: i8,
400}
401
402impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> {
403 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
404 let values = values[0].as_primitive::<T>();
405 self.count += (values.len() - values.null_count()) as u64;
406
407 if let Some(x) = sum(values) {
408 let v = self.sum.get_or_insert(T::Native::default());
409 self.sum = Some(v.add_wrapping(x));
410 }
411 Ok(())
412 }
413
414 fn evaluate(&mut self) -> Result<ScalarValue> {
415 let v = self
416 .sum
417 .map(|v| {
418 DecimalAverager::<T>::try_new(
419 self.sum_scale,
420 self.target_precision,
421 self.target_scale,
422 )?
423 .avg(v, T::Native::from_usize(self.count as usize).unwrap())
424 })
425 .transpose()?;
426
427 ScalarValue::new_primitive::<T>(
428 v,
429 &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
430 )
431 }
432
433 fn size(&self) -> usize {
434 size_of_val(self)
435 }
436
437 fn state(&mut self) -> Result<Vec<ScalarValue>> {
438 Ok(vec![
439 ScalarValue::from(self.count),
440 ScalarValue::new_primitive::<T>(
441 self.sum,
442 &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
443 )?,
444 ])
445 }
446
447 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
448 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
450
451 if let Some(x) = sum(states[1].as_primitive::<T>()) {
453 let v = self.sum.get_or_insert(T::Native::default());
454 self.sum = Some(v.add_wrapping(x));
455 }
456 Ok(())
457 }
458 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
459 let values = values[0].as_primitive::<T>();
460 self.count -= (values.len() - values.null_count()) as u64;
461 if let Some(x) = sum(values) {
462 self.sum = Some(self.sum.unwrap().sub_wrapping(x));
463 }
464 Ok(())
465 }
466
467 fn supports_retract_batch(&self) -> bool {
468 true
469 }
470}
471
472#[derive(Debug)]
474struct DurationAvgAccumulator {
475 sum: Option<i64>,
476 count: u64,
477 time_unit: TimeUnit,
478 result_unit: TimeUnit,
479}
480
481impl Accumulator for DurationAvgAccumulator {
482 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
483 let array = &values[0];
484 self.count += (array.len() - array.null_count()) as u64;
485
486 let sum_value = match self.time_unit {
487 TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
488 TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
489 TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
490 TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
491 };
492
493 if let Some(x) = sum_value {
494 let v = self.sum.get_or_insert(0);
495 *v += x;
496 }
497 Ok(())
498 }
499
500 fn evaluate(&mut self) -> Result<ScalarValue> {
501 let avg = self.sum.map(|sum| sum / self.count as i64);
502
503 match self.result_unit {
504 TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
505 TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)),
506 TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)),
507 TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)),
508 }
509 }
510
511 fn size(&self) -> usize {
512 size_of_val(self)
513 }
514
515 fn state(&mut self) -> Result<Vec<ScalarValue>> {
516 let duration_value = match self.time_unit {
517 TimeUnit::Second => ScalarValue::DurationSecond(self.sum),
518 TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum),
519 TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum),
520 TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum),
521 };
522
523 Ok(vec![ScalarValue::from(self.count), duration_value])
524 }
525
526 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
527 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
528
529 let sum_value = match self.time_unit {
530 TimeUnit::Second => sum(states[1].as_primitive::<DurationSecondType>()),
531 TimeUnit::Millisecond => {
532 sum(states[1].as_primitive::<DurationMillisecondType>())
533 }
534 TimeUnit::Microsecond => {
535 sum(states[1].as_primitive::<DurationMicrosecondType>())
536 }
537 TimeUnit::Nanosecond => {
538 sum(states[1].as_primitive::<DurationNanosecondType>())
539 }
540 };
541
542 if let Some(x) = sum_value {
543 let v = self.sum.get_or_insert(0);
544 *v += x;
545 }
546 Ok(())
547 }
548
549 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
550 let array = &values[0];
551 self.count -= (array.len() - array.null_count()) as u64;
552
553 let sum_value = match self.time_unit {
554 TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
555 TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
556 TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
557 TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
558 };
559
560 if let Some(x) = sum_value {
561 self.sum = Some(self.sum.unwrap() - x);
562 }
563 Ok(())
564 }
565
566 fn supports_retract_batch(&self) -> bool {
567 true
568 }
569}
570
571#[derive(Debug)]
577struct AvgGroupsAccumulator<T, F>
578where
579 T: ArrowNumericType + Send,
580 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
581{
582 sum_data_type: DataType,
584
585 return_data_type: DataType,
587
588 counts: Vec<u64>,
590
591 sums: Vec<T::Native>,
593
594 null_state: NullState,
596
597 avg_fn: F,
599}
600
601impl<T, F> AvgGroupsAccumulator<T, F>
602where
603 T: ArrowNumericType + Send,
604 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
605{
606 pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
607 debug!(
608 "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}",
609 std::any::type_name::<T>()
610 );
611
612 Self {
613 return_data_type: return_data_type.clone(),
614 sum_data_type: sum_data_type.clone(),
615 counts: vec![],
616 sums: vec![],
617 null_state: NullState::new(),
618 avg_fn,
619 }
620 }
621}
622
623impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
624where
625 T: ArrowNumericType + Send,
626 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
627{
628 fn update_batch(
629 &mut self,
630 values: &[ArrayRef],
631 group_indices: &[usize],
632 opt_filter: Option<&BooleanArray>,
633 total_num_groups: usize,
634 ) -> Result<()> {
635 assert_eq!(values.len(), 1, "single argument to update_batch");
636 let values = values[0].as_primitive::<T>();
637
638 self.counts.resize(total_num_groups, 0);
640 self.sums.resize(total_num_groups, T::default_value());
641 self.null_state.accumulate(
642 group_indices,
643 values,
644 opt_filter,
645 total_num_groups,
646 |group_index, new_value| {
647 let sum = &mut self.sums[group_index];
648 *sum = sum.add_wrapping(new_value);
649
650 self.counts[group_index] += 1;
651 },
652 );
653
654 Ok(())
655 }
656
657 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
658 let counts = emit_to.take_needed(&mut self.counts);
659 let sums = emit_to.take_needed(&mut self.sums);
660 let nulls = self.null_state.build(emit_to);
661
662 assert_eq!(nulls.len(), sums.len());
663 assert_eq!(counts.len(), sums.len());
664
665 let array: PrimitiveArray<T> = if nulls.null_count() > 0 {
668 let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
669 .with_data_type(self.return_data_type.clone());
670 let iter = sums.into_iter().zip(counts).zip(nulls.iter());
671
672 for ((sum, count), is_valid) in iter {
673 if is_valid {
674 builder.append_value((self.avg_fn)(sum, count)?)
675 } else {
676 builder.append_null();
677 }
678 }
679 builder.finish()
680 } else {
681 let averages: Vec<T::Native> = sums
682 .into_iter()
683 .zip(counts.into_iter())
684 .map(|(sum, count)| (self.avg_fn)(sum, count))
685 .collect::<Result<Vec<_>>>()?;
686 PrimitiveArray::new(averages.into(), Some(nulls)) .with_data_type(self.return_data_type.clone())
688 };
689
690 Ok(Arc::new(array))
691 }
692
693 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
695 let nulls = self.null_state.build(emit_to);
696 let nulls = Some(nulls);
697
698 let counts = emit_to.take_needed(&mut self.counts);
699 let counts = UInt64Array::new(counts.into(), nulls.clone()); let sums = emit_to.take_needed(&mut self.sums);
702 let sums = PrimitiveArray::<T>::new(sums.into(), nulls) .with_data_type(self.sum_data_type.clone());
704
705 Ok(vec![
706 Arc::new(counts) as ArrayRef,
707 Arc::new(sums) as ArrayRef,
708 ])
709 }
710
711 fn merge_batch(
712 &mut self,
713 values: &[ArrayRef],
714 group_indices: &[usize],
715 opt_filter: Option<&BooleanArray>,
716 total_num_groups: usize,
717 ) -> Result<()> {
718 assert_eq!(values.len(), 2, "two arguments to merge_batch");
719 let partial_counts = values[0].as_primitive::<UInt64Type>();
721 let partial_sums = values[1].as_primitive::<T>();
722 self.counts.resize(total_num_groups, 0);
724 self.null_state.accumulate(
725 group_indices,
726 partial_counts,
727 opt_filter,
728 total_num_groups,
729 |group_index, partial_count| {
730 self.counts[group_index] += partial_count;
731 },
732 );
733
734 self.sums.resize(total_num_groups, T::default_value());
736 self.null_state.accumulate(
737 group_indices,
738 partial_sums,
739 opt_filter,
740 total_num_groups,
741 |group_index, new_value: <T as ArrowPrimitiveType>::Native| {
742 let sum = &mut self.sums[group_index];
743 *sum = sum.add_wrapping(new_value);
744 },
745 );
746
747 Ok(())
748 }
749
750 fn convert_to_state(
751 &self,
752 values: &[ArrayRef],
753 opt_filter: Option<&BooleanArray>,
754 ) -> Result<Vec<ArrayRef>> {
755 let sums = values[0]
756 .as_primitive::<T>()
757 .clone()
758 .with_data_type(self.sum_data_type.clone());
759 let counts = UInt64Array::from_value(1, sums.len());
760
761 let nulls = filtered_null_mask(opt_filter, &sums);
762
763 let counts = set_nulls(counts, nulls.clone());
765 let sums = set_nulls(sums, nulls);
766
767 Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
768 }
769
770 fn supports_convert_to_state(&self) -> bool {
771 true
772 }
773
774 fn size(&self) -> usize {
775 self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
776 }
777}