1use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer, compute::sum};
19use arrow_array::{
20 builder::PrimitiveBuilder,
21 cast::AsArray,
22 types::{Decimal128Type, Int64Type},
23 Array, ArrayRef, Decimal128Array, Int64Array, PrimitiveArray,
24};
25use arrow_schema::{DataType, Field};
26use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator, Signature};
27use datafusion_common::{not_impl_err, Result, ScalarValue};
28use datafusion_physical_expr::expressions::format_state_name;
29use std::{any::Any, sync::Arc};
30
31use crate::utils::is_valid_decimal_precision;
32use arrow_array::ArrowNativeTypeOp;
33use arrow_data::decimal::{MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION};
34use datafusion::logical_expr::Volatility::Immutable;
35use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
36use datafusion_expr::type_coercion::aggregates::avg_return_type;
37use datafusion_expr::{AggregateUDFImpl, ReversedUDAF};
38use num::{integer::div_ceil, Integer};
39use DataType::*;
40
41#[derive(Debug, Clone)]
43pub struct AvgDecimal {
44 signature: Signature,
45 sum_data_type: DataType,
46 result_data_type: DataType,
47}
48
49impl AvgDecimal {
50 pub fn new(result_type: DataType, sum_type: DataType) -> Self {
52 Self {
53 signature: Signature::user_defined(Immutable),
54 result_data_type: result_type,
55 sum_data_type: sum_type,
56 }
57 }
58}
59
60impl AggregateUDFImpl for AvgDecimal {
61 fn as_any(&self) -> &dyn Any {
63 self
64 }
65
66 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
67 match (&self.sum_data_type, &self.result_data_type) {
68 (Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale)) => {
69 Ok(Box::new(AvgDecimalAccumulator::new(
70 *sum_scale,
71 *sum_precision,
72 *target_precision,
73 *target_scale,
74 )))
75 }
76 _ => not_impl_err!(
77 "AvgDecimalAccumulator for ({} --> {})",
78 self.sum_data_type,
79 self.result_data_type
80 ),
81 }
82 }
83
84 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
85 Ok(vec![
86 Field::new(
87 format_state_name(self.name(), "sum"),
88 self.sum_data_type.clone(),
89 true,
90 ),
91 Field::new(
92 format_state_name(self.name(), "count"),
93 DataType::Int64,
94 true,
95 ),
96 ])
97 }
98
99 fn name(&self) -> &str {
100 "avg"
101 }
102
103 fn reverse_expr(&self) -> ReversedUDAF {
104 ReversedUDAF::Identical
105 }
106
107 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
108 true
109 }
110
111 fn create_groups_accumulator(
112 &self,
113 _args: AccumulatorArgs,
114 ) -> Result<Box<dyn GroupsAccumulator>> {
115 match (&self.sum_data_type, &self.result_data_type) {
117 (Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale)) => {
118 Ok(Box::new(AvgDecimalGroupsAccumulator::new(
119 &self.result_data_type,
120 &self.sum_data_type,
121 *target_precision,
122 *target_scale,
123 *sum_precision,
124 *sum_scale,
125 )))
126 }
127 _ => not_impl_err!(
128 "AvgDecimalGroupsAccumulator for ({} --> {})",
129 self.sum_data_type,
130 self.result_data_type
131 ),
132 }
133 }
134
135 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
136 match &self.result_data_type {
137 Decimal128(target_precision, target_scale) => {
138 Ok(make_decimal128(None, *target_precision, *target_scale))
139 }
140 _ => not_impl_err!(
141 "The result_data_type of AvgDecimal should be Decimal128 but got{}",
142 self.result_data_type
143 ),
144 }
145 }
146
147 fn signature(&self) -> &Signature {
148 &self.signature
149 }
150
151 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
152 avg_return_type(self.name(), &arg_types[0])
153 }
154}
155
156#[derive(Debug)]
158struct AvgDecimalAccumulator {
159 sum: Option<i128>,
160 count: i64,
161 is_empty: bool,
162 is_not_null: bool,
163 sum_scale: i8,
164 sum_precision: u8,
165 target_precision: u8,
166 target_scale: i8,
167}
168
169impl AvgDecimalAccumulator {
170 pub fn new(sum_scale: i8, sum_precision: u8, target_precision: u8, target_scale: i8) -> Self {
171 Self {
172 sum: None,
173 count: 0,
174 is_empty: true,
175 is_not_null: true,
176 sum_scale,
177 sum_precision,
178 target_precision,
179 target_scale,
180 }
181 }
182
183 fn update_single(&mut self, values: &Decimal128Array, idx: usize) {
184 let v = unsafe { values.value_unchecked(idx) };
185 let (new_sum, is_overflow) = match self.sum {
186 Some(sum) => sum.overflowing_add(v),
187 None => (v, false),
188 };
189
190 if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) {
191 self.is_not_null = false;
193 return;
194 }
195
196 self.sum = Some(new_sum);
197
198 if let Some(new_count) = self.count.checked_add(1) {
199 self.count = new_count;
200 } else {
201 self.is_not_null = false;
202 return;
203 }
204
205 self.is_not_null = true;
206 }
207}
208
209fn make_decimal128(value: Option<i128>, precision: u8, scale: i8) -> ScalarValue {
210 ScalarValue::Decimal128(value, precision, scale)
211}
212
213impl Accumulator for AvgDecimalAccumulator {
214 fn state(&mut self) -> Result<Vec<ScalarValue>> {
215 Ok(vec![
216 ScalarValue::Decimal128(self.sum, self.sum_precision, self.sum_scale),
217 ScalarValue::from(self.count),
218 ])
219 }
220
221 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
222 if !self.is_empty && !self.is_not_null {
223 return Ok(());
226 }
227
228 let values = &values[0];
229 let data = values.as_primitive::<Decimal128Type>();
230
231 self.is_empty = self.is_empty && values.len() == values.null_count();
232
233 if values.null_count() == 0 {
234 for i in 0..data.len() {
235 self.update_single(data, i);
236 }
237 } else {
238 for i in 0..data.len() {
239 if data.is_null(i) {
240 continue;
241 }
242 self.update_single(data, i);
243 }
244 }
245 Ok(())
246 }
247
248 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
249 self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
251
252 if let Some(x) = sum(states[0].as_primitive::<Decimal128Type>()) {
254 let v = self.sum.get_or_insert(0);
255 let (result, overflowed) = v.overflowing_add(x);
256 if overflowed {
257 self.sum = None;
259 } else {
260 *v = result;
261 }
262 }
263 Ok(())
264 }
265
266 fn evaluate(&mut self) -> Result<ScalarValue> {
267 let scaler = 10_i128.pow(self.target_scale.saturating_sub(self.sum_scale) as u32);
268 let target_min = MIN_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
269 let target_max = MAX_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
270
271 let result = self
272 .sum
273 .map(|v| avg(v, self.count as i128, target_min, target_max, scaler));
274
275 match result {
276 Some(value) => Ok(make_decimal128(
277 value,
278 self.target_precision,
279 self.target_scale,
280 )),
281 _ => Ok(make_decimal128(
282 None,
283 self.target_precision,
284 self.target_scale,
285 )),
286 }
287 }
288
289 fn size(&self) -> usize {
290 std::mem::size_of_val(self)
291 }
292}
293
294#[derive(Debug)]
295struct AvgDecimalGroupsAccumulator {
296 is_not_null: BooleanBufferBuilder,
298
299 is_empty: BooleanBufferBuilder,
301
302 return_data_type: DataType,
304 target_precision: u8,
305 target_scale: i8,
306
307 counts: Vec<i64>,
309
310 sums: Vec<i128>,
312
313 sum_data_type: DataType,
315 sum_precision: u8,
317 sum_scale: i8,
318}
319
320impl AvgDecimalGroupsAccumulator {
321 pub fn new(
322 return_data_type: &DataType,
323 sum_data_type: &DataType,
324 target_precision: u8,
325 target_scale: i8,
326 sum_precision: u8,
327 sum_scale: i8,
328 ) -> Self {
329 Self {
330 is_not_null: BooleanBufferBuilder::new(0),
331 is_empty: BooleanBufferBuilder::new(0),
332 return_data_type: return_data_type.clone(),
333 target_precision,
334 target_scale,
335 sum_data_type: sum_data_type.clone(),
336 sum_precision,
337 sum_scale,
338 counts: vec![],
339 sums: vec![],
340 }
341 }
342
343 fn is_overflow(&self, index: usize) -> bool {
344 !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index)
345 }
346
347 fn update_single(&mut self, group_index: usize, value: i128) {
348 if self.is_overflow(group_index) {
349 return;
352 }
353
354 self.is_empty.set_bit(group_index, false);
355 let (new_sum, is_overflow) = self.sums[group_index].overflowing_add(value);
356 self.counts[group_index] += 1;
357
358 if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) {
359 self.is_not_null.set_bit(group_index, false);
361 return;
362 }
363
364 self.sums[group_index] = new_sum;
365 self.is_not_null.set_bit(group_index, true)
366 }
367}
368
369fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) {
370 if builder.len() < capacity {
371 let additional = capacity - builder.len();
372 builder.append_n(additional, true);
373 }
374}
375
376impl GroupsAccumulator for AvgDecimalGroupsAccumulator {
377 fn update_batch(
378 &mut self,
379 values: &[ArrayRef],
380 group_indices: &[usize],
381 _opt_filter: Option<&arrow_array::BooleanArray>,
382 total_num_groups: usize,
383 ) -> Result<()> {
384 assert_eq!(values.len(), 1, "single argument to update_batch");
385 let values = values[0].as_primitive::<Decimal128Type>();
386 let data = values.values();
387
388 self.counts.resize(total_num_groups, 0);
390 self.sums.resize(total_num_groups, 0);
391 ensure_bit_capacity(&mut self.is_empty, total_num_groups);
392 ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
393
394 let iter = group_indices.iter().zip(data.iter());
395 if values.null_count() == 0 {
396 for (&group_index, &value) in iter {
397 self.update_single(group_index, value);
398 }
399 } else {
400 for (idx, (&group_index, &value)) in iter.enumerate() {
401 if values.is_null(idx) {
402 continue;
403 }
404 self.update_single(group_index, value);
405 }
406 }
407 Ok(())
408 }
409
410 fn merge_batch(
411 &mut self,
412 values: &[ArrayRef],
413 group_indices: &[usize],
414 _opt_filter: Option<&arrow_array::BooleanArray>,
415 total_num_groups: usize,
416 ) -> Result<()> {
417 assert_eq!(values.len(), 2, "two arguments to merge_batch");
418 let partial_sums = values[0].as_primitive::<Decimal128Type>();
420 let partial_counts = values[1].as_primitive::<Int64Type>();
421 self.counts.resize(total_num_groups, 0);
423 let iter1 = group_indices.iter().zip(partial_counts.values().iter());
424 for (&group_index, &partial_count) in iter1 {
425 self.counts[group_index] += partial_count;
426 }
427
428 self.sums.resize(total_num_groups, 0);
430 let iter2 = group_indices.iter().zip(partial_sums.values().iter());
431 for (&group_index, &new_value) in iter2 {
432 let sum = &mut self.sums[group_index];
433 *sum = sum.add_wrapping(new_value);
434 }
435
436 Ok(())
437 }
438
439 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
440 let counts = emit_to.take_needed(&mut self.counts);
441 let sums = emit_to.take_needed(&mut self.sums);
442
443 let mut builder = PrimitiveBuilder::<Decimal128Type>::with_capacity(sums.len())
444 .with_data_type(self.return_data_type.clone());
445 let iter = sums.into_iter().zip(counts);
446
447 let scaler = 10_i128.pow(self.target_scale.saturating_sub(self.sum_scale) as u32);
448 let target_min = MIN_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
449 let target_max = MAX_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
450
451 for (sum, count) in iter {
452 if count != 0 {
453 match avg(sum, count as i128, target_min, target_max, scaler) {
454 Some(value) => {
455 builder.append_value(value);
456 }
457 _ => {
458 builder.append_null();
459 }
460 }
461 } else {
462 builder.append_null();
463 }
464 }
465 let array: PrimitiveArray<Decimal128Type> = builder.finish();
466
467 Ok(Arc::new(array))
468 }
469
470 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
472 let nulls = self.is_not_null.finish();
473 let nulls = Some(NullBuffer::new(nulls));
474
475 let counts = emit_to.take_needed(&mut self.counts);
476 let counts = Int64Array::new(counts.into(), nulls.clone());
477
478 let sums = emit_to.take_needed(&mut self.sums);
479 let sums =
480 Decimal128Array::new(sums.into(), nulls).with_data_type(self.sum_data_type.clone());
481
482 Ok(vec![
483 Arc::new(sums) as ArrayRef,
484 Arc::new(counts) as ArrayRef,
485 ])
486 }
487
488 fn size(&self) -> usize {
489 self.counts.capacity() * std::mem::size_of::<i64>()
490 + self.sums.capacity() * std::mem::size_of::<i128>()
491 }
492}
493
494#[inline(always)]
503fn avg(sum: i128, count: i128, target_min: i128, target_max: i128, scaler: i128) -> Option<i128> {
504 if let Some(value) = sum.checked_mul(scaler) {
505 let (div, rem) = value.div_rem(&count);
507 let half = div_ceil(count, 2);
508 let half_neg = half.neg_wrapping();
509 let new_value = match value >= 0 {
510 true if rem >= half => div.add_wrapping(1),
511 false if rem <= half_neg => div.sub_wrapping(1),
512 _ => div,
513 };
514 if new_value >= target_min && new_value <= target_max {
515 Some(new_value)
516 } else {
517 None
518 }
519 } else {
520 None
521 }
522}