1use arrow::array::{
21 Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray,
22 BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
23};
24
25use arrow::compute::sum;
26use arrow::datatypes::{
27 i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType,
28 DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
29 DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type,
30};
31use datafusion_common::{
32 exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
33};
34use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
35use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type};
36use datafusion_expr::utils::format_state_name;
37use datafusion_expr::Volatility::Immutable;
38use datafusion_expr::{
39 Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator,
40 ReversedUDAF, Signature,
41};
42
43use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
44use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
45 filtered_null_mask, set_nulls,
46};
47
48use datafusion_functions_aggregate_common::utils::DecimalAverager;
49use datafusion_macros::user_doc;
50use log::debug;
51use std::any::Any;
52use std::fmt::Debug;
53use std::mem::{size_of, size_of_val};
54use std::sync::Arc;
55
56make_udaf_expr_and_func!(
57 Avg,
58 avg,
59 expression,
60 "Returns the avg of a group of values.",
61 avg_udaf
62);
63
64#[user_doc(
65 doc_section(label = "General Functions"),
66 description = "Returns the average of numeric values in the specified column.",
67 syntax_example = "avg(expression)",
68 sql_example = r#"```sql
69> SELECT avg(column_name) FROM table_name;
70+---------------------------+
71| avg(column_name) |
72+---------------------------+
73| 42.75 |
74+---------------------------+
75```"#,
76 standard_argument(name = "expression",)
77)]
78#[derive(Debug)]
79pub struct Avg {
80 signature: Signature,
81 aliases: Vec<String>,
82}
83
84impl Avg {
85 pub fn new() -> Self {
86 Self {
87 signature: Signature::user_defined(Immutable),
88 aliases: vec![String::from("mean")],
89 }
90 }
91}
92
93impl Default for Avg {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99impl AggregateUDFImpl for Avg {
100 fn as_any(&self) -> &dyn Any {
101 self
102 }
103
104 fn name(&self) -> &str {
105 "avg"
106 }
107
108 fn signature(&self) -> &Signature {
109 &self.signature
110 }
111
112 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
113 avg_return_type(self.name(), &arg_types[0])
114 }
115
116 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
117 if acc_args.is_distinct {
118 return exec_err!("avg(DISTINCT) aggregations are not available");
119 }
120 use DataType::*;
121
122 let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
123 match (&data_type, acc_args.return_field.data_type()) {
125 (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
126 (
127 Decimal128(sum_precision, sum_scale),
128 Decimal128(target_precision, target_scale),
129 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
130 sum: None,
131 count: 0,
132 sum_scale: *sum_scale,
133 sum_precision: *sum_precision,
134 target_precision: *target_precision,
135 target_scale: *target_scale,
136 })),
137
138 (
139 Decimal256(sum_precision, sum_scale),
140 Decimal256(target_precision, target_scale),
141 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
142 sum: None,
143 count: 0,
144 sum_scale: *sum_scale,
145 sum_precision: *sum_precision,
146 target_precision: *target_precision,
147 target_scale: *target_scale,
148 })),
149
150 (Duration(time_unit), Duration(result_unit)) => {
151 Ok(Box::new(DurationAvgAccumulator {
152 sum: None,
153 count: 0,
154 time_unit: *time_unit,
155 result_unit: *result_unit,
156 }))
157 }
158
159 _ => exec_err!(
160 "AvgAccumulator for ({} --> {})",
161 &data_type,
162 acc_args.return_field.data_type()
163 ),
164 }
165 }
166
167 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
168 Ok(vec![
169 Field::new(
170 format_state_name(args.name, "count"),
171 DataType::UInt64,
172 true,
173 ),
174 Field::new(
175 format_state_name(args.name, "sum"),
176 args.input_fields[0].data_type().clone(),
177 true,
178 ),
179 ]
180 .into_iter()
181 .map(Arc::new)
182 .collect())
183 }
184
185 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
186 matches!(
187 args.return_field.data_type(),
188 DataType::Float64 | DataType::Decimal128(_, _) | DataType::Duration(_)
189 )
190 }
191
192 fn create_groups_accumulator(
193 &self,
194 args: AccumulatorArgs,
195 ) -> Result<Box<dyn GroupsAccumulator>> {
196 use DataType::*;
197
198 let data_type = args.exprs[0].data_type(args.schema)?;
199 match (&data_type, args.return_field.data_type()) {
201 (Float64, Float64) => {
202 Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
203 &data_type,
204 args.return_field.data_type(),
205 |sum: f64, count: u64| Ok(sum / count as f64),
206 )))
207 }
208 (
209 Decimal128(_sum_precision, sum_scale),
210 Decimal128(target_precision, target_scale),
211 ) => {
212 let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
213 *sum_scale,
214 *target_precision,
215 *target_scale,
216 )?;
217
218 let avg_fn =
219 move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
220
221 Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
222 &data_type,
223 args.return_field.data_type(),
224 avg_fn,
225 )))
226 }
227
228 (
229 Decimal256(_sum_precision, sum_scale),
230 Decimal256(target_precision, target_scale),
231 ) => {
232 let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
233 *sum_scale,
234 *target_precision,
235 *target_scale,
236 )?;
237
238 let avg_fn = move |sum: i256, count: u64| {
239 decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
240 };
241
242 Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
243 &data_type,
244 args.return_field.data_type(),
245 avg_fn,
246 )))
247 }
248
249 (Duration(time_unit), Duration(_result_unit)) => {
250 let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64);
251
252 match time_unit {
253 TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::<
254 DurationSecondType,
255 _,
256 >::new(
257 &data_type,
258 args.return_type(),
259 avg_fn,
260 ))),
261 TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::<
262 DurationMillisecondType,
263 _,
264 >::new(
265 &data_type,
266 args.return_type(),
267 avg_fn,
268 ))),
269 TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::<
270 DurationMicrosecondType,
271 _,
272 >::new(
273 &data_type,
274 args.return_type(),
275 avg_fn,
276 ))),
277 TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
278 DurationNanosecondType,
279 _,
280 >::new(
281 &data_type,
282 args.return_type(),
283 avg_fn,
284 ))),
285 }
286 }
287
288 _ => not_impl_err!(
289 "AvgGroupsAccumulator for ({} --> {})",
290 &data_type,
291 args.return_field.data_type()
292 ),
293 }
294 }
295
296 fn aliases(&self) -> &[String] {
297 &self.aliases
298 }
299
300 fn reverse_expr(&self) -> ReversedUDAF {
301 ReversedUDAF::Identical
302 }
303
304 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
305 let [args] = take_function_args(self.name(), arg_types)?;
306 coerce_avg_type(self.name(), std::slice::from_ref(args))
307 }
308
309 fn documentation(&self) -> Option<&Documentation> {
310 self.doc()
311 }
312}
313
314#[derive(Debug, Default)]
316pub struct AvgAccumulator {
317 sum: Option<f64>,
318 count: u64,
319}
320
321impl Accumulator for AvgAccumulator {
322 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
323 let values = values[0].as_primitive::<Float64Type>();
324 self.count += (values.len() - values.null_count()) as u64;
325 if let Some(x) = sum(values) {
326 let v = self.sum.get_or_insert(0.);
327 *v += x;
328 }
329 Ok(())
330 }
331
332 fn evaluate(&mut self) -> Result<ScalarValue> {
333 Ok(ScalarValue::Float64(
334 self.sum.map(|f| f / self.count as f64),
335 ))
336 }
337
338 fn size(&self) -> usize {
339 size_of_val(self)
340 }
341
342 fn state(&mut self) -> Result<Vec<ScalarValue>> {
343 Ok(vec![
344 ScalarValue::from(self.count),
345 ScalarValue::Float64(self.sum),
346 ])
347 }
348
349 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
350 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
352
353 if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
355 let v = self.sum.get_or_insert(0.);
356 *v += x;
357 }
358 Ok(())
359 }
360 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
361 let values = values[0].as_primitive::<Float64Type>();
362 self.count -= (values.len() - values.null_count()) as u64;
363 if let Some(x) = sum(values) {
364 self.sum = Some(self.sum.unwrap() - x);
365 }
366 Ok(())
367 }
368
369 fn supports_retract_batch(&self) -> bool {
370 true
371 }
372}
373
374#[derive(Debug)]
376struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> {
377 sum: Option<T::Native>,
378 count: u64,
379 sum_scale: i8,
380 sum_precision: u8,
381 target_precision: u8,
382 target_scale: i8,
383}
384
385impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> {
386 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
387 let values = values[0].as_primitive::<T>();
388 self.count += (values.len() - values.null_count()) as u64;
389
390 if let Some(x) = sum(values) {
391 let v = self.sum.get_or_insert(T::Native::default());
392 self.sum = Some(v.add_wrapping(x));
393 }
394 Ok(())
395 }
396
397 fn evaluate(&mut self) -> Result<ScalarValue> {
398 let v = self
399 .sum
400 .map(|v| {
401 DecimalAverager::<T>::try_new(
402 self.sum_scale,
403 self.target_precision,
404 self.target_scale,
405 )?
406 .avg(v, T::Native::from_usize(self.count as usize).unwrap())
407 })
408 .transpose()?;
409
410 ScalarValue::new_primitive::<T>(
411 v,
412 &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
413 )
414 }
415
416 fn size(&self) -> usize {
417 size_of_val(self)
418 }
419
420 fn state(&mut self) -> Result<Vec<ScalarValue>> {
421 Ok(vec![
422 ScalarValue::from(self.count),
423 ScalarValue::new_primitive::<T>(
424 self.sum,
425 &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
426 )?,
427 ])
428 }
429
430 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
431 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
433
434 if let Some(x) = sum(states[1].as_primitive::<T>()) {
436 let v = self.sum.get_or_insert(T::Native::default());
437 self.sum = Some(v.add_wrapping(x));
438 }
439 Ok(())
440 }
441 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
442 let values = values[0].as_primitive::<T>();
443 self.count -= (values.len() - values.null_count()) as u64;
444 if let Some(x) = sum(values) {
445 self.sum = Some(self.sum.unwrap().sub_wrapping(x));
446 }
447 Ok(())
448 }
449
450 fn supports_retract_batch(&self) -> bool {
451 true
452 }
453}
454
455#[derive(Debug)]
457struct DurationAvgAccumulator {
458 sum: Option<i64>,
459 count: u64,
460 time_unit: TimeUnit,
461 result_unit: TimeUnit,
462}
463
464impl Accumulator for DurationAvgAccumulator {
465 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
466 let array = &values[0];
467 self.count += (array.len() - array.null_count()) as u64;
468
469 let sum_value = match self.time_unit {
470 TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
471 TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
472 TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
473 TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
474 };
475
476 if let Some(x) = sum_value {
477 let v = self.sum.get_or_insert(0);
478 *v += x;
479 }
480 Ok(())
481 }
482
483 fn evaluate(&mut self) -> Result<ScalarValue> {
484 let avg = self.sum.map(|sum| sum / self.count as i64);
485
486 match self.result_unit {
487 TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
488 TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)),
489 TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)),
490 TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)),
491 }
492 }
493
494 fn size(&self) -> usize {
495 size_of_val(self)
496 }
497
498 fn state(&mut self) -> Result<Vec<ScalarValue>> {
499 let duration_value = match self.time_unit {
500 TimeUnit::Second => ScalarValue::DurationSecond(self.sum),
501 TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum),
502 TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum),
503 TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum),
504 };
505
506 Ok(vec![ScalarValue::from(self.count), duration_value])
507 }
508
509 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
510 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
511
512 let sum_value = match self.time_unit {
513 TimeUnit::Second => sum(states[1].as_primitive::<DurationSecondType>()),
514 TimeUnit::Millisecond => {
515 sum(states[1].as_primitive::<DurationMillisecondType>())
516 }
517 TimeUnit::Microsecond => {
518 sum(states[1].as_primitive::<DurationMicrosecondType>())
519 }
520 TimeUnit::Nanosecond => {
521 sum(states[1].as_primitive::<DurationNanosecondType>())
522 }
523 };
524
525 if let Some(x) = sum_value {
526 let v = self.sum.get_or_insert(0);
527 *v += x;
528 }
529 Ok(())
530 }
531
532 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
533 let array = &values[0];
534 self.count -= (array.len() - array.null_count()) as u64;
535
536 let sum_value = match self.time_unit {
537 TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
538 TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
539 TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
540 TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
541 };
542
543 if let Some(x) = sum_value {
544 self.sum = Some(self.sum.unwrap() - x);
545 }
546 Ok(())
547 }
548
549 fn supports_retract_batch(&self) -> bool {
550 true
551 }
552}
553
554#[derive(Debug)]
560struct AvgGroupsAccumulator<T, F>
561where
562 T: ArrowNumericType + Send,
563 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
564{
565 sum_data_type: DataType,
567
568 return_data_type: DataType,
570
571 counts: Vec<u64>,
573
574 sums: Vec<T::Native>,
576
577 null_state: NullState,
579
580 avg_fn: F,
582}
583
584impl<T, F> AvgGroupsAccumulator<T, F>
585where
586 T: ArrowNumericType + Send,
587 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
588{
589 pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
590 debug!(
591 "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}",
592 std::any::type_name::<T>()
593 );
594
595 Self {
596 return_data_type: return_data_type.clone(),
597 sum_data_type: sum_data_type.clone(),
598 counts: vec![],
599 sums: vec![],
600 null_state: NullState::new(),
601 avg_fn,
602 }
603 }
604}
605
606impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
607where
608 T: ArrowNumericType + Send,
609 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
610{
611 fn update_batch(
612 &mut self,
613 values: &[ArrayRef],
614 group_indices: &[usize],
615 opt_filter: Option<&BooleanArray>,
616 total_num_groups: usize,
617 ) -> Result<()> {
618 assert_eq!(values.len(), 1, "single argument to update_batch");
619 let values = values[0].as_primitive::<T>();
620
621 self.counts.resize(total_num_groups, 0);
623 self.sums.resize(total_num_groups, T::default_value());
624 self.null_state.accumulate(
625 group_indices,
626 values,
627 opt_filter,
628 total_num_groups,
629 |group_index, new_value| {
630 let sum = &mut self.sums[group_index];
631 *sum = sum.add_wrapping(new_value);
632
633 self.counts[group_index] += 1;
634 },
635 );
636
637 Ok(())
638 }
639
640 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
641 let counts = emit_to.take_needed(&mut self.counts);
642 let sums = emit_to.take_needed(&mut self.sums);
643 let nulls = self.null_state.build(emit_to);
644
645 assert_eq!(nulls.len(), sums.len());
646 assert_eq!(counts.len(), sums.len());
647
648 let array: PrimitiveArray<T> = if nulls.null_count() > 0 {
651 let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
652 .with_data_type(self.return_data_type.clone());
653 let iter = sums.into_iter().zip(counts).zip(nulls.iter());
654
655 for ((sum, count), is_valid) in iter {
656 if is_valid {
657 builder.append_value((self.avg_fn)(sum, count)?)
658 } else {
659 builder.append_null();
660 }
661 }
662 builder.finish()
663 } else {
664 let averages: Vec<T::Native> = sums
665 .into_iter()
666 .zip(counts.into_iter())
667 .map(|(sum, count)| (self.avg_fn)(sum, count))
668 .collect::<Result<Vec<_>>>()?;
669 PrimitiveArray::new(averages.into(), Some(nulls)) .with_data_type(self.return_data_type.clone())
671 };
672
673 Ok(Arc::new(array))
674 }
675
676 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
678 let nulls = self.null_state.build(emit_to);
679 let nulls = Some(nulls);
680
681 let counts = emit_to.take_needed(&mut self.counts);
682 let counts = UInt64Array::new(counts.into(), nulls.clone()); let sums = emit_to.take_needed(&mut self.sums);
685 let sums = PrimitiveArray::<T>::new(sums.into(), nulls) .with_data_type(self.sum_data_type.clone());
687
688 Ok(vec![
689 Arc::new(counts) as ArrayRef,
690 Arc::new(sums) as ArrayRef,
691 ])
692 }
693
694 fn merge_batch(
695 &mut self,
696 values: &[ArrayRef],
697 group_indices: &[usize],
698 opt_filter: Option<&BooleanArray>,
699 total_num_groups: usize,
700 ) -> Result<()> {
701 assert_eq!(values.len(), 2, "two arguments to merge_batch");
702 let partial_counts = values[0].as_primitive::<UInt64Type>();
704 let partial_sums = values[1].as_primitive::<T>();
705 self.counts.resize(total_num_groups, 0);
707 self.null_state.accumulate(
708 group_indices,
709 partial_counts,
710 opt_filter,
711 total_num_groups,
712 |group_index, partial_count| {
713 self.counts[group_index] += partial_count;
714 },
715 );
716
717 self.sums.resize(total_num_groups, T::default_value());
719 self.null_state.accumulate(
720 group_indices,
721 partial_sums,
722 opt_filter,
723 total_num_groups,
724 |group_index, new_value: <T as ArrowPrimitiveType>::Native| {
725 let sum = &mut self.sums[group_index];
726 *sum = sum.add_wrapping(new_value);
727 },
728 );
729
730 Ok(())
731 }
732
733 fn convert_to_state(
734 &self,
735 values: &[ArrayRef],
736 opt_filter: Option<&BooleanArray>,
737 ) -> Result<Vec<ArrayRef>> {
738 let sums = values[0]
739 .as_primitive::<T>()
740 .clone()
741 .with_data_type(self.sum_data_type.clone());
742 let counts = UInt64Array::from_value(1, sums.len());
743
744 let nulls = filtered_null_mask(opt_filter, &sums);
745
746 let counts = set_nulls(counts, nulls.clone());
748 let sums = set_nulls(sums, nulls);
749
750 Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
751 }
752
753 fn supports_convert_to_state(&self) -> bool {
754 true
755 }
756
757 fn size(&self) -> usize {
758 self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
759 }
760}