datafusion_comet_spark_expr/math_funcs/
round.rs1use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
19use arrow::array::{Array, ArrowNativeTypeOp};
20use arrow::array::{Int16Array, Int32Array, Int64Array, Int8Array};
21use arrow::datatypes::DataType;
22use datafusion::common::{exec_err, internal_err, DataFusionError, ScalarValue};
23use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
24use std::{cmp::min, sync::Arc};
25
26macro_rules! integer_round {
27 ($X:expr, $DIV:expr, $HALF:expr) => {{
28 let rem = $X % $DIV;
29 if rem <= -$HALF {
30 ($X - rem).sub_wrapping($DIV)
31 } else if rem >= $HALF {
32 ($X - rem).add_wrapping($DIV)
33 } else {
34 $X - rem
35 }
36 }};
37}
38
39macro_rules! round_integer_array {
40 ($ARRAY:expr, $POINT:expr, $TYPE:ty, $NATIVE:ty) => {{
41 let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap();
42 let ten: $NATIVE = 10;
43 let result: $TYPE = if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) {
44 let half = div / 2;
45 arrow::compute::kernels::arity::unary(array, |x| integer_round!(x, div, half))
46 } else {
47 arrow::compute::kernels::arity::unary(array, |_| 0)
48 };
49 Ok(ColumnarValue::Array(Arc::new(result)))
50 }};
51}
52
53macro_rules! round_integer_scalar {
54 ($SCALAR:expr, $POINT:expr, $TYPE:expr, $NATIVE:ty) => {{
55 let ten: $NATIVE = 10;
56 if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) {
57 let half = div / 2;
58 Ok(ColumnarValue::Scalar($TYPE(
59 $SCALAR.map(|x| integer_round!(x, div, half)),
60 )))
61 } else {
62 Ok(ColumnarValue::Scalar($TYPE(Some(0))))
63 }
64 }};
65}
66
67pub fn spark_round(
69 args: &[ColumnarValue],
70 data_type: &DataType,
71) -> Result<ColumnarValue, DataFusionError> {
72 let value = &args[0];
73 let point = &args[1];
74 let ColumnarValue::Scalar(ScalarValue::Int64(Some(point))) = point else {
75 return internal_err!("Invalid point argument for Round(): {:#?}", point);
76 };
77 match value {
78 ColumnarValue::Array(array) => match array.data_type() {
79 DataType::Int64 if *point < 0 => round_integer_array!(array, point, Int64Array, i64),
80 DataType::Int32 if *point < 0 => round_integer_array!(array, point, Int32Array, i32),
81 DataType::Int16 if *point < 0 => round_integer_array!(array, point, Int16Array, i16),
82 DataType::Int8 if *point < 0 => round_integer_array!(array, point, Int8Array, i8),
83 DataType::Decimal128(_, scale) if *scale >= 0 => {
84 let f = decimal_round_f(scale, point);
85 let (precision, scale) = get_precision_scale(data_type);
86 make_decimal_array(array, precision, scale, &f)
87 }
88 DataType::Float32 | DataType::Float64 => Ok(ColumnarValue::Array(round(&[
89 Arc::clone(array),
90 args[1].to_array(array.len())?,
91 ])?)),
92 dt => exec_err!("Not supported datatype for ROUND: {dt}"),
93 },
94 ColumnarValue::Scalar(a) => match a {
95 ScalarValue::Int64(a) if *point < 0 => {
96 round_integer_scalar!(a, point, ScalarValue::Int64, i64)
97 }
98 ScalarValue::Int32(a) if *point < 0 => {
99 round_integer_scalar!(a, point, ScalarValue::Int32, i32)
100 }
101 ScalarValue::Int16(a) if *point < 0 => {
102 round_integer_scalar!(a, point, ScalarValue::Int16, i16)
103 }
104 ScalarValue::Int8(a) if *point < 0 => {
105 round_integer_scalar!(a, point, ScalarValue::Int8, i8)
106 }
107 ScalarValue::Decimal128(a, _, scale) if *scale >= 0 => {
108 let f = decimal_round_f(scale, point);
109 let (precision, scale) = get_precision_scale(data_type);
110 make_decimal_scalar(a, precision, scale, &f)
111 }
112 ScalarValue::Float32(_) | ScalarValue::Float64(_) => Ok(ColumnarValue::Scalar(
113 ScalarValue::try_from_array(&round(&[a.to_array()?, args[1].to_array(1)?])?, 0)?,
114 )),
115 dt => exec_err!("Not supported datatype for ROUND: {dt}"),
116 },
117 }
118}
119
120#[inline]
123fn decimal_round_f(scale: &i8, point: &i64) -> Box<dyn Fn(i128) -> i128> {
124 if *point < 0 {
125 if let Some(div) = 10_i128.checked_pow((-(*point) as u32) + (*scale as u32)) {
126 let half = div / 2;
127 let mul = 10_i128.pow_wrapping((-(*point)) as u32);
128 Box::new(move |x: i128| (x + x.signum() * half) / div * mul)
130 } else {
131 Box::new(move |_: i128| 0)
132 }
133 } else {
134 let div = 10_i128.pow_wrapping((*scale as u32) - min(*scale as u32, *point as u32));
135 let half = div / 2;
136 Box::new(move |x: i128| (x + x.signum() * half) / div)
137 }
138}
139
140#[cfg(test)]
141mod test {
142 use std::sync::Arc;
143
144 use crate::spark_round;
145
146 use arrow::array::{Float32Array, Float64Array};
147 use arrow::datatypes::DataType;
148 use datafusion::common::cast::{as_float32_array, as_float64_array};
149 use datafusion::common::{Result, ScalarValue};
150 use datafusion::physical_plan::ColumnarValue;
151
152 #[test]
153 #[cfg_attr(miri, ignore)] fn test_round_f32_array() -> Result<()> {
155 let args = vec![
156 ColumnarValue::Array(Arc::new(Float32Array::from(vec![
157 125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
158 ]))),
159 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
160 ];
161 let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float32)? else {
162 unreachable!()
163 };
164 let floats = as_float32_array(&result)?;
165 let expected = Float32Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
166 assert_eq!(floats, &expected);
167 Ok(())
168 }
169
170 #[test]
171 #[cfg_attr(miri, ignore)] fn test_round_f64_array() -> Result<()> {
173 let args = vec![
174 ColumnarValue::Array(Arc::new(Float64Array::from(vec![
175 125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
176 ]))),
177 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
178 ];
179 let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float64)? else {
180 unreachable!()
181 };
182 let floats = as_float64_array(&result)?;
183 let expected = Float64Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
184 assert_eq!(floats, &expected);
185 Ok(())
186 }
187
188 #[test]
189 #[cfg_attr(miri, ignore)] fn test_round_f32_scalar() -> Result<()> {
191 let args = vec![
192 ColumnarValue::Scalar(ScalarValue::Float32(Some(125.2345))),
193 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
194 ];
195 let ColumnarValue::Scalar(ScalarValue::Float32(Some(result))) =
196 spark_round(&args, &DataType::Float32)?
197 else {
198 unreachable!()
199 };
200 assert_eq!(result, 125.23);
201 Ok(())
202 }
203
204 #[test]
205 #[cfg_attr(miri, ignore)] fn test_round_f64_scalar() -> Result<()> {
207 let args = vec![
208 ColumnarValue::Scalar(ScalarValue::Float64(Some(125.2345))),
209 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
210 ];
211 let ColumnarValue::Scalar(ScalarValue::Float64(Some(result))) =
212 spark_round(&args, &DataType::Float64)?
213 else {
214 unreachable!()
215 };
216 assert_eq!(result, 125.23);
217 Ok(())
218 }
219}