1use arrow::array::{
21 Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray,
22 BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
23};
24
25use arrow::compute::sum;
26use arrow::datatypes::{
27 i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type,
28 Decimal64Type, DecimalType, DurationMicrosecondType, DurationMillisecondType,
29 DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type, TimeUnit,
30 UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION,
31 DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE,
32 DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
33};
34use datafusion_common::plan_err;
35use datafusion_common::{
36 exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
37};
38use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
39use datafusion_expr::utils::format_state_name;
40use datafusion_expr::Volatility::Immutable;
41use datafusion_expr::{
42 Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator,
43 ReversedUDAF, Signature,
44};
45
46use datafusion_functions_aggregate_common::aggregate::avg_distinct::{
47 DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator,
48};
49use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
50use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
51 filtered_null_mask, set_nulls,
52};
53
54use datafusion_functions_aggregate_common::utils::DecimalAverager;
55use datafusion_macros::user_doc;
56use log::debug;
57use std::any::Any;
58use std::fmt::Debug;
59use std::mem::{size_of, size_of_val};
60use std::sync::Arc;
61
62make_udaf_expr_and_func!(
63 Avg,
64 avg,
65 expression,
66 "Returns the avg of a group of values.",
67 avg_udaf
68);
69
70pub fn avg_distinct(expr: Expr) -> Expr {
71 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
72 avg_udaf(),
73 vec![expr],
74 true,
75 None,
76 vec![],
77 None,
78 ))
79}
80
81#[user_doc(
82 doc_section(label = "General Functions"),
83 description = "Returns the average of numeric values in the specified column.",
84 syntax_example = "avg(expression)",
85 sql_example = r#"```sql
86> SELECT avg(column_name) FROM table_name;
87+---------------------------+
88| avg(column_name) |
89+---------------------------+
90| 42.75 |
91+---------------------------+
92```"#,
93 standard_argument(name = "expression",)
94)]
95#[derive(Debug, PartialEq, Eq, Hash)]
96pub struct Avg {
97 signature: Signature,
98 aliases: Vec<String>,
99}
100
101impl Avg {
102 pub fn new() -> Self {
103 Self {
104 signature: Signature::user_defined(Immutable),
105 aliases: vec![String::from("mean")],
106 }
107 }
108}
109
110impl Default for Avg {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116impl AggregateUDFImpl for Avg {
117 fn as_any(&self) -> &dyn Any {
118 self
119 }
120
121 fn name(&self) -> &str {
122 "avg"
123 }
124
125 fn signature(&self) -> &Signature {
126 &self.signature
127 }
128
129 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
130 let [args] = take_function_args(self.name(), arg_types)?;
131
132 fn coerced_type(data_type: &DataType) -> Result<DataType> {
135 match &data_type {
136 DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
137 DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
138 DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
139 DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
140 d if d.is_numeric() => Ok(DataType::Float64),
141 DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
142 DataType::Dictionary(_, v) => coerced_type(v.as_ref()),
143 _ => {
144 plan_err!("Avg does not support inputs of type {data_type}.")
145 }
146 }
147 }
148 Ok(vec![coerced_type(args)?])
149 }
150
151 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
152 match &arg_types[0] {
153 DataType::Decimal32(precision, scale) => {
154 let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
157 let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
158 Ok(DataType::Decimal32(new_precision, new_scale))
159 }
160 DataType::Decimal64(precision, scale) => {
161 let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
164 let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
165 Ok(DataType::Decimal64(new_precision, new_scale))
166 }
167 DataType::Decimal128(precision, scale) => {
168 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
171 let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
172 Ok(DataType::Decimal128(new_precision, new_scale))
173 }
174 DataType::Decimal256(precision, scale) => {
175 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
178 let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
179 Ok(DataType::Decimal256(new_precision, new_scale))
180 }
181 DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
182 _ => Ok(DataType::Float64),
183 }
184 }
185
186 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
187 let data_type = acc_args.expr_fields[0].data_type();
188 use DataType::*;
189
190 if acc_args.is_distinct {
192 match (data_type, acc_args.return_type()) {
193 (Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())),
195
196 (
197 Decimal32(_, scale),
198 Decimal32(target_precision, target_scale),
199 ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal32Type>::with_decimal_params(
200 *scale,
201 *target_precision,
202 *target_scale,
203 ))),
204 (
205 Decimal64(_, scale),
206 Decimal64(target_precision, target_scale),
207 ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal64Type>::with_decimal_params(
208 *scale,
209 *target_precision,
210 *target_scale,
211 ))),
212 (
213 Decimal128(_, scale),
214 Decimal128(target_precision, target_scale),
215 ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
216 *scale,
217 *target_precision,
218 *target_scale,
219 ))),
220
221 (
222 Decimal256(_, scale),
223 Decimal256(target_precision, target_scale),
224 ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
225 *scale,
226 *target_precision,
227 *target_scale,
228 ))),
229
230 (dt, return_type) => exec_err!(
231 "AVG(DISTINCT) for ({} --> {}) not supported",
232 dt,
233 return_type
234 ),
235 }
236 } else {
237 match (&data_type, acc_args.return_type()) {
238 (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
239 (
240 Decimal32(sum_precision, sum_scale),
241 Decimal32(target_precision, target_scale),
242 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal32Type> {
243 sum: None,
244 count: 0,
245 sum_scale: *sum_scale,
246 sum_precision: *sum_precision,
247 target_precision: *target_precision,
248 target_scale: *target_scale,
249 })),
250 (
251 Decimal64(sum_precision, sum_scale),
252 Decimal64(target_precision, target_scale),
253 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal64Type> {
254 sum: None,
255 count: 0,
256 sum_scale: *sum_scale,
257 sum_precision: *sum_precision,
258 target_precision: *target_precision,
259 target_scale: *target_scale,
260 })),
261 (
262 Decimal128(sum_precision, sum_scale),
263 Decimal128(target_precision, target_scale),
264 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
265 sum: None,
266 count: 0,
267 sum_scale: *sum_scale,
268 sum_precision: *sum_precision,
269 target_precision: *target_precision,
270 target_scale: *target_scale,
271 })),
272
273 (
274 Decimal256(sum_precision, sum_scale),
275 Decimal256(target_precision, target_scale),
276 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
277 sum: None,
278 count: 0,
279 sum_scale: *sum_scale,
280 sum_precision: *sum_precision,
281 target_precision: *target_precision,
282 target_scale: *target_scale,
283 })),
284
285 (Duration(time_unit), Duration(result_unit)) => {
286 Ok(Box::new(DurationAvgAccumulator {
287 sum: None,
288 count: 0,
289 time_unit: *time_unit,
290 result_unit: *result_unit,
291 }))
292 }
293
294 (dt, return_type) => {
295 exec_err!("AvgAccumulator for ({} --> {})", dt, return_type)
296 }
297 }
298 }
299 }
300
301 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
302 if args.is_distinct {
303 let dt = match args.input_fields[0].data_type() {
306 DataType::Decimal32(_, scale) => {
307 DataType::Decimal32(DECIMAL32_MAX_PRECISION, *scale)
308 }
309 DataType::Decimal64(_, scale) => {
310 DataType::Decimal64(DECIMAL64_MAX_PRECISION, *scale)
311 }
312 DataType::Decimal128(_, scale) => {
313 DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale)
314 }
315 DataType::Decimal256(_, scale) => {
316 DataType::Decimal256(DECIMAL256_MAX_PRECISION, *scale)
317 }
318 _ => args.return_type().clone(),
319 };
320 Ok(vec![Field::new_list(
323 format_state_name(args.name, "avg distinct"),
324 Field::new_list_field(dt, true),
325 false,
326 )
327 .into()])
328 } else {
329 Ok(vec![
330 Field::new(
331 format_state_name(args.name, "count"),
332 DataType::UInt64,
333 true,
334 ),
335 Field::new(
336 format_state_name(args.name, "sum"),
337 args.input_fields[0].data_type().clone(),
338 true,
339 ),
340 ]
341 .into_iter()
342 .map(Arc::new)
343 .collect())
344 }
345 }
346
347 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
348 matches!(
349 args.return_field.data_type(),
350 DataType::Float64
351 | DataType::Decimal32(_, _)
352 | DataType::Decimal64(_, _)
353 | DataType::Decimal128(_, _)
354 | DataType::Decimal256(_, _)
355 | DataType::Duration(_)
356 ) && !args.is_distinct
357 }
358
359 fn create_groups_accumulator(
360 &self,
361 args: AccumulatorArgs,
362 ) -> Result<Box<dyn GroupsAccumulator>> {
363 use DataType::*;
364
365 let data_type = args.expr_fields[0].data_type();
366
367 match (data_type, args.return_field.data_type()) {
369 (Float64, Float64) => {
370 Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
371 data_type,
372 args.return_field.data_type(),
373 |sum: f64, count: u64| Ok(sum / count as f64),
374 )))
375 }
376 (
377 Decimal32(_sum_precision, sum_scale),
378 Decimal32(target_precision, target_scale),
379 ) => {
380 let decimal_averager = DecimalAverager::<Decimal32Type>::try_new(
381 *sum_scale,
382 *target_precision,
383 *target_scale,
384 )?;
385
386 let avg_fn =
387 move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32);
388
389 Ok(Box::new(AvgGroupsAccumulator::<Decimal32Type, _>::new(
390 data_type,
391 args.return_field.data_type(),
392 avg_fn,
393 )))
394 }
395 (
396 Decimal64(_sum_precision, sum_scale),
397 Decimal64(target_precision, target_scale),
398 ) => {
399 let decimal_averager = DecimalAverager::<Decimal64Type>::try_new(
400 *sum_scale,
401 *target_precision,
402 *target_scale,
403 )?;
404
405 let avg_fn =
406 move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64);
407
408 Ok(Box::new(AvgGroupsAccumulator::<Decimal64Type, _>::new(
409 data_type,
410 args.return_field.data_type(),
411 avg_fn,
412 )))
413 }
414 (
415 Decimal128(_sum_precision, sum_scale),
416 Decimal128(target_precision, target_scale),
417 ) => {
418 let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
419 *sum_scale,
420 *target_precision,
421 *target_scale,
422 )?;
423
424 let avg_fn =
425 move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
426
427 Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
428 data_type,
429 args.return_field.data_type(),
430 avg_fn,
431 )))
432 }
433
434 (
435 Decimal256(_sum_precision, sum_scale),
436 Decimal256(target_precision, target_scale),
437 ) => {
438 let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
439 *sum_scale,
440 *target_precision,
441 *target_scale,
442 )?;
443
444 let avg_fn = move |sum: i256, count: u64| {
445 decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
446 };
447
448 Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
449 data_type,
450 args.return_field.data_type(),
451 avg_fn,
452 )))
453 }
454
455 (Duration(time_unit), Duration(_result_unit)) => {
456 let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64);
457
458 match time_unit {
459 TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::<
460 DurationSecondType,
461 _,
462 >::new(
463 data_type,
464 args.return_type(),
465 avg_fn,
466 ))),
467 TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::<
468 DurationMillisecondType,
469 _,
470 >::new(
471 data_type,
472 args.return_type(),
473 avg_fn,
474 ))),
475 TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::<
476 DurationMicrosecondType,
477 _,
478 >::new(
479 data_type,
480 args.return_type(),
481 avg_fn,
482 ))),
483 TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
484 DurationNanosecondType,
485 _,
486 >::new(
487 data_type,
488 args.return_type(),
489 avg_fn,
490 ))),
491 }
492 }
493
494 _ => not_impl_err!(
495 "AvgGroupsAccumulator for ({} --> {})",
496 &data_type,
497 args.return_field.data_type()
498 ),
499 }
500 }
501
502 fn aliases(&self) -> &[String] {
503 &self.aliases
504 }
505
506 fn reverse_expr(&self) -> ReversedUDAF {
507 ReversedUDAF::Identical
508 }
509
510 fn documentation(&self) -> Option<&Documentation> {
511 self.doc()
512 }
513}
514
515#[derive(Debug, Default)]
517pub struct AvgAccumulator {
518 sum: Option<f64>,
519 count: u64,
520}
521
522impl Accumulator for AvgAccumulator {
523 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
524 let values = values[0].as_primitive::<Float64Type>();
525 self.count += (values.len() - values.null_count()) as u64;
526 if let Some(x) = sum(values) {
527 let v = self.sum.get_or_insert(0.);
528 *v += x;
529 }
530 Ok(())
531 }
532
533 fn evaluate(&mut self) -> Result<ScalarValue> {
534 Ok(ScalarValue::Float64(
535 self.sum.map(|f| f / self.count as f64),
536 ))
537 }
538
539 fn size(&self) -> usize {
540 size_of_val(self)
541 }
542
543 fn state(&mut self) -> Result<Vec<ScalarValue>> {
544 Ok(vec![
545 ScalarValue::from(self.count),
546 ScalarValue::Float64(self.sum),
547 ])
548 }
549
550 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
551 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
553
554 if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
556 let v = self.sum.get_or_insert(0.);
557 *v += x;
558 }
559 Ok(())
560 }
561 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
562 let values = values[0].as_primitive::<Float64Type>();
563 self.count -= (values.len() - values.null_count()) as u64;
564 if let Some(x) = sum(values) {
565 self.sum = Some(self.sum.unwrap() - x);
566 }
567 Ok(())
568 }
569
570 fn supports_retract_batch(&self) -> bool {
571 true
572 }
573}
574
575#[derive(Debug)]
577struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> {
578 sum: Option<T::Native>,
579 count: u64,
580 sum_scale: i8,
581 sum_precision: u8,
582 target_precision: u8,
583 target_scale: i8,
584}
585
586impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> {
587 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
588 let values = values[0].as_primitive::<T>();
589 self.count += (values.len() - values.null_count()) as u64;
590
591 if let Some(x) = sum(values) {
592 let v = self.sum.get_or_insert_with(T::Native::default);
593 self.sum = Some(v.add_wrapping(x));
594 }
595 Ok(())
596 }
597
598 fn evaluate(&mut self) -> Result<ScalarValue> {
599 let v = self
600 .sum
601 .map(|v| {
602 DecimalAverager::<T>::try_new(
603 self.sum_scale,
604 self.target_precision,
605 self.target_scale,
606 )?
607 .avg(v, T::Native::from_usize(self.count as usize).unwrap())
608 })
609 .transpose()?;
610
611 ScalarValue::new_primitive::<T>(
612 v,
613 &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
614 )
615 }
616
617 fn size(&self) -> usize {
618 size_of_val(self)
619 }
620
621 fn state(&mut self) -> Result<Vec<ScalarValue>> {
622 Ok(vec![
623 ScalarValue::from(self.count),
624 ScalarValue::new_primitive::<T>(
625 self.sum,
626 &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
627 )?,
628 ])
629 }
630
631 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
632 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
634
635 if let Some(x) = sum(states[1].as_primitive::<T>()) {
637 let v = self.sum.get_or_insert_with(T::Native::default);
638 self.sum = Some(v.add_wrapping(x));
639 }
640 Ok(())
641 }
642 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
643 let values = values[0].as_primitive::<T>();
644 self.count -= (values.len() - values.null_count()) as u64;
645 if let Some(x) = sum(values) {
646 self.sum = Some(self.sum.unwrap().sub_wrapping(x));
647 }
648 Ok(())
649 }
650
651 fn supports_retract_batch(&self) -> bool {
652 true
653 }
654}
655
656#[derive(Debug)]
658struct DurationAvgAccumulator {
659 sum: Option<i64>,
660 count: u64,
661 time_unit: TimeUnit,
662 result_unit: TimeUnit,
663}
664
665impl Accumulator for DurationAvgAccumulator {
666 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
667 let array = &values[0];
668 self.count += (array.len() - array.null_count()) as u64;
669
670 let sum_value = match self.time_unit {
671 TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
672 TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
673 TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
674 TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
675 };
676
677 if let Some(x) = sum_value {
678 let v = self.sum.get_or_insert(0);
679 *v += x;
680 }
681 Ok(())
682 }
683
684 fn evaluate(&mut self) -> Result<ScalarValue> {
685 let avg = self.sum.map(|sum| sum / self.count as i64);
686
687 match self.result_unit {
688 TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
689 TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)),
690 TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)),
691 TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)),
692 }
693 }
694
695 fn size(&self) -> usize {
696 size_of_val(self)
697 }
698
699 fn state(&mut self) -> Result<Vec<ScalarValue>> {
700 let duration_value = match self.time_unit {
701 TimeUnit::Second => ScalarValue::DurationSecond(self.sum),
702 TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum),
703 TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum),
704 TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum),
705 };
706
707 Ok(vec![ScalarValue::from(self.count), duration_value])
708 }
709
710 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
711 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
712
713 let sum_value = match self.time_unit {
714 TimeUnit::Second => sum(states[1].as_primitive::<DurationSecondType>()),
715 TimeUnit::Millisecond => {
716 sum(states[1].as_primitive::<DurationMillisecondType>())
717 }
718 TimeUnit::Microsecond => {
719 sum(states[1].as_primitive::<DurationMicrosecondType>())
720 }
721 TimeUnit::Nanosecond => {
722 sum(states[1].as_primitive::<DurationNanosecondType>())
723 }
724 };
725
726 if let Some(x) = sum_value {
727 let v = self.sum.get_or_insert(0);
728 *v += x;
729 }
730 Ok(())
731 }
732
733 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
734 let array = &values[0];
735 self.count -= (array.len() - array.null_count()) as u64;
736
737 let sum_value = match self.time_unit {
738 TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
739 TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
740 TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
741 TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
742 };
743
744 if let Some(x) = sum_value {
745 self.sum = Some(self.sum.unwrap() - x);
746 }
747 Ok(())
748 }
749
750 fn supports_retract_batch(&self) -> bool {
751 true
752 }
753}
754
755#[derive(Debug)]
761struct AvgGroupsAccumulator<T, F>
762where
763 T: ArrowNumericType + Send,
764 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
765{
766 sum_data_type: DataType,
768
769 return_data_type: DataType,
771
772 counts: Vec<u64>,
774
775 sums: Vec<T::Native>,
777
778 null_state: NullState,
780
781 avg_fn: F,
783}
784
785impl<T, F> AvgGroupsAccumulator<T, F>
786where
787 T: ArrowNumericType + Send,
788 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
789{
790 pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
791 debug!(
792 "AvgGroupsAccumulator ({}, sum type: {sum_data_type}) --> {return_data_type}",
793 std::any::type_name::<T>()
794 );
795
796 Self {
797 return_data_type: return_data_type.clone(),
798 sum_data_type: sum_data_type.clone(),
799 counts: vec![],
800 sums: vec![],
801 null_state: NullState::new(),
802 avg_fn,
803 }
804 }
805}
806
807impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
808where
809 T: ArrowNumericType + Send,
810 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
811{
812 fn update_batch(
813 &mut self,
814 values: &[ArrayRef],
815 group_indices: &[usize],
816 opt_filter: Option<&BooleanArray>,
817 total_num_groups: usize,
818 ) -> Result<()> {
819 assert_eq!(values.len(), 1, "single argument to update_batch");
820 let values = values[0].as_primitive::<T>();
821
822 self.counts.resize(total_num_groups, 0);
824 self.sums.resize(total_num_groups, T::default_value());
825 self.null_state.accumulate(
826 group_indices,
827 values,
828 opt_filter,
829 total_num_groups,
830 |group_index, new_value| {
831 let sum = &mut self.sums[group_index];
832 *sum = sum.add_wrapping(new_value);
833
834 self.counts[group_index] += 1;
835 },
836 );
837
838 Ok(())
839 }
840
841 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
842 let counts = emit_to.take_needed(&mut self.counts);
843 let sums = emit_to.take_needed(&mut self.sums);
844 let nulls = self.null_state.build(emit_to);
845
846 assert_eq!(nulls.len(), sums.len());
847 assert_eq!(counts.len(), sums.len());
848
849 let array: PrimitiveArray<T> = if nulls.null_count() > 0 {
852 let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
853 .with_data_type(self.return_data_type.clone());
854 let iter = sums.into_iter().zip(counts).zip(nulls.iter());
855
856 for ((sum, count), is_valid) in iter {
857 if is_valid {
858 builder.append_value((self.avg_fn)(sum, count)?)
859 } else {
860 builder.append_null();
861 }
862 }
863 builder.finish()
864 } else {
865 let averages: Vec<T::Native> = sums
866 .into_iter()
867 .zip(counts.into_iter())
868 .map(|(sum, count)| (self.avg_fn)(sum, count))
869 .collect::<Result<Vec<_>>>()?;
870 PrimitiveArray::new(averages.into(), Some(nulls)) .with_data_type(self.return_data_type.clone())
872 };
873
874 Ok(Arc::new(array))
875 }
876
877 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
879 let nulls = self.null_state.build(emit_to);
880 let nulls = Some(nulls);
881
882 let counts = emit_to.take_needed(&mut self.counts);
883 let counts = UInt64Array::new(counts.into(), nulls.clone()); let sums = emit_to.take_needed(&mut self.sums);
886 let sums = PrimitiveArray::<T>::new(sums.into(), nulls) .with_data_type(self.sum_data_type.clone());
888
889 Ok(vec![
890 Arc::new(counts) as ArrayRef,
891 Arc::new(sums) as ArrayRef,
892 ])
893 }
894
895 fn merge_batch(
896 &mut self,
897 values: &[ArrayRef],
898 group_indices: &[usize],
899 opt_filter: Option<&BooleanArray>,
900 total_num_groups: usize,
901 ) -> Result<()> {
902 assert_eq!(values.len(), 2, "two arguments to merge_batch");
903 let partial_counts = values[0].as_primitive::<UInt64Type>();
905 let partial_sums = values[1].as_primitive::<T>();
906 self.counts.resize(total_num_groups, 0);
908 self.null_state.accumulate(
909 group_indices,
910 partial_counts,
911 opt_filter,
912 total_num_groups,
913 |group_index, partial_count| {
914 self.counts[group_index] += partial_count;
915 },
916 );
917
918 self.sums.resize(total_num_groups, T::default_value());
920 self.null_state.accumulate(
921 group_indices,
922 partial_sums,
923 opt_filter,
924 total_num_groups,
925 |group_index, new_value: <T as ArrowPrimitiveType>::Native| {
926 let sum = &mut self.sums[group_index];
927 *sum = sum.add_wrapping(new_value);
928 },
929 );
930
931 Ok(())
932 }
933
934 fn convert_to_state(
935 &self,
936 values: &[ArrayRef],
937 opt_filter: Option<&BooleanArray>,
938 ) -> Result<Vec<ArrayRef>> {
939 let sums = values[0]
940 .as_primitive::<T>()
941 .clone()
942 .with_data_type(self.sum_data_type.clone());
943 let counts = UInt64Array::from_value(1, sums.len());
944
945 let nulls = filtered_null_mask(opt_filter, &sums);
946
947 let counts = set_nulls(counts, nulls.clone());
949 let sums = set_nulls(sums, nulls);
950
951 Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
952 }
953
954 fn supports_convert_to_state(&self) -> bool {
955 true
956 }
957
958 fn size(&self) -> usize {
959 self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
960 }
961}