Skip to main content

datafusion_functions/math/
round.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
19
20use arrow::array::ArrayRef;
21use arrow::datatypes::DataType::{
22    Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64,
23};
24use arrow::datatypes::{
25    ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
26    Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type,
27};
28use arrow::datatypes::{Field, FieldRef};
29use arrow::error::ArrowError;
30use datafusion_common::types::{
31    NativeType, logical_float32, logical_float64, logical_int32,
32};
33use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
34use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
35use datafusion_expr::{
36    Coercion, ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs,
37    ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
38};
39use datafusion_macros::user_doc;
40use std::sync::Arc;
41
42fn output_scale_for_decimal(precision: u8, input_scale: i8, decimal_places: i32) -> i8 {
43    // `decimal_places` controls the maximum output scale, but scale cannot exceed the input scale.
44    //
45    // For negative-scale decimals, allow further scale reduction to match negative `decimal_places`
46    // (e.g. scale -2 rounded to -3 becomes scale -3). This preserves fixed precision by
47    // representing the rounded result at a coarser scale.
48    if input_scale < 0 {
49        // Decimal scales must be within [-precision, precision] and fit in i8. For negative-scale
50        // decimals, allow rounding to move the output scale further negative, but cap it at
51        // `-precision` (beyond that, the rounded result is always 0).
52        let min_scale = -i32::from(precision);
53        let new_scale = i32::from(input_scale).min(decimal_places).max(min_scale);
54        return new_scale as i8;
55    }
56
57    // The `min` ensures the result is always within i8 range because `input_scale` is i8.
58    let decimal_places = decimal_places.max(0);
59    i32::from(input_scale).min(decimal_places) as i8
60}
61
62fn normalize_decimal_places_for_decimal(
63    decimal_places: i32,
64    precision: u8,
65    scale: i8,
66) -> Option<i32> {
67    if decimal_places >= 0 {
68        return Some(decimal_places);
69    }
70
71    // For fixed precision decimals, the absolute value is strictly less than 10^(precision - scale).
72    // If the rounding position is beyond that (abs(decimal_places) > precision - scale), the
73    // rounded result is always 0, and we can avoid overflow in intermediate 10^n computations.
74    let max_rounding_pow10 = i64::from(precision) - i64::from(scale);
75    if max_rounding_pow10 <= 0 {
76        return None;
77    }
78
79    let abs_decimal_places = i64::from(decimal_places.unsigned_abs());
80    (abs_decimal_places <= max_rounding_pow10).then_some(decimal_places)
81}
82
83fn validate_decimal_precision<T: DecimalType>(
84    value: T::Native,
85    precision: u8,
86    scale: i8,
87) -> Result<T::Native, ArrowError> {
88    T::validate_decimal_precision(value, precision, scale).map_err(|e| {
89        ArrowError::ComputeError(format!(
90            "Decimal overflow: rounded value exceeds precision {precision}: {e}"
91        ))
92    })?;
93    Ok(value)
94}
95
96fn calculate_new_precision_scale<T: DecimalType>(
97    precision: u8,
98    scale: i8,
99    decimal_places: Option<i32>,
100) -> Result<DataType> {
101    if let Some(decimal_places) = decimal_places {
102        let new_scale = output_scale_for_decimal(precision, scale, decimal_places);
103
104        // When rounding an integer decimal (scale == 0) to a negative `decimal_places`, a carry can
105        // add an extra digit to the integer part (e.g. 99 -> 100 when rounding to -1). This can
106        // only happen when the rounding position is within the existing precision.
107        let abs_decimal_places = decimal_places.unsigned_abs();
108        let new_precision = if scale == 0
109            && decimal_places < 0
110            && abs_decimal_places <= u32::from(precision)
111        {
112            precision.saturating_add(1).min(T::MAX_PRECISION)
113        } else {
114            precision
115        };
116        Ok(T::TYPE_CONSTRUCTOR(new_precision, new_scale))
117    } else {
118        let new_precision = precision.saturating_add(1).min(T::MAX_PRECISION);
119        Ok(T::TYPE_CONSTRUCTOR(new_precision, scale))
120    }
121}
122
123fn decimal_places_from_scalar(scalar: &ScalarValue) -> Result<i32> {
124    let out_of_range = |value: String| {
125        datafusion_common::DataFusionError::Execution(format!(
126            "round decimal_places {value} is out of supported i32 range"
127        ))
128    };
129    match scalar {
130        ScalarValue::Int8(Some(v)) => Ok(i32::from(*v)),
131        ScalarValue::Int16(Some(v)) => Ok(i32::from(*v)),
132        ScalarValue::Int32(Some(v)) => Ok(*v),
133        ScalarValue::Int64(Some(v)) => {
134            i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
135        }
136        ScalarValue::UInt8(Some(v)) => Ok(i32::from(*v)),
137        ScalarValue::UInt16(Some(v)) => Ok(i32::from(*v)),
138        ScalarValue::UInt32(Some(v)) => {
139            i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
140        }
141        ScalarValue::UInt64(Some(v)) => {
142            i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
143        }
144        other => exec_err!(
145            "Unexpected datatype for decimal_places: {}",
146            other.data_type()
147        ),
148    }
149}
150
151#[user_doc(
152    doc_section(label = "Math Functions"),
153    description = "Rounds a number to the nearest integer.",
154    syntax_example = "round(numeric_expression[, decimal_places])",
155    standard_argument(name = "numeric_expression", prefix = "Numeric"),
156    argument(
157        name = "decimal_places",
158        description = "Optional. The number of decimal places to round to. Defaults to 0."
159    ),
160    sql_example = r#"```sql
161> SELECT round(3.14159);
162+--------------+
163| round(3.14159)|
164+--------------+
165| 3.0          |
166+--------------+
167```"#
168)]
169#[derive(Debug, PartialEq, Eq, Hash)]
170pub struct RoundFunc {
171    signature: Signature,
172}
173
174impl Default for RoundFunc {
175    fn default() -> Self {
176        RoundFunc::new()
177    }
178}
179
180impl RoundFunc {
181    pub fn new() -> Self {
182        let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
183        let decimal_places = Coercion::new_implicit(
184            TypeSignatureClass::Native(logical_int32()),
185            vec![TypeSignatureClass::Integer],
186            NativeType::Int32,
187        );
188        let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32()));
189        let float64 = Coercion::new_implicit(
190            TypeSignatureClass::Native(logical_float64()),
191            vec![TypeSignatureClass::Numeric],
192            NativeType::Float64,
193        );
194        Self {
195            signature: Signature::one_of(
196                vec![
197                    TypeSignature::Coercible(vec![
198                        decimal.clone(),
199                        decimal_places.clone(),
200                    ]),
201                    TypeSignature::Coercible(vec![decimal]),
202                    TypeSignature::Coercible(vec![
203                        float32.clone(),
204                        decimal_places.clone(),
205                    ]),
206                    TypeSignature::Coercible(vec![float32]),
207                    TypeSignature::Coercible(vec![float64.clone(), decimal_places]),
208                    TypeSignature::Coercible(vec![float64]),
209                ],
210                Volatility::Immutable,
211            ),
212        }
213    }
214}
215
216impl ScalarUDFImpl for RoundFunc {
217    fn name(&self) -> &str {
218        "round"
219    }
220
221    fn signature(&self) -> &Signature {
222        &self.signature
223    }
224
225    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
226        let input_field = &args.arg_fields[0];
227        let input_type = input_field.data_type();
228
229        // If decimal_places is a scalar literal, we can incorporate it into the output type
230        // (scale reduction). Otherwise, keep the input scale as we can't pick a per-row scale.
231        //
232        // Note: `scalar_arguments` contains the original literal values (pre-coercion), so
233        // integer literals may appear as Int64 even though the signature coerces them to Int32.
234        let decimal_places: Option<i32> = match args.scalar_arguments.get(1) {
235            None => Some(0),    // No dp argument means default to 0
236            Some(None) => None, // dp is not a literal (e.g. column)
237            Some(Some(scalar)) if scalar.is_null() => Some(0), // null dp => default to 0
238            Some(Some(scalar)) => Some(decimal_places_from_scalar(scalar)?),
239        };
240
241        // Calculate return type based on input type
242        // For decimals: reduce scale to decimal_places (reclaims precision for integer part)
243        // This matches Spark/DuckDB behavior where ROUND adjusts the scale
244        // BUT only if dp is a scalar literal - otherwise keep original scale and add
245        // extra precision to accommodate potential carry-over.
246        let return_type =
247            match input_type {
248                Float32 => Float32,
249                Decimal32(precision, scale) => calculate_new_precision_scale::<
250                    Decimal32Type,
251                >(
252                    *precision, *scale, decimal_places
253                )?,
254                Decimal64(precision, scale) => calculate_new_precision_scale::<
255                    Decimal64Type,
256                >(
257                    *precision, *scale, decimal_places
258                )?,
259                Decimal128(precision, scale) => calculate_new_precision_scale::<
260                    Decimal128Type,
261                >(
262                    *precision, *scale, decimal_places
263                )?,
264                Decimal256(precision, scale) => calculate_new_precision_scale::<
265                    Decimal256Type,
266                >(
267                    *precision, *scale, decimal_places
268                )?,
269                _ => Float64,
270            };
271
272        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
273        Ok(Arc::new(Field::new(self.name(), return_type, nullable)))
274    }
275
276    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
277        internal_err!("use return_field_from_args instead")
278    }
279
280    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
281        if args.arg_fields.iter().any(|a| a.data_type().is_null()) {
282            return ColumnarValue::Scalar(ScalarValue::Null)
283                .cast_to(args.return_type(), None);
284        }
285
286        let default_decimal_places = ColumnarValue::Scalar(ScalarValue::Int32(Some(0)));
287        let decimal_places = if args.args.len() == 2 {
288            &args.args[1]
289        } else {
290            &default_decimal_places
291        };
292
293        if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) =
294            (&args.args[0], decimal_places)
295        {
296            if value_scalar.is_null() || dp_scalar.is_null() {
297                return ColumnarValue::Scalar(ScalarValue::Null)
298                    .cast_to(args.return_type(), None);
299            }
300
301            let dp = if let ScalarValue::Int32(Some(dp)) = dp_scalar {
302                *dp
303            } else {
304                return internal_err!(
305                    "Unexpected datatype for decimal_places: {}",
306                    dp_scalar.data_type()
307                );
308            };
309
310            match (value_scalar, args.return_type()) {
311                (ScalarValue::Float32(Some(v)), _) => {
312                    let rounded = round_float(*v, dp)?;
313                    Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
314                }
315                (ScalarValue::Float64(Some(v)), _) => {
316                    let rounded = round_float(*v, dp)?;
317                    Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
318                }
319                (
320                    ScalarValue::Decimal32(Some(v), in_precision, scale),
321                    Decimal32(out_precision, out_scale),
322                ) => {
323                    let rounded =
324                        round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
325                    let rounded = if *out_precision == Decimal32Type::MAX_PRECISION
326                        && *scale == 0
327                        && dp < 0
328                    {
329                        // With scale == 0 and negative dp, rounding can carry into an additional
330                        // digit (e.g. 99 -> 100). If we're already at max precision we can't widen
331                        // the type, so validate and error rather than producing an invalid decimal.
332                        validate_decimal_precision::<Decimal32Type>(
333                            rounded,
334                            *out_precision,
335                            *out_scale,
336                        )
337                    } else {
338                        Ok(rounded)
339                    }?;
340                    let scalar =
341                        ScalarValue::Decimal32(Some(rounded), *out_precision, *out_scale);
342                    Ok(ColumnarValue::Scalar(scalar))
343                }
344                (
345                    ScalarValue::Decimal64(Some(v), in_precision, scale),
346                    Decimal64(out_precision, out_scale),
347                ) => {
348                    let rounded =
349                        round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
350                    let rounded = if *out_precision == Decimal64Type::MAX_PRECISION
351                        && *scale == 0
352                        && dp < 0
353                    {
354                        // See Decimal32 branch for details.
355                        validate_decimal_precision::<Decimal64Type>(
356                            rounded,
357                            *out_precision,
358                            *out_scale,
359                        )
360                    } else {
361                        Ok(rounded)
362                    }?;
363                    let scalar =
364                        ScalarValue::Decimal64(Some(rounded), *out_precision, *out_scale);
365                    Ok(ColumnarValue::Scalar(scalar))
366                }
367                (
368                    ScalarValue::Decimal128(Some(v), in_precision, scale),
369                    Decimal128(out_precision, out_scale),
370                ) => {
371                    let rounded =
372                        round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
373                    let rounded = if *out_precision == Decimal128Type::MAX_PRECISION
374                        && *scale == 0
375                        && dp < 0
376                    {
377                        // See Decimal32 branch for details.
378                        validate_decimal_precision::<Decimal128Type>(
379                            rounded,
380                            *out_precision,
381                            *out_scale,
382                        )
383                    } else {
384                        Ok(rounded)
385                    }?;
386                    let scalar = ScalarValue::Decimal128(
387                        Some(rounded),
388                        *out_precision,
389                        *out_scale,
390                    );
391                    Ok(ColumnarValue::Scalar(scalar))
392                }
393                (
394                    ScalarValue::Decimal256(Some(v), in_precision, scale),
395                    Decimal256(out_precision, out_scale),
396                ) => {
397                    let rounded =
398                        round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
399                    let rounded = if *out_precision == Decimal256Type::MAX_PRECISION
400                        && *scale == 0
401                        && dp < 0
402                    {
403                        // See Decimal32 branch for details.
404                        validate_decimal_precision::<Decimal256Type>(
405                            rounded,
406                            *out_precision,
407                            *out_scale,
408                        )
409                    } else {
410                        Ok(rounded)
411                    }?;
412                    let scalar = ScalarValue::Decimal256(
413                        Some(rounded),
414                        *out_precision,
415                        *out_scale,
416                    );
417                    Ok(ColumnarValue::Scalar(scalar))
418                }
419                (ScalarValue::Null, _) => ColumnarValue::Scalar(ScalarValue::Null)
420                    .cast_to(args.return_type(), None),
421                (value_scalar, return_type) => {
422                    internal_err!(
423                        "Unexpected datatype for round(value, decimal_places): value {}, return type {}",
424                        value_scalar.data_type(),
425                        return_type
426                    )
427                }
428            }
429        } else {
430            round_columnar(
431                &args.args[0],
432                decimal_places,
433                args.number_rows,
434                args.return_type(),
435            )
436        }
437    }
438
439    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
440        // round preserves the order of the first argument
441        let value = &input[0];
442        let precision = input.get(1);
443
444        if precision
445            .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
446            .unwrap_or(true)
447        {
448            Ok(value.sort_properties)
449        } else {
450            Ok(SortProperties::Unordered)
451        }
452    }
453
454    fn documentation(&self) -> Option<&Documentation> {
455        self.doc()
456    }
457}
458
459fn round_columnar(
460    value: &ColumnarValue,
461    decimal_places: &ColumnarValue,
462    number_rows: usize,
463    return_type: &DataType,
464) -> Result<ColumnarValue> {
465    let value_array = value.to_array(number_rows)?;
466    let both_scalars = matches!(value, ColumnarValue::Scalar(_))
467        && matches!(decimal_places, ColumnarValue::Scalar(_));
468    let decimal_places_is_array = matches!(decimal_places, ColumnarValue::Array(_));
469
470    let arr: ArrayRef = match (value_array.data_type(), return_type) {
471        (Float64, _) => {
472            let result = calculate_binary_math::<Float64Type, Int32Type, Float64Type, _>(
473                value_array.as_ref(),
474                decimal_places,
475                round_float::<f64>,
476            )?;
477            result as _
478        }
479        (Float32, _) => {
480            let result = calculate_binary_math::<Float32Type, Int32Type, Float32Type, _>(
481                value_array.as_ref(),
482                decimal_places,
483                round_float::<f32>,
484            )?;
485            result as _
486        }
487        (Decimal32(input_precision, scale), Decimal32(precision, new_scale)) => {
488            // reduce scale to reclaim integer precision
489            let result = calculate_binary_decimal_math::<
490                Decimal32Type,
491                Int32Type,
492                Decimal32Type,
493                _,
494            >(
495                value_array.as_ref(),
496                decimal_places,
497                |v, dp| {
498                    let rounded = round_decimal_or_zero(
499                        v,
500                        *input_precision,
501                        *scale,
502                        *new_scale,
503                        dp,
504                    )?;
505                    if *precision == Decimal32Type::MAX_PRECISION
506                        && (decimal_places_is_array || (*scale == 0 && dp < 0))
507                    {
508                        // If we're already at max precision, we can't widen the result type. For
509                        // dp arrays, or for scale == 0 with negative dp, rounding can overflow the
510                        // fixed-precision type. Validate per-row and return an error instead of
511                        // producing an invalid decimal that Arrow may display incorrectly.
512                        validate_decimal_precision::<Decimal32Type>(
513                            rounded, *precision, *new_scale,
514                        )
515                    } else {
516                        Ok(rounded)
517                    }
518                },
519                *precision,
520                *new_scale,
521            )?;
522            result as _
523        }
524        (Decimal64(input_precision, scale), Decimal64(precision, new_scale)) => {
525            let result = calculate_binary_decimal_math::<
526                Decimal64Type,
527                Int32Type,
528                Decimal64Type,
529                _,
530            >(
531                value_array.as_ref(),
532                decimal_places,
533                |v, dp| {
534                    let rounded = round_decimal_or_zero(
535                        v,
536                        *input_precision,
537                        *scale,
538                        *new_scale,
539                        dp,
540                    )?;
541                    if *precision == Decimal64Type::MAX_PRECISION
542                        && (decimal_places_is_array || (*scale == 0 && dp < 0))
543                    {
544                        // See Decimal32 branch for details.
545                        validate_decimal_precision::<Decimal64Type>(
546                            rounded, *precision, *new_scale,
547                        )
548                    } else {
549                        Ok(rounded)
550                    }
551                },
552                *precision,
553                *new_scale,
554            )?;
555            result as _
556        }
557        (Decimal128(input_precision, scale), Decimal128(precision, new_scale)) => {
558            let result = calculate_binary_decimal_math::<
559                Decimal128Type,
560                Int32Type,
561                Decimal128Type,
562                _,
563            >(
564                value_array.as_ref(),
565                decimal_places,
566                |v, dp| {
567                    let rounded = round_decimal_or_zero(
568                        v,
569                        *input_precision,
570                        *scale,
571                        *new_scale,
572                        dp,
573                    )?;
574                    if *precision == Decimal128Type::MAX_PRECISION
575                        && (decimal_places_is_array || (*scale == 0 && dp < 0))
576                    {
577                        // See Decimal32 branch for details.
578                        validate_decimal_precision::<Decimal128Type>(
579                            rounded, *precision, *new_scale,
580                        )
581                    } else {
582                        Ok(rounded)
583                    }
584                },
585                *precision,
586                *new_scale,
587            )?;
588            result as _
589        }
590        (Decimal256(input_precision, scale), Decimal256(precision, new_scale)) => {
591            let result = calculate_binary_decimal_math::<
592                Decimal256Type,
593                Int32Type,
594                Decimal256Type,
595                _,
596            >(
597                value_array.as_ref(),
598                decimal_places,
599                |v, dp| {
600                    let rounded = round_decimal_or_zero(
601                        v,
602                        *input_precision,
603                        *scale,
604                        *new_scale,
605                        dp,
606                    )?;
607                    if *precision == Decimal256Type::MAX_PRECISION
608                        && (decimal_places_is_array || (*scale == 0 && dp < 0))
609                    {
610                        // See Decimal32 branch for details.
611                        validate_decimal_precision::<Decimal256Type>(
612                            rounded, *precision, *new_scale,
613                        )
614                    } else {
615                        Ok(rounded)
616                    }
617                },
618                *precision,
619                *new_scale,
620            )?;
621            result as _
622        }
623        (other, _) => exec_err!("Unsupported data type {other:?} for function round")?,
624    };
625
626    if both_scalars {
627        ScalarValue::try_from_array(&arr, 0).map(ColumnarValue::Scalar)
628    } else {
629        Ok(ColumnarValue::Array(arr))
630    }
631}
632
633fn round_float<T>(value: T, decimal_places: i32) -> Result<T, ArrowError>
634where
635    T: num_traits::Float,
636{
637    let factor = T::from(10_f64.powi(decimal_places)).ok_or_else(|| {
638        ArrowError::ComputeError(format!(
639            "Invalid value for decimal places: {decimal_places}"
640        ))
641    })?;
642    Ok((value * factor).round() / factor)
643}
644
645fn round_decimal<V: ArrowNativeTypeOp>(
646    value: V,
647    input_scale: i8,
648    output_scale: i8,
649    decimal_places: i32,
650) -> Result<V, ArrowError> {
651    let diff = i64::from(input_scale) - i64::from(decimal_places);
652    if diff <= 0 {
653        return Ok(value);
654    }
655
656    debug_assert!(diff <= i64::from(u32::MAX));
657    let diff = diff as u32;
658
659    let one = V::ONE;
660    let two = V::from_usize(2).ok_or_else(|| {
661        ArrowError::ComputeError("Internal error: could not create constant 2".into())
662    })?;
663    let ten = V::from_usize(10).ok_or_else(|| {
664        ArrowError::ComputeError("Internal error: could not create constant 10".into())
665    })?;
666
667    let factor = ten.pow_checked(diff).map_err(|_| {
668        ArrowError::ComputeError(format!(
669            "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}"
670        ))
671    })?;
672
673    let mut quotient = value.div_wrapping(factor);
674    let remainder = value.mod_wrapping(factor);
675
676    // `factor` is an even number (10^n, n > 0), so `factor / 2` is the tie threshold
677    let threshold = factor.div_wrapping(two);
678    if remainder >= threshold {
679        quotient = quotient.add_checked(one).map_err(|_| {
680            ArrowError::ComputeError("Overflow while rounding decimal".into())
681        })?;
682    } else if remainder <= threshold.neg_wrapping() {
683        quotient = quotient.sub_checked(one).map_err(|_| {
684            ArrowError::ComputeError("Overflow while rounding decimal".into())
685        })?;
686    }
687
688    // `quotient` is the rounded value at scale `decimal_places`. Rescale to the desired
689    // `output_scale` (which is always >= `decimal_places` in cases where diff > 0).
690    let scale_shift = i64::from(output_scale) - i64::from(decimal_places);
691    if scale_shift == 0 {
692        return Ok(quotient);
693    }
694
695    debug_assert!(scale_shift > 0);
696    debug_assert!(scale_shift <= i64::from(u32::MAX));
697    let scale_shift = scale_shift as u32;
698    let shift_factor = ten.pow_checked(scale_shift).map_err(|_| {
699        ArrowError::ComputeError(format!(
700            "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}"
701        ))
702    })?;
703    quotient
704        .mul_checked(shift_factor)
705        .map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into()))
706}
707
708fn round_decimal_or_zero<V: ArrowNativeTypeOp>(
709    value: V,
710    precision: u8,
711    input_scale: i8,
712    output_scale: i8,
713    decimal_places: i32,
714) -> Result<V, ArrowError> {
715    if let Some(dp) =
716        normalize_decimal_places_for_decimal(decimal_places, precision, input_scale)
717    {
718        round_decimal(value, input_scale, output_scale, dp)
719    } else {
720        V::from_usize(0).ok_or_else(|| {
721            ArrowError::ComputeError("Internal error: could not create constant 0".into())
722        })
723    }
724}
725
726#[cfg(test)]
727mod test {
728    use std::sync::Arc;
729
730    use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
731    use datafusion_common::DataFusionError;
732    use datafusion_common::ScalarValue;
733    use datafusion_common::cast::{as_float32_array, as_float64_array};
734    use datafusion_expr::ColumnarValue;
735
736    fn round_arrays(
737        value: ArrayRef,
738        decimal_places: Option<ArrayRef>,
739    ) -> Result<ArrayRef, DataFusionError> {
740        let number_rows = value.len();
741        // NOTE: For decimal inputs, the actual ROUND return type can differ from the
742        // input type (scale reduction for literal `decimal_places`). These unit tests
743        // only exercise Float32/Float64 behavior.
744        let return_type = value.data_type().clone();
745        let value = ColumnarValue::Array(value);
746        let decimal_places = decimal_places
747            .map(ColumnarValue::Array)
748            .unwrap_or_else(|| ColumnarValue::Scalar(ScalarValue::Int32(Some(0))));
749
750        let result =
751            super::round_columnar(&value, &decimal_places, number_rows, &return_type)?;
752        match result {
753            ColumnarValue::Array(array) => Ok(array),
754            ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1),
755        }
756    }
757
758    #[test]
759    fn test_round_f32() {
760        let args: Vec<ArrayRef> = vec![
761            Arc::new(Float32Array::from(vec![125.2345; 10])), // input
762            Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
763        ];
764
765        let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])))
766            .expect("failed to initialize function round");
767        let floats =
768            as_float32_array(&result).expect("failed to initialize function round");
769
770        let expected = Float32Array::from(vec![
771            125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
772        ]);
773
774        assert_eq!(floats, &expected);
775    }
776
777    #[test]
778    fn test_round_f64() {
779        let args: Vec<ArrayRef> = vec![
780            Arc::new(Float64Array::from(vec![125.2345; 10])), // input
781            Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
782        ];
783
784        let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])))
785            .expect("failed to initialize function round");
786        let floats =
787            as_float64_array(&result).expect("failed to initialize function round");
788
789        let expected = Float64Array::from(vec![
790            125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
791        ]);
792
793        assert_eq!(floats, &expected);
794    }
795
796    #[test]
797    fn test_round_f32_one_input() {
798        let args: Vec<ArrayRef> = vec![
799            Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input
800        ];
801
802        let result = round_arrays(Arc::clone(&args[0]), None)
803            .expect("failed to initialize function round");
804        let floats =
805            as_float32_array(&result).expect("failed to initialize function round");
806
807        let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]);
808
809        assert_eq!(floats, &expected);
810    }
811
812    #[test]
813    fn test_round_f64_one_input() {
814        let args: Vec<ArrayRef> = vec![
815            Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input
816        ];
817
818        let result = round_arrays(Arc::clone(&args[0]), None)
819            .expect("failed to initialize function round");
820        let floats =
821            as_float64_array(&result).expect("failed to initialize function round");
822
823        let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]);
824
825        assert_eq!(floats, &expected);
826    }
827
828    #[test]
829    fn test_round_f32_cast_fail() {
830        let args: Vec<ArrayRef> = vec![
831            Arc::new(Float64Array::from(vec![125.2345])), // input
832            Arc::new(Int64Array::from(vec![2147483648])), // decimal_places
833        ];
834
835        let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])));
836
837        assert!(result.is_err());
838        assert!(matches!(
839            result,
840            Err(DataFusionError::ArrowError(_, _)) | Err(DataFusionError::Execution(_))
841        ));
842    }
843}