datafusion_comet_spark_expr/agg_funcs/
avg_decimal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    builder::PrimitiveBuilder,
20    cast::AsArray,
21    types::{Decimal128Type, Int64Type},
22    Array, ArrayRef, Decimal128Array, Int64Array, PrimitiveArray,
23};
24use arrow::datatypes::{DataType, Field, FieldRef};
25use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer, compute::sum};
26use datafusion::common::{not_impl_err, Result, ScalarValue};
27use datafusion::logical_expr::{
28    Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature,
29};
30use datafusion::physical_expr::expressions::format_state_name;
31use std::{any::Any, sync::Arc};
32
33use crate::utils::is_valid_decimal_precision;
34use arrow::array::ArrowNativeTypeOp;
35use arrow::datatypes::{MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION};
36use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
37use datafusion::logical_expr::type_coercion::aggregates::avg_return_type;
38use datafusion::logical_expr::Volatility::Immutable;
39use num::{integer::div_ceil, Integer};
40use DataType::*;
41
42/// AVG aggregate expression
43#[derive(Debug, Clone)]
44pub struct AvgDecimal {
45    signature: Signature,
46    sum_data_type: DataType,
47    result_data_type: DataType,
48}
49
50impl AvgDecimal {
51    /// Create a new AVG aggregate function
52    pub fn new(result_type: DataType, sum_type: DataType) -> Self {
53        Self {
54            signature: Signature::user_defined(Immutable),
55            result_data_type: result_type,
56            sum_data_type: sum_type,
57        }
58    }
59}
60
61impl AggregateUDFImpl for AvgDecimal {
62    /// Return a reference to Any that can be used for downcasting
63    fn as_any(&self) -> &dyn Any {
64        self
65    }
66
67    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
68        match (&self.sum_data_type, &self.result_data_type) {
69            (Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale)) => {
70                Ok(Box::new(AvgDecimalAccumulator::new(
71                    *sum_scale,
72                    *sum_precision,
73                    *target_precision,
74                    *target_scale,
75                )))
76            }
77            _ => not_impl_err!(
78                "AvgDecimalAccumulator for ({} --> {})",
79                self.sum_data_type,
80                self.result_data_type
81            ),
82        }
83    }
84
85    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
86        Ok(vec![
87            Arc::new(Field::new(
88                format_state_name(self.name(), "sum"),
89                self.sum_data_type.clone(),
90                true,
91            )),
92            Arc::new(Field::new(
93                format_state_name(self.name(), "count"),
94                DataType::Int64,
95                true,
96            )),
97        ])
98    }
99
100    fn name(&self) -> &str {
101        "avg"
102    }
103
104    fn reverse_expr(&self) -> ReversedUDAF {
105        ReversedUDAF::Identical
106    }
107
108    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
109        true
110    }
111
112    fn create_groups_accumulator(
113        &self,
114        _args: AccumulatorArgs,
115    ) -> Result<Box<dyn GroupsAccumulator>> {
116        // instantiate specialized accumulator based for the type
117        match (&self.sum_data_type, &self.result_data_type) {
118            (Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale)) => {
119                Ok(Box::new(AvgDecimalGroupsAccumulator::new(
120                    &self.result_data_type,
121                    &self.sum_data_type,
122                    *target_precision,
123                    *target_scale,
124                    *sum_precision,
125                    *sum_scale,
126                )))
127            }
128            _ => not_impl_err!(
129                "AvgDecimalGroupsAccumulator for ({} --> {})",
130                self.sum_data_type,
131                self.result_data_type
132            ),
133        }
134    }
135
136    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
137        match &self.result_data_type {
138            Decimal128(target_precision, target_scale) => {
139                Ok(make_decimal128(None, *target_precision, *target_scale))
140            }
141            _ => not_impl_err!(
142                "The result_data_type of AvgDecimal should be Decimal128 but got{}",
143                self.result_data_type
144            ),
145        }
146    }
147
148    fn signature(&self) -> &Signature {
149        &self.signature
150    }
151
152    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
153        avg_return_type(self.name(), &arg_types[0])
154    }
155}
156
157/// An accumulator to compute the average for decimals
158#[derive(Debug)]
159struct AvgDecimalAccumulator {
160    sum: Option<i128>,
161    count: i64,
162    is_empty: bool,
163    is_not_null: bool,
164    sum_scale: i8,
165    sum_precision: u8,
166    target_precision: u8,
167    target_scale: i8,
168}
169
170impl AvgDecimalAccumulator {
171    pub fn new(sum_scale: i8, sum_precision: u8, target_precision: u8, target_scale: i8) -> Self {
172        Self {
173            sum: None,
174            count: 0,
175            is_empty: true,
176            is_not_null: true,
177            sum_scale,
178            sum_precision,
179            target_precision,
180            target_scale,
181        }
182    }
183
184    fn update_single(&mut self, values: &Decimal128Array, idx: usize) {
185        let v = unsafe { values.value_unchecked(idx) };
186        let (new_sum, is_overflow) = match self.sum {
187            Some(sum) => sum.overflowing_add(v),
188            None => (v, false),
189        };
190
191        if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) {
192            // Overflow: set buffer accumulator to null
193            self.is_not_null = false;
194            return;
195        }
196
197        self.sum = Some(new_sum);
198
199        if let Some(new_count) = self.count.checked_add(1) {
200            self.count = new_count;
201        } else {
202            self.is_not_null = false;
203            return;
204        }
205
206        self.is_not_null = true;
207    }
208}
209
210fn make_decimal128(value: Option<i128>, precision: u8, scale: i8) -> ScalarValue {
211    ScalarValue::Decimal128(value, precision, scale)
212}
213
214impl Accumulator for AvgDecimalAccumulator {
215    fn state(&mut self) -> Result<Vec<ScalarValue>> {
216        Ok(vec![
217            ScalarValue::Decimal128(self.sum, self.sum_precision, self.sum_scale),
218            ScalarValue::from(self.count),
219        ])
220    }
221
222    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
223        if !self.is_empty && !self.is_not_null {
224            // This means there's a overflow in decimal, so we will just skip the rest
225            // of the computation
226            return Ok(());
227        }
228
229        let values = &values[0];
230        let data = values.as_primitive::<Decimal128Type>();
231
232        self.is_empty = self.is_empty && values.len() == values.null_count();
233
234        if values.null_count() == 0 {
235            for i in 0..data.len() {
236                self.update_single(data, i);
237            }
238        } else {
239            for i in 0..data.len() {
240                if data.is_null(i) {
241                    continue;
242                }
243                self.update_single(data, i);
244            }
245        }
246        Ok(())
247    }
248
249    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
250        // counts are summed
251        self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
252
253        // sums are summed
254        if let Some(x) = sum(states[0].as_primitive::<Decimal128Type>()) {
255            let v = self.sum.get_or_insert(0);
256            let (result, overflowed) = v.overflowing_add(x);
257            if overflowed {
258                // Set to None if overflow happens
259                self.sum = None;
260            } else {
261                *v = result;
262            }
263        }
264        Ok(())
265    }
266
267    fn evaluate(&mut self) -> Result<ScalarValue> {
268        let scaler = 10_i128.pow(self.target_scale.saturating_sub(self.sum_scale) as u32);
269        let target_min = MIN_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
270        let target_max = MAX_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
271
272        let result = self
273            .sum
274            .map(|v| avg(v, self.count as i128, target_min, target_max, scaler));
275
276        match result {
277            Some(value) => Ok(make_decimal128(
278                value,
279                self.target_precision,
280                self.target_scale,
281            )),
282            _ => Ok(make_decimal128(
283                None,
284                self.target_precision,
285                self.target_scale,
286            )),
287        }
288    }
289
290    fn size(&self) -> usize {
291        std::mem::size_of_val(self)
292    }
293}
294
295#[derive(Debug)]
296struct AvgDecimalGroupsAccumulator {
297    /// Tracks if the value is null
298    is_not_null: BooleanBufferBuilder,
299
300    // Tracks if the value is empty
301    is_empty: BooleanBufferBuilder,
302
303    /// The type of the avg return type
304    return_data_type: DataType,
305    target_precision: u8,
306    target_scale: i8,
307
308    /// Count per group (use i64 to make Int64Array)
309    counts: Vec<i64>,
310
311    /// Sums per group, stored as i128
312    sums: Vec<i128>,
313
314    /// The type of the sum
315    sum_data_type: DataType,
316    /// This is input_precision + 10 to be consistent with Spark
317    sum_precision: u8,
318    sum_scale: i8,
319}
320
321impl AvgDecimalGroupsAccumulator {
322    pub fn new(
323        return_data_type: &DataType,
324        sum_data_type: &DataType,
325        target_precision: u8,
326        target_scale: i8,
327        sum_precision: u8,
328        sum_scale: i8,
329    ) -> Self {
330        Self {
331            is_not_null: BooleanBufferBuilder::new(0),
332            is_empty: BooleanBufferBuilder::new(0),
333            return_data_type: return_data_type.clone(),
334            target_precision,
335            target_scale,
336            sum_data_type: sum_data_type.clone(),
337            sum_precision,
338            sum_scale,
339            counts: vec![],
340            sums: vec![],
341        }
342    }
343
344    fn is_overflow(&self, index: usize) -> bool {
345        !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index)
346    }
347
348    fn update_single(&mut self, group_index: usize, value: i128) {
349        if self.is_overflow(group_index) {
350            // This means there's a overflow in decimal, so we will just skip the rest
351            // of the computation
352            return;
353        }
354
355        self.is_empty.set_bit(group_index, false);
356        let (new_sum, is_overflow) = self.sums[group_index].overflowing_add(value);
357        self.counts[group_index] += 1;
358
359        if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) {
360            // Overflow: set buffer accumulator to null
361            self.is_not_null.set_bit(group_index, false);
362            return;
363        }
364
365        self.sums[group_index] = new_sum;
366        self.is_not_null.set_bit(group_index, true)
367    }
368}
369
370fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) {
371    if builder.len() < capacity {
372        let additional = capacity - builder.len();
373        builder.append_n(additional, true);
374    }
375}
376
377impl GroupsAccumulator for AvgDecimalGroupsAccumulator {
378    fn update_batch(
379        &mut self,
380        values: &[ArrayRef],
381        group_indices: &[usize],
382        _opt_filter: Option<&arrow::array::BooleanArray>,
383        total_num_groups: usize,
384    ) -> Result<()> {
385        assert_eq!(values.len(), 1, "single argument to update_batch");
386        let values = values[0].as_primitive::<Decimal128Type>();
387        let data = values.values();
388
389        // increment counts, update sums
390        self.counts.resize(total_num_groups, 0);
391        self.sums.resize(total_num_groups, 0);
392        ensure_bit_capacity(&mut self.is_empty, total_num_groups);
393        ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
394
395        let iter = group_indices.iter().zip(data.iter());
396        if values.null_count() == 0 {
397            for (&group_index, &value) in iter {
398                self.update_single(group_index, value);
399            }
400        } else {
401            for (idx, (&group_index, &value)) in iter.enumerate() {
402                if values.is_null(idx) {
403                    continue;
404                }
405                self.update_single(group_index, value);
406            }
407        }
408        Ok(())
409    }
410
411    fn merge_batch(
412        &mut self,
413        values: &[ArrayRef],
414        group_indices: &[usize],
415        _opt_filter: Option<&arrow::array::BooleanArray>,
416        total_num_groups: usize,
417    ) -> Result<()> {
418        assert_eq!(values.len(), 2, "two arguments to merge_batch");
419        // first batch is partial sums, second is counts
420        let partial_sums = values[0].as_primitive::<Decimal128Type>();
421        let partial_counts = values[1].as_primitive::<Int64Type>();
422        // update counts with partial counts
423        self.counts.resize(total_num_groups, 0);
424        let iter1 = group_indices.iter().zip(partial_counts.values().iter());
425        for (&group_index, &partial_count) in iter1 {
426            self.counts[group_index] += partial_count;
427        }
428
429        // update sums
430        self.sums.resize(total_num_groups, 0);
431        let iter2 = group_indices.iter().zip(partial_sums.values().iter());
432        for (&group_index, &new_value) in iter2 {
433            let sum = &mut self.sums[group_index];
434            *sum = sum.add_wrapping(new_value);
435        }
436
437        Ok(())
438    }
439
440    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
441        let counts = emit_to.take_needed(&mut self.counts);
442        let sums = emit_to.take_needed(&mut self.sums);
443
444        let mut builder = PrimitiveBuilder::<Decimal128Type>::with_capacity(sums.len())
445            .with_data_type(self.return_data_type.clone());
446        let iter = sums.into_iter().zip(counts);
447
448        let scaler = 10_i128.pow(self.target_scale.saturating_sub(self.sum_scale) as u32);
449        let target_min = MIN_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
450        let target_max = MAX_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
451
452        for (sum, count) in iter {
453            if count != 0 {
454                match avg(sum, count as i128, target_min, target_max, scaler) {
455                    Some(value) => {
456                        builder.append_value(value);
457                    }
458                    _ => {
459                        builder.append_null();
460                    }
461                }
462            } else {
463                builder.append_null();
464            }
465        }
466        let array: PrimitiveArray<Decimal128Type> = builder.finish();
467
468        Ok(Arc::new(array))
469    }
470
471    // return arrays for sums and counts
472    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
473        let nulls = self.is_not_null.finish();
474        let nulls = Some(NullBuffer::new(nulls));
475
476        let counts = emit_to.take_needed(&mut self.counts);
477        let counts = Int64Array::new(counts.into(), nulls.clone());
478
479        let sums = emit_to.take_needed(&mut self.sums);
480        let sums =
481            Decimal128Array::new(sums.into(), nulls).with_data_type(self.sum_data_type.clone());
482
483        Ok(vec![
484            Arc::new(sums) as ArrayRef,
485            Arc::new(counts) as ArrayRef,
486        ])
487    }
488
489    fn size(&self) -> usize {
490        self.counts.capacity() * std::mem::size_of::<i64>()
491            + self.sums.capacity() * std::mem::size_of::<i128>()
492    }
493}
494
495/// Returns the `sum`/`count` as a i128 Decimal128 with
496/// target_scale and target_precision and return None if overflows.
497///
498/// * sum: The total sum value stored as Decimal128 with sum_scale
499/// * count: total count, stored as a i128 (*NOT* a Decimal128 value)
500/// * target_min: The minimum output value possible to represent with the target precision
501/// * target_max: The maximum output value possible to represent with the target precision
502/// * scaler: scale factor for avg
503#[inline(always)]
504fn avg(sum: i128, count: i128, target_min: i128, target_max: i128, scaler: i128) -> Option<i128> {
505    if let Some(value) = sum.checked_mul(scaler) {
506        // `sum / count` with ROUND_HALF_UP
507        let (div, rem) = value.div_rem(&count);
508        let half = div_ceil(count, 2);
509        let half_neg = half.neg_wrapping();
510        let new_value = match value >= 0 {
511            true if rem >= half => div.add_wrapping(1),
512            false if rem <= half_neg => div.sub_wrapping(1),
513            _ => div,
514        };
515        if new_value >= target_min && new_value <= target_max {
516            Some(new_value)
517        } else {
518            None
519        }
520    } else {
521        None
522    }
523}