1use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
19
20use arrow::array::ArrayRef;
21use arrow::datatypes::DataType::{
22 Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64,
23};
24use arrow::datatypes::{
25 ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
26 Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type,
27};
28use arrow::datatypes::{Field, FieldRef};
29use arrow::error::ArrowError;
30use datafusion_common::types::{
31 NativeType, logical_float32, logical_float64, logical_int32,
32};
33use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
34use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
35use datafusion_expr::{
36 Coercion, ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs,
37 ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
38};
39use datafusion_macros::user_doc;
40use std::sync::Arc;
41
42fn output_scale_for_decimal(precision: u8, input_scale: i8, decimal_places: i32) -> i8 {
43 if input_scale < 0 {
49 let min_scale = -i32::from(precision);
53 let new_scale = i32::from(input_scale).min(decimal_places).max(min_scale);
54 return new_scale as i8;
55 }
56
57 let decimal_places = decimal_places.max(0);
59 i32::from(input_scale).min(decimal_places) as i8
60}
61
62fn normalize_decimal_places_for_decimal(
63 decimal_places: i32,
64 precision: u8,
65 scale: i8,
66) -> Option<i32> {
67 if decimal_places >= 0 {
68 return Some(decimal_places);
69 }
70
71 let max_rounding_pow10 = i64::from(precision) - i64::from(scale);
75 if max_rounding_pow10 <= 0 {
76 return None;
77 }
78
79 let abs_decimal_places = i64::from(decimal_places.unsigned_abs());
80 (abs_decimal_places <= max_rounding_pow10).then_some(decimal_places)
81}
82
83fn validate_decimal_precision<T: DecimalType>(
84 value: T::Native,
85 precision: u8,
86 scale: i8,
87) -> Result<T::Native, ArrowError> {
88 T::validate_decimal_precision(value, precision, scale).map_err(|e| {
89 ArrowError::ComputeError(format!(
90 "Decimal overflow: rounded value exceeds precision {precision}: {e}"
91 ))
92 })?;
93 Ok(value)
94}
95
96fn calculate_new_precision_scale<T: DecimalType>(
97 precision: u8,
98 scale: i8,
99 decimal_places: Option<i32>,
100) -> Result<DataType> {
101 if let Some(decimal_places) = decimal_places {
102 let new_scale = output_scale_for_decimal(precision, scale, decimal_places);
103
104 let abs_decimal_places = decimal_places.unsigned_abs();
108 let new_precision = if scale == 0
109 && decimal_places < 0
110 && abs_decimal_places <= u32::from(precision)
111 {
112 precision.saturating_add(1).min(T::MAX_PRECISION)
113 } else {
114 precision
115 };
116 Ok(T::TYPE_CONSTRUCTOR(new_precision, new_scale))
117 } else {
118 let new_precision = precision.saturating_add(1).min(T::MAX_PRECISION);
119 Ok(T::TYPE_CONSTRUCTOR(new_precision, scale))
120 }
121}
122
123fn decimal_places_from_scalar(scalar: &ScalarValue) -> Result<i32> {
124 let out_of_range = |value: String| {
125 datafusion_common::DataFusionError::Execution(format!(
126 "round decimal_places {value} is out of supported i32 range"
127 ))
128 };
129 match scalar {
130 ScalarValue::Int8(Some(v)) => Ok(i32::from(*v)),
131 ScalarValue::Int16(Some(v)) => Ok(i32::from(*v)),
132 ScalarValue::Int32(Some(v)) => Ok(*v),
133 ScalarValue::Int64(Some(v)) => {
134 i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
135 }
136 ScalarValue::UInt8(Some(v)) => Ok(i32::from(*v)),
137 ScalarValue::UInt16(Some(v)) => Ok(i32::from(*v)),
138 ScalarValue::UInt32(Some(v)) => {
139 i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
140 }
141 ScalarValue::UInt64(Some(v)) => {
142 i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
143 }
144 other => exec_err!(
145 "Unexpected datatype for decimal_places: {}",
146 other.data_type()
147 ),
148 }
149}
150
151#[user_doc(
152 doc_section(label = "Math Functions"),
153 description = "Rounds a number to the nearest integer.",
154 syntax_example = "round(numeric_expression[, decimal_places])",
155 standard_argument(name = "numeric_expression", prefix = "Numeric"),
156 argument(
157 name = "decimal_places",
158 description = "Optional. The number of decimal places to round to. Defaults to 0."
159 ),
160 sql_example = r#"```sql
161> SELECT round(3.14159);
162+--------------+
163| round(3.14159)|
164+--------------+
165| 3.0 |
166+--------------+
167```"#
168)]
169#[derive(Debug, PartialEq, Eq, Hash)]
170pub struct RoundFunc {
171 signature: Signature,
172}
173
174impl Default for RoundFunc {
175 fn default() -> Self {
176 RoundFunc::new()
177 }
178}
179
180impl RoundFunc {
181 pub fn new() -> Self {
182 let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
183 let decimal_places = Coercion::new_implicit(
184 TypeSignatureClass::Native(logical_int32()),
185 vec![TypeSignatureClass::Integer],
186 NativeType::Int32,
187 );
188 let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32()));
189 let float64 = Coercion::new_implicit(
190 TypeSignatureClass::Native(logical_float64()),
191 vec![TypeSignatureClass::Numeric],
192 NativeType::Float64,
193 );
194 Self {
195 signature: Signature::one_of(
196 vec![
197 TypeSignature::Coercible(vec![
198 decimal.clone(),
199 decimal_places.clone(),
200 ]),
201 TypeSignature::Coercible(vec![decimal]),
202 TypeSignature::Coercible(vec![
203 float32.clone(),
204 decimal_places.clone(),
205 ]),
206 TypeSignature::Coercible(vec![float32]),
207 TypeSignature::Coercible(vec![float64.clone(), decimal_places]),
208 TypeSignature::Coercible(vec![float64]),
209 ],
210 Volatility::Immutable,
211 ),
212 }
213 }
214}
215
216impl ScalarUDFImpl for RoundFunc {
217 fn name(&self) -> &str {
218 "round"
219 }
220
221 fn signature(&self) -> &Signature {
222 &self.signature
223 }
224
225 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
226 let input_field = &args.arg_fields[0];
227 let input_type = input_field.data_type();
228
229 let decimal_places: Option<i32> = match args.scalar_arguments.get(1) {
235 None => Some(0), Some(None) => None, Some(Some(scalar)) if scalar.is_null() => Some(0), Some(Some(scalar)) => Some(decimal_places_from_scalar(scalar)?),
239 };
240
241 let return_type =
247 match input_type {
248 Float32 => Float32,
249 Decimal32(precision, scale) => calculate_new_precision_scale::<
250 Decimal32Type,
251 >(
252 *precision, *scale, decimal_places
253 )?,
254 Decimal64(precision, scale) => calculate_new_precision_scale::<
255 Decimal64Type,
256 >(
257 *precision, *scale, decimal_places
258 )?,
259 Decimal128(precision, scale) => calculate_new_precision_scale::<
260 Decimal128Type,
261 >(
262 *precision, *scale, decimal_places
263 )?,
264 Decimal256(precision, scale) => calculate_new_precision_scale::<
265 Decimal256Type,
266 >(
267 *precision, *scale, decimal_places
268 )?,
269 _ => Float64,
270 };
271
272 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
273 Ok(Arc::new(Field::new(self.name(), return_type, nullable)))
274 }
275
276 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
277 internal_err!("use return_field_from_args instead")
278 }
279
280 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
281 if args.arg_fields.iter().any(|a| a.data_type().is_null()) {
282 return ColumnarValue::Scalar(ScalarValue::Null)
283 .cast_to(args.return_type(), None);
284 }
285
286 let default_decimal_places = ColumnarValue::Scalar(ScalarValue::Int32(Some(0)));
287 let decimal_places = if args.args.len() == 2 {
288 &args.args[1]
289 } else {
290 &default_decimal_places
291 };
292
293 if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) =
294 (&args.args[0], decimal_places)
295 {
296 if value_scalar.is_null() || dp_scalar.is_null() {
297 return ColumnarValue::Scalar(ScalarValue::Null)
298 .cast_to(args.return_type(), None);
299 }
300
301 let dp = if let ScalarValue::Int32(Some(dp)) = dp_scalar {
302 *dp
303 } else {
304 return internal_err!(
305 "Unexpected datatype for decimal_places: {}",
306 dp_scalar.data_type()
307 );
308 };
309
310 match (value_scalar, args.return_type()) {
311 (ScalarValue::Float32(Some(v)), _) => {
312 let rounded = round_float(*v, dp)?;
313 Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
314 }
315 (ScalarValue::Float64(Some(v)), _) => {
316 let rounded = round_float(*v, dp)?;
317 Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
318 }
319 (
320 ScalarValue::Decimal32(Some(v), in_precision, scale),
321 Decimal32(out_precision, out_scale),
322 ) => {
323 let rounded =
324 round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
325 let rounded = if *out_precision == Decimal32Type::MAX_PRECISION
326 && *scale == 0
327 && dp < 0
328 {
329 validate_decimal_precision::<Decimal32Type>(
333 rounded,
334 *out_precision,
335 *out_scale,
336 )
337 } else {
338 Ok(rounded)
339 }?;
340 let scalar =
341 ScalarValue::Decimal32(Some(rounded), *out_precision, *out_scale);
342 Ok(ColumnarValue::Scalar(scalar))
343 }
344 (
345 ScalarValue::Decimal64(Some(v), in_precision, scale),
346 Decimal64(out_precision, out_scale),
347 ) => {
348 let rounded =
349 round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
350 let rounded = if *out_precision == Decimal64Type::MAX_PRECISION
351 && *scale == 0
352 && dp < 0
353 {
354 validate_decimal_precision::<Decimal64Type>(
356 rounded,
357 *out_precision,
358 *out_scale,
359 )
360 } else {
361 Ok(rounded)
362 }?;
363 let scalar =
364 ScalarValue::Decimal64(Some(rounded), *out_precision, *out_scale);
365 Ok(ColumnarValue::Scalar(scalar))
366 }
367 (
368 ScalarValue::Decimal128(Some(v), in_precision, scale),
369 Decimal128(out_precision, out_scale),
370 ) => {
371 let rounded =
372 round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
373 let rounded = if *out_precision == Decimal128Type::MAX_PRECISION
374 && *scale == 0
375 && dp < 0
376 {
377 validate_decimal_precision::<Decimal128Type>(
379 rounded,
380 *out_precision,
381 *out_scale,
382 )
383 } else {
384 Ok(rounded)
385 }?;
386 let scalar = ScalarValue::Decimal128(
387 Some(rounded),
388 *out_precision,
389 *out_scale,
390 );
391 Ok(ColumnarValue::Scalar(scalar))
392 }
393 (
394 ScalarValue::Decimal256(Some(v), in_precision, scale),
395 Decimal256(out_precision, out_scale),
396 ) => {
397 let rounded =
398 round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
399 let rounded = if *out_precision == Decimal256Type::MAX_PRECISION
400 && *scale == 0
401 && dp < 0
402 {
403 validate_decimal_precision::<Decimal256Type>(
405 rounded,
406 *out_precision,
407 *out_scale,
408 )
409 } else {
410 Ok(rounded)
411 }?;
412 let scalar = ScalarValue::Decimal256(
413 Some(rounded),
414 *out_precision,
415 *out_scale,
416 );
417 Ok(ColumnarValue::Scalar(scalar))
418 }
419 (ScalarValue::Null, _) => ColumnarValue::Scalar(ScalarValue::Null)
420 .cast_to(args.return_type(), None),
421 (value_scalar, return_type) => {
422 internal_err!(
423 "Unexpected datatype for round(value, decimal_places): value {}, return type {}",
424 value_scalar.data_type(),
425 return_type
426 )
427 }
428 }
429 } else {
430 round_columnar(
431 &args.args[0],
432 decimal_places,
433 args.number_rows,
434 args.return_type(),
435 )
436 }
437 }
438
439 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
440 let value = &input[0];
442 let precision = input.get(1);
443
444 if precision
445 .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
446 .unwrap_or(true)
447 {
448 Ok(value.sort_properties)
449 } else {
450 Ok(SortProperties::Unordered)
451 }
452 }
453
454 fn documentation(&self) -> Option<&Documentation> {
455 self.doc()
456 }
457}
458
459fn round_columnar(
460 value: &ColumnarValue,
461 decimal_places: &ColumnarValue,
462 number_rows: usize,
463 return_type: &DataType,
464) -> Result<ColumnarValue> {
465 let value_array = value.to_array(number_rows)?;
466 let both_scalars = matches!(value, ColumnarValue::Scalar(_))
467 && matches!(decimal_places, ColumnarValue::Scalar(_));
468 let decimal_places_is_array = matches!(decimal_places, ColumnarValue::Array(_));
469
470 let arr: ArrayRef = match (value_array.data_type(), return_type) {
471 (Float64, _) => {
472 let result = calculate_binary_math::<Float64Type, Int32Type, Float64Type, _>(
473 value_array.as_ref(),
474 decimal_places,
475 round_float::<f64>,
476 )?;
477 result as _
478 }
479 (Float32, _) => {
480 let result = calculate_binary_math::<Float32Type, Int32Type, Float32Type, _>(
481 value_array.as_ref(),
482 decimal_places,
483 round_float::<f32>,
484 )?;
485 result as _
486 }
487 (Decimal32(input_precision, scale), Decimal32(precision, new_scale)) => {
488 let result = calculate_binary_decimal_math::<
490 Decimal32Type,
491 Int32Type,
492 Decimal32Type,
493 _,
494 >(
495 value_array.as_ref(),
496 decimal_places,
497 |v, dp| {
498 let rounded = round_decimal_or_zero(
499 v,
500 *input_precision,
501 *scale,
502 *new_scale,
503 dp,
504 )?;
505 if *precision == Decimal32Type::MAX_PRECISION
506 && (decimal_places_is_array || (*scale == 0 && dp < 0))
507 {
508 validate_decimal_precision::<Decimal32Type>(
513 rounded, *precision, *new_scale,
514 )
515 } else {
516 Ok(rounded)
517 }
518 },
519 *precision,
520 *new_scale,
521 )?;
522 result as _
523 }
524 (Decimal64(input_precision, scale), Decimal64(precision, new_scale)) => {
525 let result = calculate_binary_decimal_math::<
526 Decimal64Type,
527 Int32Type,
528 Decimal64Type,
529 _,
530 >(
531 value_array.as_ref(),
532 decimal_places,
533 |v, dp| {
534 let rounded = round_decimal_or_zero(
535 v,
536 *input_precision,
537 *scale,
538 *new_scale,
539 dp,
540 )?;
541 if *precision == Decimal64Type::MAX_PRECISION
542 && (decimal_places_is_array || (*scale == 0 && dp < 0))
543 {
544 validate_decimal_precision::<Decimal64Type>(
546 rounded, *precision, *new_scale,
547 )
548 } else {
549 Ok(rounded)
550 }
551 },
552 *precision,
553 *new_scale,
554 )?;
555 result as _
556 }
557 (Decimal128(input_precision, scale), Decimal128(precision, new_scale)) => {
558 let result = calculate_binary_decimal_math::<
559 Decimal128Type,
560 Int32Type,
561 Decimal128Type,
562 _,
563 >(
564 value_array.as_ref(),
565 decimal_places,
566 |v, dp| {
567 let rounded = round_decimal_or_zero(
568 v,
569 *input_precision,
570 *scale,
571 *new_scale,
572 dp,
573 )?;
574 if *precision == Decimal128Type::MAX_PRECISION
575 && (decimal_places_is_array || (*scale == 0 && dp < 0))
576 {
577 validate_decimal_precision::<Decimal128Type>(
579 rounded, *precision, *new_scale,
580 )
581 } else {
582 Ok(rounded)
583 }
584 },
585 *precision,
586 *new_scale,
587 )?;
588 result as _
589 }
590 (Decimal256(input_precision, scale), Decimal256(precision, new_scale)) => {
591 let result = calculate_binary_decimal_math::<
592 Decimal256Type,
593 Int32Type,
594 Decimal256Type,
595 _,
596 >(
597 value_array.as_ref(),
598 decimal_places,
599 |v, dp| {
600 let rounded = round_decimal_or_zero(
601 v,
602 *input_precision,
603 *scale,
604 *new_scale,
605 dp,
606 )?;
607 if *precision == Decimal256Type::MAX_PRECISION
608 && (decimal_places_is_array || (*scale == 0 && dp < 0))
609 {
610 validate_decimal_precision::<Decimal256Type>(
612 rounded, *precision, *new_scale,
613 )
614 } else {
615 Ok(rounded)
616 }
617 },
618 *precision,
619 *new_scale,
620 )?;
621 result as _
622 }
623 (other, _) => exec_err!("Unsupported data type {other:?} for function round")?,
624 };
625
626 if both_scalars {
627 ScalarValue::try_from_array(&arr, 0).map(ColumnarValue::Scalar)
628 } else {
629 Ok(ColumnarValue::Array(arr))
630 }
631}
632
633fn round_float<T>(value: T, decimal_places: i32) -> Result<T, ArrowError>
634where
635 T: num_traits::Float,
636{
637 let factor = T::from(10_f64.powi(decimal_places)).ok_or_else(|| {
638 ArrowError::ComputeError(format!(
639 "Invalid value for decimal places: {decimal_places}"
640 ))
641 })?;
642 Ok((value * factor).round() / factor)
643}
644
645fn round_decimal<V: ArrowNativeTypeOp>(
646 value: V,
647 input_scale: i8,
648 output_scale: i8,
649 decimal_places: i32,
650) -> Result<V, ArrowError> {
651 let diff = i64::from(input_scale) - i64::from(decimal_places);
652 if diff <= 0 {
653 return Ok(value);
654 }
655
656 debug_assert!(diff <= i64::from(u32::MAX));
657 let diff = diff as u32;
658
659 let one = V::ONE;
660 let two = V::from_usize(2).ok_or_else(|| {
661 ArrowError::ComputeError("Internal error: could not create constant 2".into())
662 })?;
663 let ten = V::from_usize(10).ok_or_else(|| {
664 ArrowError::ComputeError("Internal error: could not create constant 10".into())
665 })?;
666
667 let factor = ten.pow_checked(diff).map_err(|_| {
668 ArrowError::ComputeError(format!(
669 "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}"
670 ))
671 })?;
672
673 let mut quotient = value.div_wrapping(factor);
674 let remainder = value.mod_wrapping(factor);
675
676 let threshold = factor.div_wrapping(two);
678 if remainder >= threshold {
679 quotient = quotient.add_checked(one).map_err(|_| {
680 ArrowError::ComputeError("Overflow while rounding decimal".into())
681 })?;
682 } else if remainder <= threshold.neg_wrapping() {
683 quotient = quotient.sub_checked(one).map_err(|_| {
684 ArrowError::ComputeError("Overflow while rounding decimal".into())
685 })?;
686 }
687
688 let scale_shift = i64::from(output_scale) - i64::from(decimal_places);
691 if scale_shift == 0 {
692 return Ok(quotient);
693 }
694
695 debug_assert!(scale_shift > 0);
696 debug_assert!(scale_shift <= i64::from(u32::MAX));
697 let scale_shift = scale_shift as u32;
698 let shift_factor = ten.pow_checked(scale_shift).map_err(|_| {
699 ArrowError::ComputeError(format!(
700 "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}"
701 ))
702 })?;
703 quotient
704 .mul_checked(shift_factor)
705 .map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into()))
706}
707
708fn round_decimal_or_zero<V: ArrowNativeTypeOp>(
709 value: V,
710 precision: u8,
711 input_scale: i8,
712 output_scale: i8,
713 decimal_places: i32,
714) -> Result<V, ArrowError> {
715 if let Some(dp) =
716 normalize_decimal_places_for_decimal(decimal_places, precision, input_scale)
717 {
718 round_decimal(value, input_scale, output_scale, dp)
719 } else {
720 V::from_usize(0).ok_or_else(|| {
721 ArrowError::ComputeError("Internal error: could not create constant 0".into())
722 })
723 }
724}
725
726#[cfg(test)]
727mod test {
728 use std::sync::Arc;
729
730 use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
731 use datafusion_common::DataFusionError;
732 use datafusion_common::ScalarValue;
733 use datafusion_common::cast::{as_float32_array, as_float64_array};
734 use datafusion_expr::ColumnarValue;
735
736 fn round_arrays(
737 value: ArrayRef,
738 decimal_places: Option<ArrayRef>,
739 ) -> Result<ArrayRef, DataFusionError> {
740 let number_rows = value.len();
741 let return_type = value.data_type().clone();
745 let value = ColumnarValue::Array(value);
746 let decimal_places = decimal_places
747 .map(ColumnarValue::Array)
748 .unwrap_or_else(|| ColumnarValue::Scalar(ScalarValue::Int32(Some(0))));
749
750 let result =
751 super::round_columnar(&value, &decimal_places, number_rows, &return_type)?;
752 match result {
753 ColumnarValue::Array(array) => Ok(array),
754 ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1),
755 }
756 }
757
758 #[test]
759 fn test_round_f32() {
760 let args: Vec<ArrayRef> = vec![
761 Arc::new(Float32Array::from(vec![125.2345; 10])), Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), ];
764
765 let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])))
766 .expect("failed to initialize function round");
767 let floats =
768 as_float32_array(&result).expect("failed to initialize function round");
769
770 let expected = Float32Array::from(vec![
771 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
772 ]);
773
774 assert_eq!(floats, &expected);
775 }
776
777 #[test]
778 fn test_round_f64() {
779 let args: Vec<ArrayRef> = vec![
780 Arc::new(Float64Array::from(vec![125.2345; 10])), Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), ];
783
784 let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])))
785 .expect("failed to initialize function round");
786 let floats =
787 as_float64_array(&result).expect("failed to initialize function round");
788
789 let expected = Float64Array::from(vec![
790 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
791 ]);
792
793 assert_eq!(floats, &expected);
794 }
795
796 #[test]
797 fn test_round_f32_one_input() {
798 let args: Vec<ArrayRef> = vec![
799 Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), ];
801
802 let result = round_arrays(Arc::clone(&args[0]), None)
803 .expect("failed to initialize function round");
804 let floats =
805 as_float32_array(&result).expect("failed to initialize function round");
806
807 let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]);
808
809 assert_eq!(floats, &expected);
810 }
811
812 #[test]
813 fn test_round_f64_one_input() {
814 let args: Vec<ArrayRef> = vec![
815 Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), ];
817
818 let result = round_arrays(Arc::clone(&args[0]), None)
819 .expect("failed to initialize function round");
820 let floats =
821 as_float64_array(&result).expect("failed to initialize function round");
822
823 let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]);
824
825 assert_eq!(floats, &expected);
826 }
827
828 #[test]
829 fn test_round_f32_cast_fail() {
830 let args: Vec<ArrayRef> = vec![
831 Arc::new(Float64Array::from(vec![125.2345])), Arc::new(Int64Array::from(vec![2147483648])), ];
834
835 let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])));
836
837 assert!(result.is_err());
838 assert!(matches!(
839 result,
840 Err(DataFusionError::ArrowError(_, _)) | Err(DataFusionError::Execution(_))
841 ));
842 }
843}