datafusion_comet_spark_expr/agg_funcs/
sum_decimal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::utils::{is_valid_decimal_precision, unlikely};
19use arrow::array::{
20    cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array,
21};
22use arrow::datatypes::{DataType, Field, FieldRef};
23use arrow::{
24    array::BooleanBufferBuilder,
25    buffer::{BooleanBuffer, NullBuffer},
26};
27use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue};
28use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
29use datafusion::logical_expr::Volatility::Immutable;
30use datafusion::logical_expr::{
31    Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature,
32};
33use std::{any::Any, ops::BitAnd, sync::Arc};
34
35#[derive(Debug)]
36pub struct SumDecimal {
37    /// Aggregate function signature
38    signature: Signature,
39    /// The data type of the SUM result. This will always be a decimal type
40    /// with the same precision and scale as specified in this struct
41    result_type: DataType,
42    /// Decimal precision
43    precision: u8,
44    /// Decimal scale
45    scale: i8,
46}
47
48impl SumDecimal {
49    pub fn try_new(data_type: DataType) -> DFResult<Self> {
50        // The `data_type` is the SUM result type passed from Spark side
51        let (precision, scale) = match data_type {
52            DataType::Decimal128(p, s) => (p, s),
53            _ => {
54                return Err(DataFusionError::Internal(
55                    "Invalid data type for SumDecimal".into(),
56                ))
57            }
58        };
59        Ok(Self {
60            signature: Signature::user_defined(Immutable),
61            result_type: data_type,
62            precision,
63            scale,
64        })
65    }
66}
67
68impl AggregateUDFImpl for SumDecimal {
69    fn as_any(&self) -> &dyn Any {
70        self
71    }
72
73    fn accumulator(&self, _args: AccumulatorArgs) -> DFResult<Box<dyn Accumulator>> {
74        Ok(Box::new(SumDecimalAccumulator::new(
75            self.precision,
76            self.scale,
77        )))
78    }
79
80    fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<FieldRef>> {
81        let fields = vec![
82            Arc::new(Field::new(
83                self.name(),
84                self.result_type.clone(),
85                self.is_nullable(),
86            )),
87            Arc::new(Field::new("is_empty", DataType::Boolean, false)),
88        ];
89        Ok(fields)
90    }
91
92    fn name(&self) -> &str {
93        "sum"
94    }
95
96    fn signature(&self) -> &Signature {
97        &self.signature
98    }
99
100    fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
101        Ok(self.result_type.clone())
102    }
103
104    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
105        true
106    }
107
108    fn create_groups_accumulator(
109        &self,
110        _args: AccumulatorArgs,
111    ) -> DFResult<Box<dyn GroupsAccumulator>> {
112        Ok(Box::new(SumDecimalGroupsAccumulator::new(
113            self.result_type.clone(),
114            self.precision,
115        )))
116    }
117
118    fn default_value(&self, _data_type: &DataType) -> DFResult<ScalarValue> {
119        ScalarValue::new_primitive::<Decimal128Type>(
120            None,
121            &DataType::Decimal128(self.precision, self.scale),
122        )
123    }
124
125    fn reverse_expr(&self) -> ReversedUDAF {
126        ReversedUDAF::Identical
127    }
128
129    fn is_nullable(&self) -> bool {
130        // SumDecimal is always nullable because overflows can cause null values
131        true
132    }
133}
134
135#[derive(Debug)]
136struct SumDecimalAccumulator {
137    sum: i128,
138    is_empty: bool,
139    is_not_null: bool,
140
141    precision: u8,
142    scale: i8,
143}
144
145impl SumDecimalAccumulator {
146    fn new(precision: u8, scale: i8) -> Self {
147        Self {
148            sum: 0,
149            is_empty: true,
150            is_not_null: true,
151            precision,
152            scale,
153        }
154    }
155
156    fn update_single(&mut self, values: &Decimal128Array, idx: usize) {
157        let v = unsafe { values.value_unchecked(idx) };
158        let (new_sum, is_overflow) = self.sum.overflowing_add(v);
159
160        if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
161            // Overflow: set buffer accumulator to null
162            self.is_not_null = false;
163            return;
164        }
165
166        self.sum = new_sum;
167        self.is_not_null = true;
168    }
169}
170
171impl Accumulator for SumDecimalAccumulator {
172    fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
173        assert_eq!(
174            values.len(),
175            1,
176            "Expect only one element in 'values' but found {}",
177            values.len()
178        );
179
180        if !self.is_empty && !self.is_not_null {
181            // This means there's a overflow in decimal, so we will just skip the rest
182            // of the computation
183            return Ok(());
184        }
185
186        let values = &values[0];
187        let data = values.as_primitive::<Decimal128Type>();
188
189        self.is_empty = self.is_empty && values.len() == values.null_count();
190
191        if values.null_count() == 0 {
192            for i in 0..data.len() {
193                self.update_single(data, i);
194            }
195        } else {
196            for i in 0..data.len() {
197                if data.is_null(i) {
198                    continue;
199                }
200                self.update_single(data, i);
201            }
202        }
203
204        Ok(())
205    }
206
207    fn evaluate(&mut self) -> DFResult<ScalarValue> {
208        // For each group:
209        //   1. if `is_empty` is true, it means either there is no value or all values for the group
210        //      are null, in this case we'll return null
211        //   2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In
212        //      non-ANSI mode Spark returns null.
213        if self.is_empty
214            || !self.is_not_null
215            || !is_valid_decimal_precision(self.sum, self.precision)
216        {
217            ScalarValue::new_primitive::<Decimal128Type>(
218                None,
219                &DataType::Decimal128(self.precision, self.scale),
220            )
221        } else {
222            ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)
223        }
224    }
225
226    fn size(&self) -> usize {
227        std::mem::size_of_val(self)
228    }
229
230    fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
231        let sum = if self.is_not_null {
232            ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)?
233        } else {
234            ScalarValue::new_primitive::<Decimal128Type>(
235                None,
236                &DataType::Decimal128(self.precision, self.scale),
237            )?
238        };
239        Ok(vec![sum, ScalarValue::from(self.is_empty)])
240    }
241
242    fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
243        assert_eq!(
244            states.len(),
245            2,
246            "Expect two element in 'states' but found {}",
247            states.len()
248        );
249        assert_eq!(states[0].len(), 1);
250        assert_eq!(states[1].len(), 1);
251
252        let that_sum = states[0].as_primitive::<Decimal128Type>();
253        let that_is_empty = states[1].as_any().downcast_ref::<BooleanArray>().unwrap();
254
255        let this_overflow = !self.is_empty && !self.is_not_null;
256        let that_overflow = !that_is_empty.value(0) && that_sum.is_null(0);
257
258        self.is_not_null = !this_overflow && !that_overflow;
259        self.is_empty = self.is_empty && that_is_empty.value(0);
260
261        if self.is_not_null {
262            self.sum += that_sum.value(0);
263        }
264
265        Ok(())
266    }
267}
268
269struct SumDecimalGroupsAccumulator {
270    // Whether aggregate buffer for a particular group is null. True indicates it is not null.
271    is_not_null: BooleanBufferBuilder,
272    is_empty: BooleanBufferBuilder,
273    sum: Vec<i128>,
274    result_type: DataType,
275    precision: u8,
276}
277
278impl SumDecimalGroupsAccumulator {
279    fn new(result_type: DataType, precision: u8) -> Self {
280        Self {
281            is_not_null: BooleanBufferBuilder::new(0),
282            is_empty: BooleanBufferBuilder::new(0),
283            sum: Vec::new(),
284            result_type,
285            precision,
286        }
287    }
288
289    fn is_overflow(&self, index: usize) -> bool {
290        !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index)
291    }
292
293    fn update_single(&mut self, group_index: usize, value: i128) {
294        if unlikely(self.is_overflow(group_index)) {
295            // This means there's a overflow in decimal, so we will just skip the rest
296            // of the computation
297            return;
298        }
299
300        self.is_empty.set_bit(group_index, false);
301        let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value);
302
303        if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
304            // Overflow: set buffer accumulator to null
305            self.is_not_null.set_bit(group_index, false);
306            return;
307        }
308
309        self.sum[group_index] = new_sum;
310        self.is_not_null.set_bit(group_index, true)
311    }
312}
313
314fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) {
315    if builder.len() < capacity {
316        let additional = capacity - builder.len();
317        builder.append_n(additional, true);
318    }
319}
320
321/// Build a boolean buffer from the state and reset the state, based on the emit_to
322/// strategy.
323fn build_bool_state(state: &mut BooleanBufferBuilder, emit_to: &EmitTo) -> BooleanBuffer {
324    let bool_state: BooleanBuffer = state.finish();
325
326    match emit_to {
327        EmitTo::All => bool_state,
328        EmitTo::First(n) => {
329            // split off the first N values in bool_state
330            let first_n_bools: BooleanBuffer = bool_state.iter().take(*n).collect();
331            // reset the existing seen buffer
332            for seen in bool_state.iter().skip(*n) {
333                state.append(seen);
334            }
335            first_n_bools
336        }
337    }
338}
339
340impl GroupsAccumulator for SumDecimalGroupsAccumulator {
341    fn update_batch(
342        &mut self,
343        values: &[ArrayRef],
344        group_indices: &[usize],
345        opt_filter: Option<&BooleanArray>,
346        total_num_groups: usize,
347    ) -> DFResult<()> {
348        assert!(opt_filter.is_none(), "opt_filter is not supported yet");
349        assert_eq!(values.len(), 1);
350        let values = values[0].as_primitive::<Decimal128Type>();
351        let data = values.values();
352
353        // Update size for the accumulate states
354        self.sum.resize(total_num_groups, 0);
355        ensure_bit_capacity(&mut self.is_empty, total_num_groups);
356        ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
357
358        let iter = group_indices.iter().zip(data.iter());
359        if values.null_count() == 0 {
360            for (&group_index, &value) in iter {
361                self.update_single(group_index, value);
362            }
363        } else {
364            for (idx, (&group_index, &value)) in iter.enumerate() {
365                if values.is_null(idx) {
366                    continue;
367                }
368                self.update_single(group_index, value);
369            }
370        }
371
372        Ok(())
373    }
374
375    fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
376        // For each group:
377        //   1. if `is_empty` is true, it means either there is no value or all values for the group
378        //      are null, in this case we'll return null
379        //   2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In
380        //      non-ANSI mode Spark returns null.
381        let result = emit_to.take_needed(&mut self.sum);
382        result.iter().enumerate().for_each(|(i, &v)| {
383            if !is_valid_decimal_precision(v, self.precision) {
384                self.is_not_null.set_bit(i, false);
385            }
386        });
387
388        let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
389        let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
390        let x = (!&is_empty).bitand(&nulls);
391
392        let result = Decimal128Array::new(result.into(), Some(NullBuffer::new(x)))
393            .with_data_type(self.result_type.clone());
394
395        Ok(Arc::new(result))
396    }
397
398    fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
399        let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
400        let nulls = Some(NullBuffer::new(nulls));
401
402        let sum = emit_to.take_needed(&mut self.sum);
403        let sum = Decimal128Array::new(sum.into(), nulls.clone())
404            .with_data_type(self.result_type.clone());
405
406        let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
407        let is_empty = BooleanArray::new(is_empty, None);
408
409        Ok(vec![
410            Arc::new(sum) as ArrayRef,
411            Arc::new(is_empty) as ArrayRef,
412        ])
413    }
414
415    fn merge_batch(
416        &mut self,
417        values: &[ArrayRef],
418        group_indices: &[usize],
419        opt_filter: Option<&BooleanArray>,
420        total_num_groups: usize,
421    ) -> DFResult<()> {
422        assert_eq!(
423            values.len(),
424            2,
425            "Expected two arrays: 'sum' and 'is_empty', but found {}",
426            values.len()
427        );
428        assert!(opt_filter.is_none(), "opt_filter is not supported yet");
429
430        // Make sure we have enough capacity for the additional groups
431        self.sum.resize(total_num_groups, 0);
432        ensure_bit_capacity(&mut self.is_empty, total_num_groups);
433        ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
434
435        let that_sum = &values[0];
436        let that_sum = that_sum.as_primitive::<Decimal128Type>();
437        let that_is_empty = &values[1];
438        let that_is_empty = that_is_empty
439            .as_any()
440            .downcast_ref::<BooleanArray>()
441            .unwrap();
442
443        group_indices
444            .iter()
445            .enumerate()
446            .for_each(|(idx, &group_index)| unsafe {
447                let this_overflow = self.is_overflow(group_index);
448                let that_is_empty = that_is_empty.value_unchecked(idx);
449                let that_overflow = !that_is_empty && that_sum.is_null(idx);
450                let is_overflow = this_overflow || that_overflow;
451
452                // This part follows the logic in Spark:
453                //   `org.apache.spark.sql.catalyst.expressions.aggregate.Sum`
454                self.is_not_null.set_bit(group_index, !is_overflow);
455                self.is_empty.set_bit(
456                    group_index,
457                    self.is_empty.get_bit(group_index) && that_is_empty,
458                );
459                if !is_overflow {
460                    // .. otherwise, the sum value for this particular index must not be null,
461                    // and thus we merge both values and update this sum.
462                    self.sum[group_index] += that_sum.value_unchecked(idx);
463                }
464            });
465
466        Ok(())
467    }
468
469    fn size(&self) -> usize {
470        self.sum.capacity() * std::mem::size_of::<i128>()
471            + self.is_empty.capacity() / 8
472            + self.is_not_null.capacity() / 8
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479    use arrow::array::builder::{Decimal128Builder, StringBuilder};
480    use arrow::array::RecordBatch;
481    use arrow::datatypes::*;
482    use datafusion::common::Result;
483    use datafusion::datasource::memory::MemorySourceConfig;
484    use datafusion::datasource::source::DataSourceExec;
485    use datafusion::execution::TaskContext;
486    use datafusion::logical_expr::AggregateUDF;
487    use datafusion::physical_expr::aggregate::AggregateExprBuilder;
488    use datafusion::physical_expr::expressions::Column;
489    use datafusion::physical_expr::PhysicalExpr;
490    use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
491    use datafusion::physical_plan::ExecutionPlan;
492    use futures::StreamExt;
493
494    #[test]
495    fn invalid_data_type() {
496        assert!(SumDecimal::try_new(DataType::Int32).is_err());
497    }
498
499    #[tokio::test]
500    async fn sum_no_overflow() -> Result<()> {
501        let num_rows = 8192;
502        let batch = create_record_batch(num_rows);
503        let mut batches = Vec::new();
504        for _ in 0..10 {
505            batches.push(batch.clone());
506        }
507        let partitions = &[batches];
508        let c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
509        let c1: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c1", 1));
510
511        let data_type = DataType::Decimal128(8, 2);
512        let schema = Arc::clone(&partitions[0][0].schema());
513        let scan: Arc<dyn ExecutionPlan> = Arc::new(DataSourceExec::new(Arc::new(
514            MemorySourceConfig::try_new(partitions, Arc::clone(&schema), None).unwrap(),
515        )));
516
517        let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new(
518            data_type.clone(),
519        )?));
520
521        let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1])
522            .schema(Arc::clone(&schema))
523            .alias("sum")
524            .with_ignore_nulls(false)
525            .with_distinct(false)
526            .build()?;
527
528        let aggregate = Arc::new(AggregateExec::try_new(
529            AggregateMode::Partial,
530            PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]),
531            vec![aggr_expr.into()],
532            vec![None], // no filter expressions
533            scan,
534            Arc::clone(&schema),
535        )?);
536
537        let mut stream = aggregate
538            .execute(0, Arc::new(TaskContext::default()))
539            .unwrap();
540        while let Some(batch) = stream.next().await {
541            let _batch = batch?;
542        }
543
544        Ok(())
545    }
546
547    fn create_record_batch(num_rows: usize) -> RecordBatch {
548        let mut decimal_builder = Decimal128Builder::with_capacity(num_rows);
549        let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
550        for i in 0..num_rows {
551            decimal_builder.append_value(i as i128);
552            string_builder.append_value(format!("this is string #{}", i % 1024));
553        }
554        let decimal_array = Arc::new(decimal_builder.finish());
555        let string_array = Arc::new(string_builder.finish());
556
557        let mut fields = vec![];
558        let mut columns: Vec<ArrayRef> = vec![];
559
560        // string column
561        fields.push(Field::new("c0", DataType::Utf8, false));
562        columns.push(string_array);
563
564        // decimal column
565        fields.push(Field::new("c1", DataType::Decimal128(38, 10), false));
566        columns.push(decimal_array);
567
568        let schema = Schema::new(fields);
569        RecordBatch::try_new(Arc::new(schema), columns).unwrap()
570    }
571}