datafusion_comet_spark_expr/math_funcs/
round.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
19use arrow::array::{Int16Array, Int32Array, Int64Array, Int8Array};
20use arrow_array::{Array, ArrowNativeTypeOp};
21use arrow_schema::DataType;
22use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
23use datafusion_common::{exec_err, internal_err, DataFusionError, ScalarValue};
24use std::{cmp::min, sync::Arc};
25
26macro_rules! integer_round {
27    ($X:expr, $DIV:expr, $HALF:expr) => {{
28        let rem = $X % $DIV;
29        if rem <= -$HALF {
30            ($X - rem).sub_wrapping($DIV)
31        } else if rem >= $HALF {
32            ($X - rem).add_wrapping($DIV)
33        } else {
34            $X - rem
35        }
36    }};
37}
38
39macro_rules! round_integer_array {
40    ($ARRAY:expr, $POINT:expr, $TYPE:ty, $NATIVE:ty) => {{
41        let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap();
42        let ten: $NATIVE = 10;
43        let result: $TYPE = if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) {
44            let half = div / 2;
45            arrow::compute::kernels::arity::unary(array, |x| integer_round!(x, div, half))
46        } else {
47            arrow::compute::kernels::arity::unary(array, |_| 0)
48        };
49        Ok(ColumnarValue::Array(Arc::new(result)))
50    }};
51}
52
53macro_rules! round_integer_scalar {
54    ($SCALAR:expr, $POINT:expr, $TYPE:expr, $NATIVE:ty) => {{
55        let ten: $NATIVE = 10;
56        if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) {
57            let half = div / 2;
58            Ok(ColumnarValue::Scalar($TYPE(
59                $SCALAR.map(|x| integer_round!(x, div, half)),
60            )))
61        } else {
62            Ok(ColumnarValue::Scalar($TYPE(Some(0))))
63        }
64    }};
65}
66
67/// `round` function that simulates Spark `round` expression
68pub fn spark_round(
69    args: &[ColumnarValue],
70    data_type: &DataType,
71) -> Result<ColumnarValue, DataFusionError> {
72    let value = &args[0];
73    let point = &args[1];
74    let ColumnarValue::Scalar(ScalarValue::Int64(Some(point))) = point else {
75        return internal_err!("Invalid point argument for Round(): {:#?}", point);
76    };
77    match value {
78        ColumnarValue::Array(array) => match array.data_type() {
79            DataType::Int64 if *point < 0 => round_integer_array!(array, point, Int64Array, i64),
80            DataType::Int32 if *point < 0 => round_integer_array!(array, point, Int32Array, i32),
81            DataType::Int16 if *point < 0 => round_integer_array!(array, point, Int16Array, i16),
82            DataType::Int8 if *point < 0 => round_integer_array!(array, point, Int8Array, i8),
83            DataType::Decimal128(_, scale) if *scale >= 0 => {
84                let f = decimal_round_f(scale, point);
85                let (precision, scale) = get_precision_scale(data_type);
86                make_decimal_array(array, precision, scale, &f)
87            }
88            DataType::Float32 | DataType::Float64 => Ok(ColumnarValue::Array(round(&[
89                Arc::clone(array),
90                args[1].to_array(array.len())?,
91            ])?)),
92            dt => exec_err!("Not supported datatype for ROUND: {dt}"),
93        },
94        ColumnarValue::Scalar(a) => match a {
95            ScalarValue::Int64(a) if *point < 0 => {
96                round_integer_scalar!(a, point, ScalarValue::Int64, i64)
97            }
98            ScalarValue::Int32(a) if *point < 0 => {
99                round_integer_scalar!(a, point, ScalarValue::Int32, i32)
100            }
101            ScalarValue::Int16(a) if *point < 0 => {
102                round_integer_scalar!(a, point, ScalarValue::Int16, i16)
103            }
104            ScalarValue::Int8(a) if *point < 0 => {
105                round_integer_scalar!(a, point, ScalarValue::Int8, i8)
106            }
107            ScalarValue::Decimal128(a, _, scale) if *scale >= 0 => {
108                let f = decimal_round_f(scale, point);
109                let (precision, scale) = get_precision_scale(data_type);
110                make_decimal_scalar(a, precision, scale, &f)
111            }
112            ScalarValue::Float32(_) | ScalarValue::Float64(_) => Ok(ColumnarValue::Scalar(
113                ScalarValue::try_from_array(&round(&[a.to_array()?, args[1].to_array(1)?])?, 0)?,
114            )),
115            dt => exec_err!("Not supported datatype for ROUND: {dt}"),
116        },
117    }
118}
119
120// Spark uses BigDecimal. See RoundBase implementation in Spark. Instead, we do the same by
121// 1) add the half of divisor, 2) round down by division, 3) adjust precision by multiplication
122#[inline]
123fn decimal_round_f(scale: &i8, point: &i64) -> Box<dyn Fn(i128) -> i128> {
124    if *point < 0 {
125        if let Some(div) = 10_i128.checked_pow((-(*point) as u32) + (*scale as u32)) {
126            let half = div / 2;
127            let mul = 10_i128.pow_wrapping((-(*point)) as u32);
128            // i128 can hold 39 digits of a base 10 number, adding half will not cause overflow
129            Box::new(move |x: i128| (x + x.signum() * half) / div * mul)
130        } else {
131            Box::new(move |_: i128| 0)
132        }
133    } else {
134        let div = 10_i128.pow_wrapping((*scale as u32) - min(*scale as u32, *point as u32));
135        let half = div / 2;
136        Box::new(move |x: i128| (x + x.signum() * half) / div)
137    }
138}
139
140#[cfg(test)]
141mod test {
142    use std::sync::Arc;
143
144    use crate::spark_round;
145
146    use arrow::array::{Float32Array, Float64Array};
147    use arrow_schema::DataType;
148    use datafusion_common::cast::{as_float32_array, as_float64_array};
149    use datafusion_common::{Result, ScalarValue};
150    use datafusion_expr::ColumnarValue;
151
152    #[test]
153    fn test_round_f32_array() -> Result<()> {
154        let args = vec![
155            ColumnarValue::Array(Arc::new(Float32Array::from(vec![
156                125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
157            ]))),
158            ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
159        ];
160        let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float32)? else {
161            unreachable!()
162        };
163        let floats = as_float32_array(&result)?;
164        let expected = Float32Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
165        assert_eq!(floats, &expected);
166        Ok(())
167    }
168
169    #[test]
170    fn test_round_f64_array() -> Result<()> {
171        let args = vec![
172            ColumnarValue::Array(Arc::new(Float64Array::from(vec![
173                125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
174            ]))),
175            ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
176        ];
177        let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float64)? else {
178            unreachable!()
179        };
180        let floats = as_float64_array(&result)?;
181        let expected = Float64Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
182        assert_eq!(floats, &expected);
183        Ok(())
184    }
185
186    #[test]
187    fn test_round_f32_scalar() -> Result<()> {
188        let args = vec![
189            ColumnarValue::Scalar(ScalarValue::Float32(Some(125.2345))),
190            ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
191        ];
192        let ColumnarValue::Scalar(ScalarValue::Float32(Some(result))) =
193            spark_round(&args, &DataType::Float32)?
194        else {
195            unreachable!()
196        };
197        assert_eq!(result, 125.23);
198        Ok(())
199    }
200
201    #[test]
202    fn test_round_f64_scalar() -> Result<()> {
203        let args = vec![
204            ColumnarValue::Scalar(ScalarValue::Float64(Some(125.2345))),
205            ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
206        ];
207        let ColumnarValue::Scalar(ScalarValue::Float64(Some(result))) =
208            spark_round(&args, &DataType::Float64)?
209        else {
210            unreachable!()
211        };
212        assert_eq!(result, 125.23);
213        Ok(())
214    }
215}