Skip to main content

datafusion_functions/math/
round.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19
20use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
21
22use arrow::array::ArrayRef;
23use arrow::datatypes::DataType::{
24    Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64,
25};
26use arrow::datatypes::{
27    ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
28    Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type,
29};
30use arrow::datatypes::{Field, FieldRef};
31use arrow::error::ArrowError;
32use datafusion_common::types::{
33    NativeType, logical_float32, logical_float64, logical_int32,
34};
35use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
36use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
37use datafusion_expr::{
38    Coercion, ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs,
39    ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
40};
41use datafusion_macros::user_doc;
42use std::sync::Arc;
43
44fn output_scale_for_decimal(precision: u8, input_scale: i8, decimal_places: i32) -> i8 {
45    // `decimal_places` controls the maximum output scale, but scale cannot exceed the input scale.
46    //
47    // For negative-scale decimals, allow further scale reduction to match negative `decimal_places`
48    // (e.g. scale -2 rounded to -3 becomes scale -3). This preserves fixed precision by
49    // representing the rounded result at a coarser scale.
50    if input_scale < 0 {
51        // Decimal scales must be within [-precision, precision] and fit in i8. For negative-scale
52        // decimals, allow rounding to move the output scale further negative, but cap it at
53        // `-precision` (beyond that, the rounded result is always 0).
54        let min_scale = -i32::from(precision);
55        let new_scale = i32::from(input_scale).min(decimal_places).max(min_scale);
56        return new_scale as i8;
57    }
58
59    // The `min` ensures the result is always within i8 range because `input_scale` is i8.
60    let decimal_places = decimal_places.max(0);
61    i32::from(input_scale).min(decimal_places) as i8
62}
63
64fn normalize_decimal_places_for_decimal(
65    decimal_places: i32,
66    precision: u8,
67    scale: i8,
68) -> Option<i32> {
69    if decimal_places >= 0 {
70        return Some(decimal_places);
71    }
72
73    // For fixed precision decimals, the absolute value is strictly less than 10^(precision - scale).
74    // If the rounding position is beyond that (abs(decimal_places) > precision - scale), the
75    // rounded result is always 0, and we can avoid overflow in intermediate 10^n computations.
76    let max_rounding_pow10 = i64::from(precision) - i64::from(scale);
77    if max_rounding_pow10 <= 0 {
78        return None;
79    }
80
81    let abs_decimal_places = i64::from(decimal_places.unsigned_abs());
82    (abs_decimal_places <= max_rounding_pow10).then_some(decimal_places)
83}
84
85fn validate_decimal_precision<T: DecimalType>(
86    value: T::Native,
87    precision: u8,
88    scale: i8,
89) -> Result<T::Native, ArrowError> {
90    T::validate_decimal_precision(value, precision, scale).map_err(|e| {
91        ArrowError::ComputeError(format!(
92            "Decimal overflow: rounded value exceeds precision {precision}: {e}"
93        ))
94    })?;
95    Ok(value)
96}
97
98fn calculate_new_precision_scale<T: DecimalType>(
99    precision: u8,
100    scale: i8,
101    decimal_places: Option<i32>,
102) -> Result<DataType> {
103    if let Some(decimal_places) = decimal_places {
104        let new_scale = output_scale_for_decimal(precision, scale, decimal_places);
105
106        // When rounding an integer decimal (scale == 0) to a negative `decimal_places`, a carry can
107        // add an extra digit to the integer part (e.g. 99 -> 100 when rounding to -1). This can
108        // only happen when the rounding position is within the existing precision.
109        let abs_decimal_places = decimal_places.unsigned_abs();
110        let new_precision = if scale == 0
111            && decimal_places < 0
112            && abs_decimal_places <= u32::from(precision)
113        {
114            precision.saturating_add(1).min(T::MAX_PRECISION)
115        } else {
116            precision
117        };
118        Ok(T::TYPE_CONSTRUCTOR(new_precision, new_scale))
119    } else {
120        let new_precision = precision.saturating_add(1).min(T::MAX_PRECISION);
121        Ok(T::TYPE_CONSTRUCTOR(new_precision, scale))
122    }
123}
124
125fn decimal_places_from_scalar(scalar: &ScalarValue) -> Result<i32> {
126    let out_of_range = |value: String| {
127        datafusion_common::DataFusionError::Execution(format!(
128            "round decimal_places {value} is out of supported i32 range"
129        ))
130    };
131    match scalar {
132        ScalarValue::Int8(Some(v)) => Ok(i32::from(*v)),
133        ScalarValue::Int16(Some(v)) => Ok(i32::from(*v)),
134        ScalarValue::Int32(Some(v)) => Ok(*v),
135        ScalarValue::Int64(Some(v)) => {
136            i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
137        }
138        ScalarValue::UInt8(Some(v)) => Ok(i32::from(*v)),
139        ScalarValue::UInt16(Some(v)) => Ok(i32::from(*v)),
140        ScalarValue::UInt32(Some(v)) => {
141            i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
142        }
143        ScalarValue::UInt64(Some(v)) => {
144            i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
145        }
146        other => exec_err!(
147            "Unexpected datatype for decimal_places: {}",
148            other.data_type()
149        ),
150    }
151}
152
153#[user_doc(
154    doc_section(label = "Math Functions"),
155    description = "Rounds a number to the nearest integer.",
156    syntax_example = "round(numeric_expression[, decimal_places])",
157    standard_argument(name = "numeric_expression", prefix = "Numeric"),
158    argument(
159        name = "decimal_places",
160        description = "Optional. The number of decimal places to round to. Defaults to 0."
161    ),
162    sql_example = r#"```sql
163> SELECT round(3.14159);
164+--------------+
165| round(3.14159)|
166+--------------+
167| 3.0          |
168+--------------+
169```"#
170)]
171#[derive(Debug, PartialEq, Eq, Hash)]
172pub struct RoundFunc {
173    signature: Signature,
174}
175
176impl Default for RoundFunc {
177    fn default() -> Self {
178        RoundFunc::new()
179    }
180}
181
182impl RoundFunc {
183    pub fn new() -> Self {
184        let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
185        let decimal_places = Coercion::new_implicit(
186            TypeSignatureClass::Native(logical_int32()),
187            vec![TypeSignatureClass::Integer],
188            NativeType::Int32,
189        );
190        let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32()));
191        let float64 = Coercion::new_implicit(
192            TypeSignatureClass::Native(logical_float64()),
193            vec![TypeSignatureClass::Numeric],
194            NativeType::Float64,
195        );
196        Self {
197            signature: Signature::one_of(
198                vec![
199                    TypeSignature::Coercible(vec![
200                        decimal.clone(),
201                        decimal_places.clone(),
202                    ]),
203                    TypeSignature::Coercible(vec![decimal]),
204                    TypeSignature::Coercible(vec![
205                        float32.clone(),
206                        decimal_places.clone(),
207                    ]),
208                    TypeSignature::Coercible(vec![float32]),
209                    TypeSignature::Coercible(vec![float64.clone(), decimal_places]),
210                    TypeSignature::Coercible(vec![float64]),
211                ],
212                Volatility::Immutable,
213            ),
214        }
215    }
216}
217
218impl ScalarUDFImpl for RoundFunc {
219    fn as_any(&self) -> &dyn Any {
220        self
221    }
222
223    fn name(&self) -> &str {
224        "round"
225    }
226
227    fn signature(&self) -> &Signature {
228        &self.signature
229    }
230
231    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
232        let input_field = &args.arg_fields[0];
233        let input_type = input_field.data_type();
234
235        // If decimal_places is a scalar literal, we can incorporate it into the output type
236        // (scale reduction). Otherwise, keep the input scale as we can't pick a per-row scale.
237        //
238        // Note: `scalar_arguments` contains the original literal values (pre-coercion), so
239        // integer literals may appear as Int64 even though the signature coerces them to Int32.
240        let decimal_places: Option<i32> = match args.scalar_arguments.get(1) {
241            None => Some(0),    // No dp argument means default to 0
242            Some(None) => None, // dp is not a literal (e.g. column)
243            Some(Some(scalar)) if scalar.is_null() => Some(0), // null dp => default to 0
244            Some(Some(scalar)) => Some(decimal_places_from_scalar(scalar)?),
245        };
246
247        // Calculate return type based on input type
248        // For decimals: reduce scale to decimal_places (reclaims precision for integer part)
249        // This matches Spark/DuckDB behavior where ROUND adjusts the scale
250        // BUT only if dp is a scalar literal - otherwise keep original scale and add
251        // extra precision to accommodate potential carry-over.
252        let return_type =
253            match input_type {
254                Float32 => Float32,
255                Decimal32(precision, scale) => calculate_new_precision_scale::<
256                    Decimal32Type,
257                >(
258                    *precision, *scale, decimal_places
259                )?,
260                Decimal64(precision, scale) => calculate_new_precision_scale::<
261                    Decimal64Type,
262                >(
263                    *precision, *scale, decimal_places
264                )?,
265                Decimal128(precision, scale) => calculate_new_precision_scale::<
266                    Decimal128Type,
267                >(
268                    *precision, *scale, decimal_places
269                )?,
270                Decimal256(precision, scale) => calculate_new_precision_scale::<
271                    Decimal256Type,
272                >(
273                    *precision, *scale, decimal_places
274                )?,
275                _ => Float64,
276            };
277
278        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
279        Ok(Arc::new(Field::new(self.name(), return_type, nullable)))
280    }
281
282    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
283        internal_err!("use return_field_from_args instead")
284    }
285
286    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
287        if args.arg_fields.iter().any(|a| a.data_type().is_null()) {
288            return ColumnarValue::Scalar(ScalarValue::Null)
289                .cast_to(args.return_type(), None);
290        }
291
292        let default_decimal_places = ColumnarValue::Scalar(ScalarValue::Int32(Some(0)));
293        let decimal_places = if args.args.len() == 2 {
294            &args.args[1]
295        } else {
296            &default_decimal_places
297        };
298
299        if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) =
300            (&args.args[0], decimal_places)
301        {
302            if value_scalar.is_null() || dp_scalar.is_null() {
303                return ColumnarValue::Scalar(ScalarValue::Null)
304                    .cast_to(args.return_type(), None);
305            }
306
307            let dp = if let ScalarValue::Int32(Some(dp)) = dp_scalar {
308                *dp
309            } else {
310                return internal_err!(
311                    "Unexpected datatype for decimal_places: {}",
312                    dp_scalar.data_type()
313                );
314            };
315
316            match (value_scalar, args.return_type()) {
317                (ScalarValue::Float32(Some(v)), _) => {
318                    let rounded = round_float(*v, dp)?;
319                    Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
320                }
321                (ScalarValue::Float64(Some(v)), _) => {
322                    let rounded = round_float(*v, dp)?;
323                    Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
324                }
325                (
326                    ScalarValue::Decimal32(Some(v), in_precision, scale),
327                    Decimal32(out_precision, out_scale),
328                ) => {
329                    let rounded =
330                        round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
331                    let rounded = if *out_precision == Decimal32Type::MAX_PRECISION
332                        && *scale == 0
333                        && dp < 0
334                    {
335                        // With scale == 0 and negative dp, rounding can carry into an additional
336                        // digit (e.g. 99 -> 100). If we're already at max precision we can't widen
337                        // the type, so validate and error rather than producing an invalid decimal.
338                        validate_decimal_precision::<Decimal32Type>(
339                            rounded,
340                            *out_precision,
341                            *out_scale,
342                        )
343                    } else {
344                        Ok(rounded)
345                    }?;
346                    let scalar =
347                        ScalarValue::Decimal32(Some(rounded), *out_precision, *out_scale);
348                    Ok(ColumnarValue::Scalar(scalar))
349                }
350                (
351                    ScalarValue::Decimal64(Some(v), in_precision, scale),
352                    Decimal64(out_precision, out_scale),
353                ) => {
354                    let rounded =
355                        round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
356                    let rounded = if *out_precision == Decimal64Type::MAX_PRECISION
357                        && *scale == 0
358                        && dp < 0
359                    {
360                        // See Decimal32 branch for details.
361                        validate_decimal_precision::<Decimal64Type>(
362                            rounded,
363                            *out_precision,
364                            *out_scale,
365                        )
366                    } else {
367                        Ok(rounded)
368                    }?;
369                    let scalar =
370                        ScalarValue::Decimal64(Some(rounded), *out_precision, *out_scale);
371                    Ok(ColumnarValue::Scalar(scalar))
372                }
373                (
374                    ScalarValue::Decimal128(Some(v), in_precision, scale),
375                    Decimal128(out_precision, out_scale),
376                ) => {
377                    let rounded =
378                        round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
379                    let rounded = if *out_precision == Decimal128Type::MAX_PRECISION
380                        && *scale == 0
381                        && dp < 0
382                    {
383                        // See Decimal32 branch for details.
384                        validate_decimal_precision::<Decimal128Type>(
385                            rounded,
386                            *out_precision,
387                            *out_scale,
388                        )
389                    } else {
390                        Ok(rounded)
391                    }?;
392                    let scalar = ScalarValue::Decimal128(
393                        Some(rounded),
394                        *out_precision,
395                        *out_scale,
396                    );
397                    Ok(ColumnarValue::Scalar(scalar))
398                }
399                (
400                    ScalarValue::Decimal256(Some(v), in_precision, scale),
401                    Decimal256(out_precision, out_scale),
402                ) => {
403                    let rounded =
404                        round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?;
405                    let rounded = if *out_precision == Decimal256Type::MAX_PRECISION
406                        && *scale == 0
407                        && dp < 0
408                    {
409                        // See Decimal32 branch for details.
410                        validate_decimal_precision::<Decimal256Type>(
411                            rounded,
412                            *out_precision,
413                            *out_scale,
414                        )
415                    } else {
416                        Ok(rounded)
417                    }?;
418                    let scalar = ScalarValue::Decimal256(
419                        Some(rounded),
420                        *out_precision,
421                        *out_scale,
422                    );
423                    Ok(ColumnarValue::Scalar(scalar))
424                }
425                (ScalarValue::Null, _) => ColumnarValue::Scalar(ScalarValue::Null)
426                    .cast_to(args.return_type(), None),
427                (value_scalar, return_type) => {
428                    internal_err!(
429                        "Unexpected datatype for round(value, decimal_places): value {}, return type {}",
430                        value_scalar.data_type(),
431                        return_type
432                    )
433                }
434            }
435        } else {
436            round_columnar(
437                &args.args[0],
438                decimal_places,
439                args.number_rows,
440                args.return_type(),
441            )
442        }
443    }
444
445    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
446        // round preserves the order of the first argument
447        let value = &input[0];
448        let precision = input.get(1);
449
450        if precision
451            .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
452            .unwrap_or(true)
453        {
454            Ok(value.sort_properties)
455        } else {
456            Ok(SortProperties::Unordered)
457        }
458    }
459
460    fn documentation(&self) -> Option<&Documentation> {
461        self.doc()
462    }
463}
464
465fn round_columnar(
466    value: &ColumnarValue,
467    decimal_places: &ColumnarValue,
468    number_rows: usize,
469    return_type: &DataType,
470) -> Result<ColumnarValue> {
471    let value_array = value.to_array(number_rows)?;
472    let both_scalars = matches!(value, ColumnarValue::Scalar(_))
473        && matches!(decimal_places, ColumnarValue::Scalar(_));
474    let decimal_places_is_array = matches!(decimal_places, ColumnarValue::Array(_));
475
476    let arr: ArrayRef = match (value_array.data_type(), return_type) {
477        (Float64, _) => {
478            let result = calculate_binary_math::<Float64Type, Int32Type, Float64Type, _>(
479                value_array.as_ref(),
480                decimal_places,
481                round_float::<f64>,
482            )?;
483            result as _
484        }
485        (Float32, _) => {
486            let result = calculate_binary_math::<Float32Type, Int32Type, Float32Type, _>(
487                value_array.as_ref(),
488                decimal_places,
489                round_float::<f32>,
490            )?;
491            result as _
492        }
493        (Decimal32(input_precision, scale), Decimal32(precision, new_scale)) => {
494            // reduce scale to reclaim integer precision
495            let result = calculate_binary_decimal_math::<
496                Decimal32Type,
497                Int32Type,
498                Decimal32Type,
499                _,
500            >(
501                value_array.as_ref(),
502                decimal_places,
503                |v, dp| {
504                    let rounded = round_decimal_or_zero(
505                        v,
506                        *input_precision,
507                        *scale,
508                        *new_scale,
509                        dp,
510                    )?;
511                    if *precision == Decimal32Type::MAX_PRECISION
512                        && (decimal_places_is_array || (*scale == 0 && dp < 0))
513                    {
514                        // If we're already at max precision, we can't widen the result type. For
515                        // dp arrays, or for scale == 0 with negative dp, rounding can overflow the
516                        // fixed-precision type. Validate per-row and return an error instead of
517                        // producing an invalid decimal that Arrow may display incorrectly.
518                        validate_decimal_precision::<Decimal32Type>(
519                            rounded, *precision, *new_scale,
520                        )
521                    } else {
522                        Ok(rounded)
523                    }
524                },
525                *precision,
526                *new_scale,
527            )?;
528            result as _
529        }
530        (Decimal64(input_precision, scale), Decimal64(precision, new_scale)) => {
531            let result = calculate_binary_decimal_math::<
532                Decimal64Type,
533                Int32Type,
534                Decimal64Type,
535                _,
536            >(
537                value_array.as_ref(),
538                decimal_places,
539                |v, dp| {
540                    let rounded = round_decimal_or_zero(
541                        v,
542                        *input_precision,
543                        *scale,
544                        *new_scale,
545                        dp,
546                    )?;
547                    if *precision == Decimal64Type::MAX_PRECISION
548                        && (decimal_places_is_array || (*scale == 0 && dp < 0))
549                    {
550                        // See Decimal32 branch for details.
551                        validate_decimal_precision::<Decimal64Type>(
552                            rounded, *precision, *new_scale,
553                        )
554                    } else {
555                        Ok(rounded)
556                    }
557                },
558                *precision,
559                *new_scale,
560            )?;
561            result as _
562        }
563        (Decimal128(input_precision, scale), Decimal128(precision, new_scale)) => {
564            let result = calculate_binary_decimal_math::<
565                Decimal128Type,
566                Int32Type,
567                Decimal128Type,
568                _,
569            >(
570                value_array.as_ref(),
571                decimal_places,
572                |v, dp| {
573                    let rounded = round_decimal_or_zero(
574                        v,
575                        *input_precision,
576                        *scale,
577                        *new_scale,
578                        dp,
579                    )?;
580                    if *precision == Decimal128Type::MAX_PRECISION
581                        && (decimal_places_is_array || (*scale == 0 && dp < 0))
582                    {
583                        // See Decimal32 branch for details.
584                        validate_decimal_precision::<Decimal128Type>(
585                            rounded, *precision, *new_scale,
586                        )
587                    } else {
588                        Ok(rounded)
589                    }
590                },
591                *precision,
592                *new_scale,
593            )?;
594            result as _
595        }
596        (Decimal256(input_precision, scale), Decimal256(precision, new_scale)) => {
597            let result = calculate_binary_decimal_math::<
598                Decimal256Type,
599                Int32Type,
600                Decimal256Type,
601                _,
602            >(
603                value_array.as_ref(),
604                decimal_places,
605                |v, dp| {
606                    let rounded = round_decimal_or_zero(
607                        v,
608                        *input_precision,
609                        *scale,
610                        *new_scale,
611                        dp,
612                    )?;
613                    if *precision == Decimal256Type::MAX_PRECISION
614                        && (decimal_places_is_array || (*scale == 0 && dp < 0))
615                    {
616                        // See Decimal32 branch for details.
617                        validate_decimal_precision::<Decimal256Type>(
618                            rounded, *precision, *new_scale,
619                        )
620                    } else {
621                        Ok(rounded)
622                    }
623                },
624                *precision,
625                *new_scale,
626            )?;
627            result as _
628        }
629        (other, _) => exec_err!("Unsupported data type {other:?} for function round")?,
630    };
631
632    if both_scalars {
633        ScalarValue::try_from_array(&arr, 0).map(ColumnarValue::Scalar)
634    } else {
635        Ok(ColumnarValue::Array(arr))
636    }
637}
638
639fn round_float<T>(value: T, decimal_places: i32) -> Result<T, ArrowError>
640where
641    T: num_traits::Float,
642{
643    let factor = T::from(10_f64.powi(decimal_places)).ok_or_else(|| {
644        ArrowError::ComputeError(format!(
645            "Invalid value for decimal places: {decimal_places}"
646        ))
647    })?;
648    Ok((value * factor).round() / factor)
649}
650
651fn round_decimal<V: ArrowNativeTypeOp>(
652    value: V,
653    input_scale: i8,
654    output_scale: i8,
655    decimal_places: i32,
656) -> Result<V, ArrowError> {
657    let diff = i64::from(input_scale) - i64::from(decimal_places);
658    if diff <= 0 {
659        return Ok(value);
660    }
661
662    debug_assert!(diff <= i64::from(u32::MAX));
663    let diff = diff as u32;
664
665    let one = V::ONE;
666    let two = V::from_usize(2).ok_or_else(|| {
667        ArrowError::ComputeError("Internal error: could not create constant 2".into())
668    })?;
669    let ten = V::from_usize(10).ok_or_else(|| {
670        ArrowError::ComputeError("Internal error: could not create constant 10".into())
671    })?;
672
673    let factor = ten.pow_checked(diff).map_err(|_| {
674        ArrowError::ComputeError(format!(
675            "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}"
676        ))
677    })?;
678
679    let mut quotient = value.div_wrapping(factor);
680    let remainder = value.mod_wrapping(factor);
681
682    // `factor` is an even number (10^n, n > 0), so `factor / 2` is the tie threshold
683    let threshold = factor.div_wrapping(two);
684    if remainder >= threshold {
685        quotient = quotient.add_checked(one).map_err(|_| {
686            ArrowError::ComputeError("Overflow while rounding decimal".into())
687        })?;
688    } else if remainder <= threshold.neg_wrapping() {
689        quotient = quotient.sub_checked(one).map_err(|_| {
690            ArrowError::ComputeError("Overflow while rounding decimal".into())
691        })?;
692    }
693
694    // `quotient` is the rounded value at scale `decimal_places`. Rescale to the desired
695    // `output_scale` (which is always >= `decimal_places` in cases where diff > 0).
696    let scale_shift = i64::from(output_scale) - i64::from(decimal_places);
697    if scale_shift == 0 {
698        return Ok(quotient);
699    }
700
701    debug_assert!(scale_shift > 0);
702    debug_assert!(scale_shift <= i64::from(u32::MAX));
703    let scale_shift = scale_shift as u32;
704    let shift_factor = ten.pow_checked(scale_shift).map_err(|_| {
705        ArrowError::ComputeError(format!(
706            "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}"
707        ))
708    })?;
709    quotient
710        .mul_checked(shift_factor)
711        .map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into()))
712}
713
714fn round_decimal_or_zero<V: ArrowNativeTypeOp>(
715    value: V,
716    precision: u8,
717    input_scale: i8,
718    output_scale: i8,
719    decimal_places: i32,
720) -> Result<V, ArrowError> {
721    if let Some(dp) =
722        normalize_decimal_places_for_decimal(decimal_places, precision, input_scale)
723    {
724        round_decimal(value, input_scale, output_scale, dp)
725    } else {
726        V::from_usize(0).ok_or_else(|| {
727            ArrowError::ComputeError("Internal error: could not create constant 0".into())
728        })
729    }
730}
731
732#[cfg(test)]
733mod test {
734    use std::sync::Arc;
735
736    use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
737    use datafusion_common::DataFusionError;
738    use datafusion_common::ScalarValue;
739    use datafusion_common::cast::{as_float32_array, as_float64_array};
740    use datafusion_expr::ColumnarValue;
741
742    fn round_arrays(
743        value: ArrayRef,
744        decimal_places: Option<ArrayRef>,
745    ) -> Result<ArrayRef, DataFusionError> {
746        let number_rows = value.len();
747        // NOTE: For decimal inputs, the actual ROUND return type can differ from the
748        // input type (scale reduction for literal `decimal_places`). These unit tests
749        // only exercise Float32/Float64 behavior.
750        let return_type = value.data_type().clone();
751        let value = ColumnarValue::Array(value);
752        let decimal_places = decimal_places
753            .map(ColumnarValue::Array)
754            .unwrap_or_else(|| ColumnarValue::Scalar(ScalarValue::Int32(Some(0))));
755
756        let result =
757            super::round_columnar(&value, &decimal_places, number_rows, &return_type)?;
758        match result {
759            ColumnarValue::Array(array) => Ok(array),
760            ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1),
761        }
762    }
763
764    #[test]
765    fn test_round_f32() {
766        let args: Vec<ArrayRef> = vec![
767            Arc::new(Float32Array::from(vec![125.2345; 10])), // input
768            Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
769        ];
770
771        let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])))
772            .expect("failed to initialize function round");
773        let floats =
774            as_float32_array(&result).expect("failed to initialize function round");
775
776        let expected = Float32Array::from(vec![
777            125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
778        ]);
779
780        assert_eq!(floats, &expected);
781    }
782
783    #[test]
784    fn test_round_f64() {
785        let args: Vec<ArrayRef> = vec![
786            Arc::new(Float64Array::from(vec![125.2345; 10])), // input
787            Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
788        ];
789
790        let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])))
791            .expect("failed to initialize function round");
792        let floats =
793            as_float64_array(&result).expect("failed to initialize function round");
794
795        let expected = Float64Array::from(vec![
796            125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
797        ]);
798
799        assert_eq!(floats, &expected);
800    }
801
802    #[test]
803    fn test_round_f32_one_input() {
804        let args: Vec<ArrayRef> = vec![
805            Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input
806        ];
807
808        let result = round_arrays(Arc::clone(&args[0]), None)
809            .expect("failed to initialize function round");
810        let floats =
811            as_float32_array(&result).expect("failed to initialize function round");
812
813        let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]);
814
815        assert_eq!(floats, &expected);
816    }
817
818    #[test]
819    fn test_round_f64_one_input() {
820        let args: Vec<ArrayRef> = vec![
821            Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input
822        ];
823
824        let result = round_arrays(Arc::clone(&args[0]), None)
825            .expect("failed to initialize function round");
826        let floats =
827            as_float64_array(&result).expect("failed to initialize function round");
828
829        let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]);
830
831        assert_eq!(floats, &expected);
832    }
833
834    #[test]
835    fn test_round_f32_cast_fail() {
836        let args: Vec<ArrayRef> = vec![
837            Arc::new(Float64Array::from(vec![125.2345])), // input
838            Arc::new(Int64Array::from(vec![2147483648])), // decimal_places
839        ];
840
841        let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])));
842
843        assert!(result.is_err());
844        assert!(matches!(
845            result,
846            Err(DataFusionError::ArrowError(_, _)) | Err(DataFusionError::Execution(_))
847        ));
848    }
849}