datafusion_comet_spark_expr/math_funcs/
floor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::downcast_compute_op;
19use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
20use arrow::array::{Array, ArrowNativeTypeOp};
21use arrow::array::{Float32Array, Float64Array, Int64Array};
22use arrow::datatypes::DataType;
23use datafusion::common::{DataFusionError, ScalarValue};
24use datafusion::physical_plan::ColumnarValue;
25use num::integer::div_floor;
26use std::sync::Arc;
27
28/// `floor` function that simulates Spark `floor` expression
29pub fn spark_floor(
30    args: &[ColumnarValue],
31    data_type: &DataType,
32) -> Result<ColumnarValue, DataFusionError> {
33    let value = &args[0];
34    match value {
35        ColumnarValue::Array(array) => match array.data_type() {
36            DataType::Float32 => {
37                let result = downcast_compute_op!(array, "floor", floor, Float32Array, Int64Array);
38                Ok(ColumnarValue::Array(result?))
39            }
40            DataType::Float64 => {
41                let result = downcast_compute_op!(array, "floor", floor, Float64Array, Int64Array);
42                Ok(ColumnarValue::Array(result?))
43            }
44            DataType::Int64 => {
45                let result = array.as_any().downcast_ref::<Int64Array>().unwrap();
46                Ok(ColumnarValue::Array(Arc::new(result.clone())))
47            }
48            DataType::Decimal128(_, scale) if *scale > 0 => {
49                let f = decimal_floor_f(scale);
50                let (precision, scale) = get_precision_scale(data_type);
51                make_decimal_array(array, precision, scale, &f)
52            }
53            other => Err(DataFusionError::Internal(format!(
54                "Unsupported data type {other:?} for function floor",
55            ))),
56        },
57        ColumnarValue::Scalar(a) => match a {
58            ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
59                a.map(|x| x.floor() as i64),
60            ))),
61            ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
62                a.map(|x| x.floor() as i64),
63            ))),
64            ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))),
65            ScalarValue::Decimal128(a, _, scale) if *scale > 0 => {
66                let f = decimal_floor_f(scale);
67                let (precision, scale) = get_precision_scale(data_type);
68                make_decimal_scalar(a, precision, scale, &f)
69            }
70            _ => Err(DataFusionError::Internal(format!(
71                "Unsupported data type {:?} for function floor",
72                value.data_type(),
73            ))),
74        },
75    }
76}
77
78#[inline]
79fn decimal_floor_f(scale: &i8) -> impl Fn(i128) -> i128 {
80    let div = 10_i128.pow_wrapping(*scale as u32);
81    move |x: i128| div_floor(x, div)
82}
83
84#[cfg(test)]
85mod test {
86    use crate::spark_floor;
87    use arrow::array::{Decimal128Array, Float32Array, Float64Array, Int64Array};
88    use arrow::datatypes::DataType;
89    use datafusion::common::cast::as_int64_array;
90    use datafusion::common::{Result, ScalarValue};
91    use datafusion::physical_plan::ColumnarValue;
92    use std::sync::Arc;
93
94    #[test]
95    fn test_floor_f32_array() -> Result<()> {
96        let input = Float32Array::from(vec![
97            Some(125.9345),
98            Some(15.9999),
99            Some(0.9),
100            Some(-0.1),
101            Some(-1.999),
102            Some(123.0),
103            None,
104        ]);
105        let args = vec![ColumnarValue::Array(Arc::new(input))];
106        let ColumnarValue::Array(result) = spark_floor(&args, &DataType::Float32)? else {
107            unreachable!()
108        };
109        let actual = as_int64_array(&result)?;
110        let expected = Int64Array::from(vec![
111            Some(125),
112            Some(15),
113            Some(0),
114            Some(-1),
115            Some(-2),
116            Some(123),
117            None,
118        ]);
119        assert_eq!(actual, &expected);
120        Ok(())
121    }
122
123    #[test]
124    fn test_floor_f64_array() -> Result<()> {
125        let input = Float64Array::from(vec![
126            Some(125.9345),
127            Some(15.9999),
128            Some(0.9),
129            Some(-0.1),
130            Some(-1.999),
131            Some(123.0),
132            None,
133        ]);
134        let args = vec![ColumnarValue::Array(Arc::new(input))];
135        let ColumnarValue::Array(result) = spark_floor(&args, &DataType::Float64)? else {
136            unreachable!()
137        };
138        let actual = as_int64_array(&result)?;
139        let expected = Int64Array::from(vec![
140            Some(125),
141            Some(15),
142            Some(0),
143            Some(-1),
144            Some(-2),
145            Some(123),
146            None,
147        ]);
148        assert_eq!(actual, &expected);
149        Ok(())
150    }
151
152    #[test]
153    fn test_floor_i64_array() -> Result<()> {
154        let input = Int64Array::from(vec![Some(-1), Some(0), Some(1), None]);
155        let args = vec![ColumnarValue::Array(Arc::new(input))];
156        let ColumnarValue::Array(result) = spark_floor(&args, &DataType::Int64)? else {
157            unreachable!()
158        };
159        let actual = as_int64_array(&result)?;
160        let expected = Int64Array::from(vec![Some(-1), Some(0), Some(1), None]);
161        assert_eq!(actual, &expected);
162        Ok(())
163    }
164
165    // https://github.com/apache/datafusion-comet/issues/1729
166    #[test]
167    #[ignore]
168    fn test_floor_decimal128_array() -> Result<()> {
169        let array = Decimal128Array::from(vec![
170            Some(12345),  // 123.45
171            Some(12500),  // 125.00
172            Some(-12999), // -129.99
173            None,
174        ])
175        .with_precision_and_scale(5, 2)?;
176        let args = vec![ColumnarValue::Array(Arc::new(array))];
177        let ColumnarValue::Array(result) = spark_floor(&args, &DataType::Decimal128(5, 2))? else {
178            unreachable!()
179        };
180        let expected = Decimal128Array::from(vec![
181            Some(12300),  // 123.00
182            Some(12500),  // 125.00
183            Some(-13000), // -130.00
184            None,
185        ])
186        .with_precision_and_scale(5, 2)?;
187        let actual = result.as_any().downcast_ref::<Decimal128Array>().unwrap();
188        assert_eq!(actual, &expected);
189        Ok(())
190    }
191
192    #[test]
193    fn test_floor_f32_scalar() -> Result<()> {
194        let args = vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(125.9345)))];
195        let ColumnarValue::Scalar(ScalarValue::Int64(Some(result))) =
196            spark_floor(&args, &DataType::Float32)?
197        else {
198            unreachable!()
199        };
200        assert_eq!(result, 125);
201        Ok(())
202    }
203
204    #[test]
205    fn test_floor_f64_scalar() -> Result<()> {
206        let args = vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(-1.999)))];
207        let ColumnarValue::Scalar(ScalarValue::Int64(Some(result))) =
208            spark_floor(&args, &DataType::Float64)?
209        else {
210            unreachable!()
211        };
212        assert_eq!(result, -2);
213        Ok(())
214    }
215
216    #[test]
217    fn test_floor_i64_scalar() -> Result<()> {
218        let args = vec![ColumnarValue::Scalar(ScalarValue::Int64(Some(48)))];
219        let ColumnarValue::Scalar(ScalarValue::Int64(Some(result))) =
220            spark_floor(&args, &DataType::Int64)?
221        else {
222            unreachable!()
223        };
224        assert_eq!(result, 48);
225        Ok(())
226    }
227
228    // https://github.com/apache/datafusion-comet/issues/1729
229    #[test]
230    #[ignore]
231    fn test_floor_decimal128_scalar() -> Result<()> {
232        let args = vec![ColumnarValue::Scalar(ScalarValue::Decimal128(
233            Some(567),
234            3,
235            1,
236        ))]; // 56.7
237        let ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), 3, 1)) =
238            spark_floor(&args, &DataType::Decimal128(3, 1))?
239        else {
240            unreachable!()
241        };
242        assert_eq!(result, 560); // 56.0
243        Ok(())
244    }
245}