datafusion_comet_spark_expr/math_funcs/
round.rs1use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
19use arrow::array::{Int16Array, Int32Array, Int64Array, Int8Array};
20use arrow_array::{Array, ArrowNativeTypeOp};
21use arrow_schema::DataType;
22use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
23use datafusion_common::{exec_err, internal_err, DataFusionError, ScalarValue};
24use std::{cmp::min, sync::Arc};
25
26macro_rules! integer_round {
27 ($X:expr, $DIV:expr, $HALF:expr) => {{
28 let rem = $X % $DIV;
29 if rem <= -$HALF {
30 ($X - rem).sub_wrapping($DIV)
31 } else if rem >= $HALF {
32 ($X - rem).add_wrapping($DIV)
33 } else {
34 $X - rem
35 }
36 }};
37}
38
39macro_rules! round_integer_array {
40 ($ARRAY:expr, $POINT:expr, $TYPE:ty, $NATIVE:ty) => {{
41 let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap();
42 let ten: $NATIVE = 10;
43 let result: $TYPE = if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) {
44 let half = div / 2;
45 arrow::compute::kernels::arity::unary(array, |x| integer_round!(x, div, half))
46 } else {
47 arrow::compute::kernels::arity::unary(array, |_| 0)
48 };
49 Ok(ColumnarValue::Array(Arc::new(result)))
50 }};
51}
52
53macro_rules! round_integer_scalar {
54 ($SCALAR:expr, $POINT:expr, $TYPE:expr, $NATIVE:ty) => {{
55 let ten: $NATIVE = 10;
56 if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) {
57 let half = div / 2;
58 Ok(ColumnarValue::Scalar($TYPE(
59 $SCALAR.map(|x| integer_round!(x, div, half)),
60 )))
61 } else {
62 Ok(ColumnarValue::Scalar($TYPE(Some(0))))
63 }
64 }};
65}
66
67pub fn spark_round(
69 args: &[ColumnarValue],
70 data_type: &DataType,
71) -> Result<ColumnarValue, DataFusionError> {
72 let value = &args[0];
73 let point = &args[1];
74 let ColumnarValue::Scalar(ScalarValue::Int64(Some(point))) = point else {
75 return internal_err!("Invalid point argument for Round(): {:#?}", point);
76 };
77 match value {
78 ColumnarValue::Array(array) => match array.data_type() {
79 DataType::Int64 if *point < 0 => round_integer_array!(array, point, Int64Array, i64),
80 DataType::Int32 if *point < 0 => round_integer_array!(array, point, Int32Array, i32),
81 DataType::Int16 if *point < 0 => round_integer_array!(array, point, Int16Array, i16),
82 DataType::Int8 if *point < 0 => round_integer_array!(array, point, Int8Array, i8),
83 DataType::Decimal128(_, scale) if *scale >= 0 => {
84 let f = decimal_round_f(scale, point);
85 let (precision, scale) = get_precision_scale(data_type);
86 make_decimal_array(array, precision, scale, &f)
87 }
88 DataType::Float32 | DataType::Float64 => Ok(ColumnarValue::Array(round(&[
89 Arc::clone(array),
90 args[1].to_array(array.len())?,
91 ])?)),
92 dt => exec_err!("Not supported datatype for ROUND: {dt}"),
93 },
94 ColumnarValue::Scalar(a) => match a {
95 ScalarValue::Int64(a) if *point < 0 => {
96 round_integer_scalar!(a, point, ScalarValue::Int64, i64)
97 }
98 ScalarValue::Int32(a) if *point < 0 => {
99 round_integer_scalar!(a, point, ScalarValue::Int32, i32)
100 }
101 ScalarValue::Int16(a) if *point < 0 => {
102 round_integer_scalar!(a, point, ScalarValue::Int16, i16)
103 }
104 ScalarValue::Int8(a) if *point < 0 => {
105 round_integer_scalar!(a, point, ScalarValue::Int8, i8)
106 }
107 ScalarValue::Decimal128(a, _, scale) if *scale >= 0 => {
108 let f = decimal_round_f(scale, point);
109 let (precision, scale) = get_precision_scale(data_type);
110 make_decimal_scalar(a, precision, scale, &f)
111 }
112 ScalarValue::Float32(_) | ScalarValue::Float64(_) => Ok(ColumnarValue::Scalar(
113 ScalarValue::try_from_array(&round(&[a.to_array()?, args[1].to_array(1)?])?, 0)?,
114 )),
115 dt => exec_err!("Not supported datatype for ROUND: {dt}"),
116 },
117 }
118}
119
120#[inline]
123fn decimal_round_f(scale: &i8, point: &i64) -> Box<dyn Fn(i128) -> i128> {
124 if *point < 0 {
125 if let Some(div) = 10_i128.checked_pow((-(*point) as u32) + (*scale as u32)) {
126 let half = div / 2;
127 let mul = 10_i128.pow_wrapping((-(*point)) as u32);
128 Box::new(move |x: i128| (x + x.signum() * half) / div * mul)
130 } else {
131 Box::new(move |_: i128| 0)
132 }
133 } else {
134 let div = 10_i128.pow_wrapping((*scale as u32) - min(*scale as u32, *point as u32));
135 let half = div / 2;
136 Box::new(move |x: i128| (x + x.signum() * half) / div)
137 }
138}
139
140#[cfg(test)]
141mod test {
142 use std::sync::Arc;
143
144 use crate::spark_round;
145
146 use arrow::array::{Float32Array, Float64Array};
147 use arrow_schema::DataType;
148 use datafusion_common::cast::{as_float32_array, as_float64_array};
149 use datafusion_common::{Result, ScalarValue};
150 use datafusion_expr::ColumnarValue;
151
152 #[test]
153 fn test_round_f32_array() -> Result<()> {
154 let args = vec![
155 ColumnarValue::Array(Arc::new(Float32Array::from(vec![
156 125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
157 ]))),
158 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
159 ];
160 let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float32)? else {
161 unreachable!()
162 };
163 let floats = as_float32_array(&result)?;
164 let expected = Float32Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
165 assert_eq!(floats, &expected);
166 Ok(())
167 }
168
169 #[test]
170 fn test_round_f64_array() -> Result<()> {
171 let args = vec![
172 ColumnarValue::Array(Arc::new(Float64Array::from(vec![
173 125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
174 ]))),
175 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
176 ];
177 let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float64)? else {
178 unreachable!()
179 };
180 let floats = as_float64_array(&result)?;
181 let expected = Float64Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
182 assert_eq!(floats, &expected);
183 Ok(())
184 }
185
186 #[test]
187 fn test_round_f32_scalar() -> Result<()> {
188 let args = vec![
189 ColumnarValue::Scalar(ScalarValue::Float32(Some(125.2345))),
190 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
191 ];
192 let ColumnarValue::Scalar(ScalarValue::Float32(Some(result))) =
193 spark_round(&args, &DataType::Float32)?
194 else {
195 unreachable!()
196 };
197 assert_eq!(result, 125.23);
198 Ok(())
199 }
200
201 #[test]
202 fn test_round_f64_scalar() -> Result<()> {
203 let args = vec![
204 ColumnarValue::Scalar(ScalarValue::Float64(Some(125.2345))),
205 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
206 ];
207 let ColumnarValue::Scalar(ScalarValue::Float64(Some(result))) =
208 spark_round(&args, &DataType::Float64)?
209 else {
210 unreachable!()
211 };
212 assert_eq!(result, 125.23);
213 Ok(())
214 }
215}