1use arrow::array::{ArrayRef, ArrowNumericType, AsArray, BooleanArray, PrimitiveArray};
19use arrow::datatypes::{
20 DECIMAL128_MAX_PRECISION, DataType, Decimal128Type, Field, FieldRef, Float64Type,
21 Int64Type,
22};
23use datafusion_common::{Result, ScalarValue, downcast_value, exec_err, not_impl_err};
24use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
25use datafusion_expr::utils::format_state_name;
26use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
27use std::any::Any;
28use std::fmt::{Debug, Formatter};
29use std::mem::size_of_val;
30
31#[derive(PartialEq, Eq, Hash)]
32pub struct SparkTrySum {
33 signature: Signature,
34}
35
36impl Default for SparkTrySum {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl SparkTrySum {
43 pub fn new() -> Self {
44 Self {
45 signature: Signature::user_defined(Volatility::Immutable),
46 }
47 }
48}
49
50impl Debug for SparkTrySum {
51 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("SparkTrySum")
53 .field("signature", &self.signature)
54 .finish()
55 }
56}
57
58struct TrySumAccumulator<T: ArrowNumericType> {
60 sum: Option<T::Native>,
61 data_type: DataType,
62 failed: bool,
63 dec_precision: Option<u8>,
65}
66
67impl<T: ArrowNumericType> Debug for TrySumAccumulator<T> {
68 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
69 write!(f, "TrySumAccumulator({})", self.data_type)
70 }
71}
72
73impl<T: ArrowNumericType> TrySumAccumulator<T> {
74 fn new(data_type: DataType) -> Self {
75 let dec_precision = match &data_type {
76 DataType::Decimal128(p, _) => Some(*p),
77 _ => None,
78 };
79 Self {
80 sum: None,
81 data_type,
82 failed: false,
83 dec_precision,
84 }
85 }
86}
87
88impl<T: ArrowNumericType> Accumulator for TrySumAccumulator<T> {
89 fn state(&mut self) -> Result<Vec<ScalarValue>> {
90 Ok(vec![
91 self.evaluate()?,
92 ScalarValue::Boolean(Some(self.failed)),
93 ])
94 }
95
96 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
97 update_batch_internal(self, values)
98 }
99
100 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
101 if downcast_value!(states[1], BooleanArray)
103 .iter()
104 .flatten()
105 .any(|f| f)
106 {
107 self.failed = true;
108 return Ok(());
109 }
110
111 update_batch_internal(self, states)
113 }
114
115 fn evaluate(&mut self) -> Result<ScalarValue> {
116 evaluate_internal(self)
117 }
118
119 fn size(&self) -> usize {
120 size_of_val(self)
121 }
122}
123
124fn update_batch_internal<T: ArrowNumericType>(
127 acc: &mut TrySumAccumulator<T>,
128 values: &[ArrayRef],
129) -> Result<()> {
130 if values.is_empty() || acc.failed {
131 return Ok(());
132 }
133
134 let array: &PrimitiveArray<T> = values[0].as_primitive::<T>();
135
136 match acc.data_type {
137 DataType::Int64 => update_int64(acc, array),
138 DataType::Float64 => update_float64(acc, array),
139 DataType::Decimal128(_, _) => update_decimal128(acc, array),
140 _ => exec_err!(
141 "try_sum: unsupported type in update_batch: {:?}",
142 acc.data_type
143 ),
144 }
145}
146
147fn update_int64<T: ArrowNumericType>(
148 acc: &mut TrySumAccumulator<T>,
149 array: &PrimitiveArray<T>,
150) -> Result<()> {
151 for v in array.iter().flatten() {
152 let v_i64 = unsafe { std::mem::transmute_copy::<T::Native, i64>(&v) };
154 let sum_i64 = acc
155 .sum
156 .map(|s| unsafe { std::mem::transmute_copy::<T::Native, i64>(&s) });
157
158 let new_sum = match sum_i64 {
159 None => v_i64,
160 Some(s) => match s.checked_add(v_i64) {
161 Some(result) => result,
162 None => {
163 acc.failed = true;
164 return Ok(());
165 }
166 },
167 };
168
169 acc.sum = Some(unsafe { std::mem::transmute_copy::<i64, T::Native>(&new_sum) });
170 }
171 Ok(())
172}
173
174fn update_float64<T: ArrowNumericType>(
175 acc: &mut TrySumAccumulator<T>,
176 array: &PrimitiveArray<T>,
177) -> Result<()> {
178 for v in array.iter().flatten() {
179 let v_f64 = unsafe { std::mem::transmute_copy::<T::Native, f64>(&v) };
180 let sum_f64 = acc
181 .sum
182 .map(|s| unsafe { std::mem::transmute_copy::<T::Native, f64>(&s) })
183 .unwrap_or(0.0);
184 let new_sum = sum_f64 + v_f64;
185 acc.sum = Some(unsafe { std::mem::transmute_copy::<f64, T::Native>(&new_sum) });
186 }
187 Ok(())
188}
189
190fn update_decimal128<T: ArrowNumericType>(
191 acc: &mut TrySumAccumulator<T>,
192 array: &PrimitiveArray<T>,
193) -> Result<()> {
194 let precision = acc.dec_precision.unwrap_or(38);
195
196 for v in array.iter().flatten() {
197 let v_i128 = unsafe { std::mem::transmute_copy::<T::Native, i128>(&v) };
198 let sum_i128 = acc
199 .sum
200 .map(|s| unsafe { std::mem::transmute_copy::<T::Native, i128>(&s) });
201
202 let new_sum = match sum_i128 {
203 None => v_i128,
204 Some(s) => match s.checked_add(v_i128) {
205 Some(result) => result,
206 None => {
207 acc.failed = true;
208 return Ok(());
209 }
210 },
211 };
212
213 if exceeds_decimal128_precision(new_sum, precision) {
214 acc.failed = true;
215 return Ok(());
216 }
217
218 acc.sum = Some(unsafe { std::mem::transmute_copy::<i128, T::Native>(&new_sum) });
219 }
220 Ok(())
221}
222
223fn evaluate_internal<T: ArrowNumericType>(
224 acc: &mut TrySumAccumulator<T>,
225) -> Result<ScalarValue> {
226 if acc.failed {
227 return ScalarValue::new_primitive::<T>(None, &acc.data_type);
228 }
229 ScalarValue::new_primitive::<T>(acc.sum, &acc.data_type)
230}
231
232fn pow10_i128(p: u8) -> Option<i128> {
234 let mut v: i128 = 1;
235 for _ in 0..p {
236 v = v.checked_mul(10)?;
237 }
238 Some(v)
239}
240
241fn exceeds_decimal128_precision(sum: i128, p: u8) -> bool {
242 if let Some(max_plus_one) = pow10_i128(p) {
243 let max = max_plus_one - 1;
244 sum > max || sum < -max
245 } else {
246 true
247 }
248}
249
250impl AggregateUDFImpl for SparkTrySum {
251 fn as_any(&self) -> &dyn Any {
252 self
253 }
254
255 fn name(&self) -> &str {
256 "try_sum"
257 }
258
259 fn signature(&self) -> &Signature {
260 &self.signature
261 }
262
263 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
264 use DataType::*;
265
266 let dt = &arg_types[0];
267 let result_type = match dt {
268 Null => Float64,
269 Decimal128(p, s) => {
270 let new_precision = DECIMAL128_MAX_PRECISION.min(p + 10);
271 Decimal128(new_precision, *s)
272 }
273 Int8 | Int16 | Int32 | Int64 => Int64,
274 Float16 | Float32 | Float64 => Float64,
275
276 other => return exec_err!("try_sum: unsupported type: {other:?}"),
277 };
278
279 Ok(result_type)
280 }
281
282 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
283 macro_rules! helper {
284 ($t:ty, $dt:expr) => {
285 Ok(Box::new(TrySumAccumulator::<$t>::new($dt.clone())))
286 };
287 }
288
289 match acc_args.return_field.data_type() {
290 DataType::Int64 => helper!(Int64Type, acc_args.return_field.data_type()),
291 DataType::Float64 => helper!(Float64Type, acc_args.return_field.data_type()),
292 DataType::Decimal128(_, _) => {
293 helper!(Decimal128Type, acc_args.return_field.data_type())
294 }
295 _ => not_impl_err!(
296 "try_sum: unsupported type for accumulator: {}",
297 acc_args.return_field.data_type()
298 ),
299 }
300 }
301
302 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
303 let sum_dt = args.return_field.data_type().clone();
304 Ok(vec![
305 Field::new(format_state_name(args.name, "sum"), sum_dt, true).into(),
306 Field::new(
307 format_state_name(args.name, "failed"),
308 DataType::Boolean,
309 false,
310 )
311 .into(),
312 ])
313 }
314
315 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
316 use DataType::*;
317 if arg_types.len() != 1 {
318 return exec_err!(
319 "try_sum: exactly 1 argument expected, got {}",
320 arg_types.len()
321 );
322 }
323
324 let dt = &arg_types[0];
325 let coerced = match dt {
326 Null => Float64,
327 Decimal128(p, s) => Decimal128(*p, *s),
328 Int8 | Int16 | Int32 | Int64 => Int64,
329 Float16 | Float32 | Float64 => Float64,
330 other => return exec_err!("try_sum: unsupported type: {other:?}"),
331 };
332 Ok(vec![coerced])
333 }
334
335 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
336 Ok(ScalarValue::Null)
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use arrow::array::{BooleanArray, Decimal128Array, Float64Array, Int64Array};
343 use datafusion_common::{DataFusionError, ScalarValue};
344 use std::sync::Arc;
345
346 use super::*;
347 fn int64(values: Vec<Option<i64>>) -> ArrayRef {
350 Arc::new(Int64Array::from(values)) as ArrayRef
351 }
352
353 fn f64(values: Vec<Option<f64>>) -> ArrayRef {
354 Arc::new(Float64Array::from(values)) as ArrayRef
355 }
356
357 fn dec128(p: u8, s: i8, vals: Vec<Option<i128>>) -> Result<ArrayRef> {
358 let base = Decimal128Array::from(vals);
359 let arr = base.with_precision_and_scale(p, s).map_err(|e| {
360 DataFusionError::Execution(format!("invalid precision/scale ({p},{s}): {e}"))
361 })?;
362 Ok(Arc::new(arr) as ArrayRef)
363 }
364
365 #[test]
368 fn try_sum_int_basic() -> Result<()> {
369 let mut acc = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
370 acc.update_batch(&[int64((0..10).map(Some).collect())])?;
371 let out = acc.evaluate()?;
372 assert_eq!(out, ScalarValue::Int64(Some(45)));
373 Ok(())
374 }
375
376 #[test]
377 fn try_sum_int_with_nulls() -> Result<()> {
378 let mut acc = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
379 acc.update_batch(&[int64(vec![None, Some(2), Some(3), None, Some(5)])])?;
380 let out = acc.evaluate()?;
381 assert_eq!(out, ScalarValue::Int64(Some(10)));
382 Ok(())
383 }
384
385 #[test]
386 fn try_sum_float_basic() -> Result<()> {
387 let mut acc = TrySumAccumulator::<Float64Type>::new(DataType::Float64);
388 acc.update_batch(&[f64(vec![Some(1.5), Some(2.5), None, Some(3.0)])])?;
389 let out = acc.evaluate()?;
390 assert_eq!(out, ScalarValue::Float64(Some(7.0)));
391 Ok(())
392 }
393
394 #[test]
395 fn float_overflow_behaves_like_spark_sum_infinite() -> Result<()> {
396 let mut acc = TrySumAccumulator::<Float64Type>::new(DataType::Float64);
397 acc.update_batch(&[f64(vec![Some(1e308), Some(1e308)])])?;
398
399 let out = acc.evaluate()?;
400 assert!(
401 matches!(out, ScalarValue::Float64(Some(v)) if v.is_infinite() && v.is_sign_positive()),
402 "waiting +Infinity, got: {out:?}"
403 );
404 Ok(())
405 }
406
407 #[test]
408 fn try_sum_float_negative_zero_normalizes_to_positive_zero() -> Result<()> {
409 let mut acc = TrySumAccumulator::<Float64Type>::new(DataType::Float64);
410 acc.update_batch(&[f64(vec![Some(-0.0), Some(0.0)])])?;
412 let out = acc.evaluate()?;
413 assert_eq!(out, ScalarValue::Float64(Some(0.0)));
414 if let ScalarValue::Float64(Some(v)) = out {
416 assert!(v.is_sign_positive() || v == 0.0);
417 }
418 Ok(())
419 }
420
421 #[test]
422 fn try_sum_decimal_basic() -> Result<()> {
423 let p = 10u8;
424 let s = 2i8;
425 let mut acc =
426 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(p, s));
427 acc.update_batch(&[dec128(p, s, vec![Some(123), Some(477)])?])?;
428 let out = acc.evaluate()?;
429 assert_eq!(out, ScalarValue::Decimal128(Some(600), p, s));
430 Ok(())
431 }
432
433 #[test]
434 fn try_sum_decimal_with_nulls() -> Result<()> {
435 let p = 10u8;
436 let s = 2i8;
437 let mut acc =
438 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(p, s));
439 acc.update_batch(&[dec128(p, s, vec![Some(150), None, Some(200)])?])?;
440 let out = acc.evaluate()?;
441 assert_eq!(out, ScalarValue::Decimal128(Some(350), p, s));
442 Ok(())
443 }
444
445 #[test]
446 fn try_sum_decimal_overflow_sets_failed() -> Result<()> {
447 let p = 5u8;
448 let s = 0i8;
449 let mut acc =
450 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(p, s));
451 acc.update_batch(&[dec128(p, s, vec![Some(90_000), Some(20_000)])?])?;
452 let out = acc.evaluate()?;
453 assert_eq!(out, ScalarValue::Decimal128(None, p, s));
454 assert!(acc.failed);
455 Ok(())
456 }
457
458 #[test]
459 fn try_sum_decimal_merge_ok_and_failure_propagation() -> Result<()> {
460 let p = 10u8;
461 let s = 2i8;
462
463 let mut p_ok =
464 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(p, s));
465 p_ok.update_batch(&[dec128(p, s, vec![Some(100), Some(200)])?])?;
466 let s_ok = p_ok
467 .state()?
468 .into_iter()
469 .map(|sv| sv.to_array())
470 .collect::<Result<Vec<_>>>()?;
471
472 let mut p_fail =
473 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(p, s));
474 p_fail.update_batch(&[dec128(p, s, vec![Some(i128::MAX), Some(1)])?])?;
475 let s_fail = p_fail
476 .state()?
477 .into_iter()
478 .map(|sv| sv.to_array())
479 .collect::<Result<Vec<_>>>()?;
480
481 let mut final_acc =
482 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(p, s));
483 final_acc.merge_batch(&s_ok)?;
484 final_acc.merge_batch(&s_fail)?;
485
486 assert!(final_acc.failed);
487 assert_eq!(final_acc.evaluate()?, ScalarValue::Decimal128(None, p, s));
488 Ok(())
489 }
490
491 #[test]
492 fn try_sum_int_overflow_sets_failed() -> Result<()> {
493 let mut acc = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
494 acc.update_batch(&[int64(vec![Some(i64::MAX), Some(1)])])?;
496 let out = acc.evaluate()?;
497 assert_eq!(out, ScalarValue::Int64(None));
498 assert!(acc.failed);
499 Ok(())
500 }
501
502 #[test]
503 fn try_sum_int_negative_overflow_sets_failed() -> Result<()> {
504 let mut acc = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
505 acc.update_batch(&[int64(vec![Some(i64::MIN), Some(-1)])])?;
507 assert_eq!(acc.evaluate()?, ScalarValue::Int64(None));
508 assert!(acc.failed);
509 Ok(())
510 }
511
512 #[test]
515 fn try_sum_state_two_fields_and_merge_ok() -> Result<()> {
516 let mut acc1 = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
518 acc1.update_batch(&[int64(vec![Some(10), Some(5)])])?;
519 let state1 = acc1.state()?; assert_eq!(state1.len(), 2);
521
522 let mut acc2 = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
524 acc2.update_batch(&[int64(vec![Some(20), None])])?;
525 let state2 = acc2.state()?; let state1_arrays: Vec<ArrayRef> = state1
528 .into_iter()
529 .map(|sv| sv.to_array())
530 .collect::<Result<_>>()?;
531
532 let state2_arrays: Vec<ArrayRef> = state2
533 .into_iter()
534 .map(|sv| sv.to_array())
535 .collect::<Result<_>>()?;
536
537 let mut final_acc = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
539
540 final_acc.merge_batch(&state1_arrays)?;
541 final_acc.merge_batch(&state2_arrays)?;
542
543 assert!(!final_acc.failed);
545 assert_eq!(final_acc.evaluate()?, ScalarValue::Int64(Some(35)));
546 Ok(())
547 }
548
549 #[test]
550 fn try_sum_merge_propagates_failure() -> Result<()> {
551 let failed_sum = Arc::new(Int64Array::from(vec![None])) as ArrayRef;
553 let failed_flag = Arc::new(BooleanArray::from(vec![Some(true)])) as ArrayRef;
554
555 let mut acc = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
556 acc.merge_batch(&[failed_sum, failed_flag])?;
557
558 assert!(acc.failed);
559 assert_eq!(acc.evaluate()?, ScalarValue::Int64(None));
560 Ok(())
561 }
562
563 #[test]
564 fn try_sum_merge_empty_partition_is_not_failure() -> Result<()> {
565 let empty_sum = Arc::new(Int64Array::from(vec![None])) as ArrayRef;
567 let ok_flag = Arc::new(BooleanArray::from(vec![Some(false)])) as ArrayRef;
568
569 let mut acc = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
570 acc.update_batch(&[int64(vec![Some(7), Some(8)])])?; acc.merge_batch(&[empty_sum, ok_flag])?;
573
574 assert!(!acc.failed);
575 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(15)));
576 Ok(())
577 }
578
579 #[test]
582 fn try_sum_return_type_matches_input() -> Result<()> {
583 let f = SparkTrySum::new();
584 assert_eq!(f.return_type(&[DataType::Int64])?, DataType::Int64);
585 assert_eq!(f.return_type(&[DataType::Float64])?, DataType::Float64);
586 Ok(())
587 }
588
589 #[test]
590 fn try_sum_state_and_evaluate_consistency() -> Result<()> {
591 let mut acc = TrySumAccumulator::<Float64Type>::new(DataType::Float64);
592 acc.update_batch(&[f64(vec![Some(1.0), Some(2.0)])])?;
593 let eval = acc.evaluate()?;
594 let state = acc.state()?;
595 assert_eq!(state[0], eval);
596 assert_eq!(state[1], ScalarValue::Boolean(Some(false)));
597 Ok(())
598 }
599
600 #[test]
605 fn decimal_10_2_sum_and_schema_widened() -> Result<()> {
606 let f = SparkTrySum::new();
608 assert_eq!(
609 f.return_type(&[DataType::Decimal128(10, 2)])?,
610 DataType::Decimal128(20, 2),
611 "Spark needs +10 more digits of precision"
612 );
613
614 let mut acc =
615 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(20, 2));
616 acc.update_batch(&[dec128(10, 2, vec![Some(123), Some(477)])?])?;
617 assert_eq!(acc.evaluate()?, ScalarValue::Decimal128(Some(600), 20, 2));
618 Ok(())
619 }
620
621 #[test]
622 fn decimal_5_0_fits_after_widening() -> Result<()> {
623 let f = SparkTrySum::new();
625 assert_eq!(
626 f.return_type(&[DataType::Decimal128(5, 0)])?,
627 DataType::Decimal128(15, 0)
628 );
629
630 let mut acc =
631 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(15, 0));
632 acc.update_batch(&[dec128(5, 0, vec![Some(90_000), Some(20_000)])?])?;
633 assert_eq!(
634 acc.evaluate()?,
635 ScalarValue::Decimal128(Some(110_000), 15, 0)
636 );
637 Ok(())
638 }
639
640 #[test]
641 fn decimal_38_0_max_precision_overflows_to_null() -> Result<()> {
642 let f = SparkTrySum::new();
643 assert_eq!(
644 f.return_type(&[DataType::Decimal128(38, 0)])?,
645 DataType::Decimal128(38, 0)
646 );
647 let ten_pow_38_minus_1 = {
648 let p10 = pow10_i128(38)
649 .ok_or_else(|| DataFusionError::Internal("10^38 overflow".into()))?;
650 p10 - 1
651 };
652 let mut acc =
653 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(38, 0));
654 acc.update_batch(&[dec128(38, 0, vec![Some(ten_pow_38_minus_1), Some(1)])?])?;
655
656 assert!(acc.failed, "need fail in overflow p=38");
657 assert_eq!(acc.evaluate()?, ScalarValue::Decimal128(None, 38, 0));
658 Ok(())
659 }
660}