datafusion_comet_spark_expr/math_funcs/
round.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
19use arrow::array::{Array, ArrowNativeTypeOp};
20use arrow::array::{Int16Array, Int32Array, Int64Array, Int8Array};
21use arrow::datatypes::DataType;
22use datafusion::common::{exec_err, internal_err, DataFusionError, ScalarValue};
23use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
24use std::{cmp::min, sync::Arc};
25
26macro_rules! integer_round {
27    ($X:expr, $DIV:expr, $HALF:expr) => {{
28        let rem = $X % $DIV;
29        if rem <= -$HALF {
30            ($X - rem).sub_wrapping($DIV)
31        } else if rem >= $HALF {
32            ($X - rem).add_wrapping($DIV)
33        } else {
34            $X - rem
35        }
36    }};
37}
38
39macro_rules! round_integer_array {
40    ($ARRAY:expr, $POINT:expr, $TYPE:ty, $NATIVE:ty) => {{
41        let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap();
42        let ten: $NATIVE = 10;
43        let result: $TYPE = if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) {
44            let half = div / 2;
45            arrow::compute::kernels::arity::unary(array, |x| integer_round!(x, div, half))
46        } else {
47            arrow::compute::kernels::arity::unary(array, |_| 0)
48        };
49        Ok(ColumnarValue::Array(Arc::new(result)))
50    }};
51}
52
53macro_rules! round_integer_scalar {
54    ($SCALAR:expr, $POINT:expr, $TYPE:expr, $NATIVE:ty) => {{
55        let ten: $NATIVE = 10;
56        if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) {
57            let half = div / 2;
58            Ok(ColumnarValue::Scalar($TYPE(
59                $SCALAR.map(|x| integer_round!(x, div, half)),
60            )))
61        } else {
62            Ok(ColumnarValue::Scalar($TYPE(Some(0))))
63        }
64    }};
65}
66
67/// `round` function that simulates Spark `round` expression
68pub fn spark_round(
69    args: &[ColumnarValue],
70    data_type: &DataType,
71) -> Result<ColumnarValue, DataFusionError> {
72    let value = &args[0];
73    let point = &args[1];
74    let ColumnarValue::Scalar(ScalarValue::Int64(Some(point))) = point else {
75        return internal_err!("Invalid point argument for Round(): {:#?}", point);
76    };
77    match value {
78        ColumnarValue::Array(array) => match array.data_type() {
79            DataType::Int64 if *point < 0 => round_integer_array!(array, point, Int64Array, i64),
80            DataType::Int32 if *point < 0 => round_integer_array!(array, point, Int32Array, i32),
81            DataType::Int16 if *point < 0 => round_integer_array!(array, point, Int16Array, i16),
82            DataType::Int8 if *point < 0 => round_integer_array!(array, point, Int8Array, i8),
83            DataType::Decimal128(_, scale) if *scale >= 0 => {
84                let f = decimal_round_f(scale, point);
85                let (precision, scale) = get_precision_scale(data_type);
86                make_decimal_array(array, precision, scale, &f)
87            }
88            DataType::Float32 | DataType::Float64 => Ok(ColumnarValue::Array(round(&[
89                Arc::clone(array),
90                args[1].to_array(array.len())?,
91            ])?)),
92            dt => exec_err!("Not supported datatype for ROUND: {dt}"),
93        },
94        ColumnarValue::Scalar(a) => match a {
95            ScalarValue::Int64(a) if *point < 0 => {
96                round_integer_scalar!(a, point, ScalarValue::Int64, i64)
97            }
98            ScalarValue::Int32(a) if *point < 0 => {
99                round_integer_scalar!(a, point, ScalarValue::Int32, i32)
100            }
101            ScalarValue::Int16(a) if *point < 0 => {
102                round_integer_scalar!(a, point, ScalarValue::Int16, i16)
103            }
104            ScalarValue::Int8(a) if *point < 0 => {
105                round_integer_scalar!(a, point, ScalarValue::Int8, i8)
106            }
107            ScalarValue::Decimal128(a, _, scale) if *scale >= 0 => {
108                let f = decimal_round_f(scale, point);
109                let (precision, scale) = get_precision_scale(data_type);
110                make_decimal_scalar(a, precision, scale, &f)
111            }
112            ScalarValue::Float32(_) | ScalarValue::Float64(_) => Ok(ColumnarValue::Scalar(
113                ScalarValue::try_from_array(&round(&[a.to_array()?, args[1].to_array(1)?])?, 0)?,
114            )),
115            dt => exec_err!("Not supported datatype for ROUND: {dt}"),
116        },
117    }
118}
119
120// Spark uses BigDecimal. See RoundBase implementation in Spark. Instead, we do the same by
121// 1) add the half of divisor, 2) round down by division, 3) adjust precision by multiplication
122#[inline]
123fn decimal_round_f(scale: &i8, point: &i64) -> Box<dyn Fn(i128) -> i128> {
124    if *point < 0 {
125        if let Some(div) = 10_i128.checked_pow((-(*point) as u32) + (*scale as u32)) {
126            let half = div / 2;
127            let mul = 10_i128.pow_wrapping((-(*point)) as u32);
128            // i128 can hold 39 digits of a base 10 number, adding half will not cause overflow
129            Box::new(move |x: i128| (x + x.signum() * half) / div * mul)
130        } else {
131            Box::new(move |_: i128| 0)
132        }
133    } else {
134        let div = 10_i128.pow_wrapping((*scale as u32) - min(*scale as u32, *point as u32));
135        let half = div / 2;
136        Box::new(move |x: i128| (x + x.signum() * half) / div)
137    }
138}
139
140#[cfg(test)]
141mod test {
142    use std::sync::Arc;
143
144    use crate::spark_round;
145
146    use arrow::array::{Float32Array, Float64Array};
147    use arrow::datatypes::DataType;
148    use datafusion::common::cast::{as_float32_array, as_float64_array};
149    use datafusion::common::{Result, ScalarValue};
150    use datafusion::physical_plan::ColumnarValue;
151
152    #[test]
153    #[cfg_attr(miri, ignore)] // rounding does not work when miri enabled
154    fn test_round_f32_array() -> Result<()> {
155        let args = vec![
156            ColumnarValue::Array(Arc::new(Float32Array::from(vec![
157                125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
158            ]))),
159            ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
160        ];
161        let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float32)? else {
162            unreachable!()
163        };
164        let floats = as_float32_array(&result)?;
165        let expected = Float32Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
166        assert_eq!(floats, &expected);
167        Ok(())
168    }
169
170    #[test]
171    #[cfg_attr(miri, ignore)] // rounding does not work when miri enabled
172    fn test_round_f64_array() -> Result<()> {
173        let args = vec![
174            ColumnarValue::Array(Arc::new(Float64Array::from(vec![
175                125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
176            ]))),
177            ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
178        ];
179        let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float64)? else {
180            unreachable!()
181        };
182        let floats = as_float64_array(&result)?;
183        let expected = Float64Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
184        assert_eq!(floats, &expected);
185        Ok(())
186    }
187
188    #[test]
189    #[cfg_attr(miri, ignore)] // rounding does not work when miri enabled
190    fn test_round_f32_scalar() -> Result<()> {
191        let args = vec![
192            ColumnarValue::Scalar(ScalarValue::Float32(Some(125.2345))),
193            ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
194        ];
195        let ColumnarValue::Scalar(ScalarValue::Float32(Some(result))) =
196            spark_round(&args, &DataType::Float32)?
197        else {
198            unreachable!()
199        };
200        assert_eq!(result, 125.23);
201        Ok(())
202    }
203
204    #[test]
205    #[cfg_attr(miri, ignore)] // rounding does not work when miri enabled
206    fn test_round_f64_scalar() -> Result<()> {
207        let args = vec![
208            ColumnarValue::Scalar(ScalarValue::Float64(Some(125.2345))),
209            ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
210        ];
211        let ColumnarValue::Scalar(ScalarValue::Float64(Some(result))) =
212            spark_round(&args, &DataType::Float64)?
213        else {
214            unreachable!()
215        };
216        assert_eq!(result, 125.23);
217        Ok(())
218    }
219}