datafusion_comet_spark_expr/agg_funcs/
sum_decimal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::utils::{is_valid_decimal_precision, unlikely};
19use arrow::{
20    array::BooleanBufferBuilder,
21    buffer::{BooleanBuffer, NullBuffer},
22};
23use arrow_array::{
24    cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array,
25};
26use arrow_schema::{DataType, Field};
27use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator};
28use datafusion_common::{DataFusionError, Result as DFResult, ScalarValue};
29use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
30use datafusion_expr::Volatility::Immutable;
31use datafusion_expr::{AggregateUDFImpl, ReversedUDAF, Signature};
32use std::{any::Any, ops::BitAnd, sync::Arc};
33
34#[derive(Debug)]
35pub struct SumDecimal {
36    /// Aggregate function signature
37    signature: Signature,
38    /// The data type of the SUM result. This will always be a decimal type
39    /// with the same precision and scale as specified in this struct
40    result_type: DataType,
41    /// Decimal precision
42    precision: u8,
43    /// Decimal scale
44    scale: i8,
45}
46
47impl SumDecimal {
48    pub fn try_new(data_type: DataType) -> DFResult<Self> {
49        // The `data_type` is the SUM result type passed from Spark side
50        let (precision, scale) = match data_type {
51            DataType::Decimal128(p, s) => (p, s),
52            _ => {
53                return Err(DataFusionError::Internal(
54                    "Invalid data type for SumDecimal".into(),
55                ))
56            }
57        };
58        Ok(Self {
59            signature: Signature::user_defined(Immutable),
60            result_type: data_type,
61            precision,
62            scale,
63        })
64    }
65}
66
67impl AggregateUDFImpl for SumDecimal {
68    fn as_any(&self) -> &dyn Any {
69        self
70    }
71
72    fn accumulator(&self, _args: AccumulatorArgs) -> DFResult<Box<dyn Accumulator>> {
73        Ok(Box::new(SumDecimalAccumulator::new(
74            self.precision,
75            self.scale,
76        )))
77    }
78
79    fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<Field>> {
80        let fields = vec![
81            Field::new(self.name(), self.result_type.clone(), self.is_nullable()),
82            Field::new("is_empty", DataType::Boolean, false),
83        ];
84        Ok(fields)
85    }
86
87    fn name(&self) -> &str {
88        "sum"
89    }
90
91    fn signature(&self) -> &Signature {
92        &self.signature
93    }
94
95    fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
96        Ok(self.result_type.clone())
97    }
98
99    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
100        true
101    }
102
103    fn create_groups_accumulator(
104        &self,
105        _args: AccumulatorArgs,
106    ) -> DFResult<Box<dyn GroupsAccumulator>> {
107        Ok(Box::new(SumDecimalGroupsAccumulator::new(
108            self.result_type.clone(),
109            self.precision,
110        )))
111    }
112
113    fn default_value(&self, _data_type: &DataType) -> DFResult<ScalarValue> {
114        ScalarValue::new_primitive::<Decimal128Type>(
115            None,
116            &DataType::Decimal128(self.precision, self.scale),
117        )
118    }
119
120    fn reverse_expr(&self) -> ReversedUDAF {
121        ReversedUDAF::Identical
122    }
123
124    fn is_nullable(&self) -> bool {
125        // SumDecimal is always nullable because overflows can cause null values
126        true
127    }
128}
129
130#[derive(Debug)]
131struct SumDecimalAccumulator {
132    sum: i128,
133    is_empty: bool,
134    is_not_null: bool,
135
136    precision: u8,
137    scale: i8,
138}
139
140impl SumDecimalAccumulator {
141    fn new(precision: u8, scale: i8) -> Self {
142        Self {
143            sum: 0,
144            is_empty: true,
145            is_not_null: true,
146            precision,
147            scale,
148        }
149    }
150
151    fn update_single(&mut self, values: &Decimal128Array, idx: usize) {
152        let v = unsafe { values.value_unchecked(idx) };
153        let (new_sum, is_overflow) = self.sum.overflowing_add(v);
154
155        if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
156            // Overflow: set buffer accumulator to null
157            self.is_not_null = false;
158            return;
159        }
160
161        self.sum = new_sum;
162        self.is_not_null = true;
163    }
164}
165
166impl Accumulator for SumDecimalAccumulator {
167    fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
168        assert_eq!(
169            values.len(),
170            1,
171            "Expect only one element in 'values' but found {}",
172            values.len()
173        );
174
175        if !self.is_empty && !self.is_not_null {
176            // This means there's a overflow in decimal, so we will just skip the rest
177            // of the computation
178            return Ok(());
179        }
180
181        let values = &values[0];
182        let data = values.as_primitive::<Decimal128Type>();
183
184        self.is_empty = self.is_empty && values.len() == values.null_count();
185
186        if values.null_count() == 0 {
187            for i in 0..data.len() {
188                self.update_single(data, i);
189            }
190        } else {
191            for i in 0..data.len() {
192                if data.is_null(i) {
193                    continue;
194                }
195                self.update_single(data, i);
196            }
197        }
198
199        Ok(())
200    }
201
202    fn evaluate(&mut self) -> DFResult<ScalarValue> {
203        // For each group:
204        //   1. if `is_empty` is true, it means either there is no value or all values for the group
205        //      are null, in this case we'll return null
206        //   2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In
207        //      non-ANSI mode Spark returns null.
208        if self.is_empty || !self.is_not_null {
209            ScalarValue::new_primitive::<Decimal128Type>(
210                None,
211                &DataType::Decimal128(self.precision, self.scale),
212            )
213        } else {
214            ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)
215        }
216    }
217
218    fn size(&self) -> usize {
219        std::mem::size_of_val(self)
220    }
221
222    fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
223        let sum = if self.is_not_null {
224            ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)?
225        } else {
226            ScalarValue::new_primitive::<Decimal128Type>(
227                None,
228                &DataType::Decimal128(self.precision, self.scale),
229            )?
230        };
231        Ok(vec![sum, ScalarValue::from(self.is_empty)])
232    }
233
234    fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
235        assert_eq!(
236            states.len(),
237            2,
238            "Expect two element in 'states' but found {}",
239            states.len()
240        );
241        assert_eq!(states[0].len(), 1);
242        assert_eq!(states[1].len(), 1);
243
244        let that_sum = states[0].as_primitive::<Decimal128Type>();
245        let that_is_empty = states[1].as_any().downcast_ref::<BooleanArray>().unwrap();
246
247        let this_overflow = !self.is_empty && !self.is_not_null;
248        let that_overflow = !that_is_empty.value(0) && that_sum.is_null(0);
249
250        self.is_not_null = !this_overflow && !that_overflow;
251        self.is_empty = self.is_empty && that_is_empty.value(0);
252
253        if self.is_not_null {
254            self.sum += that_sum.value(0);
255        }
256
257        Ok(())
258    }
259}
260
261struct SumDecimalGroupsAccumulator {
262    // Whether aggregate buffer for a particular group is null. True indicates it is not null.
263    is_not_null: BooleanBufferBuilder,
264    is_empty: BooleanBufferBuilder,
265    sum: Vec<i128>,
266    result_type: DataType,
267    precision: u8,
268}
269
270impl SumDecimalGroupsAccumulator {
271    fn new(result_type: DataType, precision: u8) -> Self {
272        Self {
273            is_not_null: BooleanBufferBuilder::new(0),
274            is_empty: BooleanBufferBuilder::new(0),
275            sum: Vec::new(),
276            result_type,
277            precision,
278        }
279    }
280
281    fn is_overflow(&self, index: usize) -> bool {
282        !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index)
283    }
284
285    fn update_single(&mut self, group_index: usize, value: i128) {
286        if unlikely(self.is_overflow(group_index)) {
287            // This means there's a overflow in decimal, so we will just skip the rest
288            // of the computation
289            return;
290        }
291
292        self.is_empty.set_bit(group_index, false);
293        let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value);
294
295        if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
296            // Overflow: set buffer accumulator to null
297            self.is_not_null.set_bit(group_index, false);
298            return;
299        }
300
301        self.sum[group_index] = new_sum;
302        self.is_not_null.set_bit(group_index, true)
303    }
304}
305
306fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) {
307    if builder.len() < capacity {
308        let additional = capacity - builder.len();
309        builder.append_n(additional, true);
310    }
311}
312
313/// Build a boolean buffer from the state and reset the state, based on the emit_to
314/// strategy.
315fn build_bool_state(state: &mut BooleanBufferBuilder, emit_to: &EmitTo) -> BooleanBuffer {
316    let bool_state: BooleanBuffer = state.finish();
317
318    match emit_to {
319        EmitTo::All => bool_state,
320        EmitTo::First(n) => {
321            // split off the first N values in bool_state
322            let first_n_bools: BooleanBuffer = bool_state.iter().take(*n).collect();
323            // reset the existing seen buffer
324            for seen in bool_state.iter().skip(*n) {
325                state.append(seen);
326            }
327            first_n_bools
328        }
329    }
330}
331
332impl GroupsAccumulator for SumDecimalGroupsAccumulator {
333    fn update_batch(
334        &mut self,
335        values: &[ArrayRef],
336        group_indices: &[usize],
337        opt_filter: Option<&BooleanArray>,
338        total_num_groups: usize,
339    ) -> DFResult<()> {
340        assert!(opt_filter.is_none(), "opt_filter is not supported yet");
341        assert_eq!(values.len(), 1);
342        let values = values[0].as_primitive::<Decimal128Type>();
343        let data = values.values();
344
345        // Update size for the accumulate states
346        self.sum.resize(total_num_groups, 0);
347        ensure_bit_capacity(&mut self.is_empty, total_num_groups);
348        ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
349
350        let iter = group_indices.iter().zip(data.iter());
351        if values.null_count() == 0 {
352            for (&group_index, &value) in iter {
353                self.update_single(group_index, value);
354            }
355        } else {
356            for (idx, (&group_index, &value)) in iter.enumerate() {
357                if values.is_null(idx) {
358                    continue;
359                }
360                self.update_single(group_index, value);
361            }
362        }
363
364        Ok(())
365    }
366
367    fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
368        // For each group:
369        //   1. if `is_empty` is true, it means either there is no value or all values for the group
370        //      are null, in this case we'll return null
371        //   2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In
372        //      non-ANSI mode Spark returns null.
373        let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
374        let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
375        let x = (!&is_empty).bitand(&nulls);
376
377        let result = emit_to.take_needed(&mut self.sum);
378        let result = Decimal128Array::new(result.into(), Some(NullBuffer::new(x)))
379            .with_data_type(self.result_type.clone());
380
381        Ok(Arc::new(result))
382    }
383
384    fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
385        let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
386        let nulls = Some(NullBuffer::new(nulls));
387
388        let sum = emit_to.take_needed(&mut self.sum);
389        let sum = Decimal128Array::new(sum.into(), nulls.clone())
390            .with_data_type(self.result_type.clone());
391
392        let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
393        let is_empty = BooleanArray::new(is_empty, None);
394
395        Ok(vec![
396            Arc::new(sum) as ArrayRef,
397            Arc::new(is_empty) as ArrayRef,
398        ])
399    }
400
401    fn merge_batch(
402        &mut self,
403        values: &[ArrayRef],
404        group_indices: &[usize],
405        opt_filter: Option<&BooleanArray>,
406        total_num_groups: usize,
407    ) -> DFResult<()> {
408        assert_eq!(
409            values.len(),
410            2,
411            "Expected two arrays: 'sum' and 'is_empty', but found {}",
412            values.len()
413        );
414        assert!(opt_filter.is_none(), "opt_filter is not supported yet");
415
416        // Make sure we have enough capacity for the additional groups
417        self.sum.resize(total_num_groups, 0);
418        ensure_bit_capacity(&mut self.is_empty, total_num_groups);
419        ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
420
421        let that_sum = &values[0];
422        let that_sum = that_sum.as_primitive::<Decimal128Type>();
423        let that_is_empty = &values[1];
424        let that_is_empty = that_is_empty
425            .as_any()
426            .downcast_ref::<BooleanArray>()
427            .unwrap();
428
429        group_indices
430            .iter()
431            .enumerate()
432            .for_each(|(idx, &group_index)| unsafe {
433                let this_overflow = self.is_overflow(group_index);
434                let that_is_empty = that_is_empty.value_unchecked(idx);
435                let that_overflow = !that_is_empty && that_sum.is_null(idx);
436                let is_overflow = this_overflow || that_overflow;
437
438                // This part follows the logic in Spark:
439                //   `org.apache.spark.sql.catalyst.expressions.aggregate.Sum`
440                self.is_not_null.set_bit(group_index, !is_overflow);
441                self.is_empty.set_bit(
442                    group_index,
443                    self.is_empty.get_bit(group_index) && that_is_empty,
444                );
445                if !is_overflow {
446                    // .. otherwise, the sum value for this particular index must not be null,
447                    // and thus we merge both values and update this sum.
448                    self.sum[group_index] += that_sum.value_unchecked(idx);
449                }
450            });
451
452        Ok(())
453    }
454
455    fn size(&self) -> usize {
456        self.sum.capacity() * std::mem::size_of::<i128>()
457            + self.is_empty.capacity() / 8
458            + self.is_not_null.capacity() / 8
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use arrow::datatypes::*;
466    use arrow_array::builder::{Decimal128Builder, StringBuilder};
467    use arrow_array::RecordBatch;
468    use datafusion::datasource::memory::MemorySourceConfig;
469    use datafusion::datasource::source::DataSourceExec;
470    use datafusion::execution::TaskContext;
471    use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
472    use datafusion::physical_plan::ExecutionPlan;
473    use datafusion_common::Result;
474    use datafusion_expr::AggregateUDF;
475    use datafusion_physical_expr::aggregate::AggregateExprBuilder;
476    use datafusion_physical_expr::expressions::Column;
477    use datafusion_physical_expr::PhysicalExpr;
478    use futures::StreamExt;
479
480    #[test]
481    fn invalid_data_type() {
482        assert!(SumDecimal::try_new(DataType::Int32).is_err());
483    }
484
485    #[tokio::test]
486    async fn sum_no_overflow() -> Result<()> {
487        let num_rows = 8192;
488        let batch = create_record_batch(num_rows);
489        let mut batches = Vec::new();
490        for _ in 0..10 {
491            batches.push(batch.clone());
492        }
493        let partitions = &[batches];
494        let c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
495        let c1: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c1", 1));
496
497        let data_type = DataType::Decimal128(8, 2);
498        let schema = Arc::clone(&partitions[0][0].schema());
499        let scan: Arc<dyn ExecutionPlan> = Arc::new(DataSourceExec::new(Arc::new(
500            MemorySourceConfig::try_new(partitions, Arc::clone(&schema), None).unwrap(),
501        )));
502
503        let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new(
504            data_type.clone(),
505        )?));
506
507        let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1])
508            .schema(Arc::clone(&schema))
509            .alias("sum")
510            .with_ignore_nulls(false)
511            .with_distinct(false)
512            .build()?;
513
514        let aggregate = Arc::new(AggregateExec::try_new(
515            AggregateMode::Partial,
516            PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]),
517            vec![aggr_expr.into()],
518            vec![None], // no filter expressions
519            scan,
520            Arc::clone(&schema),
521        )?);
522
523        let mut stream = aggregate
524            .execute(0, Arc::new(TaskContext::default()))
525            .unwrap();
526        while let Some(batch) = stream.next().await {
527            let _batch = batch?;
528        }
529
530        Ok(())
531    }
532
533    fn create_record_batch(num_rows: usize) -> RecordBatch {
534        let mut decimal_builder = Decimal128Builder::with_capacity(num_rows);
535        let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
536        for i in 0..num_rows {
537            decimal_builder.append_value(i as i128);
538            string_builder.append_value(format!("this is string #{}", i % 1024));
539        }
540        let decimal_array = Arc::new(decimal_builder.finish());
541        let string_array = Arc::new(string_builder.finish());
542
543        let mut fields = vec![];
544        let mut columns: Vec<ArrayRef> = vec![];
545
546        // string column
547        fields.push(Field::new("c0", DataType::Utf8, false));
548        columns.push(string_array);
549
550        // decimal column
551        fields.push(Field::new("c1", DataType::Decimal128(38, 10), false));
552        columns.push(decimal_array);
553
554        let schema = Schema::new(fields);
555        RecordBatch::try_new(Arc::new(schema), columns).unwrap()
556    }
557}