1use crate::utils::{is_valid_decimal_precision, unlikely};
19use arrow::{
20 array::BooleanBufferBuilder,
21 buffer::{BooleanBuffer, NullBuffer},
22};
23use arrow_array::{
24 cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array,
25};
26use arrow_schema::{DataType, Field};
27use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator};
28use datafusion_common::{DataFusionError, Result as DFResult, ScalarValue};
29use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
30use datafusion_expr::Volatility::Immutable;
31use datafusion_expr::{AggregateUDFImpl, ReversedUDAF, Signature};
32use std::{any::Any, ops::BitAnd, sync::Arc};
33
34#[derive(Debug)]
35pub struct SumDecimal {
36 signature: Signature,
38 result_type: DataType,
41 precision: u8,
43 scale: i8,
45}
46
47impl SumDecimal {
48 pub fn try_new(data_type: DataType) -> DFResult<Self> {
49 let (precision, scale) = match data_type {
51 DataType::Decimal128(p, s) => (p, s),
52 _ => {
53 return Err(DataFusionError::Internal(
54 "Invalid data type for SumDecimal".into(),
55 ))
56 }
57 };
58 Ok(Self {
59 signature: Signature::user_defined(Immutable),
60 result_type: data_type,
61 precision,
62 scale,
63 })
64 }
65}
66
67impl AggregateUDFImpl for SumDecimal {
68 fn as_any(&self) -> &dyn Any {
69 self
70 }
71
72 fn accumulator(&self, _args: AccumulatorArgs) -> DFResult<Box<dyn Accumulator>> {
73 Ok(Box::new(SumDecimalAccumulator::new(
74 self.precision,
75 self.scale,
76 )))
77 }
78
79 fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<Field>> {
80 let fields = vec![
81 Field::new(self.name(), self.result_type.clone(), self.is_nullable()),
82 Field::new("is_empty", DataType::Boolean, false),
83 ];
84 Ok(fields)
85 }
86
87 fn name(&self) -> &str {
88 "sum"
89 }
90
91 fn signature(&self) -> &Signature {
92 &self.signature
93 }
94
95 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
96 Ok(self.result_type.clone())
97 }
98
99 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
100 true
101 }
102
103 fn create_groups_accumulator(
104 &self,
105 _args: AccumulatorArgs,
106 ) -> DFResult<Box<dyn GroupsAccumulator>> {
107 Ok(Box::new(SumDecimalGroupsAccumulator::new(
108 self.result_type.clone(),
109 self.precision,
110 )))
111 }
112
113 fn default_value(&self, _data_type: &DataType) -> DFResult<ScalarValue> {
114 ScalarValue::new_primitive::<Decimal128Type>(
115 None,
116 &DataType::Decimal128(self.precision, self.scale),
117 )
118 }
119
120 fn reverse_expr(&self) -> ReversedUDAF {
121 ReversedUDAF::Identical
122 }
123
124 fn is_nullable(&self) -> bool {
125 true
127 }
128}
129
130#[derive(Debug)]
131struct SumDecimalAccumulator {
132 sum: i128,
133 is_empty: bool,
134 is_not_null: bool,
135
136 precision: u8,
137 scale: i8,
138}
139
140impl SumDecimalAccumulator {
141 fn new(precision: u8, scale: i8) -> Self {
142 Self {
143 sum: 0,
144 is_empty: true,
145 is_not_null: true,
146 precision,
147 scale,
148 }
149 }
150
151 fn update_single(&mut self, values: &Decimal128Array, idx: usize) {
152 let v = unsafe { values.value_unchecked(idx) };
153 let (new_sum, is_overflow) = self.sum.overflowing_add(v);
154
155 if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
156 self.is_not_null = false;
158 return;
159 }
160
161 self.sum = new_sum;
162 self.is_not_null = true;
163 }
164}
165
166impl Accumulator for SumDecimalAccumulator {
167 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
168 assert_eq!(
169 values.len(),
170 1,
171 "Expect only one element in 'values' but found {}",
172 values.len()
173 );
174
175 if !self.is_empty && !self.is_not_null {
176 return Ok(());
179 }
180
181 let values = &values[0];
182 let data = values.as_primitive::<Decimal128Type>();
183
184 self.is_empty = self.is_empty && values.len() == values.null_count();
185
186 if values.null_count() == 0 {
187 for i in 0..data.len() {
188 self.update_single(data, i);
189 }
190 } else {
191 for i in 0..data.len() {
192 if data.is_null(i) {
193 continue;
194 }
195 self.update_single(data, i);
196 }
197 }
198
199 Ok(())
200 }
201
202 fn evaluate(&mut self) -> DFResult<ScalarValue> {
203 if self.is_empty || !self.is_not_null {
209 ScalarValue::new_primitive::<Decimal128Type>(
210 None,
211 &DataType::Decimal128(self.precision, self.scale),
212 )
213 } else {
214 ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)
215 }
216 }
217
218 fn size(&self) -> usize {
219 std::mem::size_of_val(self)
220 }
221
222 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
223 let sum = if self.is_not_null {
224 ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)?
225 } else {
226 ScalarValue::new_primitive::<Decimal128Type>(
227 None,
228 &DataType::Decimal128(self.precision, self.scale),
229 )?
230 };
231 Ok(vec![sum, ScalarValue::from(self.is_empty)])
232 }
233
234 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
235 assert_eq!(
236 states.len(),
237 2,
238 "Expect two element in 'states' but found {}",
239 states.len()
240 );
241 assert_eq!(states[0].len(), 1);
242 assert_eq!(states[1].len(), 1);
243
244 let that_sum = states[0].as_primitive::<Decimal128Type>();
245 let that_is_empty = states[1].as_any().downcast_ref::<BooleanArray>().unwrap();
246
247 let this_overflow = !self.is_empty && !self.is_not_null;
248 let that_overflow = !that_is_empty.value(0) && that_sum.is_null(0);
249
250 self.is_not_null = !this_overflow && !that_overflow;
251 self.is_empty = self.is_empty && that_is_empty.value(0);
252
253 if self.is_not_null {
254 self.sum += that_sum.value(0);
255 }
256
257 Ok(())
258 }
259}
260
261struct SumDecimalGroupsAccumulator {
262 is_not_null: BooleanBufferBuilder,
264 is_empty: BooleanBufferBuilder,
265 sum: Vec<i128>,
266 result_type: DataType,
267 precision: u8,
268}
269
270impl SumDecimalGroupsAccumulator {
271 fn new(result_type: DataType, precision: u8) -> Self {
272 Self {
273 is_not_null: BooleanBufferBuilder::new(0),
274 is_empty: BooleanBufferBuilder::new(0),
275 sum: Vec::new(),
276 result_type,
277 precision,
278 }
279 }
280
281 fn is_overflow(&self, index: usize) -> bool {
282 !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index)
283 }
284
285 fn update_single(&mut self, group_index: usize, value: i128) {
286 if unlikely(self.is_overflow(group_index)) {
287 return;
290 }
291
292 self.is_empty.set_bit(group_index, false);
293 let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value);
294
295 if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
296 self.is_not_null.set_bit(group_index, false);
298 return;
299 }
300
301 self.sum[group_index] = new_sum;
302 self.is_not_null.set_bit(group_index, true)
303 }
304}
305
306fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) {
307 if builder.len() < capacity {
308 let additional = capacity - builder.len();
309 builder.append_n(additional, true);
310 }
311}
312
313fn build_bool_state(state: &mut BooleanBufferBuilder, emit_to: &EmitTo) -> BooleanBuffer {
316 let bool_state: BooleanBuffer = state.finish();
317
318 match emit_to {
319 EmitTo::All => bool_state,
320 EmitTo::First(n) => {
321 let first_n_bools: BooleanBuffer = bool_state.iter().take(*n).collect();
323 for seen in bool_state.iter().skip(*n) {
325 state.append(seen);
326 }
327 first_n_bools
328 }
329 }
330}
331
332impl GroupsAccumulator for SumDecimalGroupsAccumulator {
333 fn update_batch(
334 &mut self,
335 values: &[ArrayRef],
336 group_indices: &[usize],
337 opt_filter: Option<&BooleanArray>,
338 total_num_groups: usize,
339 ) -> DFResult<()> {
340 assert!(opt_filter.is_none(), "opt_filter is not supported yet");
341 assert_eq!(values.len(), 1);
342 let values = values[0].as_primitive::<Decimal128Type>();
343 let data = values.values();
344
345 self.sum.resize(total_num_groups, 0);
347 ensure_bit_capacity(&mut self.is_empty, total_num_groups);
348 ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
349
350 let iter = group_indices.iter().zip(data.iter());
351 if values.null_count() == 0 {
352 for (&group_index, &value) in iter {
353 self.update_single(group_index, value);
354 }
355 } else {
356 for (idx, (&group_index, &value)) in iter.enumerate() {
357 if values.is_null(idx) {
358 continue;
359 }
360 self.update_single(group_index, value);
361 }
362 }
363
364 Ok(())
365 }
366
367 fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
368 let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
374 let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
375 let x = (!&is_empty).bitand(&nulls);
376
377 let result = emit_to.take_needed(&mut self.sum);
378 let result = Decimal128Array::new(result.into(), Some(NullBuffer::new(x)))
379 .with_data_type(self.result_type.clone());
380
381 Ok(Arc::new(result))
382 }
383
384 fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
385 let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
386 let nulls = Some(NullBuffer::new(nulls));
387
388 let sum = emit_to.take_needed(&mut self.sum);
389 let sum = Decimal128Array::new(sum.into(), nulls.clone())
390 .with_data_type(self.result_type.clone());
391
392 let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
393 let is_empty = BooleanArray::new(is_empty, None);
394
395 Ok(vec![
396 Arc::new(sum) as ArrayRef,
397 Arc::new(is_empty) as ArrayRef,
398 ])
399 }
400
401 fn merge_batch(
402 &mut self,
403 values: &[ArrayRef],
404 group_indices: &[usize],
405 opt_filter: Option<&BooleanArray>,
406 total_num_groups: usize,
407 ) -> DFResult<()> {
408 assert_eq!(
409 values.len(),
410 2,
411 "Expected two arrays: 'sum' and 'is_empty', but found {}",
412 values.len()
413 );
414 assert!(opt_filter.is_none(), "opt_filter is not supported yet");
415
416 self.sum.resize(total_num_groups, 0);
418 ensure_bit_capacity(&mut self.is_empty, total_num_groups);
419 ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
420
421 let that_sum = &values[0];
422 let that_sum = that_sum.as_primitive::<Decimal128Type>();
423 let that_is_empty = &values[1];
424 let that_is_empty = that_is_empty
425 .as_any()
426 .downcast_ref::<BooleanArray>()
427 .unwrap();
428
429 group_indices
430 .iter()
431 .enumerate()
432 .for_each(|(idx, &group_index)| unsafe {
433 let this_overflow = self.is_overflow(group_index);
434 let that_is_empty = that_is_empty.value_unchecked(idx);
435 let that_overflow = !that_is_empty && that_sum.is_null(idx);
436 let is_overflow = this_overflow || that_overflow;
437
438 self.is_not_null.set_bit(group_index, !is_overflow);
441 self.is_empty.set_bit(
442 group_index,
443 self.is_empty.get_bit(group_index) && that_is_empty,
444 );
445 if !is_overflow {
446 self.sum[group_index] += that_sum.value_unchecked(idx);
449 }
450 });
451
452 Ok(())
453 }
454
455 fn size(&self) -> usize {
456 self.sum.capacity() * std::mem::size_of::<i128>()
457 + self.is_empty.capacity() / 8
458 + self.is_not_null.capacity() / 8
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use arrow::datatypes::*;
466 use arrow_array::builder::{Decimal128Builder, StringBuilder};
467 use arrow_array::RecordBatch;
468 use datafusion::datasource::memory::MemorySourceConfig;
469 use datafusion::datasource::source::DataSourceExec;
470 use datafusion::execution::TaskContext;
471 use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
472 use datafusion::physical_plan::ExecutionPlan;
473 use datafusion_common::Result;
474 use datafusion_expr::AggregateUDF;
475 use datafusion_physical_expr::aggregate::AggregateExprBuilder;
476 use datafusion_physical_expr::expressions::Column;
477 use datafusion_physical_expr::PhysicalExpr;
478 use futures::StreamExt;
479
480 #[test]
481 fn invalid_data_type() {
482 assert!(SumDecimal::try_new(DataType::Int32).is_err());
483 }
484
485 #[tokio::test]
486 async fn sum_no_overflow() -> Result<()> {
487 let num_rows = 8192;
488 let batch = create_record_batch(num_rows);
489 let mut batches = Vec::new();
490 for _ in 0..10 {
491 batches.push(batch.clone());
492 }
493 let partitions = &[batches];
494 let c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
495 let c1: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c1", 1));
496
497 let data_type = DataType::Decimal128(8, 2);
498 let schema = Arc::clone(&partitions[0][0].schema());
499 let scan: Arc<dyn ExecutionPlan> = Arc::new(DataSourceExec::new(Arc::new(
500 MemorySourceConfig::try_new(partitions, Arc::clone(&schema), None).unwrap(),
501 )));
502
503 let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new(
504 data_type.clone(),
505 )?));
506
507 let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1])
508 .schema(Arc::clone(&schema))
509 .alias("sum")
510 .with_ignore_nulls(false)
511 .with_distinct(false)
512 .build()?;
513
514 let aggregate = Arc::new(AggregateExec::try_new(
515 AggregateMode::Partial,
516 PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]),
517 vec![aggr_expr.into()],
518 vec![None], scan,
520 Arc::clone(&schema),
521 )?);
522
523 let mut stream = aggregate
524 .execute(0, Arc::new(TaskContext::default()))
525 .unwrap();
526 while let Some(batch) = stream.next().await {
527 let _batch = batch?;
528 }
529
530 Ok(())
531 }
532
533 fn create_record_batch(num_rows: usize) -> RecordBatch {
534 let mut decimal_builder = Decimal128Builder::with_capacity(num_rows);
535 let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
536 for i in 0..num_rows {
537 decimal_builder.append_value(i as i128);
538 string_builder.append_value(format!("this is string #{}", i % 1024));
539 }
540 let decimal_array = Arc::new(decimal_builder.finish());
541 let string_array = Arc::new(string_builder.finish());
542
543 let mut fields = vec![];
544 let mut columns: Vec<ArrayRef> = vec![];
545
546 fields.push(Field::new("c0", DataType::Utf8, false));
548 columns.push(string_array);
549
550 fields.push(Field::new("c1", DataType::Decimal128(38, 10), false));
552 columns.push(decimal_array);
553
554 let schema = Schema::new(fields);
555 RecordBatch::try_new(Arc::new(schema), columns).unwrap()
556 }
557}