1use arrow::array::{
19 builder::PrimitiveBuilder,
20 cast::AsArray,
21 types::{Decimal128Type, Int64Type},
22 Array, ArrayRef, Decimal128Array, Int64Array, PrimitiveArray,
23};
24use arrow::datatypes::{DataType, Field, FieldRef};
25use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer, compute::sum};
26use datafusion::common::{not_impl_err, Result, ScalarValue};
27use datafusion::logical_expr::{
28 Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature,
29};
30use datafusion::physical_expr::expressions::format_state_name;
31use std::{any::Any, sync::Arc};
32
33use crate::utils::is_valid_decimal_precision;
34use arrow::array::ArrowNativeTypeOp;
35use arrow::datatypes::{MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION};
36use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
37use datafusion::logical_expr::type_coercion::aggregates::avg_return_type;
38use datafusion::logical_expr::Volatility::Immutable;
39use num::{integer::div_ceil, Integer};
40use DataType::*;
41
42#[derive(Debug, Clone)]
44pub struct AvgDecimal {
45 signature: Signature,
46 sum_data_type: DataType,
47 result_data_type: DataType,
48}
49
50impl AvgDecimal {
51 pub fn new(result_type: DataType, sum_type: DataType) -> Self {
53 Self {
54 signature: Signature::user_defined(Immutable),
55 result_data_type: result_type,
56 sum_data_type: sum_type,
57 }
58 }
59}
60
61impl AggregateUDFImpl for AvgDecimal {
62 fn as_any(&self) -> &dyn Any {
64 self
65 }
66
67 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
68 match (&self.sum_data_type, &self.result_data_type) {
69 (Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale)) => {
70 Ok(Box::new(AvgDecimalAccumulator::new(
71 *sum_scale,
72 *sum_precision,
73 *target_precision,
74 *target_scale,
75 )))
76 }
77 _ => not_impl_err!(
78 "AvgDecimalAccumulator for ({} --> {})",
79 self.sum_data_type,
80 self.result_data_type
81 ),
82 }
83 }
84
85 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
86 Ok(vec![
87 Arc::new(Field::new(
88 format_state_name(self.name(), "sum"),
89 self.sum_data_type.clone(),
90 true,
91 )),
92 Arc::new(Field::new(
93 format_state_name(self.name(), "count"),
94 DataType::Int64,
95 true,
96 )),
97 ])
98 }
99
100 fn name(&self) -> &str {
101 "avg"
102 }
103
104 fn reverse_expr(&self) -> ReversedUDAF {
105 ReversedUDAF::Identical
106 }
107
108 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
109 true
110 }
111
112 fn create_groups_accumulator(
113 &self,
114 _args: AccumulatorArgs,
115 ) -> Result<Box<dyn GroupsAccumulator>> {
116 match (&self.sum_data_type, &self.result_data_type) {
118 (Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale)) => {
119 Ok(Box::new(AvgDecimalGroupsAccumulator::new(
120 &self.result_data_type,
121 &self.sum_data_type,
122 *target_precision,
123 *target_scale,
124 *sum_precision,
125 *sum_scale,
126 )))
127 }
128 _ => not_impl_err!(
129 "AvgDecimalGroupsAccumulator for ({} --> {})",
130 self.sum_data_type,
131 self.result_data_type
132 ),
133 }
134 }
135
136 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
137 match &self.result_data_type {
138 Decimal128(target_precision, target_scale) => {
139 Ok(make_decimal128(None, *target_precision, *target_scale))
140 }
141 _ => not_impl_err!(
142 "The result_data_type of AvgDecimal should be Decimal128 but got{}",
143 self.result_data_type
144 ),
145 }
146 }
147
148 fn signature(&self) -> &Signature {
149 &self.signature
150 }
151
152 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
153 avg_return_type(self.name(), &arg_types[0])
154 }
155}
156
157#[derive(Debug)]
159struct AvgDecimalAccumulator {
160 sum: Option<i128>,
161 count: i64,
162 is_empty: bool,
163 is_not_null: bool,
164 sum_scale: i8,
165 sum_precision: u8,
166 target_precision: u8,
167 target_scale: i8,
168}
169
170impl AvgDecimalAccumulator {
171 pub fn new(sum_scale: i8, sum_precision: u8, target_precision: u8, target_scale: i8) -> Self {
172 Self {
173 sum: None,
174 count: 0,
175 is_empty: true,
176 is_not_null: true,
177 sum_scale,
178 sum_precision,
179 target_precision,
180 target_scale,
181 }
182 }
183
184 fn update_single(&mut self, values: &Decimal128Array, idx: usize) {
185 let v = unsafe { values.value_unchecked(idx) };
186 let (new_sum, is_overflow) = match self.sum {
187 Some(sum) => sum.overflowing_add(v),
188 None => (v, false),
189 };
190
191 if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) {
192 self.is_not_null = false;
194 return;
195 }
196
197 self.sum = Some(new_sum);
198
199 if let Some(new_count) = self.count.checked_add(1) {
200 self.count = new_count;
201 } else {
202 self.is_not_null = false;
203 return;
204 }
205
206 self.is_not_null = true;
207 }
208}
209
210fn make_decimal128(value: Option<i128>, precision: u8, scale: i8) -> ScalarValue {
211 ScalarValue::Decimal128(value, precision, scale)
212}
213
214impl Accumulator for AvgDecimalAccumulator {
215 fn state(&mut self) -> Result<Vec<ScalarValue>> {
216 Ok(vec![
217 ScalarValue::Decimal128(self.sum, self.sum_precision, self.sum_scale),
218 ScalarValue::from(self.count),
219 ])
220 }
221
222 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
223 if !self.is_empty && !self.is_not_null {
224 return Ok(());
227 }
228
229 let values = &values[0];
230 let data = values.as_primitive::<Decimal128Type>();
231
232 self.is_empty = self.is_empty && values.len() == values.null_count();
233
234 if values.null_count() == 0 {
235 for i in 0..data.len() {
236 self.update_single(data, i);
237 }
238 } else {
239 for i in 0..data.len() {
240 if data.is_null(i) {
241 continue;
242 }
243 self.update_single(data, i);
244 }
245 }
246 Ok(())
247 }
248
249 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
250 self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
252
253 if let Some(x) = sum(states[0].as_primitive::<Decimal128Type>()) {
255 let v = self.sum.get_or_insert(0);
256 let (result, overflowed) = v.overflowing_add(x);
257 if overflowed {
258 self.sum = None;
260 } else {
261 *v = result;
262 }
263 }
264 Ok(())
265 }
266
267 fn evaluate(&mut self) -> Result<ScalarValue> {
268 let scaler = 10_i128.pow(self.target_scale.saturating_sub(self.sum_scale) as u32);
269 let target_min = MIN_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
270 let target_max = MAX_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
271
272 let result = self
273 .sum
274 .map(|v| avg(v, self.count as i128, target_min, target_max, scaler));
275
276 match result {
277 Some(value) => Ok(make_decimal128(
278 value,
279 self.target_precision,
280 self.target_scale,
281 )),
282 _ => Ok(make_decimal128(
283 None,
284 self.target_precision,
285 self.target_scale,
286 )),
287 }
288 }
289
290 fn size(&self) -> usize {
291 std::mem::size_of_val(self)
292 }
293}
294
295#[derive(Debug)]
296struct AvgDecimalGroupsAccumulator {
297 is_not_null: BooleanBufferBuilder,
299
300 is_empty: BooleanBufferBuilder,
302
303 return_data_type: DataType,
305 target_precision: u8,
306 target_scale: i8,
307
308 counts: Vec<i64>,
310
311 sums: Vec<i128>,
313
314 sum_data_type: DataType,
316 sum_precision: u8,
318 sum_scale: i8,
319}
320
321impl AvgDecimalGroupsAccumulator {
322 pub fn new(
323 return_data_type: &DataType,
324 sum_data_type: &DataType,
325 target_precision: u8,
326 target_scale: i8,
327 sum_precision: u8,
328 sum_scale: i8,
329 ) -> Self {
330 Self {
331 is_not_null: BooleanBufferBuilder::new(0),
332 is_empty: BooleanBufferBuilder::new(0),
333 return_data_type: return_data_type.clone(),
334 target_precision,
335 target_scale,
336 sum_data_type: sum_data_type.clone(),
337 sum_precision,
338 sum_scale,
339 counts: vec![],
340 sums: vec![],
341 }
342 }
343
344 fn is_overflow(&self, index: usize) -> bool {
345 !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index)
346 }
347
348 fn update_single(&mut self, group_index: usize, value: i128) {
349 if self.is_overflow(group_index) {
350 return;
353 }
354
355 self.is_empty.set_bit(group_index, false);
356 let (new_sum, is_overflow) = self.sums[group_index].overflowing_add(value);
357 self.counts[group_index] += 1;
358
359 if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) {
360 self.is_not_null.set_bit(group_index, false);
362 return;
363 }
364
365 self.sums[group_index] = new_sum;
366 self.is_not_null.set_bit(group_index, true)
367 }
368}
369
370fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) {
371 if builder.len() < capacity {
372 let additional = capacity - builder.len();
373 builder.append_n(additional, true);
374 }
375}
376
377impl GroupsAccumulator for AvgDecimalGroupsAccumulator {
378 fn update_batch(
379 &mut self,
380 values: &[ArrayRef],
381 group_indices: &[usize],
382 _opt_filter: Option<&arrow::array::BooleanArray>,
383 total_num_groups: usize,
384 ) -> Result<()> {
385 assert_eq!(values.len(), 1, "single argument to update_batch");
386 let values = values[0].as_primitive::<Decimal128Type>();
387 let data = values.values();
388
389 self.counts.resize(total_num_groups, 0);
391 self.sums.resize(total_num_groups, 0);
392 ensure_bit_capacity(&mut self.is_empty, total_num_groups);
393 ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
394
395 let iter = group_indices.iter().zip(data.iter());
396 if values.null_count() == 0 {
397 for (&group_index, &value) in iter {
398 self.update_single(group_index, value);
399 }
400 } else {
401 for (idx, (&group_index, &value)) in iter.enumerate() {
402 if values.is_null(idx) {
403 continue;
404 }
405 self.update_single(group_index, value);
406 }
407 }
408 Ok(())
409 }
410
411 fn merge_batch(
412 &mut self,
413 values: &[ArrayRef],
414 group_indices: &[usize],
415 _opt_filter: Option<&arrow::array::BooleanArray>,
416 total_num_groups: usize,
417 ) -> Result<()> {
418 assert_eq!(values.len(), 2, "two arguments to merge_batch");
419 let partial_sums = values[0].as_primitive::<Decimal128Type>();
421 let partial_counts = values[1].as_primitive::<Int64Type>();
422 self.counts.resize(total_num_groups, 0);
424 let iter1 = group_indices.iter().zip(partial_counts.values().iter());
425 for (&group_index, &partial_count) in iter1 {
426 self.counts[group_index] += partial_count;
427 }
428
429 self.sums.resize(total_num_groups, 0);
431 let iter2 = group_indices.iter().zip(partial_sums.values().iter());
432 for (&group_index, &new_value) in iter2 {
433 let sum = &mut self.sums[group_index];
434 *sum = sum.add_wrapping(new_value);
435 }
436
437 Ok(())
438 }
439
440 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
441 let counts = emit_to.take_needed(&mut self.counts);
442 let sums = emit_to.take_needed(&mut self.sums);
443
444 let mut builder = PrimitiveBuilder::<Decimal128Type>::with_capacity(sums.len())
445 .with_data_type(self.return_data_type.clone());
446 let iter = sums.into_iter().zip(counts);
447
448 let scaler = 10_i128.pow(self.target_scale.saturating_sub(self.sum_scale) as u32);
449 let target_min = MIN_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
450 let target_max = MAX_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
451
452 for (sum, count) in iter {
453 if count != 0 {
454 match avg(sum, count as i128, target_min, target_max, scaler) {
455 Some(value) => {
456 builder.append_value(value);
457 }
458 _ => {
459 builder.append_null();
460 }
461 }
462 } else {
463 builder.append_null();
464 }
465 }
466 let array: PrimitiveArray<Decimal128Type> = builder.finish();
467
468 Ok(Arc::new(array))
469 }
470
471 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
473 let nulls = self.is_not_null.finish();
474 let nulls = Some(NullBuffer::new(nulls));
475
476 let counts = emit_to.take_needed(&mut self.counts);
477 let counts = Int64Array::new(counts.into(), nulls.clone());
478
479 let sums = emit_to.take_needed(&mut self.sums);
480 let sums =
481 Decimal128Array::new(sums.into(), nulls).with_data_type(self.sum_data_type.clone());
482
483 Ok(vec![
484 Arc::new(sums) as ArrayRef,
485 Arc::new(counts) as ArrayRef,
486 ])
487 }
488
489 fn size(&self) -> usize {
490 self.counts.capacity() * std::mem::size_of::<i64>()
491 + self.sums.capacity() * std::mem::size_of::<i128>()
492 }
493}
494
495#[inline(always)]
504fn avg(sum: i128, count: i128, target_min: i128, target_max: i128, scaler: i128) -> Option<i128> {
505 if let Some(value) = sum.checked_mul(scaler) {
506 let (div, rem) = value.div_rem(&count);
508 let half = div_ceil(count, 2);
509 let half_neg = half.neg_wrapping();
510 let new_value = match value >= 0 {
511 true if rem >= half => div.add_wrapping(1),
512 false if rem <= half_neg => div.sub_wrapping(1),
513 _ => div,
514 };
515 if new_value >= target_min && new_value <= target_max {
516 Some(new_value)
517 } else {
518 None
519 }
520 } else {
521 None
522 }
523}