datafusion_comet_spark_expr/math_funcs/
floor.rs1use crate::downcast_compute_op;
19use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
20use arrow::array::{Array, ArrowNativeTypeOp};
21use arrow::array::{Float32Array, Float64Array, Int64Array};
22use arrow::datatypes::DataType;
23use datafusion::common::{DataFusionError, ScalarValue};
24use datafusion::physical_plan::ColumnarValue;
25use num::integer::div_floor;
26use std::sync::Arc;
27
28pub fn spark_floor(
30 args: &[ColumnarValue],
31 data_type: &DataType,
32) -> Result<ColumnarValue, DataFusionError> {
33 let value = &args[0];
34 match value {
35 ColumnarValue::Array(array) => match array.data_type() {
36 DataType::Float32 => {
37 let result = downcast_compute_op!(array, "floor", floor, Float32Array, Int64Array);
38 Ok(ColumnarValue::Array(result?))
39 }
40 DataType::Float64 => {
41 let result = downcast_compute_op!(array, "floor", floor, Float64Array, Int64Array);
42 Ok(ColumnarValue::Array(result?))
43 }
44 DataType::Int64 => {
45 let result = array.as_any().downcast_ref::<Int64Array>().unwrap();
46 Ok(ColumnarValue::Array(Arc::new(result.clone())))
47 }
48 DataType::Decimal128(_, scale) if *scale > 0 => {
49 let f = decimal_floor_f(scale);
50 let (precision, scale) = get_precision_scale(data_type);
51 make_decimal_array(array, precision, scale, &f)
52 }
53 other => Err(DataFusionError::Internal(format!(
54 "Unsupported data type {other:?} for function floor",
55 ))),
56 },
57 ColumnarValue::Scalar(a) => match a {
58 ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
59 a.map(|x| x.floor() as i64),
60 ))),
61 ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
62 a.map(|x| x.floor() as i64),
63 ))),
64 ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))),
65 ScalarValue::Decimal128(a, _, scale) if *scale > 0 => {
66 let f = decimal_floor_f(scale);
67 let (precision, scale) = get_precision_scale(data_type);
68 make_decimal_scalar(a, precision, scale, &f)
69 }
70 _ => Err(DataFusionError::Internal(format!(
71 "Unsupported data type {:?} for function floor",
72 value.data_type(),
73 ))),
74 },
75 }
76}
77
78#[inline]
79fn decimal_floor_f(scale: &i8) -> impl Fn(i128) -> i128 {
80 let div = 10_i128.pow_wrapping(*scale as u32);
81 move |x: i128| div_floor(x, div)
82}
83
84#[cfg(test)]
85mod test {
86 use crate::spark_floor;
87 use arrow::array::{Decimal128Array, Float32Array, Float64Array, Int64Array};
88 use arrow::datatypes::DataType;
89 use datafusion::common::cast::as_int64_array;
90 use datafusion::common::{Result, ScalarValue};
91 use datafusion::physical_plan::ColumnarValue;
92 use std::sync::Arc;
93
94 #[test]
95 fn test_floor_f32_array() -> Result<()> {
96 let input = Float32Array::from(vec![
97 Some(125.9345),
98 Some(15.9999),
99 Some(0.9),
100 Some(-0.1),
101 Some(-1.999),
102 Some(123.0),
103 None,
104 ]);
105 let args = vec![ColumnarValue::Array(Arc::new(input))];
106 let ColumnarValue::Array(result) = spark_floor(&args, &DataType::Float32)? else {
107 unreachable!()
108 };
109 let actual = as_int64_array(&result)?;
110 let expected = Int64Array::from(vec![
111 Some(125),
112 Some(15),
113 Some(0),
114 Some(-1),
115 Some(-2),
116 Some(123),
117 None,
118 ]);
119 assert_eq!(actual, &expected);
120 Ok(())
121 }
122
123 #[test]
124 fn test_floor_f64_array() -> Result<()> {
125 let input = Float64Array::from(vec![
126 Some(125.9345),
127 Some(15.9999),
128 Some(0.9),
129 Some(-0.1),
130 Some(-1.999),
131 Some(123.0),
132 None,
133 ]);
134 let args = vec![ColumnarValue::Array(Arc::new(input))];
135 let ColumnarValue::Array(result) = spark_floor(&args, &DataType::Float64)? else {
136 unreachable!()
137 };
138 let actual = as_int64_array(&result)?;
139 let expected = Int64Array::from(vec![
140 Some(125),
141 Some(15),
142 Some(0),
143 Some(-1),
144 Some(-2),
145 Some(123),
146 None,
147 ]);
148 assert_eq!(actual, &expected);
149 Ok(())
150 }
151
152 #[test]
153 fn test_floor_i64_array() -> Result<()> {
154 let input = Int64Array::from(vec![Some(-1), Some(0), Some(1), None]);
155 let args = vec![ColumnarValue::Array(Arc::new(input))];
156 let ColumnarValue::Array(result) = spark_floor(&args, &DataType::Int64)? else {
157 unreachable!()
158 };
159 let actual = as_int64_array(&result)?;
160 let expected = Int64Array::from(vec![Some(-1), Some(0), Some(1), None]);
161 assert_eq!(actual, &expected);
162 Ok(())
163 }
164
165 #[test]
167 #[ignore]
168 fn test_floor_decimal128_array() -> Result<()> {
169 let array = Decimal128Array::from(vec![
170 Some(12345), Some(12500), Some(-12999), None,
174 ])
175 .with_precision_and_scale(5, 2)?;
176 let args = vec![ColumnarValue::Array(Arc::new(array))];
177 let ColumnarValue::Array(result) = spark_floor(&args, &DataType::Decimal128(5, 2))? else {
178 unreachable!()
179 };
180 let expected = Decimal128Array::from(vec![
181 Some(12300), Some(12500), Some(-13000), None,
185 ])
186 .with_precision_and_scale(5, 2)?;
187 let actual = result.as_any().downcast_ref::<Decimal128Array>().unwrap();
188 assert_eq!(actual, &expected);
189 Ok(())
190 }
191
192 #[test]
193 fn test_floor_f32_scalar() -> Result<()> {
194 let args = vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(125.9345)))];
195 let ColumnarValue::Scalar(ScalarValue::Int64(Some(result))) =
196 spark_floor(&args, &DataType::Float32)?
197 else {
198 unreachable!()
199 };
200 assert_eq!(result, 125);
201 Ok(())
202 }
203
204 #[test]
205 fn test_floor_f64_scalar() -> Result<()> {
206 let args = vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(-1.999)))];
207 let ColumnarValue::Scalar(ScalarValue::Int64(Some(result))) =
208 spark_floor(&args, &DataType::Float64)?
209 else {
210 unreachable!()
211 };
212 assert_eq!(result, -2);
213 Ok(())
214 }
215
216 #[test]
217 fn test_floor_i64_scalar() -> Result<()> {
218 let args = vec![ColumnarValue::Scalar(ScalarValue::Int64(Some(48)))];
219 let ColumnarValue::Scalar(ScalarValue::Int64(Some(result))) =
220 spark_floor(&args, &DataType::Int64)?
221 else {
222 unreachable!()
223 };
224 assert_eq!(result, 48);
225 Ok(())
226 }
227
228 #[test]
230 #[ignore]
231 fn test_floor_decimal128_scalar() -> Result<()> {
232 let args = vec![ColumnarValue::Scalar(ScalarValue::Decimal128(
233 Some(567),
234 3,
235 1,
236 ))]; let ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), 3, 1)) =
238 spark_floor(&args, &DataType::Decimal128(3, 1))?
239 else {
240 unreachable!()
241 };
242 assert_eq!(result, 560); Ok(())
244 }
245}