1use arrow::array::{ArrayRef, ArrowNumericType, AsArray, BooleanArray, PrimitiveArray};
19use arrow::datatypes::{
20 DECIMAL128_MAX_PRECISION, DataType, Decimal128Type, Field, FieldRef, Float64Type,
21 Int64Type,
22};
23use datafusion_common::{Result, ScalarValue, downcast_value, exec_err, not_impl_err};
24use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
25use datafusion_expr::utils::format_state_name;
26use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
27use std::fmt::{Debug, Formatter};
28use std::mem::size_of_val;
29
30#[derive(PartialEq, Eq, Hash)]
31pub struct SparkTrySum {
32 signature: Signature,
33}
34
35impl Default for SparkTrySum {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl SparkTrySum {
42 pub fn new() -> Self {
43 Self {
44 signature: Signature::user_defined(Volatility::Immutable),
45 }
46 }
47}
48
49impl Debug for SparkTrySum {
50 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("SparkTrySum")
52 .field("signature", &self.signature)
53 .finish()
54 }
55}
56
57struct TrySumAccumulator<T: ArrowNumericType> {
59 sum: Option<T::Native>,
60 data_type: DataType,
61 failed: bool,
62 dec_precision: Option<u8>,
64}
65
66impl<T: ArrowNumericType> Debug for TrySumAccumulator<T> {
67 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
68 write!(f, "TrySumAccumulator({})", self.data_type)
69 }
70}
71
72impl<T: ArrowNumericType> TrySumAccumulator<T> {
73 fn new(data_type: DataType) -> Self {
74 let dec_precision = match &data_type {
75 DataType::Decimal128(p, _) => Some(*p),
76 _ => None,
77 };
78 Self {
79 sum: None,
80 data_type,
81 failed: false,
82 dec_precision,
83 }
84 }
85}
86
87impl<T: ArrowNumericType> Accumulator for TrySumAccumulator<T> {
88 fn state(&mut self) -> Result<Vec<ScalarValue>> {
89 Ok(vec![
90 self.evaluate()?,
91 ScalarValue::Boolean(Some(self.failed)),
92 ])
93 }
94
95 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
96 update_batch_internal(self, values)
97 }
98
99 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
100 if downcast_value!(states[1], BooleanArray)
102 .iter()
103 .flatten()
104 .any(|f| f)
105 {
106 self.failed = true;
107 return Ok(());
108 }
109
110 update_batch_internal(self, states)
112 }
113
114 fn evaluate(&mut self) -> Result<ScalarValue> {
115 evaluate_internal(self)
116 }
117
118 fn size(&self) -> usize {
119 size_of_val(self)
120 }
121}
122
123fn update_batch_internal<T: ArrowNumericType>(
126 acc: &mut TrySumAccumulator<T>,
127 values: &[ArrayRef],
128) -> Result<()> {
129 if values.is_empty() || acc.failed {
130 return Ok(());
131 }
132
133 let array: &PrimitiveArray<T> = values[0].as_primitive::<T>();
134
135 match acc.data_type {
136 DataType::Int64 => update_int64(acc, array),
137 DataType::Float64 => update_float64(acc, array),
138 DataType::Decimal128(_, _) => update_decimal128(acc, array),
139 _ => exec_err!(
140 "try_sum: unsupported type in update_batch: {:?}",
141 acc.data_type
142 ),
143 }
144}
145
146fn update_int64<T: ArrowNumericType>(
147 acc: &mut TrySumAccumulator<T>,
148 array: &PrimitiveArray<T>,
149) -> Result<()> {
150 for v in array.iter().flatten() {
151 let v_i64 = unsafe { std::mem::transmute_copy::<T::Native, i64>(&v) };
153 let sum_i64 = acc
154 .sum
155 .map(|s| unsafe { std::mem::transmute_copy::<T::Native, i64>(&s) });
156
157 let new_sum = match sum_i64 {
158 None => v_i64,
159 Some(s) => match s.checked_add(v_i64) {
160 Some(result) => result,
161 None => {
162 acc.failed = true;
163 return Ok(());
164 }
165 },
166 };
167
168 acc.sum = Some(unsafe { std::mem::transmute_copy::<i64, T::Native>(&new_sum) });
169 }
170 Ok(())
171}
172
173fn update_float64<T: ArrowNumericType>(
174 acc: &mut TrySumAccumulator<T>,
175 array: &PrimitiveArray<T>,
176) -> Result<()> {
177 for v in array.iter().flatten() {
178 let v_f64 = unsafe { std::mem::transmute_copy::<T::Native, f64>(&v) };
179 let sum_f64 = acc
180 .sum
181 .map(|s| unsafe { std::mem::transmute_copy::<T::Native, f64>(&s) })
182 .unwrap_or(0.0);
183 let new_sum = sum_f64 + v_f64;
184 acc.sum = Some(unsafe { std::mem::transmute_copy::<f64, T::Native>(&new_sum) });
185 }
186 Ok(())
187}
188
189fn update_decimal128<T: ArrowNumericType>(
190 acc: &mut TrySumAccumulator<T>,
191 array: &PrimitiveArray<T>,
192) -> Result<()> {
193 let precision = acc.dec_precision.unwrap_or(38);
194
195 for v in array.iter().flatten() {
196 let v_i128 = unsafe { std::mem::transmute_copy::<T::Native, i128>(&v) };
197 let sum_i128 = acc
198 .sum
199 .map(|s| unsafe { std::mem::transmute_copy::<T::Native, i128>(&s) });
200
201 let new_sum = match sum_i128 {
202 None => v_i128,
203 Some(s) => match s.checked_add(v_i128) {
204 Some(result) => result,
205 None => {
206 acc.failed = true;
207 return Ok(());
208 }
209 },
210 };
211
212 if exceeds_decimal128_precision(new_sum, precision) {
213 acc.failed = true;
214 return Ok(());
215 }
216
217 acc.sum = Some(unsafe { std::mem::transmute_copy::<i128, T::Native>(&new_sum) });
218 }
219 Ok(())
220}
221
222fn evaluate_internal<T: ArrowNumericType>(
223 acc: &mut TrySumAccumulator<T>,
224) -> Result<ScalarValue> {
225 if acc.failed {
226 return ScalarValue::new_primitive::<T>(None, &acc.data_type);
227 }
228 ScalarValue::new_primitive::<T>(acc.sum, &acc.data_type)
229}
230
231fn pow10_i128(p: u8) -> Option<i128> {
233 let mut v: i128 = 1;
234 for _ in 0..p {
235 v = v.checked_mul(10)?;
236 }
237 Some(v)
238}
239
240fn exceeds_decimal128_precision(sum: i128, p: u8) -> bool {
241 if let Some(max_plus_one) = pow10_i128(p) {
242 let max = max_plus_one - 1;
243 sum > max || sum < -max
244 } else {
245 true
246 }
247}
248
249impl AggregateUDFImpl for SparkTrySum {
250 fn name(&self) -> &str {
251 "try_sum"
252 }
253
254 fn signature(&self) -> &Signature {
255 &self.signature
256 }
257
258 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
259 use DataType::*;
260
261 let dt = &arg_types[0];
262 let result_type = match dt {
263 Null => Float64,
264 Decimal128(p, s) => {
265 let new_precision = DECIMAL128_MAX_PRECISION.min(p + 10);
266 Decimal128(new_precision, *s)
267 }
268 Int8 | Int16 | Int32 | Int64 => Int64,
269 Float16 | Float32 | Float64 => Float64,
270
271 other => return exec_err!("try_sum: unsupported type: {other:?}"),
272 };
273
274 Ok(result_type)
275 }
276
277 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
278 macro_rules! helper {
279 ($t:ty, $dt:expr) => {
280 Ok(Box::new(TrySumAccumulator::<$t>::new($dt.clone())))
281 };
282 }
283
284 match acc_args.return_field.data_type() {
285 DataType::Int64 => helper!(Int64Type, acc_args.return_field.data_type()),
286 DataType::Float64 => helper!(Float64Type, acc_args.return_field.data_type()),
287 DataType::Decimal128(_, _) => {
288 helper!(Decimal128Type, acc_args.return_field.data_type())
289 }
290 _ => not_impl_err!(
291 "try_sum: unsupported type for accumulator: {}",
292 acc_args.return_field.data_type()
293 ),
294 }
295 }
296
297 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
298 let sum_dt = args.return_field.data_type().clone();
299 Ok(vec![
300 Field::new(format_state_name(args.name, "sum"), sum_dt, true).into(),
301 Field::new(
302 format_state_name(args.name, "failed"),
303 DataType::Boolean,
304 false,
305 )
306 .into(),
307 ])
308 }
309
310 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
311 use DataType::*;
312 if arg_types.len() != 1 {
313 return exec_err!(
314 "try_sum: exactly 1 argument expected, got {}",
315 arg_types.len()
316 );
317 }
318
319 let dt = &arg_types[0];
320 let coerced = match dt {
321 Null => Float64,
322 Decimal128(p, s) => Decimal128(*p, *s),
323 Int8 | Int16 | Int32 | Int64 => Int64,
324 Float16 | Float32 | Float64 => Float64,
325 other => return exec_err!("try_sum: unsupported type: {other:?}"),
326 };
327 Ok(vec![coerced])
328 }
329
330 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
331 Ok(ScalarValue::Null)
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use arrow::array::{Decimal128Array, Float64Array, Int64Array};
338 use datafusion_common::DataFusionError;
339 use std::sync::Arc;
340
341 use super::*;
342 fn int64(values: Vec<Option<i64>>) -> ArrayRef {
345 Arc::new(Int64Array::from(values)) as ArrayRef
346 }
347
348 fn f64(values: Vec<Option<f64>>) -> ArrayRef {
349 Arc::new(Float64Array::from(values)) as ArrayRef
350 }
351
352 fn dec128(p: u8, s: i8, vals: Vec<Option<i128>>) -> Result<ArrayRef> {
353 let base = Decimal128Array::from(vals);
354 let arr = base.with_precision_and_scale(p, s).map_err(|e| {
355 DataFusionError::Execution(format!("invalid precision/scale ({p},{s}): {e}"))
356 })?;
357 Ok(Arc::new(arr) as ArrayRef)
358 }
359
360 #[test]
363 fn try_sum_int_basic() -> Result<()> {
364 let mut acc = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
365 acc.update_batch(&[int64((0..10).map(Some).collect())])?;
366 let out = acc.evaluate()?;
367 assert_eq!(out, ScalarValue::Int64(Some(45)));
368 Ok(())
369 }
370
371 #[test]
372 fn try_sum_int_with_nulls() -> Result<()> {
373 let mut acc = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
374 acc.update_batch(&[int64(vec![None, Some(2), Some(3), None, Some(5)])])?;
375 let out = acc.evaluate()?;
376 assert_eq!(out, ScalarValue::Int64(Some(10)));
377 Ok(())
378 }
379
380 #[test]
381 fn try_sum_float_basic() -> Result<()> {
382 let mut acc = TrySumAccumulator::<Float64Type>::new(DataType::Float64);
383 acc.update_batch(&[f64(vec![Some(1.5), Some(2.5), None, Some(3.0)])])?;
384 let out = acc.evaluate()?;
385 assert_eq!(out, ScalarValue::Float64(Some(7.0)));
386 Ok(())
387 }
388
389 #[test]
390 fn float_overflow_behaves_like_spark_sum_infinite() -> Result<()> {
391 let mut acc = TrySumAccumulator::<Float64Type>::new(DataType::Float64);
392 acc.update_batch(&[f64(vec![Some(1e308), Some(1e308)])])?;
393
394 let out = acc.evaluate()?;
395 assert!(
396 matches!(out, ScalarValue::Float64(Some(v)) if v.is_infinite() && v.is_sign_positive()),
397 "waiting +Infinity, got: {out:?}"
398 );
399 Ok(())
400 }
401
402 #[test]
403 fn try_sum_float_negative_zero_normalizes_to_positive_zero() -> Result<()> {
404 let mut acc = TrySumAccumulator::<Float64Type>::new(DataType::Float64);
405 acc.update_batch(&[f64(vec![Some(-0.0), Some(0.0)])])?;
407 let out = acc.evaluate()?;
408 assert_eq!(out, ScalarValue::Float64(Some(0.0)));
409 if let ScalarValue::Float64(Some(v)) = out {
411 assert!(v.is_sign_positive() || v == 0.0);
412 }
413 Ok(())
414 }
415
416 #[test]
417 fn try_sum_decimal_basic() -> Result<()> {
418 let p = 10u8;
419 let s = 2i8;
420 let mut acc =
421 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(p, s));
422 acc.update_batch(&[dec128(p, s, vec![Some(123), Some(477)])?])?;
423 let out = acc.evaluate()?;
424 assert_eq!(out, ScalarValue::Decimal128(Some(600), p, s));
425 Ok(())
426 }
427
428 #[test]
429 fn try_sum_decimal_with_nulls() -> Result<()> {
430 let p = 10u8;
431 let s = 2i8;
432 let mut acc =
433 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(p, s));
434 acc.update_batch(&[dec128(p, s, vec![Some(150), None, Some(200)])?])?;
435 let out = acc.evaluate()?;
436 assert_eq!(out, ScalarValue::Decimal128(Some(350), p, s));
437 Ok(())
438 }
439
440 #[test]
441 fn try_sum_decimal_overflow_sets_failed() -> Result<()> {
442 let p = 5u8;
443 let s = 0i8;
444 let mut acc =
445 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(p, s));
446 acc.update_batch(&[dec128(p, s, vec![Some(90_000), Some(20_000)])?])?;
447 let out = acc.evaluate()?;
448 assert_eq!(out, ScalarValue::Decimal128(None, p, s));
449 assert!(acc.failed);
450 Ok(())
451 }
452
453 #[test]
454 fn try_sum_decimal_merge_ok_and_failure_propagation() -> Result<()> {
455 let p = 10u8;
456 let s = 2i8;
457
458 let mut p_ok =
459 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(p, s));
460 p_ok.update_batch(&[dec128(p, s, vec![Some(100), Some(200)])?])?;
461 let s_ok = p_ok
462 .state()?
463 .into_iter()
464 .map(|sv| sv.to_array())
465 .collect::<Result<Vec<_>>>()?;
466
467 let mut p_fail =
468 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(p, s));
469 p_fail.update_batch(&[dec128(p, s, vec![Some(i128::MAX), Some(1)])?])?;
470 let s_fail = p_fail
471 .state()?
472 .into_iter()
473 .map(|sv| sv.to_array())
474 .collect::<Result<Vec<_>>>()?;
475
476 let mut final_acc =
477 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(p, s));
478 final_acc.merge_batch(&s_ok)?;
479 final_acc.merge_batch(&s_fail)?;
480
481 assert!(final_acc.failed);
482 assert_eq!(final_acc.evaluate()?, ScalarValue::Decimal128(None, p, s));
483 Ok(())
484 }
485
486 #[test]
487 fn try_sum_int_overflow_sets_failed() -> Result<()> {
488 let mut acc = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
489 acc.update_batch(&[int64(vec![Some(i64::MAX), Some(1)])])?;
491 let out = acc.evaluate()?;
492 assert_eq!(out, ScalarValue::Int64(None));
493 assert!(acc.failed);
494 Ok(())
495 }
496
497 #[test]
498 fn try_sum_int_negative_overflow_sets_failed() -> Result<()> {
499 let mut acc = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
500 acc.update_batch(&[int64(vec![Some(i64::MIN), Some(-1)])])?;
502 assert_eq!(acc.evaluate()?, ScalarValue::Int64(None));
503 assert!(acc.failed);
504 Ok(())
505 }
506
507 #[test]
510 fn try_sum_state_two_fields_and_merge_ok() -> Result<()> {
511 let mut acc1 = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
513 acc1.update_batch(&[int64(vec![Some(10), Some(5)])])?;
514 let state1 = acc1.state()?; assert_eq!(state1.len(), 2);
516
517 let mut acc2 = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
519 acc2.update_batch(&[int64(vec![Some(20), None])])?;
520 let state2 = acc2.state()?; let state1_arrays: Vec<ArrayRef> = state1
523 .into_iter()
524 .map(|sv| sv.to_array())
525 .collect::<Result<_>>()?;
526
527 let state2_arrays: Vec<ArrayRef> = state2
528 .into_iter()
529 .map(|sv| sv.to_array())
530 .collect::<Result<_>>()?;
531
532 let mut final_acc = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
534
535 final_acc.merge_batch(&state1_arrays)?;
536 final_acc.merge_batch(&state2_arrays)?;
537
538 assert!(!final_acc.failed);
540 assert_eq!(final_acc.evaluate()?, ScalarValue::Int64(Some(35)));
541 Ok(())
542 }
543
544 #[test]
545 fn try_sum_merge_propagates_failure() -> Result<()> {
546 let failed_sum = Arc::new(Int64Array::from(vec![None])) as ArrayRef;
548 let failed_flag = Arc::new(BooleanArray::from(vec![Some(true)])) as ArrayRef;
549
550 let mut acc = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
551 acc.merge_batch(&[failed_sum, failed_flag])?;
552
553 assert!(acc.failed);
554 assert_eq!(acc.evaluate()?, ScalarValue::Int64(None));
555 Ok(())
556 }
557
558 #[test]
559 fn try_sum_merge_empty_partition_is_not_failure() -> Result<()> {
560 let empty_sum = Arc::new(Int64Array::from(vec![None])) as ArrayRef;
562 let ok_flag = Arc::new(BooleanArray::from(vec![Some(false)])) as ArrayRef;
563
564 let mut acc = TrySumAccumulator::<Int64Type>::new(DataType::Int64);
565 acc.update_batch(&[int64(vec![Some(7), Some(8)])])?; acc.merge_batch(&[empty_sum, ok_flag])?;
568
569 assert!(!acc.failed);
570 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(15)));
571 Ok(())
572 }
573
574 #[test]
577 fn try_sum_return_type_matches_input() -> Result<()> {
578 let f = SparkTrySum::new();
579 assert_eq!(f.return_type(&[DataType::Int64])?, DataType::Int64);
580 assert_eq!(f.return_type(&[DataType::Float64])?, DataType::Float64);
581 Ok(())
582 }
583
584 #[test]
585 fn try_sum_state_and_evaluate_consistency() -> Result<()> {
586 let mut acc = TrySumAccumulator::<Float64Type>::new(DataType::Float64);
587 acc.update_batch(&[f64(vec![Some(1.0), Some(2.0)])])?;
588 let eval = acc.evaluate()?;
589 let state = acc.state()?;
590 assert_eq!(state[0], eval);
591 assert_eq!(state[1], ScalarValue::Boolean(Some(false)));
592 Ok(())
593 }
594
595 #[test]
600 fn decimal_10_2_sum_and_schema_widened() -> Result<()> {
601 let f = SparkTrySum::new();
603 assert_eq!(
604 f.return_type(&[DataType::Decimal128(10, 2)])?,
605 DataType::Decimal128(20, 2),
606 "Spark needs +10 more digits of precision"
607 );
608
609 let mut acc =
610 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(20, 2));
611 acc.update_batch(&[dec128(10, 2, vec![Some(123), Some(477)])?])?;
612 assert_eq!(acc.evaluate()?, ScalarValue::Decimal128(Some(600), 20, 2));
613 Ok(())
614 }
615
616 #[test]
617 fn decimal_5_0_fits_after_widening() -> Result<()> {
618 let f = SparkTrySum::new();
620 assert_eq!(
621 f.return_type(&[DataType::Decimal128(5, 0)])?,
622 DataType::Decimal128(15, 0)
623 );
624
625 let mut acc =
626 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(15, 0));
627 acc.update_batch(&[dec128(5, 0, vec![Some(90_000), Some(20_000)])?])?;
628 assert_eq!(
629 acc.evaluate()?,
630 ScalarValue::Decimal128(Some(110_000), 15, 0)
631 );
632 Ok(())
633 }
634
635 #[test]
636 fn decimal_38_0_max_precision_overflows_to_null() -> Result<()> {
637 let f = SparkTrySum::new();
638 assert_eq!(
639 f.return_type(&[DataType::Decimal128(38, 0)])?,
640 DataType::Decimal128(38, 0)
641 );
642 let ten_pow_38_minus_1 = {
643 let p10 = pow10_i128(38)
644 .ok_or_else(|| DataFusionError::Internal("10^38 overflow".into()))?;
645 p10 - 1
646 };
647 let mut acc =
648 TrySumAccumulator::<Decimal128Type>::new(DataType::Decimal128(38, 0));
649 acc.update_batch(&[dec128(38, 0, vec![Some(ten_pow_38_minus_1), Some(1)])?])?;
650
651 assert!(acc.failed, "need fail in overflow p=38");
652 assert_eq!(acc.evaluate()?, ScalarValue::Decimal128(None, 38, 0));
653 Ok(())
654 }
655}