1use arrow::array::{
21 Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray,
22 BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
23};
24
25use arrow::compute::sum;
26use arrow::datatypes::{
27 ArrowNativeType, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE,
28 DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION,
29 DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType,
30 Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType,
31 DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
32 DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type, i256,
33};
34use datafusion_common::types::{NativeType, logical_float64};
35use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err};
36use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
37use datafusion_expr::utils::format_state_name;
38use datafusion_expr::{
39 Accumulator, AggregateUDFImpl, Coercion, Documentation, EmitTo, Expr,
40 GroupsAccumulator, ReversedUDAF, Signature, TypeSignature, TypeSignatureClass,
41 Volatility,
42};
43use datafusion_functions_aggregate_common::aggregate::avg_distinct::{
44 DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator,
45};
46use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
47use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
48 filtered_null_mask, set_nulls,
49};
50use datafusion_functions_aggregate_common::utils::DecimalAverager;
51use datafusion_macros::user_doc;
52use log::debug;
53use std::fmt::Debug;
54use std::mem::{size_of, size_of_val};
55use std::sync::Arc;
56
57make_udaf_expr_and_func!(
58 Avg,
59 avg,
60 expression,
61 "Returns the avg of a group of values.",
62 avg_udaf
63);
64
65pub fn avg_distinct(expr: Expr) -> Expr {
66 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
67 avg_udaf(),
68 vec![expr],
69 true,
70 None,
71 vec![],
72 None,
73 ))
74}
75
76#[user_doc(
77 doc_section(label = "General Functions"),
78 description = "Returns the average of numeric values in the specified column.",
79 syntax_example = "avg(expression)",
80 sql_example = r#"```sql
81> SELECT avg(column_name) FROM table_name;
82+---------------------------+
83| avg(column_name) |
84+---------------------------+
85| 42.75 |
86+---------------------------+
87```"#,
88 standard_argument(name = "expression",)
89)]
90#[derive(Debug, PartialEq, Eq, Hash)]
91pub struct Avg {
92 signature: Signature,
93 aliases: Vec<String>,
94}
95
96impl Avg {
97 pub fn new() -> Self {
98 Self {
99 signature: Signature::one_of(
102 vec![
103 TypeSignature::Coercible(vec![Coercion::new_exact(
104 TypeSignatureClass::Decimal,
105 )]),
106 TypeSignature::Coercible(vec![Coercion::new_exact(
107 TypeSignatureClass::Duration,
108 )]),
109 TypeSignature::Coercible(vec![Coercion::new_implicit(
110 TypeSignatureClass::Native(logical_float64()),
111 vec![TypeSignatureClass::Integer, TypeSignatureClass::Float],
112 NativeType::Float64,
113 )]),
114 ],
115 Volatility::Immutable,
116 ),
117 aliases: vec![String::from("mean")],
118 }
119 }
120}
121
122impl Default for Avg {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128impl AggregateUDFImpl for Avg {
129 fn name(&self) -> &str {
130 "avg"
131 }
132
133 fn signature(&self) -> &Signature {
134 &self.signature
135 }
136
137 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
138 match &arg_types[0] {
139 DataType::Decimal32(precision, scale) => {
140 let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
143 let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
144 Ok(DataType::Decimal32(new_precision, new_scale))
145 }
146 DataType::Decimal64(precision, scale) => {
147 let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
150 let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
151 Ok(DataType::Decimal64(new_precision, new_scale))
152 }
153 DataType::Decimal128(precision, scale) => {
154 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
157 let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
158 Ok(DataType::Decimal128(new_precision, new_scale))
159 }
160 DataType::Decimal256(precision, scale) => {
161 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
164 let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
165 Ok(DataType::Decimal256(new_precision, new_scale))
166 }
167 DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
168 _ => Ok(DataType::Float64),
169 }
170 }
171
172 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
173 let data_type = acc_args.expr_fields[0].data_type();
174 use DataType::*;
175
176 if acc_args.is_distinct {
178 match (data_type, acc_args.return_type()) {
179 (Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())),
181
182 (
183 Decimal32(_, scale),
184 Decimal32(target_precision, target_scale),
185 ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal32Type>::with_decimal_params(
186 *scale,
187 *target_precision,
188 *target_scale,
189 ))),
190 (
191 Decimal64(_, scale),
192 Decimal64(target_precision, target_scale),
193 ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal64Type>::with_decimal_params(
194 *scale,
195 *target_precision,
196 *target_scale,
197 ))),
198 (
199 Decimal128(_, scale),
200 Decimal128(target_precision, target_scale),
201 ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
202 *scale,
203 *target_precision,
204 *target_scale,
205 ))),
206
207 (
208 Decimal256(_, scale),
209 Decimal256(target_precision, target_scale),
210 ) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
211 *scale,
212 *target_precision,
213 *target_scale,
214 ))),
215
216 (dt, return_type) => exec_err!(
217 "AVG(DISTINCT) for ({} --> {}) not supported",
218 dt,
219 return_type
220 ),
221 }
222 } else {
223 match (&data_type, acc_args.return_type()) {
224 (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
225 (
226 Decimal32(sum_precision, sum_scale),
227 Decimal32(target_precision, target_scale),
228 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal32Type> {
229 sum: None,
230 count: 0,
231 sum_scale: *sum_scale,
232 sum_precision: *sum_precision,
233 target_precision: *target_precision,
234 target_scale: *target_scale,
235 })),
236 (
237 Decimal64(sum_precision, sum_scale),
238 Decimal64(target_precision, target_scale),
239 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal64Type> {
240 sum: None,
241 count: 0,
242 sum_scale: *sum_scale,
243 sum_precision: *sum_precision,
244 target_precision: *target_precision,
245 target_scale: *target_scale,
246 })),
247 (
248 Decimal128(sum_precision, sum_scale),
249 Decimal128(target_precision, target_scale),
250 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
251 sum: None,
252 count: 0,
253 sum_scale: *sum_scale,
254 sum_precision: *sum_precision,
255 target_precision: *target_precision,
256 target_scale: *target_scale,
257 })),
258
259 (
260 Decimal256(sum_precision, sum_scale),
261 Decimal256(target_precision, target_scale),
262 ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
263 sum: None,
264 count: 0,
265 sum_scale: *sum_scale,
266 sum_precision: *sum_precision,
267 target_precision: *target_precision,
268 target_scale: *target_scale,
269 })),
270
271 (Duration(time_unit), Duration(result_unit)) => {
272 Ok(Box::new(DurationAvgAccumulator {
273 sum: None,
274 count: 0,
275 time_unit: *time_unit,
276 result_unit: *result_unit,
277 }))
278 }
279
280 (dt, return_type) => {
281 exec_err!("AvgAccumulator for ({} --> {})", dt, return_type)
282 }
283 }
284 }
285 }
286
287 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
288 if args.is_distinct {
289 let dt = match args.input_fields[0].data_type() {
292 DataType::Decimal32(_, scale) => {
293 DataType::Decimal32(DECIMAL32_MAX_PRECISION, *scale)
294 }
295 DataType::Decimal64(_, scale) => {
296 DataType::Decimal64(DECIMAL64_MAX_PRECISION, *scale)
297 }
298 DataType::Decimal128(_, scale) => {
299 DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale)
300 }
301 DataType::Decimal256(_, scale) => {
302 DataType::Decimal256(DECIMAL256_MAX_PRECISION, *scale)
303 }
304 _ => args.return_type().clone(),
305 };
306 Ok(vec![
309 Field::new_list(
310 format_state_name(args.name, "avg distinct"),
311 Field::new_list_field(dt, true),
312 false,
313 )
314 .into(),
315 ])
316 } else {
317 Ok(vec![
318 Field::new(
319 format_state_name(args.name, "count"),
320 DataType::UInt64,
321 true,
322 ),
323 Field::new(
324 format_state_name(args.name, "sum"),
325 args.input_fields[0].data_type().clone(),
326 true,
327 ),
328 ]
329 .into_iter()
330 .map(Arc::new)
331 .collect())
332 }
333 }
334
335 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
336 matches!(
337 args.return_field.data_type(),
338 DataType::Float64
339 | DataType::Decimal32(_, _)
340 | DataType::Decimal64(_, _)
341 | DataType::Decimal128(_, _)
342 | DataType::Decimal256(_, _)
343 | DataType::Duration(_)
344 ) && !args.is_distinct
345 }
346
347 fn create_groups_accumulator(
348 &self,
349 args: AccumulatorArgs,
350 ) -> Result<Box<dyn GroupsAccumulator>> {
351 use DataType::*;
352
353 let data_type = args.expr_fields[0].data_type();
354
355 match (data_type, args.return_field.data_type()) {
357 (Float64, Float64) => {
358 Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
359 data_type,
360 args.return_field.data_type(),
361 |sum: f64, count: u64| Ok(sum / count as f64),
362 )))
363 }
364 (
365 Decimal32(_sum_precision, sum_scale),
366 Decimal32(target_precision, target_scale),
367 ) => {
368 let decimal_averager = DecimalAverager::<Decimal32Type>::try_new(
369 *sum_scale,
370 *target_precision,
371 *target_scale,
372 )?;
373
374 let avg_fn =
375 move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32);
376
377 Ok(Box::new(AvgGroupsAccumulator::<Decimal32Type, _>::new(
378 data_type,
379 args.return_field.data_type(),
380 avg_fn,
381 )))
382 }
383 (
384 Decimal64(_sum_precision, sum_scale),
385 Decimal64(target_precision, target_scale),
386 ) => {
387 let decimal_averager = DecimalAverager::<Decimal64Type>::try_new(
388 *sum_scale,
389 *target_precision,
390 *target_scale,
391 )?;
392
393 let avg_fn =
394 move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64);
395
396 Ok(Box::new(AvgGroupsAccumulator::<Decimal64Type, _>::new(
397 data_type,
398 args.return_field.data_type(),
399 avg_fn,
400 )))
401 }
402 (
403 Decimal128(_sum_precision, sum_scale),
404 Decimal128(target_precision, target_scale),
405 ) => {
406 let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
407 *sum_scale,
408 *target_precision,
409 *target_scale,
410 )?;
411
412 let avg_fn =
413 move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
414
415 Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
416 data_type,
417 args.return_field.data_type(),
418 avg_fn,
419 )))
420 }
421
422 (
423 Decimal256(_sum_precision, sum_scale),
424 Decimal256(target_precision, target_scale),
425 ) => {
426 let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
427 *sum_scale,
428 *target_precision,
429 *target_scale,
430 )?;
431
432 let avg_fn = move |sum: i256, count: u64| {
433 decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
434 };
435
436 Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
437 data_type,
438 args.return_field.data_type(),
439 avg_fn,
440 )))
441 }
442
443 (Duration(time_unit), Duration(_result_unit)) => {
444 let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64);
445
446 match time_unit {
447 TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::<
448 DurationSecondType,
449 _,
450 >::new(
451 data_type,
452 args.return_type(),
453 avg_fn,
454 ))),
455 TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::<
456 DurationMillisecondType,
457 _,
458 >::new(
459 data_type,
460 args.return_type(),
461 avg_fn,
462 ))),
463 TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::<
464 DurationMicrosecondType,
465 _,
466 >::new(
467 data_type,
468 args.return_type(),
469 avg_fn,
470 ))),
471 TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
472 DurationNanosecondType,
473 _,
474 >::new(
475 data_type,
476 args.return_type(),
477 avg_fn,
478 ))),
479 }
480 }
481
482 _ => not_impl_err!(
483 "AvgGroupsAccumulator for ({} --> {})",
484 &data_type,
485 args.return_field.data_type()
486 ),
487 }
488 }
489
490 fn aliases(&self) -> &[String] {
491 &self.aliases
492 }
493
494 fn reverse_expr(&self) -> ReversedUDAF {
495 ReversedUDAF::Identical
496 }
497
498 fn documentation(&self) -> Option<&Documentation> {
499 self.doc()
500 }
501}
502
503#[derive(Debug, Default)]
505pub struct AvgAccumulator {
506 sum: Option<f64>,
507 count: u64,
508}
509
510impl Accumulator for AvgAccumulator {
511 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
512 let values = values[0].as_primitive::<Float64Type>();
513 self.count += (values.len() - values.null_count()) as u64;
514 if let Some(x) = sum(values) {
515 let v = self.sum.get_or_insert(0.);
516 *v += x;
517 }
518 Ok(())
519 }
520
521 fn evaluate(&mut self) -> Result<ScalarValue> {
522 let avg = if self.count == 0 {
527 None
528 } else {
529 self.sum.map(|f| f / self.count as f64)
530 };
531 Ok(ScalarValue::Float64(avg))
532 }
533
534 fn size(&self) -> usize {
535 size_of_val(self)
536 }
537
538 fn state(&mut self) -> Result<Vec<ScalarValue>> {
539 Ok(vec![
540 ScalarValue::from(self.count),
541 ScalarValue::Float64(self.sum),
542 ])
543 }
544
545 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
546 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
548
549 if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
551 let v = self.sum.get_or_insert(0.);
552 *v += x;
553 }
554 Ok(())
555 }
556 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
557 let values = values[0].as_primitive::<Float64Type>();
558 self.count -= (values.len() - values.null_count()) as u64;
559 if let Some(x) = sum(values) {
560 self.sum = Some(self.sum.unwrap() - x);
561 }
562 Ok(())
563 }
564
565 fn supports_retract_batch(&self) -> bool {
566 true
567 }
568}
569
570#[derive(Debug)]
572struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> {
573 sum: Option<T::Native>,
574 count: u64,
575 sum_scale: i8,
576 sum_precision: u8,
577 target_precision: u8,
578 target_scale: i8,
579}
580
581impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> {
582 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
583 let values = values[0].as_primitive::<T>();
584 self.count += (values.len() - values.null_count()) as u64;
585
586 if let Some(x) = sum(values) {
587 let v = self.sum.get_or_insert_with(T::Native::default);
588 self.sum = Some(v.add_wrapping(x));
589 }
590 Ok(())
591 }
592
593 fn evaluate(&mut self) -> Result<ScalarValue> {
594 let v = if self.count == 0 {
598 None
599 } else {
600 self.sum
601 .map(|v| {
602 DecimalAverager::<T>::try_new(
603 self.sum_scale,
604 self.target_precision,
605 self.target_scale,
606 )?
607 .avg(v, T::Native::from_usize(self.count as usize).unwrap())
608 })
609 .transpose()?
610 };
611
612 ScalarValue::new_primitive::<T>(
613 v,
614 &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
615 )
616 }
617
618 fn size(&self) -> usize {
619 size_of_val(self)
620 }
621
622 fn state(&mut self) -> Result<Vec<ScalarValue>> {
623 Ok(vec![
624 ScalarValue::from(self.count),
625 ScalarValue::new_primitive::<T>(
626 self.sum,
627 &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
628 )?,
629 ])
630 }
631
632 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
633 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
635
636 if let Some(x) = sum(states[1].as_primitive::<T>()) {
638 let v = self.sum.get_or_insert_with(T::Native::default);
639 self.sum = Some(v.add_wrapping(x));
640 }
641 Ok(())
642 }
643 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
644 let values = values[0].as_primitive::<T>();
645 self.count -= (values.len() - values.null_count()) as u64;
646 if let Some(x) = sum(values) {
647 self.sum = Some(self.sum.unwrap().sub_wrapping(x));
648 }
649 Ok(())
650 }
651
652 fn supports_retract_batch(&self) -> bool {
653 true
654 }
655}
656
657#[derive(Debug)]
659struct DurationAvgAccumulator {
660 sum: Option<i64>,
661 count: u64,
662 time_unit: TimeUnit,
663 result_unit: TimeUnit,
664}
665
666impl Accumulator for DurationAvgAccumulator {
667 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
668 let array = &values[0];
669 self.count += (array.len() - array.null_count()) as u64;
670
671 let sum_value = match self.time_unit {
672 TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
673 TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
674 TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
675 TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
676 };
677
678 if let Some(x) = sum_value {
679 let v = self.sum.get_or_insert(0);
680 *v += x;
681 }
682 Ok(())
683 }
684
685 fn evaluate(&mut self) -> Result<ScalarValue> {
686 let avg = if self.count == 0 {
690 None
691 } else {
692 self.sum.map(|sum| sum / self.count as i64)
693 };
694
695 match self.result_unit {
696 TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
697 TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)),
698 TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)),
699 TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)),
700 }
701 }
702
703 fn size(&self) -> usize {
704 size_of_val(self)
705 }
706
707 fn state(&mut self) -> Result<Vec<ScalarValue>> {
708 let duration_value = match self.time_unit {
709 TimeUnit::Second => ScalarValue::DurationSecond(self.sum),
710 TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum),
711 TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum),
712 TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum),
713 };
714
715 Ok(vec![ScalarValue::from(self.count), duration_value])
716 }
717
718 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
719 self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
720
721 let sum_value = match self.time_unit {
722 TimeUnit::Second => sum(states[1].as_primitive::<DurationSecondType>()),
723 TimeUnit::Millisecond => {
724 sum(states[1].as_primitive::<DurationMillisecondType>())
725 }
726 TimeUnit::Microsecond => {
727 sum(states[1].as_primitive::<DurationMicrosecondType>())
728 }
729 TimeUnit::Nanosecond => {
730 sum(states[1].as_primitive::<DurationNanosecondType>())
731 }
732 };
733
734 if let Some(x) = sum_value {
735 let v = self.sum.get_or_insert(0);
736 *v += x;
737 }
738 Ok(())
739 }
740
741 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
742 let array = &values[0];
743 self.count -= (array.len() - array.null_count()) as u64;
744
745 let sum_value = match self.time_unit {
746 TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
747 TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
748 TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
749 TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
750 };
751
752 if let Some(x) = sum_value {
753 self.sum = Some(self.sum.unwrap() - x);
754 }
755 Ok(())
756 }
757
758 fn supports_retract_batch(&self) -> bool {
759 true
760 }
761}
762
763#[derive(Debug)]
769struct AvgGroupsAccumulator<T, F>
770where
771 T: ArrowNumericType + Send,
772 F: Fn(T::Native, u64) -> Result<T::Native> + Send + 'static,
773{
774 sum_data_type: DataType,
776
777 return_data_type: DataType,
779
780 counts: Vec<u64>,
782
783 sums: Vec<T::Native>,
785
786 null_state: NullState,
788
789 avg_fn: F,
791}
792
793impl<T, F> AvgGroupsAccumulator<T, F>
794where
795 T: ArrowNumericType + Send,
796 F: Fn(T::Native, u64) -> Result<T::Native> + Send + 'static,
797{
798 pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
799 debug!(
800 "AvgGroupsAccumulator ({}, sum type: {sum_data_type}) --> {return_data_type}",
801 std::any::type_name::<T>()
802 );
803
804 Self {
805 return_data_type: return_data_type.clone(),
806 sum_data_type: sum_data_type.clone(),
807 counts: vec![],
808 sums: vec![],
809 null_state: NullState::new(),
810 avg_fn,
811 }
812 }
813}
814
815impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
816where
817 T: ArrowNumericType + Send,
818 F: Fn(T::Native, u64) -> Result<T::Native> + Send + 'static,
819{
820 fn update_batch(
821 &mut self,
822 values: &[ArrayRef],
823 group_indices: &[usize],
824 opt_filter: Option<&BooleanArray>,
825 total_num_groups: usize,
826 ) -> Result<()> {
827 assert_eq!(values.len(), 1, "single argument to update_batch");
828 let values = values[0].as_primitive::<T>();
829
830 self.counts.resize(total_num_groups, 0);
832 self.sums.resize(total_num_groups, T::default_value());
833 self.null_state.accumulate(
834 group_indices,
835 values,
836 opt_filter,
837 total_num_groups,
838 |group_index, new_value| {
839 let sum = unsafe { self.sums.get_unchecked_mut(group_index) };
841 *sum = sum.add_wrapping(new_value);
842
843 self.counts[group_index] += 1;
844 },
845 );
846
847 Ok(())
848 }
849
850 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
851 let counts = emit_to.take_needed(&mut self.counts);
852 let sums = emit_to.take_needed(&mut self.sums);
853 let nulls = self.null_state.build(emit_to);
854
855 if let Some(nulls) = &nulls {
856 assert_eq!(nulls.len(), sums.len());
857 }
858 assert_eq!(counts.len(), sums.len());
859
860 let array: PrimitiveArray<T> = if let Some(nulls) = &nulls
863 && nulls.null_count() > 0
864 {
865 let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
866 .with_data_type(self.return_data_type.clone());
867 let iter = sums.into_iter().zip(counts).zip(nulls.iter());
868
869 for ((sum, count), is_valid) in iter {
870 if is_valid {
871 builder.append_value((self.avg_fn)(sum, count)?)
872 } else {
873 builder.append_null();
874 }
875 }
876 builder.finish()
877 } else {
878 let averages: Vec<T::Native> = sums
879 .into_iter()
880 .zip(counts)
881 .map(|(sum, count)| (self.avg_fn)(sum, count))
882 .collect::<Result<Vec<_>>>()?;
883 PrimitiveArray::new(averages.into(), nulls) .with_data_type(self.return_data_type.clone())
885 };
886
887 Ok(Arc::new(array))
888 }
889
890 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
892 let nulls = self.null_state.build(emit_to);
893
894 let counts = emit_to.take_needed(&mut self.counts);
895 let counts = UInt64Array::new(counts.into(), nulls.clone()); let sums = emit_to.take_needed(&mut self.sums);
898 let sums = PrimitiveArray::<T>::new(sums.into(), nulls) .with_data_type(self.sum_data_type.clone());
900
901 Ok(vec![
902 Arc::new(counts) as ArrayRef,
903 Arc::new(sums) as ArrayRef,
904 ])
905 }
906
907 fn merge_batch(
908 &mut self,
909 values: &[ArrayRef],
910 group_indices: &[usize],
911 opt_filter: Option<&BooleanArray>,
912 total_num_groups: usize,
913 ) -> Result<()> {
914 assert_eq!(values.len(), 2, "two arguments to merge_batch");
915 let partial_counts = values[0].as_primitive::<UInt64Type>();
917 let partial_sums = values[1].as_primitive::<T>();
918 self.counts.resize(total_num_groups, 0);
920 self.null_state.accumulate(
921 group_indices,
922 partial_counts,
923 opt_filter,
924 total_num_groups,
925 |group_index, partial_count| {
926 let count = unsafe { self.counts.get_unchecked_mut(group_index) };
928 *count += partial_count;
929 },
930 );
931
932 self.sums.resize(total_num_groups, T::default_value());
934 self.null_state.accumulate(
935 group_indices,
936 partial_sums,
937 opt_filter,
938 total_num_groups,
939 |group_index, new_value: <T as ArrowPrimitiveType>::Native| {
940 let sum = unsafe { self.sums.get_unchecked_mut(group_index) };
942 *sum = sum.add_wrapping(new_value);
943 },
944 );
945
946 Ok(())
947 }
948
949 fn convert_to_state(
950 &self,
951 values: &[ArrayRef],
952 opt_filter: Option<&BooleanArray>,
953 ) -> Result<Vec<ArrayRef>> {
954 let sums = values[0]
955 .as_primitive::<T>()
956 .clone()
957 .with_data_type(self.sum_data_type.clone());
958 let counts = UInt64Array::from_value(1, sums.len());
959
960 let nulls = filtered_null_mask(opt_filter, &sums);
961
962 let counts = set_nulls(counts, nulls.clone());
964 let sums = set_nulls(sums, nulls);
965
966 Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
967 }
968
969 fn supports_convert_to_state(&self) -> bool {
970 true
971 }
972
973 fn size(&self) -> usize {
974 self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
975 }
976}