1use crate::utils::{is_valid_decimal_precision, unlikely};
19use arrow::array::{
20 cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array,
21};
22use arrow::datatypes::{DataType, Field, FieldRef};
23use arrow::{
24 array::BooleanBufferBuilder,
25 buffer::{BooleanBuffer, NullBuffer},
26};
27use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue};
28use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
29use datafusion::logical_expr::Volatility::Immutable;
30use datafusion::logical_expr::{
31 Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature,
32};
33use std::{any::Any, ops::BitAnd, sync::Arc};
34
35#[derive(Debug)]
36pub struct SumDecimal {
37 signature: Signature,
39 result_type: DataType,
42 precision: u8,
44 scale: i8,
46}
47
48impl SumDecimal {
49 pub fn try_new(data_type: DataType) -> DFResult<Self> {
50 let (precision, scale) = match data_type {
52 DataType::Decimal128(p, s) => (p, s),
53 _ => {
54 return Err(DataFusionError::Internal(
55 "Invalid data type for SumDecimal".into(),
56 ))
57 }
58 };
59 Ok(Self {
60 signature: Signature::user_defined(Immutable),
61 result_type: data_type,
62 precision,
63 scale,
64 })
65 }
66}
67
68impl AggregateUDFImpl for SumDecimal {
69 fn as_any(&self) -> &dyn Any {
70 self
71 }
72
73 fn accumulator(&self, _args: AccumulatorArgs) -> DFResult<Box<dyn Accumulator>> {
74 Ok(Box::new(SumDecimalAccumulator::new(
75 self.precision,
76 self.scale,
77 )))
78 }
79
80 fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<FieldRef>> {
81 let fields = vec![
82 Arc::new(Field::new(
83 self.name(),
84 self.result_type.clone(),
85 self.is_nullable(),
86 )),
87 Arc::new(Field::new("is_empty", DataType::Boolean, false)),
88 ];
89 Ok(fields)
90 }
91
92 fn name(&self) -> &str {
93 "sum"
94 }
95
96 fn signature(&self) -> &Signature {
97 &self.signature
98 }
99
100 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
101 Ok(self.result_type.clone())
102 }
103
104 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
105 true
106 }
107
108 fn create_groups_accumulator(
109 &self,
110 _args: AccumulatorArgs,
111 ) -> DFResult<Box<dyn GroupsAccumulator>> {
112 Ok(Box::new(SumDecimalGroupsAccumulator::new(
113 self.result_type.clone(),
114 self.precision,
115 )))
116 }
117
118 fn default_value(&self, _data_type: &DataType) -> DFResult<ScalarValue> {
119 ScalarValue::new_primitive::<Decimal128Type>(
120 None,
121 &DataType::Decimal128(self.precision, self.scale),
122 )
123 }
124
125 fn reverse_expr(&self) -> ReversedUDAF {
126 ReversedUDAF::Identical
127 }
128
129 fn is_nullable(&self) -> bool {
130 true
132 }
133}
134
135#[derive(Debug)]
136struct SumDecimalAccumulator {
137 sum: i128,
138 is_empty: bool,
139 is_not_null: bool,
140
141 precision: u8,
142 scale: i8,
143}
144
145impl SumDecimalAccumulator {
146 fn new(precision: u8, scale: i8) -> Self {
147 Self {
148 sum: 0,
149 is_empty: true,
150 is_not_null: true,
151 precision,
152 scale,
153 }
154 }
155
156 fn update_single(&mut self, values: &Decimal128Array, idx: usize) {
157 let v = unsafe { values.value_unchecked(idx) };
158 let (new_sum, is_overflow) = self.sum.overflowing_add(v);
159
160 if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
161 self.is_not_null = false;
163 return;
164 }
165
166 self.sum = new_sum;
167 self.is_not_null = true;
168 }
169}
170
171impl Accumulator for SumDecimalAccumulator {
172 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
173 assert_eq!(
174 values.len(),
175 1,
176 "Expect only one element in 'values' but found {}",
177 values.len()
178 );
179
180 if !self.is_empty && !self.is_not_null {
181 return Ok(());
184 }
185
186 let values = &values[0];
187 let data = values.as_primitive::<Decimal128Type>();
188
189 self.is_empty = self.is_empty && values.len() == values.null_count();
190
191 if values.null_count() == 0 {
192 for i in 0..data.len() {
193 self.update_single(data, i);
194 }
195 } else {
196 for i in 0..data.len() {
197 if data.is_null(i) {
198 continue;
199 }
200 self.update_single(data, i);
201 }
202 }
203
204 Ok(())
205 }
206
207 fn evaluate(&mut self) -> DFResult<ScalarValue> {
208 if self.is_empty
214 || !self.is_not_null
215 || !is_valid_decimal_precision(self.sum, self.precision)
216 {
217 ScalarValue::new_primitive::<Decimal128Type>(
218 None,
219 &DataType::Decimal128(self.precision, self.scale),
220 )
221 } else {
222 ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)
223 }
224 }
225
226 fn size(&self) -> usize {
227 std::mem::size_of_val(self)
228 }
229
230 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
231 let sum = if self.is_not_null {
232 ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)?
233 } else {
234 ScalarValue::new_primitive::<Decimal128Type>(
235 None,
236 &DataType::Decimal128(self.precision, self.scale),
237 )?
238 };
239 Ok(vec![sum, ScalarValue::from(self.is_empty)])
240 }
241
242 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
243 assert_eq!(
244 states.len(),
245 2,
246 "Expect two element in 'states' but found {}",
247 states.len()
248 );
249 assert_eq!(states[0].len(), 1);
250 assert_eq!(states[1].len(), 1);
251
252 let that_sum = states[0].as_primitive::<Decimal128Type>();
253 let that_is_empty = states[1].as_any().downcast_ref::<BooleanArray>().unwrap();
254
255 let this_overflow = !self.is_empty && !self.is_not_null;
256 let that_overflow = !that_is_empty.value(0) && that_sum.is_null(0);
257
258 self.is_not_null = !this_overflow && !that_overflow;
259 self.is_empty = self.is_empty && that_is_empty.value(0);
260
261 if self.is_not_null {
262 self.sum += that_sum.value(0);
263 }
264
265 Ok(())
266 }
267}
268
269struct SumDecimalGroupsAccumulator {
270 is_not_null: BooleanBufferBuilder,
272 is_empty: BooleanBufferBuilder,
273 sum: Vec<i128>,
274 result_type: DataType,
275 precision: u8,
276}
277
278impl SumDecimalGroupsAccumulator {
279 fn new(result_type: DataType, precision: u8) -> Self {
280 Self {
281 is_not_null: BooleanBufferBuilder::new(0),
282 is_empty: BooleanBufferBuilder::new(0),
283 sum: Vec::new(),
284 result_type,
285 precision,
286 }
287 }
288
289 fn is_overflow(&self, index: usize) -> bool {
290 !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index)
291 }
292
293 fn update_single(&mut self, group_index: usize, value: i128) {
294 if unlikely(self.is_overflow(group_index)) {
295 return;
298 }
299
300 self.is_empty.set_bit(group_index, false);
301 let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value);
302
303 if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
304 self.is_not_null.set_bit(group_index, false);
306 return;
307 }
308
309 self.sum[group_index] = new_sum;
310 self.is_not_null.set_bit(group_index, true)
311 }
312}
313
314fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) {
315 if builder.len() < capacity {
316 let additional = capacity - builder.len();
317 builder.append_n(additional, true);
318 }
319}
320
321fn build_bool_state(state: &mut BooleanBufferBuilder, emit_to: &EmitTo) -> BooleanBuffer {
324 let bool_state: BooleanBuffer = state.finish();
325
326 match emit_to {
327 EmitTo::All => bool_state,
328 EmitTo::First(n) => {
329 let first_n_bools: BooleanBuffer = bool_state.iter().take(*n).collect();
331 for seen in bool_state.iter().skip(*n) {
333 state.append(seen);
334 }
335 first_n_bools
336 }
337 }
338}
339
340impl GroupsAccumulator for SumDecimalGroupsAccumulator {
341 fn update_batch(
342 &mut self,
343 values: &[ArrayRef],
344 group_indices: &[usize],
345 opt_filter: Option<&BooleanArray>,
346 total_num_groups: usize,
347 ) -> DFResult<()> {
348 assert!(opt_filter.is_none(), "opt_filter is not supported yet");
349 assert_eq!(values.len(), 1);
350 let values = values[0].as_primitive::<Decimal128Type>();
351 let data = values.values();
352
353 self.sum.resize(total_num_groups, 0);
355 ensure_bit_capacity(&mut self.is_empty, total_num_groups);
356 ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
357
358 let iter = group_indices.iter().zip(data.iter());
359 if values.null_count() == 0 {
360 for (&group_index, &value) in iter {
361 self.update_single(group_index, value);
362 }
363 } else {
364 for (idx, (&group_index, &value)) in iter.enumerate() {
365 if values.is_null(idx) {
366 continue;
367 }
368 self.update_single(group_index, value);
369 }
370 }
371
372 Ok(())
373 }
374
375 fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
376 let result = emit_to.take_needed(&mut self.sum);
382 result.iter().enumerate().for_each(|(i, &v)| {
383 if !is_valid_decimal_precision(v, self.precision) {
384 self.is_not_null.set_bit(i, false);
385 }
386 });
387
388 let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
389 let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
390 let x = (!&is_empty).bitand(&nulls);
391
392 let result = Decimal128Array::new(result.into(), Some(NullBuffer::new(x)))
393 .with_data_type(self.result_type.clone());
394
395 Ok(Arc::new(result))
396 }
397
398 fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
399 let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
400 let nulls = Some(NullBuffer::new(nulls));
401
402 let sum = emit_to.take_needed(&mut self.sum);
403 let sum = Decimal128Array::new(sum.into(), nulls.clone())
404 .with_data_type(self.result_type.clone());
405
406 let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
407 let is_empty = BooleanArray::new(is_empty, None);
408
409 Ok(vec![
410 Arc::new(sum) as ArrayRef,
411 Arc::new(is_empty) as ArrayRef,
412 ])
413 }
414
415 fn merge_batch(
416 &mut self,
417 values: &[ArrayRef],
418 group_indices: &[usize],
419 opt_filter: Option<&BooleanArray>,
420 total_num_groups: usize,
421 ) -> DFResult<()> {
422 assert_eq!(
423 values.len(),
424 2,
425 "Expected two arrays: 'sum' and 'is_empty', but found {}",
426 values.len()
427 );
428 assert!(opt_filter.is_none(), "opt_filter is not supported yet");
429
430 self.sum.resize(total_num_groups, 0);
432 ensure_bit_capacity(&mut self.is_empty, total_num_groups);
433 ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
434
435 let that_sum = &values[0];
436 let that_sum = that_sum.as_primitive::<Decimal128Type>();
437 let that_is_empty = &values[1];
438 let that_is_empty = that_is_empty
439 .as_any()
440 .downcast_ref::<BooleanArray>()
441 .unwrap();
442
443 group_indices
444 .iter()
445 .enumerate()
446 .for_each(|(idx, &group_index)| unsafe {
447 let this_overflow = self.is_overflow(group_index);
448 let that_is_empty = that_is_empty.value_unchecked(idx);
449 let that_overflow = !that_is_empty && that_sum.is_null(idx);
450 let is_overflow = this_overflow || that_overflow;
451
452 self.is_not_null.set_bit(group_index, !is_overflow);
455 self.is_empty.set_bit(
456 group_index,
457 self.is_empty.get_bit(group_index) && that_is_empty,
458 );
459 if !is_overflow {
460 self.sum[group_index] += that_sum.value_unchecked(idx);
463 }
464 });
465
466 Ok(())
467 }
468
469 fn size(&self) -> usize {
470 self.sum.capacity() * std::mem::size_of::<i128>()
471 + self.is_empty.capacity() / 8
472 + self.is_not_null.capacity() / 8
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479 use arrow::array::builder::{Decimal128Builder, StringBuilder};
480 use arrow::array::RecordBatch;
481 use arrow::datatypes::*;
482 use datafusion::common::Result;
483 use datafusion::datasource::memory::MemorySourceConfig;
484 use datafusion::datasource::source::DataSourceExec;
485 use datafusion::execution::TaskContext;
486 use datafusion::logical_expr::AggregateUDF;
487 use datafusion::physical_expr::aggregate::AggregateExprBuilder;
488 use datafusion::physical_expr::expressions::Column;
489 use datafusion::physical_expr::PhysicalExpr;
490 use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
491 use datafusion::physical_plan::ExecutionPlan;
492 use futures::StreamExt;
493
494 #[test]
495 fn invalid_data_type() {
496 assert!(SumDecimal::try_new(DataType::Int32).is_err());
497 }
498
499 #[tokio::test]
500 async fn sum_no_overflow() -> Result<()> {
501 let num_rows = 8192;
502 let batch = create_record_batch(num_rows);
503 let mut batches = Vec::new();
504 for _ in 0..10 {
505 batches.push(batch.clone());
506 }
507 let partitions = &[batches];
508 let c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
509 let c1: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c1", 1));
510
511 let data_type = DataType::Decimal128(8, 2);
512 let schema = Arc::clone(&partitions[0][0].schema());
513 let scan: Arc<dyn ExecutionPlan> = Arc::new(DataSourceExec::new(Arc::new(
514 MemorySourceConfig::try_new(partitions, Arc::clone(&schema), None).unwrap(),
515 )));
516
517 let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new(
518 data_type.clone(),
519 )?));
520
521 let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1])
522 .schema(Arc::clone(&schema))
523 .alias("sum")
524 .with_ignore_nulls(false)
525 .with_distinct(false)
526 .build()?;
527
528 let aggregate = Arc::new(AggregateExec::try_new(
529 AggregateMode::Partial,
530 PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]),
531 vec![aggr_expr.into()],
532 vec![None], scan,
534 Arc::clone(&schema),
535 )?);
536
537 let mut stream = aggregate
538 .execute(0, Arc::new(TaskContext::default()))
539 .unwrap();
540 while let Some(batch) = stream.next().await {
541 let _batch = batch?;
542 }
543
544 Ok(())
545 }
546
547 fn create_record_batch(num_rows: usize) -> RecordBatch {
548 let mut decimal_builder = Decimal128Builder::with_capacity(num_rows);
549 let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
550 for i in 0..num_rows {
551 decimal_builder.append_value(i as i128);
552 string_builder.append_value(format!("this is string #{}", i % 1024));
553 }
554 let decimal_array = Arc::new(decimal_builder.finish());
555 let string_array = Arc::new(string_builder.finish());
556
557 let mut fields = vec![];
558 let mut columns: Vec<ArrayRef> = vec![];
559
560 fields.push(Field::new("c0", DataType::Utf8, false));
562 columns.push(string_array);
563
564 fields.push(Field::new("c1", DataType::Decimal128(38, 10), false));
566 columns.push(decimal_array);
567
568 let schema = Schema::new(fields);
569 RecordBatch::try_new(Arc::new(schema), columns).unwrap()
570 }
571}