1use arrow::array::{
21 Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray,
22 BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
23};
24
25use arrow::compute::sum;
26use arrow::datatypes::{
27 ArrowNativeType, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE,
28 DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION,
29 DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType,
30 Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType,
31 DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
32 DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type, i256,
33};
34use datafusion_common::types::{NativeType, logical_float64};
35use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err};
36use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
37use datafusion_expr::utils::format_state_name;
38use datafusion_expr::{
39 Accumulator, AggregateUDFImpl, Coercion, Documentation, EmitTo, Expr,
40 GroupsAccumulator, ReversedUDAF, Signature, TypeSignature, TypeSignatureClass,
41 Volatility,
42};
43use datafusion_functions_aggregate_common::aggregate::avg_distinct::{
44 DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator,
45};
46use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
47use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
48 filtered_null_mask, set_nulls,
49};
50use datafusion_functions_aggregate_common::utils::DecimalAverager;
51use datafusion_macros::user_doc;
52use log::debug;
53use std::any::Any;
54use std::fmt::Debug;
55use std::mem::{size_of, size_of_val};
56use std::sync::Arc;
57
58make_udaf_expr_and_func!(
59 Avg,
60 avg,
61 expression,
62 "Returns the avg of a group of values.",
63 avg_udaf
64);
65
66pub fn avg_distinct(expr: Expr) -> Expr {
67 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
68 avg_udaf(),
69 vec![expr],
70 true,
71 None,
72 vec![],
73 None,
74 ))
75}
76
77#[user_doc(
78 doc_section(label = "General Functions"),
79 description = "Returns the average of numeric values in the specified column.",
80 syntax_example = "avg(expression)",
81 sql_example = r#"```sql
82> SELECT avg(column_name) FROM table_name;
83+---------------------------+
84| avg(column_name) |
85+---------------------------+
86| 42.75 |
87+---------------------------+
88```"#,
89 standard_argument(name = "expression",)
90)]
91#[derive(Debug, PartialEq, Eq, Hash)]
92pub struct Avg {
93 signature: Signature,
94 aliases: Vec<String>,
95}
96
97impl Avg {
98 pub fn new() -> Self {
99 Self {
100 signature: Signature::one_of(
103 vec![
104 TypeSignature::Coercible(vec![Coercion::new_exact(
105 TypeSignatureClass::Decimal,
106 )]),
107 TypeSignature::Coercible(vec![Coercion::new_exact(
108 TypeSignatureClass::Duration,
109 )]),
110 TypeSignature::Coercible(vec![Coercion::new_implicit(
111 TypeSignatureClass::Native(logical_float64()),
112 vec![TypeSignatureClass::Integer, TypeSignatureClass::Float],
113 NativeType::Float64,
114 )]),
115 ],
116 Volatility::Immutable,
117 ),
118 aliases: vec![String::from("mean")],
119 }
120 }
121}
122
123impl Default for Avg {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl AggregateUDFImpl for Avg {
130 fn as_any(&self) -> &dyn Any {
131 self
132 }
133
134 fn name(&self) -> &str {
135 "avg"
136 }
137
138 fn signature(&self) -> &Signature {
139 &self.signature
140 }
141
142 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
143 match &arg_types[0] {
144 DataType::Decimal32(precision, scale) => {
145 let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
148 let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
149 Ok(DataType::Decimal32(new_precision, new_scale))
150 }
151 DataType::Decimal64(precision, scale) => {
152 let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
155 let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
156 Ok(DataType::Decimal64(new_precision, new_scale))
157 }
158 DataType::Decimal128(precision, scale) => {
159 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
162 let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
163 Ok(DataType::Decimal128(new_precision, new_scale))
164 }
165 DataType::Decimal256(precision, scale) => {
166 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
169 let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
170 Ok(DataType::Decimal256(new_precision, new_scale))
171 }
172 DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
173 _ => Ok(DataType::Float64),
174 }
175 }
176
177 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
178 let data_type = acc_args.expr_fields[0].data_type();
179 use DataType::*;
180
181 if acc_args.is_distinct {
183 match (data_type, acc_args.return_type()) {
184 (Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())),
186
187 (
188 Decimal32(_, scale),
189 Decimal32(target_precision, target_scale),
190 ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal32Type>::with_decimal_params(
191 *scale,
192 *target_precision,
193 *target_scale,
194 ))),
195 (
196 Decimal64(_, scale),
197 Decimal64(target_precision, target_scale),
198 ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal64Type>::with_decimal_params(
199 *scale,
200 *target_precision,
201 *target_scale,
202 ))),
203 (
204 Decimal128(_, scale),
205 Decimal128(target_precision, target_scale),
206 ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
207 *scale,
208 *target_precision,
209 *target_scale,
210 ))),
211
212 (
213 Decimal256(_, scale),
214 Decimal256(target_precision, target_scale),
215 ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
216 *scale,
217 *target_precision,
218 *target_scale,
219 ))),
220
221 (dt, return_type) => exec_err!(
222 "AVG(DISTINCT) for ({} --> {}) not supported",
223 dt,
224 return_type
225 ),
226 }
227 } else {
228 match (&data_type, acc_args.return_type()) {
229 (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
230 (
231 Decimal32(sum_precision, sum_scale),
232 Decimal32(target_precision, target_scale),
233 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal32Type> {
234 sum: None,
235 count: 0,
236 sum_scale: *sum_scale,
237 sum_precision: *sum_precision,
238 target_precision: *target_precision,
239 target_scale: *target_scale,
240 })),
241 (
242 Decimal64(sum_precision, sum_scale),
243 Decimal64(target_precision, target_scale),
244 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal64Type> {
245 sum: None,
246 count: 0,
247 sum_scale: *sum_scale,
248 sum_precision: *sum_precision,
249 target_precision: *target_precision,
250 target_scale: *target_scale,
251 })),
252 (
253 Decimal128(sum_precision, sum_scale),
254 Decimal128(target_precision, target_scale),
255 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
256 sum: None,
257 count: 0,
258 sum_scale: *sum_scale,
259 sum_precision: *sum_precision,
260 target_precision: *target_precision,
261 target_scale: *target_scale,
262 })),
263
264 (
265 Decimal256(sum_precision, sum_scale),
266 Decimal256(target_precision, target_scale),
267 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
268 sum: None,
269 count: 0,
270 sum_scale: *sum_scale,
271 sum_precision: *sum_precision,
272 target_precision: *target_precision,
273 target_scale: *target_scale,
274 })),
275
276 (Duration(time_unit), Duration(result_unit)) => {
277 Ok(Box::new(DurationAvgAccumulator {
278 sum: None,
279 count: 0,
280 time_unit: *time_unit,
281 result_unit: *result_unit,
282 }))
283 }
284
285 (dt, return_type) => {
286 exec_err!("AvgAccumulator for ({} --> {})", dt, return_type)
287 }
288 }
289 }
290 }
291
292 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
293 if args.is_distinct {
294 let dt = match args.input_fields[0].data_type() {
297 DataType::Decimal32(_, scale) => {
298 DataType::Decimal32(DECIMAL32_MAX_PRECISION, *scale)
299 }
300 DataType::Decimal64(_, scale) => {
301 DataType::Decimal64(DECIMAL64_MAX_PRECISION, *scale)
302 }
303 DataType::Decimal128(_, scale) => {
304 DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale)
305 }
306 DataType::Decimal256(_, scale) => {
307 DataType::Decimal256(DECIMAL256_MAX_PRECISION, *scale)
308 }
309 _ => args.return_type().clone(),
310 };
311 Ok(vec![
314 Field::new_list(
315 format_state_name(args.name, "avg distinct"),
316 Field::new_list_field(dt, true),
317 false,
318 )
319 .into(),
320 ])
321 } else {
322 Ok(vec![
323 Field::new(
324 format_state_name(args.name, "count"),
325 DataType::UInt64,
326 true,
327 ),
328 Field::new(
329 format_state_name(args.name, "sum"),
330 args.input_fields[0].data_type().clone(),
331 true,
332 ),
333 ]
334 .into_iter()
335 .map(Arc::new)
336 .collect())
337 }
338 }
339
340 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
341 matches!(
342 args.return_field.data_type(),
343 DataType::Float64
344 | DataType::Decimal32(_, _)
345 | DataType::Decimal64(_, _)
346 | DataType::Decimal128(_, _)
347 | DataType::Decimal256(_, _)
348 | DataType::Duration(_)
349 ) && !args.is_distinct
350 }
351
352 fn create_groups_accumulator(
353 &self,
354 args: AccumulatorArgs,
355 ) -> Result<Box<dyn GroupsAccumulator>> {
356 use DataType::*;
357
358 let data_type = args.expr_fields[0].data_type();
359
360 match (data_type, args.return_field.data_type()) {
362 (Float64, Float64) => {
363 Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
364 data_type,
365 args.return_field.data_type(),
366 |sum: f64, count: u64| Ok(sum / count as f64),
367 )))
368 }
369 (
370 Decimal32(_sum_precision, sum_scale),
371 Decimal32(target_precision, target_scale),
372 ) => {
373 let decimal_averager = DecimalAverager::<Decimal32Type>::try_new(
374 *sum_scale,
375 *target_precision,
376 *target_scale,
377 )?;
378
379 let avg_fn =
380 move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32);
381
382 Ok(Box::new(AvgGroupsAccumulator::<Decimal32Type, _>::new(
383 data_type,
384 args.return_field.data_type(),
385 avg_fn,
386 )))
387 }
388 (
389 Decimal64(_sum_precision, sum_scale),
390 Decimal64(target_precision, target_scale),
391 ) => {
392 let decimal_averager = DecimalAverager::<Decimal64Type>::try_new(
393 *sum_scale,
394 *target_precision,
395 *target_scale,
396 )?;
397
398 let avg_fn =
399 move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64);
400
401 Ok(Box::new(AvgGroupsAccumulator::<Decimal64Type, _>::new(
402 data_type,
403 args.return_field.data_type(),
404 avg_fn,
405 )))
406 }
407 (
408 Decimal128(_sum_precision, sum_scale),
409 Decimal128(target_precision, target_scale),
410 ) => {
411 let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
412 *sum_scale,
413 *target_precision,
414 *target_scale,
415 )?;
416
417 let avg_fn =
418 move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
419
420 Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
421 data_type,
422 args.return_field.data_type(),
423 avg_fn,
424 )))
425 }
426
427 (
428 Decimal256(_sum_precision, sum_scale),
429 Decimal256(target_precision, target_scale),
430 ) => {
431 let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
432 *sum_scale,
433 *target_precision,
434 *target_scale,
435 )?;
436
437 let avg_fn = move |sum: i256, count: u64| {
438 decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
439 };
440
441 Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
442 data_type,
443 args.return_field.data_type(),
444 avg_fn,
445 )))
446 }
447
448 (Duration(time_unit), Duration(_result_unit)) => {
449 let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64);
450
451 match time_unit {
452 TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::<
453 DurationSecondType,
454 _,
455 >::new(
456 data_type,
457 args.return_type(),
458 avg_fn,
459 ))),
460 TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::<
461 DurationMillisecondType,
462 _,
463 >::new(
464 data_type,
465 args.return_type(),
466 avg_fn,
467 ))),
468 TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::<
469 DurationMicrosecondType,
470 _,
471 >::new(
472 data_type,
473 args.return_type(),
474 avg_fn,
475 ))),
476 TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
477 DurationNanosecondType,
478 _,
479 >::new(
480 data_type,
481 args.return_type(),
482 avg_fn,
483 ))),
484 }
485 }
486
487 _ => not_impl_err!(
488 "AvgGroupsAccumulator for ({} --> {})",
489 &data_type,
490 args.return_field.data_type()
491 ),
492 }
493 }
494
495 fn aliases(&self) -> &[String] {
496 &self.aliases
497 }
498
499 fn reverse_expr(&self) -> ReversedUDAF {
500 ReversedUDAF::Identical
501 }
502
503 fn documentation(&self) -> Option<&Documentation> {
504 self.doc()
505 }
506}
507
508#[derive(Debug, Default)]
510pub struct AvgAccumulator {
511 sum: Option<f64>,
512 count: u64,
513}
514
515impl Accumulator for AvgAccumulator {
516 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
517 let values = values[0].as_primitive::<Float64Type>();
518 self.count += (values.len() - values.null_count()) as u64;
519 if let Some(x) = sum(values) {
520 let v = self.sum.get_or_insert(0.);
521 *v += x;
522 }
523 Ok(())
524 }
525
526 fn evaluate(&mut self) -> Result<ScalarValue> {
527 Ok(ScalarValue::Float64(
528 self.sum.map(|f| f / self.count as f64),
529 ))
530 }
531
532 fn size(&self) -> usize {
533 size_of_val(self)
534 }
535
536 fn state(&mut self) -> Result<Vec<ScalarValue>> {
537 Ok(vec![
538 ScalarValue::from(self.count),
539 ScalarValue::Float64(self.sum),
540 ])
541 }
542
543 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
544 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
546
547 if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
549 let v = self.sum.get_or_insert(0.);
550 *v += x;
551 }
552 Ok(())
553 }
554 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
555 let values = values[0].as_primitive::<Float64Type>();
556 self.count -= (values.len() - values.null_count()) as u64;
557 if let Some(x) = sum(values) {
558 self.sum = Some(self.sum.unwrap() - x);
559 }
560 Ok(())
561 }
562
563 fn supports_retract_batch(&self) -> bool {
564 true
565 }
566}
567
568#[derive(Debug)]
570struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> {
571 sum: Option<T::Native>,
572 count: u64,
573 sum_scale: i8,
574 sum_precision: u8,
575 target_precision: u8,
576 target_scale: i8,
577}
578
579impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> {
580 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
581 let values = values[0].as_primitive::<T>();
582 self.count += (values.len() - values.null_count()) as u64;
583
584 if let Some(x) = sum(values) {
585 let v = self.sum.get_or_insert_with(T::Native::default);
586 self.sum = Some(v.add_wrapping(x));
587 }
588 Ok(())
589 }
590
591 fn evaluate(&mut self) -> Result<ScalarValue> {
592 let v = self
593 .sum
594 .map(|v| {
595 DecimalAverager::<T>::try_new(
596 self.sum_scale,
597 self.target_precision,
598 self.target_scale,
599 )?
600 .avg(v, T::Native::from_usize(self.count as usize).unwrap())
601 })
602 .transpose()?;
603
604 ScalarValue::new_primitive::<T>(
605 v,
606 &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
607 )
608 }
609
610 fn size(&self) -> usize {
611 size_of_val(self)
612 }
613
614 fn state(&mut self) -> Result<Vec<ScalarValue>> {
615 Ok(vec![
616 ScalarValue::from(self.count),
617 ScalarValue::new_primitive::<T>(
618 self.sum,
619 &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
620 )?,
621 ])
622 }
623
624 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
625 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
627
628 if let Some(x) = sum(states[1].as_primitive::<T>()) {
630 let v = self.sum.get_or_insert_with(T::Native::default);
631 self.sum = Some(v.add_wrapping(x));
632 }
633 Ok(())
634 }
635 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
636 let values = values[0].as_primitive::<T>();
637 self.count -= (values.len() - values.null_count()) as u64;
638 if let Some(x) = sum(values) {
639 self.sum = Some(self.sum.unwrap().sub_wrapping(x));
640 }
641 Ok(())
642 }
643
644 fn supports_retract_batch(&self) -> bool {
645 true
646 }
647}
648
649#[derive(Debug)]
651struct DurationAvgAccumulator {
652 sum: Option<i64>,
653 count: u64,
654 time_unit: TimeUnit,
655 result_unit: TimeUnit,
656}
657
658impl Accumulator for DurationAvgAccumulator {
659 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
660 let array = &values[0];
661 self.count += (array.len() - array.null_count()) as u64;
662
663 let sum_value = match self.time_unit {
664 TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
665 TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
666 TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
667 TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
668 };
669
670 if let Some(x) = sum_value {
671 let v = self.sum.get_or_insert(0);
672 *v += x;
673 }
674 Ok(())
675 }
676
677 fn evaluate(&mut self) -> Result<ScalarValue> {
678 let avg = self.sum.map(|sum| sum / self.count as i64);
679
680 match self.result_unit {
681 TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
682 TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)),
683 TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)),
684 TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)),
685 }
686 }
687
688 fn size(&self) -> usize {
689 size_of_val(self)
690 }
691
692 fn state(&mut self) -> Result<Vec<ScalarValue>> {
693 let duration_value = match self.time_unit {
694 TimeUnit::Second => ScalarValue::DurationSecond(self.sum),
695 TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum),
696 TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum),
697 TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum),
698 };
699
700 Ok(vec![ScalarValue::from(self.count), duration_value])
701 }
702
703 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
704 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
705
706 let sum_value = match self.time_unit {
707 TimeUnit::Second => sum(states[1].as_primitive::<DurationSecondType>()),
708 TimeUnit::Millisecond => {
709 sum(states[1].as_primitive::<DurationMillisecondType>())
710 }
711 TimeUnit::Microsecond => {
712 sum(states[1].as_primitive::<DurationMicrosecondType>())
713 }
714 TimeUnit::Nanosecond => {
715 sum(states[1].as_primitive::<DurationNanosecondType>())
716 }
717 };
718
719 if let Some(x) = sum_value {
720 let v = self.sum.get_or_insert(0);
721 *v += x;
722 }
723 Ok(())
724 }
725
726 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
727 let array = &values[0];
728 self.count -= (array.len() - array.null_count()) as u64;
729
730 let sum_value = match self.time_unit {
731 TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
732 TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
733 TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
734 TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
735 };
736
737 if let Some(x) = sum_value {
738 self.sum = Some(self.sum.unwrap() - x);
739 }
740 Ok(())
741 }
742
743 fn supports_retract_batch(&self) -> bool {
744 true
745 }
746}
747
748#[derive(Debug)]
754struct AvgGroupsAccumulator<T, F>
755where
756 T: ArrowNumericType + Send,
757 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
758{
759 sum_data_type: DataType,
761
762 return_data_type: DataType,
764
765 counts: Vec<u64>,
767
768 sums: Vec<T::Native>,
770
771 null_state: NullState,
773
774 avg_fn: F,
776}
777
778impl<T, F> AvgGroupsAccumulator<T, F>
779where
780 T: ArrowNumericType + Send,
781 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
782{
783 pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
784 debug!(
785 "AvgGroupsAccumulator ({}, sum type: {sum_data_type}) --> {return_data_type}",
786 std::any::type_name::<T>()
787 );
788
789 Self {
790 return_data_type: return_data_type.clone(),
791 sum_data_type: sum_data_type.clone(),
792 counts: vec![],
793 sums: vec![],
794 null_state: NullState::new(),
795 avg_fn,
796 }
797 }
798}
799
800impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
801where
802 T: ArrowNumericType + Send,
803 F: Fn(T::Native, u64) -> Result<T::Native> + Send,
804{
805 fn update_batch(
806 &mut self,
807 values: &[ArrayRef],
808 group_indices: &[usize],
809 opt_filter: Option<&BooleanArray>,
810 total_num_groups: usize,
811 ) -> Result<()> {
812 assert_eq!(values.len(), 1, "single argument to update_batch");
813 let values = values[0].as_primitive::<T>();
814
815 self.counts.resize(total_num_groups, 0);
817 self.sums.resize(total_num_groups, T::default_value());
818 self.null_state.accumulate(
819 group_indices,
820 values,
821 opt_filter,
822 total_num_groups,
823 |group_index, new_value| {
824 let sum = &mut self.sums[group_index];
825 *sum = sum.add_wrapping(new_value);
826
827 self.counts[group_index] += 1;
828 },
829 );
830
831 Ok(())
832 }
833
834 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
835 let counts = emit_to.take_needed(&mut self.counts);
836 let sums = emit_to.take_needed(&mut self.sums);
837 let nulls = self.null_state.build(emit_to);
838
839 assert_eq!(nulls.len(), sums.len());
840 assert_eq!(counts.len(), sums.len());
841
842 let array: PrimitiveArray<T> = if nulls.null_count() > 0 {
845 let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
846 .with_data_type(self.return_data_type.clone());
847 let iter = sums.into_iter().zip(counts).zip(nulls.iter());
848
849 for ((sum, count), is_valid) in iter {
850 if is_valid {
851 builder.append_value((self.avg_fn)(sum, count)?)
852 } else {
853 builder.append_null();
854 }
855 }
856 builder.finish()
857 } else {
858 let averages: Vec<T::Native> = sums
859 .into_iter()
860 .zip(counts.into_iter())
861 .map(|(sum, count)| (self.avg_fn)(sum, count))
862 .collect::<Result<Vec<_>>>()?;
863 PrimitiveArray::new(averages.into(), Some(nulls)) .with_data_type(self.return_data_type.clone())
865 };
866
867 Ok(Arc::new(array))
868 }
869
870 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
872 let nulls = self.null_state.build(emit_to);
873 let nulls = Some(nulls);
874
875 let counts = emit_to.take_needed(&mut self.counts);
876 let counts = UInt64Array::new(counts.into(), nulls.clone()); let sums = emit_to.take_needed(&mut self.sums);
879 let sums = PrimitiveArray::<T>::new(sums.into(), nulls) .with_data_type(self.sum_data_type.clone());
881
882 Ok(vec![
883 Arc::new(counts) as ArrayRef,
884 Arc::new(sums) as ArrayRef,
885 ])
886 }
887
888 fn merge_batch(
889 &mut self,
890 values: &[ArrayRef],
891 group_indices: &[usize],
892 opt_filter: Option<&BooleanArray>,
893 total_num_groups: usize,
894 ) -> Result<()> {
895 assert_eq!(values.len(), 2, "two arguments to merge_batch");
896 let partial_counts = values[0].as_primitive::<UInt64Type>();
898 let partial_sums = values[1].as_primitive::<T>();
899 self.counts.resize(total_num_groups, 0);
901 self.null_state.accumulate(
902 group_indices,
903 partial_counts,
904 opt_filter,
905 total_num_groups,
906 |group_index, partial_count| {
907 self.counts[group_index] += partial_count;
908 },
909 );
910
911 self.sums.resize(total_num_groups, T::default_value());
913 self.null_state.accumulate(
914 group_indices,
915 partial_sums,
916 opt_filter,
917 total_num_groups,
918 |group_index, new_value: <T as ArrowPrimitiveType>::Native| {
919 let sum = &mut self.sums[group_index];
920 *sum = sum.add_wrapping(new_value);
921 },
922 );
923
924 Ok(())
925 }
926
927 fn convert_to_state(
928 &self,
929 values: &[ArrayRef],
930 opt_filter: Option<&BooleanArray>,
931 ) -> Result<Vec<ArrayRef>> {
932 let sums = values[0]
933 .as_primitive::<T>()
934 .clone()
935 .with_data_type(self.sum_data_type.clone());
936 let counts = UInt64Array::from_value(1, sums.len());
937
938 let nulls = filtered_null_mask(opt_filter, &sums);
939
940 let counts = set_nulls(counts, nulls.clone());
942 let sums = set_nulls(sums, nulls);
943
944 Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
945 }
946
947 fn supports_convert_to_state(&self) -> bool {
948 true
949 }
950
951 fn size(&self) -> usize {
952 self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
953 }
954}