Skip to main content

datafusion_spark/function/math/
ceil.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{ArrowNativeTypeOp, AsArray, Decimal128Array};
21use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type, Int64Type};
22use datafusion_common::utils::take_function_args;
23use datafusion_common::{Result, ScalarValue, exec_err};
24use datafusion_expr::{
25    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
26};
27
28/// Spark-compatible `ceil` expression
29/// <https://spark.apache.org/docs/latest/api/sql/index.html#ceil>
30///
31/// Differences with DataFusion ceil:
32///  - Spark's ceil returns Int64 for float inputs; DataFusion preserves
33///    the input type (Float32→Float32, Float64→Float64)
34///  - Spark's ceil on Decimal128(p, s) returns Decimal128(p−s+1, 0), reducing scale
35///    to 0; DataFusion preserves the original precision and scale
36///  - Spark only supports Decimal128; DataFusion also supports Decimal32/64/256
37///  - Spark does not check for decimal overflow; DataFusion errors on overflow
38///
39/// 2-argument ceil(value, scale) is not yet implemented
40/// <https://github.com/apache/datafusion/issues/21560>
41#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkCeil {
43    signature: Signature,
44    aliases: Vec<String>,
45}
46
47impl Default for SparkCeil {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl SparkCeil {
54    pub fn new() -> Self {
55        Self {
56            signature: Signature::numeric(1, Volatility::Immutable),
57            aliases: vec!["ceiling".to_string()],
58        }
59    }
60}
61
62impl ScalarUDFImpl for SparkCeil {
63    fn name(&self) -> &str {
64        "ceil"
65    }
66
67    fn signature(&self) -> &Signature {
68        &self.signature
69    }
70
71    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
72        match &arg_types[0] {
73            DataType::Decimal128(p, s) => {
74                if *s > 0 {
75                    Ok(DataType::Decimal128(decimal128_ceil_precision(*p, *s), 0))
76                } else {
77                    // scale <= 0 means the value is already a whole number
78                    // (or represents multiples of 10^(-scale)), so ceil is a no-op
79                    Ok(DataType::Decimal128(*p, *s))
80                }
81            }
82            dt if matches!(dt, DataType::Float32 | DataType::Float64)
83                || dt.is_integer() =>
84            {
85                Ok(DataType::Int64)
86            }
87            other => exec_err!("Unsupported data type {other:?} for function ceil"),
88        }
89    }
90
91    fn aliases(&self) -> &[String] {
92        &self.aliases
93    }
94
95    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
96        spark_ceil(&args.args)
97    }
98}
99
100fn spark_ceil(args: &[ColumnarValue]) -> Result<ColumnarValue> {
101    let [input] = take_function_args("ceil", args)?;
102
103    match input {
104        ColumnarValue::Scalar(value) => spark_ceil_scalar(value),
105        ColumnarValue::Array(input) => spark_ceil_array(input),
106    }
107}
108
109/// Compute ceil for a single decimal128 value with the given scale.
110#[inline]
111fn decimal128_ceil(value: i128, scale: u32) -> i128 {
112    let div = 10_i128.pow_wrapping(scale);
113    let d = value / div;
114    let r = value % div;
115    if r > 0 { d + 1 } else { d }
116}
117
118/// Compute the return precision for a decimal128 ceil result.
119#[inline]
120fn decimal128_ceil_precision(precision: u8, scale: i8) -> u8 {
121    ((precision as i64) - (scale as i64) + 1).clamp(1, 38) as u8
122}
123
124fn spark_ceil_scalar(value: &ScalarValue) -> Result<ColumnarValue> {
125    let result = match value {
126        ScalarValue::Float32(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)),
127        ScalarValue::Float64(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)),
128        v if v.data_type().is_integer() => v.cast_to(&DataType::Int64)?,
129        ScalarValue::Decimal128(v, p, s) if *s > 0 => {
130            let new_p = decimal128_ceil_precision(*p, *s);
131            ScalarValue::Decimal128(v.map(|x| decimal128_ceil(x, *s as u32)), new_p, 0)
132        }
133        ScalarValue::Decimal128(_, _, _) => value.clone(),
134        other => {
135            return exec_err!(
136                "Unsupported data type {:?} for function ceil",
137                other.data_type()
138            );
139        }
140    };
141    Ok(ColumnarValue::Scalar(result))
142}
143
144fn spark_ceil_array(input: &Arc<dyn arrow::array::Array>) -> Result<ColumnarValue> {
145    let result = match input.data_type() {
146        DataType::Float32 => Arc::new(
147            input
148                .as_primitive::<Float32Type>()
149                .unary::<_, Int64Type>(|x| x.ceil() as i64),
150        ) as _,
151        DataType::Float64 => Arc::new(
152            input
153                .as_primitive::<Float64Type>()
154                .unary::<_, Int64Type>(|x| x.ceil() as i64),
155        ) as _,
156        dt if dt.is_integer() => arrow::compute::cast(input, &DataType::Int64)?,
157        DataType::Decimal128(p, s) if *s > 0 => {
158            let new_p = decimal128_ceil_precision(*p, *s);
159            let result: Decimal128Array = input
160                .as_primitive::<Decimal128Type>()
161                .unary(|x| decimal128_ceil(x, *s as u32));
162            Arc::new(result.with_data_type(DataType::Decimal128(new_p, 0)))
163        }
164        DataType::Decimal128(_, _) => Arc::clone(input),
165        other => return exec_err!("Unsupported data type {other:?} for function ceil"),
166    };
167
168    Ok(ColumnarValue::Array(result))
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use arrow::array::{Decimal128Array, Float32Array, Float64Array, Int64Array};
175    use datafusion_common::ScalarValue;
176
177    #[test]
178    fn test_ceil_float64() {
179        let input = Float64Array::from(vec![
180            Some(125.2345),
181            Some(15.0001),
182            Some(0.1),
183            Some(-0.9),
184            Some(-1.1),
185            Some(123.0),
186            None,
187        ]);
188        let args = vec![ColumnarValue::Array(Arc::new(input))];
189        let result = spark_ceil(&args).unwrap();
190        let result = match result {
191            ColumnarValue::Array(arr) => arr,
192            _ => panic!("Expected array"),
193        };
194        let result = result.as_primitive::<Int64Type>();
195        assert_eq!(
196            result,
197            &Int64Array::from(vec![
198                Some(126),
199                Some(16),
200                Some(1),
201                Some(0),
202                Some(-1),
203                Some(123),
204                None,
205            ])
206        );
207    }
208
209    #[test]
210    fn test_ceil_float32() {
211        let input = Float32Array::from(vec![
212            Some(125.2345f32),
213            Some(15.0001f32),
214            Some(0.1f32),
215            Some(-0.9f32),
216            Some(-1.1f32),
217            Some(123.0f32),
218            None,
219        ]);
220        let args = vec![ColumnarValue::Array(Arc::new(input))];
221        let result = spark_ceil(&args).unwrap();
222        let result = match result {
223            ColumnarValue::Array(arr) => arr,
224            _ => panic!("Expected array"),
225        };
226        let result = result.as_primitive::<Int64Type>();
227        assert_eq!(
228            result,
229            &Int64Array::from(vec![
230                Some(126),
231                Some(16),
232                Some(1),
233                Some(0),
234                Some(-1),
235                Some(123),
236                None,
237            ])
238        );
239    }
240
241    #[test]
242    fn test_ceil_int64() {
243        let input = Int64Array::from(vec![Some(1), Some(-1), None]);
244        let args = vec![ColumnarValue::Array(Arc::new(input))];
245        let result = spark_ceil(&args).unwrap();
246        let result = match result {
247            ColumnarValue::Array(arr) => arr,
248            _ => panic!("Expected array"),
249        };
250        let result = result.as_primitive::<Int64Type>();
251        assert_eq!(result, &Int64Array::from(vec![Some(1), Some(-1), None]));
252    }
253
254    #[test]
255    fn test_ceil_decimal128() {
256        // Decimal128(10, 2): 150 = 1.50, -150 = -1.50, 100 = 1.00
257        let return_type = DataType::Decimal128(9, 0);
258        let input = Decimal128Array::from(vec![Some(150), Some(-150), Some(100), None])
259            .with_data_type(DataType::Decimal128(10, 2));
260        let args = vec![ColumnarValue::Array(Arc::new(input))];
261        let result = spark_ceil(&args).unwrap();
262        let result = match result {
263            ColumnarValue::Array(arr) => arr,
264            _ => panic!("Expected array"),
265        };
266        let result = result.as_primitive::<Decimal128Type>();
267        let expected = Decimal128Array::from(vec![Some(2), Some(-1), Some(1), None])
268            .with_data_type(return_type);
269        assert_eq!(result, &expected);
270    }
271
272    #[test]
273    fn test_ceil_float64_scalar() {
274        let input = ScalarValue::Float64(Some(-1.1));
275        let args = vec![ColumnarValue::Scalar(input)];
276        let result = match spark_ceil(&args).unwrap() {
277            ColumnarValue::Scalar(v) => v,
278            _ => panic!("Expected scalar"),
279        };
280        assert_eq!(result, ScalarValue::Int64(Some(-1)));
281    }
282
283    #[test]
284    fn test_ceil_float32_scalar() {
285        let input = ScalarValue::Float32(Some(125.2345f32));
286        let args = vec![ColumnarValue::Scalar(input)];
287        let result = match spark_ceil(&args).unwrap() {
288            ColumnarValue::Scalar(v) => v,
289            _ => panic!("Expected scalar"),
290        };
291        assert_eq!(result, ScalarValue::Int64(Some(126)));
292    }
293
294    #[test]
295    fn test_ceil_int64_scalar() {
296        let input = ScalarValue::Int64(Some(48));
297        let args = vec![ColumnarValue::Scalar(input)];
298        let result = match spark_ceil(&args).unwrap() {
299            ColumnarValue::Scalar(v) => v,
300            _ => panic!("Expected scalar"),
301        };
302        assert_eq!(result, ScalarValue::Int64(Some(48)));
303    }
304}