1use std::sync::Arc;
19
20use arrow::array::*;
21use arrow::datatypes::{
22 ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
23 Decimal256Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type,
24 Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
25};
26use datafusion_common::types::{
27 NativeType, logical_float32, logical_float64, logical_int32,
28};
29use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err};
30use datafusion_expr::{
31 Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
32 TypeSignatureClass, Volatility,
33};
34
35#[derive(Debug, PartialEq, Eq, Hash)]
49pub struct SparkRound {
50 signature: Signature,
51}
52
53impl Default for SparkRound {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl SparkRound {
60 pub fn new() -> Self {
61 let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
62 let integer = Coercion::new_exact(TypeSignatureClass::Integer);
63 let decimal_places = Coercion::new_implicit(
64 TypeSignatureClass::Native(logical_int32()),
65 vec![TypeSignatureClass::Integer],
66 NativeType::Int32,
67 );
68 let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32()));
69 let float64 = Coercion::new_implicit(
70 TypeSignatureClass::Native(logical_float64()),
71 vec![TypeSignatureClass::Numeric],
72 NativeType::Float64,
73 );
74 Self {
75 signature: Signature::one_of(
76 vec![
77 TypeSignature::Coercible(vec![
79 decimal.clone(),
80 decimal_places.clone(),
81 ]),
82 TypeSignature::Coercible(vec![decimal]),
84 TypeSignature::Coercible(vec![
86 integer.clone(),
87 decimal_places.clone(),
88 ]),
89 TypeSignature::Coercible(vec![integer]),
91 TypeSignature::Coercible(vec![
93 float32.clone(),
94 decimal_places.clone(),
95 ]),
96 TypeSignature::Coercible(vec![float32]),
98 TypeSignature::Coercible(vec![float64.clone(), decimal_places]),
100 TypeSignature::Coercible(vec![float64]),
102 ],
103 Volatility::Immutable,
104 ),
105 }
106 }
107}
108
109impl ScalarUDFImpl for SparkRound {
110 fn name(&self) -> &str {
111 "round"
112 }
113
114 fn signature(&self) -> &Signature {
115 &self.signature
116 }
117
118 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
119 Ok(arg_types[0].clone())
120 }
121
122 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
123 spark_round(&args.args, args.config_options.execution.enable_ansi_mode)
124 }
125}
126
127fn get_scale(args: &[ColumnarValue]) -> Result<Option<i32>> {
131 if args.len() < 2 {
132 return Ok(Some(0));
133 }
134
135 match &args[1] {
136 ColumnarValue::Scalar(ScalarValue::Int8(Some(v))) => Ok(Some(i32::from(*v))),
137 ColumnarValue::Scalar(ScalarValue::Int16(Some(v))) => Ok(Some(i32::from(*v))),
138 ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => Ok(Some(*v)),
139 ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => {
140 i32::try_from(*v).map(Some).map_err(|_| {
141 (exec_err!("round scale {v} is out of supported i32 range")
142 as Result<(), _>)
143 .unwrap_err()
144 })
145 }
146 ColumnarValue::Scalar(ScalarValue::UInt8(Some(v))) => Ok(Some(i32::from(*v))),
147 ColumnarValue::Scalar(ScalarValue::UInt16(Some(v))) => Ok(Some(i32::from(*v))),
148 ColumnarValue::Scalar(ScalarValue::UInt32(Some(v))) => {
149 i32::try_from(*v).map(Some).map_err(|_| {
150 (exec_err!("round scale {v} is out of supported i32 range")
151 as Result<(), _>)
152 .unwrap_err()
153 })
154 }
155 ColumnarValue::Scalar(ScalarValue::UInt64(Some(v))) => {
156 i32::try_from(*v).map(Some).map_err(|_| {
157 (exec_err!("round scale {v} is out of supported i32 range")
158 as Result<(), _>)
159 .unwrap_err()
160 })
161 }
162 ColumnarValue::Scalar(sv) if sv.is_null() => Ok(None),
163 other => exec_err!("Unsupported type for round scale: {}", other.data_type()),
164 }
165}
166
167fn round_float<T: num_traits::Float>(value: T, scale: i32) -> T {
190 if scale >= 0 {
191 let factor = T::from(10.0f64.powi(scale)).unwrap_or_else(T::infinity);
192 if factor.is_infinite() {
193 return value;
195 }
196 (value * factor).round() / factor
197 } else {
198 let factor = T::from(10.0f64.powi(-scale)).unwrap_or_else(T::infinity);
199 if factor.is_infinite() {
200 return T::zero();
202 }
203 (value / factor).round() * factor
204 }
205}
206
207fn round_integer(value: i64, scale: i32, enable_ansi_mode: bool) -> Result<i64> {
237 if scale >= 0 {
238 return Ok(value);
239 }
240 let abs_scale = (-scale) as u32;
241 let Some(factor) = 10_i64.checked_pow(abs_scale) else {
242 return Ok(0);
243 };
244 let remainder = value % factor;
245 let threshold = factor / 2;
246 let result = if remainder >= threshold {
247 if enable_ansi_mode {
248 value
249 .checked_sub(remainder)
250 .and_then(|v| v.checked_add(factor))
251 .ok_or_else(|| {
252 (exec_err!("Int64 overflow on round({value}, {scale})")
253 as Result<(), _>)
254 .unwrap_err()
255 })?
256 } else {
257 value.wrapping_sub(remainder).wrapping_add(factor)
258 }
259 } else if remainder <= -threshold {
260 if enable_ansi_mode {
261 value
262 .checked_sub(remainder)
263 .and_then(|v| v.checked_sub(factor))
264 .ok_or_else(|| {
265 (exec_err!("Int64 overflow on round({value}, {scale})")
266 as Result<(), _>)
267 .unwrap_err()
268 })?
269 } else {
270 value.wrapping_sub(remainder).wrapping_sub(factor)
271 }
272 } else {
273 value - remainder
274 };
275 Ok(result)
276}
277
278fn round_decimal<V: ArrowNativeTypeOp>(
318 value: V,
319 input_scale: i8,
320 decimal_places: i32,
321) -> Result<V> {
322 let diff = i64::from(input_scale) - i64::from(decimal_places);
323 if diff <= 0 {
324 return Ok(value);
327 }
328
329 let diff = diff as u32;
330
331 let one = V::ONE;
332 let two = V::from_usize(2).ok_or_else(|| {
333 (exec_err!("Internal error: could not create constant 2") as Result<(), _>)
334 .unwrap_err()
335 })?;
336 let ten = V::from_usize(10).ok_or_else(|| {
337 (exec_err!("Internal error: could not create constant 10") as Result<(), _>)
338 .unwrap_err()
339 })?;
340
341 let Ok(factor) = ten.pow_checked(diff) else {
342 return Ok(V::ZERO);
347 };
348
349 let mut quotient = value.div_wrapping(factor);
350 let remainder = value.mod_wrapping(factor);
351
352 let threshold = factor.div_wrapping(two);
354 if remainder >= threshold {
355 quotient = quotient.add_checked(one).map_err(|_| {
356 (exec_err!("Overflow while rounding decimal") as Result<(), _>).unwrap_err()
357 })?;
358 } else if remainder <= threshold.neg_wrapping() {
359 quotient = quotient.sub_checked(one).map_err(|_| {
360 (exec_err!("Overflow while rounding decimal") as Result<(), _>).unwrap_err()
361 })?;
362 }
363
364 quotient.mul_checked(factor).map_err(|_| {
368 (exec_err!("Overflow while rounding decimal") as Result<(), _>).unwrap_err()
369 })
370}
371
372macro_rules! impl_integer_array_round {
377 ($array:expr, $arrow_type:ty, $scale:expr, $enable_ansi_mode:expr) => {{
378 let array = $array.as_primitive::<$arrow_type>();
379 type Native = <$arrow_type as arrow::datatypes::ArrowPrimitiveType>::Native;
380 let result: PrimitiveArray<$arrow_type> = if $enable_ansi_mode {
381 array.try_unary(|x| {
382 let v = round_integer(x as i64, $scale, true)?;
383 Native::try_from(v).map_err(|_| {
384 (exec_err!(
385 "{} overflow on round({x}, {})",
386 stringify!($arrow_type),
387 $scale
388 ) as Result<(), _>)
389 .unwrap_err()
390 })
391 })?
392 } else {
393 array.unary(|x| round_integer(x as i64, $scale, false).unwrap() as Native)
394 };
395 Ok(ColumnarValue::Array(Arc::new(result)))
396 }};
397}
398
399macro_rules! impl_float_array_round {
400 ($array:expr, $arrow_type:ty, $scale:expr) => {{
401 let array = $array.as_primitive::<$arrow_type>();
402 let result: PrimitiveArray<$arrow_type> = array.unary(|x| round_float(x, $scale));
403 Ok(ColumnarValue::Array(Arc::new(result)))
404 }};
405}
406
407macro_rules! impl_decimal_array_round {
408 ($array:expr, $arrow_type:ty, $input_scale:expr, $scale:expr) => {{
409 let array = $array.as_primitive::<$arrow_type>();
410 let result: PrimitiveArray<$arrow_type> = array
411 .try_unary(|x| round_decimal(x, $input_scale, $scale))?
412 .with_data_type($array.data_type().clone());
413 Ok(ColumnarValue::Array(Arc::new(result)))
414 }};
415}
416
417fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result<ColumnarValue> {
422 if args.is_empty() || args.len() > 2 {
423 return exec_err!("round requires 1 or 2 arguments, got {}", args.len());
424 }
425
426 let scale = match get_scale(args)? {
427 Some(s) => s,
428 None => {
429 return Ok(ColumnarValue::Scalar(ScalarValue::try_from(
431 args[0].data_type(),
432 )?));
433 }
434 };
435
436 match &args[0] {
437 ColumnarValue::Array(array) => match array.data_type() {
438 DataType::Null => Ok(args[0].clone()),
439
440 DataType::Int8 => {
442 impl_integer_array_round!(array, Int8Type, scale, enable_ansi_mode)
443 }
444 DataType::Int16 => {
445 impl_integer_array_round!(array, Int16Type, scale, enable_ansi_mode)
446 }
447 DataType::Int32 => {
448 impl_integer_array_round!(array, Int32Type, scale, enable_ansi_mode)
449 }
450 DataType::Int64 => {
451 impl_integer_array_round!(array, Int64Type, scale, enable_ansi_mode)
452 }
453
454 DataType::UInt8 => {
456 impl_integer_array_round!(array, UInt8Type, scale, enable_ansi_mode)
457 }
458 DataType::UInt16 => {
459 impl_integer_array_round!(array, UInt16Type, scale, enable_ansi_mode)
460 }
461 DataType::UInt32 => {
462 impl_integer_array_round!(array, UInt32Type, scale, enable_ansi_mode)
463 }
464 DataType::UInt64 => {
465 let array = array.as_primitive::<UInt64Type>();
466 let result: PrimitiveArray<UInt64Type> = array.try_unary(|x| {
467 let v_i64 = i64::try_from(x).map_err(|_| {
468 (exec_err!(
469 "round: UInt64 value {x} exceeds i64::MAX and cannot be rounded"
470 ) as Result<(), _>)
471 .unwrap_err()
472 })?;
473 round_integer(v_i64, scale, enable_ansi_mode)
474 .map(|v| v as u64)
475 })?;
476 Ok(ColumnarValue::Array(Arc::new(result)))
477 }
478
479 DataType::Float16 => impl_float_array_round!(array, Float16Type, scale),
481 DataType::Float32 => impl_float_array_round!(array, Float32Type, scale),
482 DataType::Float64 => impl_float_array_round!(array, Float64Type, scale),
483
484 DataType::Decimal32(_, input_scale) => {
486 impl_decimal_array_round!(array, Decimal32Type, *input_scale, scale)
487 }
488 DataType::Decimal64(_, input_scale) => {
489 impl_decimal_array_round!(array, Decimal64Type, *input_scale, scale)
490 }
491 DataType::Decimal128(_, input_scale) => {
492 impl_decimal_array_round!(array, Decimal128Type, *input_scale, scale)
493 }
494 DataType::Decimal256(_, input_scale) => {
495 impl_decimal_array_round!(array, Decimal256Type, *input_scale, scale)
496 }
497
498 dt => not_impl_err!("Unsupported data type for Spark round(): {dt}"),
499 },
500
501 ColumnarValue::Scalar(sv) => match sv {
502 ScalarValue::Null => Ok(args[0].clone()),
503 _ if sv.is_null() => Ok(args[0].clone()),
504
505 ScalarValue::Int8(Some(v)) => {
507 let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?;
508 let result = if enable_ansi_mode {
509 i8::try_from(r).map_err(|_| {
510 (exec_err!("Int8 overflow on round({v}, {scale})")
511 as Result<(), _>)
512 .unwrap_err()
513 })?
514 } else {
515 r as i8
516 };
517 Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(result))))
518 }
519 ScalarValue::Int16(Some(v)) => {
520 let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?;
521 let result = if enable_ansi_mode {
522 i16::try_from(r).map_err(|_| {
523 (exec_err!("Int16 overflow on round({v}, {scale})")
524 as Result<(), _>)
525 .unwrap_err()
526 })?
527 } else {
528 r as i16
529 };
530 Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(result))))
531 }
532 ScalarValue::Int32(Some(v)) => {
533 let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?;
534 let result = if enable_ansi_mode {
535 i32::try_from(r).map_err(|_| {
536 (exec_err!("Int32 overflow on round({v}, {scale})")
537 as Result<(), _>)
538 .unwrap_err()
539 })?
540 } else {
541 r as i32
542 };
543 Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result))))
544 }
545 ScalarValue::Int64(Some(v)) => {
546 let result = round_integer(*v, scale, enable_ansi_mode)?;
547 Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result))))
548 }
549
550 ScalarValue::UInt8(Some(v)) => {
552 let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?;
553 let result = if enable_ansi_mode {
554 u8::try_from(r).map_err(|_| {
555 (exec_err!("UInt8 overflow on round({v}, {scale})")
556 as Result<(), _>)
557 .unwrap_err()
558 })?
559 } else {
560 r as u8
561 };
562 Ok(ColumnarValue::Scalar(ScalarValue::UInt8(Some(result))))
563 }
564 ScalarValue::UInt16(Some(v)) => {
565 let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?;
566 let result = if enable_ansi_mode {
567 u16::try_from(r).map_err(|_| {
568 (exec_err!("UInt16 overflow on round({v}, {scale})")
569 as Result<(), _>)
570 .unwrap_err()
571 })?
572 } else {
573 r as u16
574 };
575 Ok(ColumnarValue::Scalar(ScalarValue::UInt16(Some(result))))
576 }
577 ScalarValue::UInt32(Some(v)) => {
578 let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?;
579 let result = if enable_ansi_mode {
580 u32::try_from(r).map_err(|_| {
581 (exec_err!("UInt32 overflow on round({v}, {scale})")
582 as Result<(), _>)
583 .unwrap_err()
584 })?
585 } else {
586 r as u32
587 };
588 Ok(ColumnarValue::Scalar(ScalarValue::UInt32(Some(result))))
589 }
590 ScalarValue::UInt64(Some(v)) => {
591 let v_i64 = i64::try_from(*v).map_err(|_| {
592 (exec_err!(
593 "round: UInt64 value {v} exceeds i64::MAX and cannot be rounded"
594 ) as Result<(), _>)
595 .unwrap_err()
596 })?;
597 let result = round_integer(v_i64, scale, enable_ansi_mode)?;
598 Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(
599 result as u64,
600 ))))
601 }
602
603 ScalarValue::Float16(Some(v)) => {
605 let result = round_float(*v, scale);
606 Ok(ColumnarValue::Scalar(ScalarValue::Float16(Some(result))))
607 }
608 ScalarValue::Float32(Some(v)) => {
609 let result = round_float(*v, scale);
610 Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(result))))
611 }
612 ScalarValue::Float64(Some(v)) => {
613 let result = round_float(*v, scale);
614 Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result))))
615 }
616
617 ScalarValue::Decimal32(Some(v), precision, input_scale) => {
619 let rounded = round_decimal(*v, *input_scale, scale)?;
620 Ok(ColumnarValue::Scalar(ScalarValue::Decimal32(
621 Some(rounded),
622 *precision,
623 *input_scale,
624 )))
625 }
626 ScalarValue::Decimal64(Some(v), precision, input_scale) => {
627 let rounded = round_decimal(*v, *input_scale, scale)?;
628 Ok(ColumnarValue::Scalar(ScalarValue::Decimal64(
629 Some(rounded),
630 *precision,
631 *input_scale,
632 )))
633 }
634 ScalarValue::Decimal128(Some(v), precision, input_scale) => {
635 let rounded = round_decimal(*v, *input_scale, scale)?;
636 Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
637 Some(rounded),
638 *precision,
639 *input_scale,
640 )))
641 }
642 ScalarValue::Decimal256(Some(v), precision, input_scale) => {
643 let rounded = round_decimal(*v, *input_scale, scale)?;
644 Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(
645 Some(rounded),
646 *precision,
647 *input_scale,
648 )))
649 }
650
651 dt => not_impl_err!("Unsupported data type for Spark round(): {dt}"),
652 },
653 }
654}