datafusion_functions/math/
round.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19
20use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
21
22use arrow::array::ArrayRef;
23use arrow::datatypes::DataType::{
24    Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64,
25};
26use arrow::datatypes::{
27    ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
28    Decimal256Type, Float32Type, Float64Type, Int32Type,
29};
30use arrow::error::ArrowError;
31use datafusion_common::types::{
32    NativeType, logical_float32, logical_float64, logical_int32,
33};
34use datafusion_common::{Result, ScalarValue, exec_err};
35use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
36use datafusion_expr::{
37    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
38    TypeSignature, TypeSignatureClass, Volatility,
39};
40use datafusion_macros::user_doc;
41
42#[user_doc(
43    doc_section(label = "Math Functions"),
44    description = "Rounds a number to the nearest integer.",
45    syntax_example = "round(numeric_expression[, decimal_places])",
46    standard_argument(name = "numeric_expression", prefix = "Numeric"),
47    argument(
48        name = "decimal_places",
49        description = "Optional. The number of decimal places to round to. Defaults to 0."
50    ),
51    sql_example = r#"```sql
52> SELECT round(3.14159);
53+--------------+
54| round(3.14159)|
55+--------------+
56| 3.0          |
57+--------------+
58```"#
59)]
60#[derive(Debug, PartialEq, Eq, Hash)]
61pub struct RoundFunc {
62    signature: Signature,
63}
64
65impl Default for RoundFunc {
66    fn default() -> Self {
67        RoundFunc::new()
68    }
69}
70
71impl RoundFunc {
72    pub fn new() -> Self {
73        let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
74        let decimal_places = Coercion::new_implicit(
75            TypeSignatureClass::Native(logical_int32()),
76            vec![TypeSignatureClass::Integer],
77            NativeType::Int32,
78        );
79        let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32()));
80        let float64 = Coercion::new_implicit(
81            TypeSignatureClass::Native(logical_float64()),
82            vec![TypeSignatureClass::Numeric],
83            NativeType::Float64,
84        );
85        Self {
86            signature: Signature::one_of(
87                vec![
88                    TypeSignature::Coercible(vec![
89                        decimal.clone(),
90                        decimal_places.clone(),
91                    ]),
92                    TypeSignature::Coercible(vec![decimal]),
93                    TypeSignature::Coercible(vec![
94                        float32.clone(),
95                        decimal_places.clone(),
96                    ]),
97                    TypeSignature::Coercible(vec![float32]),
98                    TypeSignature::Coercible(vec![float64.clone(), decimal_places]),
99                    TypeSignature::Coercible(vec![float64]),
100                ],
101                Volatility::Immutable,
102            ),
103        }
104    }
105}
106
107impl ScalarUDFImpl for RoundFunc {
108    fn as_any(&self) -> &dyn Any {
109        self
110    }
111
112    fn name(&self) -> &str {
113        "round"
114    }
115
116    fn signature(&self) -> &Signature {
117        &self.signature
118    }
119
120    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
121        Ok(match arg_types[0].clone() {
122            Float32 => Float32,
123            dt @ Decimal128(_, _)
124            | dt @ Decimal256(_, _)
125            | dt @ Decimal32(_, _)
126            | dt @ Decimal64(_, _) => dt,
127            _ => Float64,
128        })
129    }
130
131    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
132        if args.arg_fields.iter().any(|a| a.data_type().is_null()) {
133            return ColumnarValue::Scalar(ScalarValue::Null)
134                .cast_to(args.return_type(), None);
135        }
136
137        let default_decimal_places = ColumnarValue::Scalar(ScalarValue::Int32(Some(0)));
138        let decimal_places = if args.args.len() == 2 {
139            &args.args[1]
140        } else {
141            &default_decimal_places
142        };
143
144        round_columnar(&args.args[0], decimal_places, args.number_rows)
145    }
146
147    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
148        // round preserves the order of the first argument
149        let value = &input[0];
150        let precision = input.get(1);
151
152        if precision
153            .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
154            .unwrap_or(true)
155        {
156            Ok(value.sort_properties)
157        } else {
158            Ok(SortProperties::Unordered)
159        }
160    }
161
162    fn documentation(&self) -> Option<&Documentation> {
163        self.doc()
164    }
165}
166
167fn round_columnar(
168    value: &ColumnarValue,
169    decimal_places: &ColumnarValue,
170    number_rows: usize,
171) -> Result<ColumnarValue> {
172    let value_array = value.to_array(number_rows)?;
173    let both_scalars = matches!(value, ColumnarValue::Scalar(_))
174        && matches!(decimal_places, ColumnarValue::Scalar(_));
175
176    let arr: ArrayRef = match value_array.data_type() {
177        Float64 => {
178            let result = calculate_binary_math::<Float64Type, Int32Type, Float64Type, _>(
179                value_array.as_ref(),
180                decimal_places,
181                round_float::<f64>,
182            )?;
183            result as _
184        }
185        Float32 => {
186            let result = calculate_binary_math::<Float32Type, Int32Type, Float32Type, _>(
187                value_array.as_ref(),
188                decimal_places,
189                round_float::<f32>,
190            )?;
191            result as _
192        }
193        Decimal32(precision, scale) => {
194            let result = calculate_binary_decimal_math::<
195                Decimal32Type,
196                Int32Type,
197                Decimal32Type,
198                _,
199            >(
200                value_array.as_ref(),
201                decimal_places,
202                |v, dp| round_decimal(v, *scale, dp),
203                *precision,
204                *scale,
205            )?;
206            result as _
207        }
208        Decimal64(precision, scale) => {
209            let result = calculate_binary_decimal_math::<
210                Decimal64Type,
211                Int32Type,
212                Decimal64Type,
213                _,
214            >(
215                value_array.as_ref(),
216                decimal_places,
217                |v, dp| round_decimal(v, *scale, dp),
218                *precision,
219                *scale,
220            )?;
221            result as _
222        }
223        Decimal128(precision, scale) => {
224            let result = calculate_binary_decimal_math::<
225                Decimal128Type,
226                Int32Type,
227                Decimal128Type,
228                _,
229            >(
230                value_array.as_ref(),
231                decimal_places,
232                |v, dp| round_decimal(v, *scale, dp),
233                *precision,
234                *scale,
235            )?;
236            result as _
237        }
238        Decimal256(precision, scale) => {
239            let result = calculate_binary_decimal_math::<
240                Decimal256Type,
241                Int32Type,
242                Decimal256Type,
243                _,
244            >(
245                value_array.as_ref(),
246                decimal_places,
247                |v, dp| round_decimal(v, *scale, dp),
248                *precision,
249                *scale,
250            )?;
251            result as _
252        }
253        other => exec_err!("Unsupported data type {other:?} for function round")?,
254    };
255
256    if both_scalars {
257        ScalarValue::try_from_array(&arr, 0).map(ColumnarValue::Scalar)
258    } else {
259        Ok(ColumnarValue::Array(arr))
260    }
261}
262
263fn round_float<T>(value: T, decimal_places: i32) -> Result<T, ArrowError>
264where
265    T: num_traits::Float,
266{
267    let factor = T::from(10_f64.powi(decimal_places)).ok_or_else(|| {
268        ArrowError::ComputeError(format!(
269            "Invalid value for decimal places: {decimal_places}"
270        ))
271    })?;
272    Ok((value * factor).round() / factor)
273}
274
275fn round_decimal<V: ArrowNativeTypeOp>(
276    value: V,
277    scale: i8,
278    decimal_places: i32,
279) -> Result<V, ArrowError> {
280    let diff = i64::from(scale) - i64::from(decimal_places);
281    if diff <= 0 {
282        return Ok(value);
283    }
284
285    let diff: u32 = diff.try_into().map_err(|e| {
286        ArrowError::ComputeError(format!(
287            "Invalid value for decimal places: {decimal_places}: {e}"
288        ))
289    })?;
290
291    let one = V::ONE;
292    let two = V::from_usize(2).ok_or_else(|| {
293        ArrowError::ComputeError("Internal error: could not create constant 2".into())
294    })?;
295    let ten = V::from_usize(10).ok_or_else(|| {
296        ArrowError::ComputeError("Internal error: could not create constant 10".into())
297    })?;
298
299    let factor = ten.pow_checked(diff).map_err(|_| {
300        ArrowError::ComputeError(format!(
301            "Overflow while rounding decimal with scale {scale} and decimal places {decimal_places}"
302        ))
303    })?;
304
305    let mut quotient = value.div_wrapping(factor);
306    let remainder = value.mod_wrapping(factor);
307
308    // `factor` is an even number (10^n, n > 0), so `factor / 2` is the tie threshold
309    let threshold = factor.div_wrapping(two);
310    if remainder >= threshold {
311        quotient = quotient.add_checked(one).map_err(|_| {
312            ArrowError::ComputeError("Overflow while rounding decimal".into())
313        })?;
314    } else if remainder <= threshold.neg_wrapping() {
315        quotient = quotient.sub_checked(one).map_err(|_| {
316            ArrowError::ComputeError("Overflow while rounding decimal".into())
317        })?;
318    }
319
320    quotient
321        .mul_checked(factor)
322        .map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into()))
323}
324
325#[cfg(test)]
326mod test {
327    use std::sync::Arc;
328
329    use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
330    use datafusion_common::DataFusionError;
331    use datafusion_common::ScalarValue;
332    use datafusion_common::cast::{as_float32_array, as_float64_array};
333    use datafusion_expr::ColumnarValue;
334
335    fn round_arrays(
336        value: ArrayRef,
337        decimal_places: Option<ArrayRef>,
338    ) -> Result<ArrayRef, DataFusionError> {
339        let number_rows = value.len();
340        let value = ColumnarValue::Array(value);
341        let decimal_places = decimal_places
342            .map(ColumnarValue::Array)
343            .unwrap_or_else(|| ColumnarValue::Scalar(ScalarValue::Int32(Some(0))));
344
345        let result = super::round_columnar(&value, &decimal_places, number_rows)?;
346        match result {
347            ColumnarValue::Array(array) => Ok(array),
348            ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1),
349        }
350    }
351
352    #[test]
353    fn test_round_f32() {
354        let args: Vec<ArrayRef> = vec![
355            Arc::new(Float32Array::from(vec![125.2345; 10])), // input
356            Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
357        ];
358
359        let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])))
360            .expect("failed to initialize function round");
361        let floats =
362            as_float32_array(&result).expect("failed to initialize function round");
363
364        let expected = Float32Array::from(vec![
365            125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
366        ]);
367
368        assert_eq!(floats, &expected);
369    }
370
371    #[test]
372    fn test_round_f64() {
373        let args: Vec<ArrayRef> = vec![
374            Arc::new(Float64Array::from(vec![125.2345; 10])), // input
375            Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
376        ];
377
378        let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])))
379            .expect("failed to initialize function round");
380        let floats =
381            as_float64_array(&result).expect("failed to initialize function round");
382
383        let expected = Float64Array::from(vec![
384            125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
385        ]);
386
387        assert_eq!(floats, &expected);
388    }
389
390    #[test]
391    fn test_round_f32_one_input() {
392        let args: Vec<ArrayRef> = vec![
393            Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input
394        ];
395
396        let result = round_arrays(Arc::clone(&args[0]), None)
397            .expect("failed to initialize function round");
398        let floats =
399            as_float32_array(&result).expect("failed to initialize function round");
400
401        let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]);
402
403        assert_eq!(floats, &expected);
404    }
405
406    #[test]
407    fn test_round_f64_one_input() {
408        let args: Vec<ArrayRef> = vec![
409            Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input
410        ];
411
412        let result = round_arrays(Arc::clone(&args[0]), None)
413            .expect("failed to initialize function round");
414        let floats =
415            as_float64_array(&result).expect("failed to initialize function round");
416
417        let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]);
418
419        assert_eq!(floats, &expected);
420    }
421
422    #[test]
423    fn test_round_f32_cast_fail() {
424        let args: Vec<ArrayRef> = vec![
425            Arc::new(Float64Array::from(vec![125.2345])), // input
426            Arc::new(Int64Array::from(vec![2147483648])), // decimal_places
427        ];
428
429        let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])));
430
431        assert!(result.is_err());
432        assert!(matches!(
433            result,
434            Err(DataFusionError::ArrowError(_, _)) | Err(DataFusionError::Execution(_))
435        ));
436    }
437}