datafusion_comet_spark_expr/agg_funcs/
avg_decimal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer, compute::sum};
19use arrow_array::{
20    builder::PrimitiveBuilder,
21    cast::AsArray,
22    types::{Decimal128Type, Int64Type},
23    Array, ArrayRef, Decimal128Array, Int64Array, PrimitiveArray,
24};
25use arrow_schema::{DataType, Field};
26use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator, Signature};
27use datafusion_common::{not_impl_err, Result, ScalarValue};
28use datafusion_physical_expr::expressions::format_state_name;
29use std::{any::Any, sync::Arc};
30
31use crate::utils::is_valid_decimal_precision;
32use arrow_array::ArrowNativeTypeOp;
33use arrow_data::decimal::{MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION};
34use datafusion::logical_expr::Volatility::Immutable;
35use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
36use datafusion_expr::type_coercion::aggregates::avg_return_type;
37use datafusion_expr::{AggregateUDFImpl, ReversedUDAF};
38use num::{integer::div_ceil, Integer};
39use DataType::*;
40
41/// AVG aggregate expression
42#[derive(Debug, Clone)]
43pub struct AvgDecimal {
44    signature: Signature,
45    sum_data_type: DataType,
46    result_data_type: DataType,
47}
48
49impl AvgDecimal {
50    /// Create a new AVG aggregate function
51    pub fn new(result_type: DataType, sum_type: DataType) -> Self {
52        Self {
53            signature: Signature::user_defined(Immutable),
54            result_data_type: result_type,
55            sum_data_type: sum_type,
56        }
57    }
58}
59
60impl AggregateUDFImpl for AvgDecimal {
61    /// Return a reference to Any that can be used for downcasting
62    fn as_any(&self) -> &dyn Any {
63        self
64    }
65
66    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
67        match (&self.sum_data_type, &self.result_data_type) {
68            (Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale)) => {
69                Ok(Box::new(AvgDecimalAccumulator::new(
70                    *sum_scale,
71                    *sum_precision,
72                    *target_precision,
73                    *target_scale,
74                )))
75            }
76            _ => not_impl_err!(
77                "AvgDecimalAccumulator for ({} --> {})",
78                self.sum_data_type,
79                self.result_data_type
80            ),
81        }
82    }
83
84    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
85        Ok(vec![
86            Field::new(
87                format_state_name(self.name(), "sum"),
88                self.sum_data_type.clone(),
89                true,
90            ),
91            Field::new(
92                format_state_name(self.name(), "count"),
93                DataType::Int64,
94                true,
95            ),
96        ])
97    }
98
99    fn name(&self) -> &str {
100        "avg"
101    }
102
103    fn reverse_expr(&self) -> ReversedUDAF {
104        ReversedUDAF::Identical
105    }
106
107    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
108        true
109    }
110
111    fn create_groups_accumulator(
112        &self,
113        _args: AccumulatorArgs,
114    ) -> Result<Box<dyn GroupsAccumulator>> {
115        // instantiate specialized accumulator based for the type
116        match (&self.sum_data_type, &self.result_data_type) {
117            (Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale)) => {
118                Ok(Box::new(AvgDecimalGroupsAccumulator::new(
119                    &self.result_data_type,
120                    &self.sum_data_type,
121                    *target_precision,
122                    *target_scale,
123                    *sum_precision,
124                    *sum_scale,
125                )))
126            }
127            _ => not_impl_err!(
128                "AvgDecimalGroupsAccumulator for ({} --> {})",
129                self.sum_data_type,
130                self.result_data_type
131            ),
132        }
133    }
134
135    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
136        match &self.result_data_type {
137            Decimal128(target_precision, target_scale) => {
138                Ok(make_decimal128(None, *target_precision, *target_scale))
139            }
140            _ => not_impl_err!(
141                "The result_data_type of AvgDecimal should be Decimal128 but got{}",
142                self.result_data_type
143            ),
144        }
145    }
146
147    fn signature(&self) -> &Signature {
148        &self.signature
149    }
150
151    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
152        avg_return_type(self.name(), &arg_types[0])
153    }
154}
155
156/// An accumulator to compute the average for decimals
157#[derive(Debug)]
158struct AvgDecimalAccumulator {
159    sum: Option<i128>,
160    count: i64,
161    is_empty: bool,
162    is_not_null: bool,
163    sum_scale: i8,
164    sum_precision: u8,
165    target_precision: u8,
166    target_scale: i8,
167}
168
169impl AvgDecimalAccumulator {
170    pub fn new(sum_scale: i8, sum_precision: u8, target_precision: u8, target_scale: i8) -> Self {
171        Self {
172            sum: None,
173            count: 0,
174            is_empty: true,
175            is_not_null: true,
176            sum_scale,
177            sum_precision,
178            target_precision,
179            target_scale,
180        }
181    }
182
183    fn update_single(&mut self, values: &Decimal128Array, idx: usize) {
184        let v = unsafe { values.value_unchecked(idx) };
185        let (new_sum, is_overflow) = match self.sum {
186            Some(sum) => sum.overflowing_add(v),
187            None => (v, false),
188        };
189
190        if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) {
191            // Overflow: set buffer accumulator to null
192            self.is_not_null = false;
193            return;
194        }
195
196        self.sum = Some(new_sum);
197
198        if let Some(new_count) = self.count.checked_add(1) {
199            self.count = new_count;
200        } else {
201            self.is_not_null = false;
202            return;
203        }
204
205        self.is_not_null = true;
206    }
207}
208
209fn make_decimal128(value: Option<i128>, precision: u8, scale: i8) -> ScalarValue {
210    ScalarValue::Decimal128(value, precision, scale)
211}
212
213impl Accumulator for AvgDecimalAccumulator {
214    fn state(&mut self) -> Result<Vec<ScalarValue>> {
215        Ok(vec![
216            ScalarValue::Decimal128(self.sum, self.sum_precision, self.sum_scale),
217            ScalarValue::from(self.count),
218        ])
219    }
220
221    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
222        if !self.is_empty && !self.is_not_null {
223            // This means there's a overflow in decimal, so we will just skip the rest
224            // of the computation
225            return Ok(());
226        }
227
228        let values = &values[0];
229        let data = values.as_primitive::<Decimal128Type>();
230
231        self.is_empty = self.is_empty && values.len() == values.null_count();
232
233        if values.null_count() == 0 {
234            for i in 0..data.len() {
235                self.update_single(data, i);
236            }
237        } else {
238            for i in 0..data.len() {
239                if data.is_null(i) {
240                    continue;
241                }
242                self.update_single(data, i);
243            }
244        }
245        Ok(())
246    }
247
248    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
249        // counts are summed
250        self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
251
252        // sums are summed
253        if let Some(x) = sum(states[0].as_primitive::<Decimal128Type>()) {
254            let v = self.sum.get_or_insert(0);
255            let (result, overflowed) = v.overflowing_add(x);
256            if overflowed {
257                // Set to None if overflow happens
258                self.sum = None;
259            } else {
260                *v = result;
261            }
262        }
263        Ok(())
264    }
265
266    fn evaluate(&mut self) -> Result<ScalarValue> {
267        let scaler = 10_i128.pow(self.target_scale.saturating_sub(self.sum_scale) as u32);
268        let target_min = MIN_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
269        let target_max = MAX_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
270
271        let result = self
272            .sum
273            .map(|v| avg(v, self.count as i128, target_min, target_max, scaler));
274
275        match result {
276            Some(value) => Ok(make_decimal128(
277                value,
278                self.target_precision,
279                self.target_scale,
280            )),
281            _ => Ok(make_decimal128(
282                None,
283                self.target_precision,
284                self.target_scale,
285            )),
286        }
287    }
288
289    fn size(&self) -> usize {
290        std::mem::size_of_val(self)
291    }
292}
293
294#[derive(Debug)]
295struct AvgDecimalGroupsAccumulator {
296    /// Tracks if the value is null
297    is_not_null: BooleanBufferBuilder,
298
299    // Tracks if the value is empty
300    is_empty: BooleanBufferBuilder,
301
302    /// The type of the avg return type
303    return_data_type: DataType,
304    target_precision: u8,
305    target_scale: i8,
306
307    /// Count per group (use i64 to make Int64Array)
308    counts: Vec<i64>,
309
310    /// Sums per group, stored as i128
311    sums: Vec<i128>,
312
313    /// The type of the sum
314    sum_data_type: DataType,
315    /// This is input_precision + 10 to be consistent with Spark
316    sum_precision: u8,
317    sum_scale: i8,
318}
319
320impl AvgDecimalGroupsAccumulator {
321    pub fn new(
322        return_data_type: &DataType,
323        sum_data_type: &DataType,
324        target_precision: u8,
325        target_scale: i8,
326        sum_precision: u8,
327        sum_scale: i8,
328    ) -> Self {
329        Self {
330            is_not_null: BooleanBufferBuilder::new(0),
331            is_empty: BooleanBufferBuilder::new(0),
332            return_data_type: return_data_type.clone(),
333            target_precision,
334            target_scale,
335            sum_data_type: sum_data_type.clone(),
336            sum_precision,
337            sum_scale,
338            counts: vec![],
339            sums: vec![],
340        }
341    }
342
343    fn is_overflow(&self, index: usize) -> bool {
344        !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index)
345    }
346
347    fn update_single(&mut self, group_index: usize, value: i128) {
348        if self.is_overflow(group_index) {
349            // This means there's a overflow in decimal, so we will just skip the rest
350            // of the computation
351            return;
352        }
353
354        self.is_empty.set_bit(group_index, false);
355        let (new_sum, is_overflow) = self.sums[group_index].overflowing_add(value);
356        self.counts[group_index] += 1;
357
358        if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) {
359            // Overflow: set buffer accumulator to null
360            self.is_not_null.set_bit(group_index, false);
361            return;
362        }
363
364        self.sums[group_index] = new_sum;
365        self.is_not_null.set_bit(group_index, true)
366    }
367}
368
369fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) {
370    if builder.len() < capacity {
371        let additional = capacity - builder.len();
372        builder.append_n(additional, true);
373    }
374}
375
376impl GroupsAccumulator for AvgDecimalGroupsAccumulator {
377    fn update_batch(
378        &mut self,
379        values: &[ArrayRef],
380        group_indices: &[usize],
381        _opt_filter: Option<&arrow_array::BooleanArray>,
382        total_num_groups: usize,
383    ) -> Result<()> {
384        assert_eq!(values.len(), 1, "single argument to update_batch");
385        let values = values[0].as_primitive::<Decimal128Type>();
386        let data = values.values();
387
388        // increment counts, update sums
389        self.counts.resize(total_num_groups, 0);
390        self.sums.resize(total_num_groups, 0);
391        ensure_bit_capacity(&mut self.is_empty, total_num_groups);
392        ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
393
394        let iter = group_indices.iter().zip(data.iter());
395        if values.null_count() == 0 {
396            for (&group_index, &value) in iter {
397                self.update_single(group_index, value);
398            }
399        } else {
400            for (idx, (&group_index, &value)) in iter.enumerate() {
401                if values.is_null(idx) {
402                    continue;
403                }
404                self.update_single(group_index, value);
405            }
406        }
407        Ok(())
408    }
409
410    fn merge_batch(
411        &mut self,
412        values: &[ArrayRef],
413        group_indices: &[usize],
414        _opt_filter: Option<&arrow_array::BooleanArray>,
415        total_num_groups: usize,
416    ) -> Result<()> {
417        assert_eq!(values.len(), 2, "two arguments to merge_batch");
418        // first batch is partial sums, second is counts
419        let partial_sums = values[0].as_primitive::<Decimal128Type>();
420        let partial_counts = values[1].as_primitive::<Int64Type>();
421        // update counts with partial counts
422        self.counts.resize(total_num_groups, 0);
423        let iter1 = group_indices.iter().zip(partial_counts.values().iter());
424        for (&group_index, &partial_count) in iter1 {
425            self.counts[group_index] += partial_count;
426        }
427
428        // update sums
429        self.sums.resize(total_num_groups, 0);
430        let iter2 = group_indices.iter().zip(partial_sums.values().iter());
431        for (&group_index, &new_value) in iter2 {
432            let sum = &mut self.sums[group_index];
433            *sum = sum.add_wrapping(new_value);
434        }
435
436        Ok(())
437    }
438
439    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
440        let counts = emit_to.take_needed(&mut self.counts);
441        let sums = emit_to.take_needed(&mut self.sums);
442
443        let mut builder = PrimitiveBuilder::<Decimal128Type>::with_capacity(sums.len())
444            .with_data_type(self.return_data_type.clone());
445        let iter = sums.into_iter().zip(counts);
446
447        let scaler = 10_i128.pow(self.target_scale.saturating_sub(self.sum_scale) as u32);
448        let target_min = MIN_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
449        let target_max = MAX_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
450
451        for (sum, count) in iter {
452            if count != 0 {
453                match avg(sum, count as i128, target_min, target_max, scaler) {
454                    Some(value) => {
455                        builder.append_value(value);
456                    }
457                    _ => {
458                        builder.append_null();
459                    }
460                }
461            } else {
462                builder.append_null();
463            }
464        }
465        let array: PrimitiveArray<Decimal128Type> = builder.finish();
466
467        Ok(Arc::new(array))
468    }
469
470    // return arrays for sums and counts
471    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
472        let nulls = self.is_not_null.finish();
473        let nulls = Some(NullBuffer::new(nulls));
474
475        let counts = emit_to.take_needed(&mut self.counts);
476        let counts = Int64Array::new(counts.into(), nulls.clone());
477
478        let sums = emit_to.take_needed(&mut self.sums);
479        let sums =
480            Decimal128Array::new(sums.into(), nulls).with_data_type(self.sum_data_type.clone());
481
482        Ok(vec![
483            Arc::new(sums) as ArrayRef,
484            Arc::new(counts) as ArrayRef,
485        ])
486    }
487
488    fn size(&self) -> usize {
489        self.counts.capacity() * std::mem::size_of::<i64>()
490            + self.sums.capacity() * std::mem::size_of::<i128>()
491    }
492}
493
494/// Returns the `sum`/`count` as a i128 Decimal128 with
495/// target_scale and target_precision and return None if overflows.
496///
497/// * sum: The total sum value stored as Decimal128 with sum_scale
498/// * count: total count, stored as a i128 (*NOT* a Decimal128 value)
499/// * target_min: The minimum output value possible to represent with the target precision
500/// * target_max: The maximum output value possible to represent with the target precision
501/// * scaler: scale factor for avg
502#[inline(always)]
503fn avg(sum: i128, count: i128, target_min: i128, target_max: i128, scaler: i128) -> Option<i128> {
504    if let Some(value) = sum.checked_mul(scaler) {
505        // `sum / count` with ROUND_HALF_UP
506        let (div, rem) = value.div_rem(&count);
507        let half = div_ceil(count, 2);
508        let half_neg = half.neg_wrapping();
509        let new_value = match value >= 0 {
510            true if rem >= half => div.add_wrapping(1),
511            false if rem <= half_neg => div.sub_wrapping(1),
512            _ => div,
513        };
514        if new_value >= target_min && new_value <= target_max {
515            Some(new_value)
516        } else {
517            None
518        }
519    } else {
520        None
521    }
522}