1use std::any::Any;
19
20use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
21
22use arrow::array::ArrayRef;
23use arrow::datatypes::DataType::{
24 Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64,
25};
26use arrow::datatypes::{
27 ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
28 Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type,
29};
30use arrow::datatypes::{Field, FieldRef};
31use arrow::error::ArrowError;
32use datafusion_common::types::{
33 NativeType, logical_float32, logical_float64, logical_int32,
34};
35use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
36use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
37use datafusion_expr::{
38 Coercion, ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs,
39 ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
40};
41use datafusion_macros::user_doc;
42use std::sync::Arc;
43
44fn output_scale_for_decimal(precision: u8, input_scale: i8, decimal_places: i32) -> i8 {
45 if input_scale < 0 {
51 let min_scale = -i32::from(precision);
55 let new_scale = i32::from(input_scale).min(decimal_places).max(min_scale);
56 return new_scale as i8;
57 }
58
59 let decimal_places = decimal_places.max(0);
61 i32::from(input_scale).min(decimal_places) as i8
62}
63
64fn normalize_decimal_places_for_decimal(
65 decimal_places: i32,
66 precision: u8,
67 scale: i8,
68) -> Option<i32> {
69 if decimal_places >= 0 {
70 return Some(decimal_places);
71 }
72
73 let max_rounding_pow10 = i64::from(precision) - i64::from(scale);
77 if max_rounding_pow10 <= 0 {
78 return None;
79 }
80
81 let abs_decimal_places = i64::from(decimal_places.unsigned_abs());
82 (abs_decimal_places <= max_rounding_pow10).then_some(decimal_places)
83}
84
85fn validate_decimal_precision<T: DecimalType>(
86 value: T::Native,
87 precision: u8,
88 scale: i8,
89) -> Result<T::Native, ArrowError> {
90 T::validate_decimal_precision(value, precision, scale).map_err(|e| {
91 ArrowError::ComputeError(format!(
92 "Decimal overflow: rounded value exceeds precision {precision}: {e}"
93 ))
94 })?;
95 Ok(value)
96}
97
98fn calculate_new_precision_scale<T: DecimalType>(
99 precision: u8,
100 scale: i8,
101 decimal_places: Option<i32>,
102) -> Result<DataType> {
103 if let Some(decimal_places) = decimal_places {
104 let new_scale = output_scale_for_decimal(precision, scale, decimal_places);
105
106 let abs_decimal_places = decimal_places.unsigned_abs();
110 let new_precision = if scale == 0
111 && decimal_places < 0
112 && abs_decimal_places <= u32::from(precision)
113 {
114 precision.saturating_add(1).min(T::MAX_PRECISION)
115 } else {
116 precision
117 };
118 Ok(T::TYPE_CONSTRUCTOR(new_precision, new_scale))
119 } else {
120 let new_precision = precision.saturating_add(1).min(T::MAX_PRECISION);
121 Ok(T::TYPE_CONSTRUCTOR(new_precision, scale))
122 }
123}
124
125fn decimal_places_from_scalar(scalar: &ScalarValue) -> Result<i32> {
126 let out_of_range = |value: String| {
127 datafusion_common::DataFusionError::Execution(format!(
128 "round decimal_places {value} is out of supported i32 range"
129 ))
130 };
131 match scalar {
132 ScalarValue::Int8(Some(v)) => Ok(i32::from(*v)),
133 ScalarValue::Int16(Some(v)) => Ok(i32::from(*v)),
134 ScalarValue::Int32(Some(v)) => Ok(*v),
135 ScalarValue::Int64(Some(v)) => {
136 i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
137 }
138 ScalarValue::UInt8(Some(v)) => Ok(i32::from(*v)),
139 ScalarValue::UInt16(Some(v)) => Ok(i32::from(*v)),
140 ScalarValue::UInt32(Some(v)) => {
141 i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
142 }
143 ScalarValue::UInt64(Some(v)) => {
144 i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
145 }
146 other => exec_err!(
147 "Unexpected datatype for decimal_places: {}",
148 other.data_type()
149 ),
150 }
151}
152
153#[user_doc(
154 doc_section(label = "Math Functions"),
155 description = "Rounds a number to the nearest integer.",
156 syntax_example = "round(numeric_expression[, decimal_places])",
157 standard_argument(name = "numeric_expression", prefix = "Numeric"),
158 argument(
159 name = "decimal_places",
160 description = "Optional. The number of decimal places to round to. Defaults to 0."
161 ),
162 sql_example = r#"```sql
163> SELECT round(3.14159);
164+--------------+
165| round(3.14159)|
166+--------------+
167| 3.0 |
168+--------------+
169```"#
170)]
171#[derive(Debug, PartialEq, Eq, Hash)]
172pub struct RoundFunc {
173 signature: Signature,
174}
175
176impl Default for RoundFunc {
177 fn default() -> Self {
178 RoundFunc::new()
179 }
180}
181
182impl RoundFunc {
183 pub fn new() -> Self {
184 let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
185 let decimal_places = Coercion::new_implicit(
186 TypeSignatureClass::Native(logical_int32()),
187 vec![TypeSignatureClass::Integer],
188 NativeType::Int32,
189 );
190 let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32()));
191 let float64 = Coercion::new_implicit(
192 TypeSignatureClass::Native(logical_float64()),
193 vec![TypeSignatureClass::Numeric],
194 NativeType::Float64,
195 );
196 Self {
197 signature: Signature::one_of(
198 vec![
199 TypeSignature::Coercible(vec![
200 decimal.clone(),
201 decimal_places.clone(),
202 ]),
203 TypeSignature::Coercible(vec![decimal]),
204 TypeSignature::Coercible(vec![
205 float32.clone(),
206 decimal_places.clone(),
207 ]),
208 TypeSignature::Coercible(vec![float32]),
209 TypeSignature::Coercible(vec![float64.clone(), decimal_places]),
210 TypeSignature::Coercible(vec![float64]),
211 ],
212 Volatility::Immutable,
213 ),
214 }
215 }
216}
217
218impl ScalarUDFImpl for RoundFunc {
219 fn as_any(&self) -> &dyn Any {
220 self
221 }
222
223 fn name(&self) -> &str {
224 "round"
225 }
226
227 fn signature(&self) -> &Signature {
228 &self.signature
229 }
230
231 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
232 let input_field = &args.arg_fields[0];
233 let input_type = input_field.data_type();
234
235 let decimal_places: Option<i32> = match args.scalar_arguments.get(1) {
241 None => Some(0), Some(None) => None, Some(Some(scalar)) if scalar.is_null() => Some(0), Some(Some(scalar)) => Some(decimal_places_from_scalar(scalar)?),
245 };
246
247 let return_type =
253 match input_type {
254 Float32 => Float32,
255 Decimal32(precision, scale) => calculate_new_precision_scale::<
256 Decimal32Type,
257 >(
258 *precision, *scale, decimal_places
259 )?,
260 Decimal64(precision, scale) => calculate_new_precision_scale::<
261 Decimal64Type,
262 >(
263 *precision, *scale, decimal_places
264 )?,
265 Decimal128(precision, scale) => calculate_new_precision_scale::<
266 Decimal128Type,
267 >(
268 *precision, *scale, decimal_places
269 )?,
270 Decimal256(precision, scale) => calculate_new_precision_scale::<
271 Decimal256Type,
272 >(
273 *precision, *scale, decimal_places
274 )?,
275 _ => Float64,
276 };
277
278 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
279 Ok(Arc::new(Field::new(self.name(), return_type, nullable)))
280 }
281
282 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
283 internal_err!("use return_field_from_args instead")
284 }
285
286 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
287 if args.arg_fields.iter().any(|a| a.data_type().is_null()) {
288 return ColumnarValue::Scalar(ScalarValue::Null)
289 .cast_to(args.return_type(), None);
290 }
291
292 let default_decimal_places = ColumnarValue::Scalar(ScalarValue::Int32(Some(0)));
293 let decimal_places = if args.args.len() == 2 {
294 &args.args[1]
295 } else {
296 &default_decimal_places
297 };
298
299 if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) =
300 (&args.args[0], decimal_places)
301 {
302 if value_scalar.is_null() || dp_scalar.is_null() {
303 return ColumnarValue::Scalar(ScalarValue::Null)
304 .cast_to(args.return_type(), None);
305 }
306
307 let dp = if let ScalarValue::Int32(Some(dp)) = dp_scalar {
308 *dp
309 } else {
310 return internal_err!(
311 "Unexpected datatype for decimal_places: {}",
312 dp_scalar.data_type()
313 );
314 };
315
316 match (value_scalar, args.return_type()) {
317 (ScalarValue::Float32(Some(v)), _) => {
318 let rounded = round_float(*v, dp)?;
319 Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
320 }
321 (ScalarValue::Float64(Some(v)), _) => {
322 let rounded = round_float(*v, dp)?;
323 Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
324 }
325 (
326 ScalarValue::Decimal32(Some(v), in_precision, scale),
327 Decimal32(out_precision, out_scale),
328 ) => {
329 let rounded =
330 round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
331 let rounded = if *out_precision == Decimal32Type::MAX_PRECISION
332 && *scale == 0
333 && dp < 0
334 {
335 validate_decimal_precision::<Decimal32Type>(
339 rounded,
340 *out_precision,
341 *out_scale,
342 )
343 } else {
344 Ok(rounded)
345 }?;
346 let scalar =
347 ScalarValue::Decimal32(Some(rounded), *out_precision, *out_scale);
348 Ok(ColumnarValue::Scalar(scalar))
349 }
350 (
351 ScalarValue::Decimal64(Some(v), in_precision, scale),
352 Decimal64(out_precision, out_scale),
353 ) => {
354 let rounded =
355 round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
356 let rounded = if *out_precision == Decimal64Type::MAX_PRECISION
357 && *scale == 0
358 && dp < 0
359 {
360 validate_decimal_precision::<Decimal64Type>(
362 rounded,
363 *out_precision,
364 *out_scale,
365 )
366 } else {
367 Ok(rounded)
368 }?;
369 let scalar =
370 ScalarValue::Decimal64(Some(rounded), *out_precision, *out_scale);
371 Ok(ColumnarValue::Scalar(scalar))
372 }
373 (
374 ScalarValue::Decimal128(Some(v), in_precision, scale),
375 Decimal128(out_precision, out_scale),
376 ) => {
377 let rounded =
378 round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
379 let rounded = if *out_precision == Decimal128Type::MAX_PRECISION
380 && *scale == 0
381 && dp < 0
382 {
383 validate_decimal_precision::<Decimal128Type>(
385 rounded,
386 *out_precision,
387 *out_scale,
388 )
389 } else {
390 Ok(rounded)
391 }?;
392 let scalar = ScalarValue::Decimal128(
393 Some(rounded),
394 *out_precision,
395 *out_scale,
396 );
397 Ok(ColumnarValue::Scalar(scalar))
398 }
399 (
400 ScalarValue::Decimal256(Some(v), in_precision, scale),
401 Decimal256(out_precision, out_scale),
402 ) => {
403 let rounded =
404 round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
405 let rounded = if *out_precision == Decimal256Type::MAX_PRECISION
406 && *scale == 0
407 && dp < 0
408 {
409 validate_decimal_precision::<Decimal256Type>(
411 rounded,
412 *out_precision,
413 *out_scale,
414 )
415 } else {
416 Ok(rounded)
417 }?;
418 let scalar = ScalarValue::Decimal256(
419 Some(rounded),
420 *out_precision,
421 *out_scale,
422 );
423 Ok(ColumnarValue::Scalar(scalar))
424 }
425 (ScalarValue::Null, _) => ColumnarValue::Scalar(ScalarValue::Null)
426 .cast_to(args.return_type(), None),
427 (value_scalar, return_type) => {
428 internal_err!(
429 "Unexpected datatype for round(value, decimal_places): value {}, return type {}",
430 value_scalar.data_type(),
431 return_type
432 )
433 }
434 }
435 } else {
436 round_columnar(
437 &args.args[0],
438 decimal_places,
439 args.number_rows,
440 args.return_type(),
441 )
442 }
443 }
444
445 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
446 let value = &input[0];
448 let precision = input.get(1);
449
450 if precision
451 .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
452 .unwrap_or(true)
453 {
454 Ok(value.sort_properties)
455 } else {
456 Ok(SortProperties::Unordered)
457 }
458 }
459
460 fn documentation(&self) -> Option<&Documentation> {
461 self.doc()
462 }
463}
464
465fn round_columnar(
466 value: &ColumnarValue,
467 decimal_places: &ColumnarValue,
468 number_rows: usize,
469 return_type: &DataType,
470) -> Result<ColumnarValue> {
471 let value_array = value.to_array(number_rows)?;
472 let both_scalars = matches!(value, ColumnarValue::Scalar(_))
473 && matches!(decimal_places, ColumnarValue::Scalar(_));
474 let decimal_places_is_array = matches!(decimal_places, ColumnarValue::Array(_));
475
476 let arr: ArrayRef = match (value_array.data_type(), return_type) {
477 (Float64, _) => {
478 let result = calculate_binary_math::<Float64Type, Int32Type, Float64Type, _>(
479 value_array.as_ref(),
480 decimal_places,
481 round_float::<f64>,
482 )?;
483 result as _
484 }
485 (Float32, _) => {
486 let result = calculate_binary_math::<Float32Type, Int32Type, Float32Type, _>(
487 value_array.as_ref(),
488 decimal_places,
489 round_float::<f32>,
490 )?;
491 result as _
492 }
493 (Decimal32(input_precision, scale), Decimal32(precision, new_scale)) => {
494 let result = calculate_binary_decimal_math::<
496 Decimal32Type,
497 Int32Type,
498 Decimal32Type,
499 _,
500 >(
501 value_array.as_ref(),
502 decimal_places,
503 |v, dp| {
504 let rounded = round_decimal_or_zero(
505 v,
506 *input_precision,
507 *scale,
508 *new_scale,
509 dp,
510 )?;
511 if *precision == Decimal32Type::MAX_PRECISION
512 && (decimal_places_is_array || (*scale == 0 && dp < 0))
513 {
514 validate_decimal_precision::<Decimal32Type>(
519 rounded, *precision, *new_scale,
520 )
521 } else {
522 Ok(rounded)
523 }
524 },
525 *precision,
526 *new_scale,
527 )?;
528 result as _
529 }
530 (Decimal64(input_precision, scale), Decimal64(precision, new_scale)) => {
531 let result = calculate_binary_decimal_math::<
532 Decimal64Type,
533 Int32Type,
534 Decimal64Type,
535 _,
536 >(
537 value_array.as_ref(),
538 decimal_places,
539 |v, dp| {
540 let rounded = round_decimal_or_zero(
541 v,
542 *input_precision,
543 *scale,
544 *new_scale,
545 dp,
546 )?;
547 if *precision == Decimal64Type::MAX_PRECISION
548 && (decimal_places_is_array || (*scale == 0 && dp < 0))
549 {
550 validate_decimal_precision::<Decimal64Type>(
552 rounded, *precision, *new_scale,
553 )
554 } else {
555 Ok(rounded)
556 }
557 },
558 *precision,
559 *new_scale,
560 )?;
561 result as _
562 }
563 (Decimal128(input_precision, scale), Decimal128(precision, new_scale)) => {
564 let result = calculate_binary_decimal_math::<
565 Decimal128Type,
566 Int32Type,
567 Decimal128Type,
568 _,
569 >(
570 value_array.as_ref(),
571 decimal_places,
572 |v, dp| {
573 let rounded = round_decimal_or_zero(
574 v,
575 *input_precision,
576 *scale,
577 *new_scale,
578 dp,
579 )?;
580 if *precision == Decimal128Type::MAX_PRECISION
581 && (decimal_places_is_array || (*scale == 0 && dp < 0))
582 {
583 validate_decimal_precision::<Decimal128Type>(
585 rounded, *precision, *new_scale,
586 )
587 } else {
588 Ok(rounded)
589 }
590 },
591 *precision,
592 *new_scale,
593 )?;
594 result as _
595 }
596 (Decimal256(input_precision, scale), Decimal256(precision, new_scale)) => {
597 let result = calculate_binary_decimal_math::<
598 Decimal256Type,
599 Int32Type,
600 Decimal256Type,
601 _,
602 >(
603 value_array.as_ref(),
604 decimal_places,
605 |v, dp| {
606 let rounded = round_decimal_or_zero(
607 v,
608 *input_precision,
609 *scale,
610 *new_scale,
611 dp,
612 )?;
613 if *precision == Decimal256Type::MAX_PRECISION
614 && (decimal_places_is_array || (*scale == 0 && dp < 0))
615 {
616 validate_decimal_precision::<Decimal256Type>(
618 rounded, *precision, *new_scale,
619 )
620 } else {
621 Ok(rounded)
622 }
623 },
624 *precision,
625 *new_scale,
626 )?;
627 result as _
628 }
629 (other, _) => exec_err!("Unsupported data type {other:?} for function round")?,
630 };
631
632 if both_scalars {
633 ScalarValue::try_from_array(&arr, 0).map(ColumnarValue::Scalar)
634 } else {
635 Ok(ColumnarValue::Array(arr))
636 }
637}
638
639fn round_float<T>(value: T, decimal_places: i32) -> Result<T, ArrowError>
640where
641 T: num_traits::Float,
642{
643 let factor = T::from(10_f64.powi(decimal_places)).ok_or_else(|| {
644 ArrowError::ComputeError(format!(
645 "Invalid value for decimal places: {decimal_places}"
646 ))
647 })?;
648 Ok((value * factor).round() / factor)
649}
650
651fn round_decimal<V: ArrowNativeTypeOp>(
652 value: V,
653 input_scale: i8,
654 output_scale: i8,
655 decimal_places: i32,
656) -> Result<V, ArrowError> {
657 let diff = i64::from(input_scale) - i64::from(decimal_places);
658 if diff <= 0 {
659 return Ok(value);
660 }
661
662 debug_assert!(diff <= i64::from(u32::MAX));
663 let diff = diff as u32;
664
665 let one = V::ONE;
666 let two = V::from_usize(2).ok_or_else(|| {
667 ArrowError::ComputeError("Internal error: could not create constant 2".into())
668 })?;
669 let ten = V::from_usize(10).ok_or_else(|| {
670 ArrowError::ComputeError("Internal error: could not create constant 10".into())
671 })?;
672
673 let factor = ten.pow_checked(diff).map_err(|_| {
674 ArrowError::ComputeError(format!(
675 "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}"
676 ))
677 })?;
678
679 let mut quotient = value.div_wrapping(factor);
680 let remainder = value.mod_wrapping(factor);
681
682 let threshold = factor.div_wrapping(two);
684 if remainder >= threshold {
685 quotient = quotient.add_checked(one).map_err(|_| {
686 ArrowError::ComputeError("Overflow while rounding decimal".into())
687 })?;
688 } else if remainder <= threshold.neg_wrapping() {
689 quotient = quotient.sub_checked(one).map_err(|_| {
690 ArrowError::ComputeError("Overflow while rounding decimal".into())
691 })?;
692 }
693
694 let scale_shift = i64::from(output_scale) - i64::from(decimal_places);
697 if scale_shift == 0 {
698 return Ok(quotient);
699 }
700
701 debug_assert!(scale_shift > 0);
702 debug_assert!(scale_shift <= i64::from(u32::MAX));
703 let scale_shift = scale_shift as u32;
704 let shift_factor = ten.pow_checked(scale_shift).map_err(|_| {
705 ArrowError::ComputeError(format!(
706 "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}"
707 ))
708 })?;
709 quotient
710 .mul_checked(shift_factor)
711 .map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into()))
712}
713
714fn round_decimal_or_zero<V: ArrowNativeTypeOp>(
715 value: V,
716 precision: u8,
717 input_scale: i8,
718 output_scale: i8,
719 decimal_places: i32,
720) -> Result<V, ArrowError> {
721 if let Some(dp) =
722 normalize_decimal_places_for_decimal(decimal_places, precision, input_scale)
723 {
724 round_decimal(value, input_scale, output_scale, dp)
725 } else {
726 V::from_usize(0).ok_or_else(|| {
727 ArrowError::ComputeError("Internal error: could not create constant 0".into())
728 })
729 }
730}
731
732#[cfg(test)]
733mod test {
734 use std::sync::Arc;
735
736 use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
737 use datafusion_common::DataFusionError;
738 use datafusion_common::ScalarValue;
739 use datafusion_common::cast::{as_float32_array, as_float64_array};
740 use datafusion_expr::ColumnarValue;
741
742 fn round_arrays(
743 value: ArrayRef,
744 decimal_places: Option<ArrayRef>,
745 ) -> Result<ArrayRef, DataFusionError> {
746 let number_rows = value.len();
747 let return_type = value.data_type().clone();
751 let value = ColumnarValue::Array(value);
752 let decimal_places = decimal_places
753 .map(ColumnarValue::Array)
754 .unwrap_or_else(|| ColumnarValue::Scalar(ScalarValue::Int32(Some(0))));
755
756 let result =
757 super::round_columnar(&value, &decimal_places, number_rows, &return_type)?;
758 match result {
759 ColumnarValue::Array(array) => Ok(array),
760 ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1),
761 }
762 }
763
764 #[test]
765 fn test_round_f32() {
766 let args: Vec<ArrayRef> = vec![
767 Arc::new(Float32Array::from(vec![125.2345; 10])), Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), ];
770
771 let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])))
772 .expect("failed to initialize function round");
773 let floats =
774 as_float32_array(&result).expect("failed to initialize function round");
775
776 let expected = Float32Array::from(vec![
777 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
778 ]);
779
780 assert_eq!(floats, &expected);
781 }
782
783 #[test]
784 fn test_round_f64() {
785 let args: Vec<ArrayRef> = vec![
786 Arc::new(Float64Array::from(vec![125.2345; 10])), Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), ];
789
790 let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])))
791 .expect("failed to initialize function round");
792 let floats =
793 as_float64_array(&result).expect("failed to initialize function round");
794
795 let expected = Float64Array::from(vec![
796 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
797 ]);
798
799 assert_eq!(floats, &expected);
800 }
801
802 #[test]
803 fn test_round_f32_one_input() {
804 let args: Vec<ArrayRef> = vec![
805 Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), ];
807
808 let result = round_arrays(Arc::clone(&args[0]), None)
809 .expect("failed to initialize function round");
810 let floats =
811 as_float32_array(&result).expect("failed to initialize function round");
812
813 let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]);
814
815 assert_eq!(floats, &expected);
816 }
817
818 #[test]
819 fn test_round_f64_one_input() {
820 let args: Vec<ArrayRef> = vec![
821 Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), ];
823
824 let result = round_arrays(Arc::clone(&args[0]), None)
825 .expect("failed to initialize function round");
826 let floats =
827 as_float64_array(&result).expect("failed to initialize function round");
828
829 let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]);
830
831 assert_eq!(floats, &expected);
832 }
833
834 #[test]
835 fn test_round_f32_cast_fail() {
836 let args: Vec<ArrayRef> = vec![
837 Arc::new(Float64Array::from(vec![125.2345])), Arc::new(Int64Array::from(vec![2147483648])), ];
840
841 let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])));
842
843 assert!(result.is_err());
844 assert!(matches!(
845 result,
846 Err(DataFusionError::ArrowError(_, _)) | Err(DataFusionError::Execution(_))
847 ));
848 }
849}