datafusion_spark/function/math/
ceil.rs1use std::sync::Arc;
19
20use arrow::array::{ArrowNativeTypeOp, AsArray, Decimal128Array};
21use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type, Int64Type};
22use datafusion_common::utils::take_function_args;
23use datafusion_common::{Result, ScalarValue, exec_err};
24use datafusion_expr::{
25 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
26};
27
28#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkCeil {
43 signature: Signature,
44 aliases: Vec<String>,
45}
46
47impl Default for SparkCeil {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl SparkCeil {
54 pub fn new() -> Self {
55 Self {
56 signature: Signature::numeric(1, Volatility::Immutable),
57 aliases: vec!["ceiling".to_string()],
58 }
59 }
60}
61
62impl ScalarUDFImpl for SparkCeil {
63 fn name(&self) -> &str {
64 "ceil"
65 }
66
67 fn signature(&self) -> &Signature {
68 &self.signature
69 }
70
71 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
72 match &arg_types[0] {
73 DataType::Decimal128(p, s) => {
74 if *s > 0 {
75 Ok(DataType::Decimal128(decimal128_ceil_precision(*p, *s), 0))
76 } else {
77 Ok(DataType::Decimal128(*p, *s))
80 }
81 }
82 dt if matches!(dt, DataType::Float32 | DataType::Float64)
83 || dt.is_integer() =>
84 {
85 Ok(DataType::Int64)
86 }
87 other => exec_err!("Unsupported data type {other:?} for function ceil"),
88 }
89 }
90
91 fn aliases(&self) -> &[String] {
92 &self.aliases
93 }
94
95 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
96 spark_ceil(&args.args)
97 }
98}
99
100fn spark_ceil(args: &[ColumnarValue]) -> Result<ColumnarValue> {
101 let [input] = take_function_args("ceil", args)?;
102
103 match input {
104 ColumnarValue::Scalar(value) => spark_ceil_scalar(value),
105 ColumnarValue::Array(input) => spark_ceil_array(input),
106 }
107}
108
109#[inline]
111fn decimal128_ceil(value: i128, scale: u32) -> i128 {
112 let div = 10_i128.pow_wrapping(scale);
113 let d = value / div;
114 let r = value % div;
115 if r > 0 { d + 1 } else { d }
116}
117
118#[inline]
120fn decimal128_ceil_precision(precision: u8, scale: i8) -> u8 {
121 ((precision as i64) - (scale as i64) + 1).clamp(1, 38) as u8
122}
123
124fn spark_ceil_scalar(value: &ScalarValue) -> Result<ColumnarValue> {
125 let result = match value {
126 ScalarValue::Float32(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)),
127 ScalarValue::Float64(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)),
128 v if v.data_type().is_integer() => v.cast_to(&DataType::Int64)?,
129 ScalarValue::Decimal128(v, p, s) if *s > 0 => {
130 let new_p = decimal128_ceil_precision(*p, *s);
131 ScalarValue::Decimal128(v.map(|x| decimal128_ceil(x, *s as u32)), new_p, 0)
132 }
133 ScalarValue::Decimal128(_, _, _) => value.clone(),
134 other => {
135 return exec_err!(
136 "Unsupported data type {:?} for function ceil",
137 other.data_type()
138 );
139 }
140 };
141 Ok(ColumnarValue::Scalar(result))
142}
143
144fn spark_ceil_array(input: &Arc<dyn arrow::array::Array>) -> Result<ColumnarValue> {
145 let result = match input.data_type() {
146 DataType::Float32 => Arc::new(
147 input
148 .as_primitive::<Float32Type>()
149 .unary::<_, Int64Type>(|x| x.ceil() as i64),
150 ) as _,
151 DataType::Float64 => Arc::new(
152 input
153 .as_primitive::<Float64Type>()
154 .unary::<_, Int64Type>(|x| x.ceil() as i64),
155 ) as _,
156 dt if dt.is_integer() => arrow::compute::cast(input, &DataType::Int64)?,
157 DataType::Decimal128(p, s) if *s > 0 => {
158 let new_p = decimal128_ceil_precision(*p, *s);
159 let result: Decimal128Array = input
160 .as_primitive::<Decimal128Type>()
161 .unary(|x| decimal128_ceil(x, *s as u32));
162 Arc::new(result.with_data_type(DataType::Decimal128(new_p, 0)))
163 }
164 DataType::Decimal128(_, _) => Arc::clone(input),
165 other => return exec_err!("Unsupported data type {other:?} for function ceil"),
166 };
167
168 Ok(ColumnarValue::Array(result))
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use arrow::array::{Decimal128Array, Float32Array, Float64Array, Int64Array};
175 use datafusion_common::ScalarValue;
176
177 #[test]
178 fn test_ceil_float64() {
179 let input = Float64Array::from(vec![
180 Some(125.2345),
181 Some(15.0001),
182 Some(0.1),
183 Some(-0.9),
184 Some(-1.1),
185 Some(123.0),
186 None,
187 ]);
188 let args = vec![ColumnarValue::Array(Arc::new(input))];
189 let result = spark_ceil(&args).unwrap();
190 let result = match result {
191 ColumnarValue::Array(arr) => arr,
192 _ => panic!("Expected array"),
193 };
194 let result = result.as_primitive::<Int64Type>();
195 assert_eq!(
196 result,
197 &Int64Array::from(vec![
198 Some(126),
199 Some(16),
200 Some(1),
201 Some(0),
202 Some(-1),
203 Some(123),
204 None,
205 ])
206 );
207 }
208
209 #[test]
210 fn test_ceil_float32() {
211 let input = Float32Array::from(vec![
212 Some(125.2345f32),
213 Some(15.0001f32),
214 Some(0.1f32),
215 Some(-0.9f32),
216 Some(-1.1f32),
217 Some(123.0f32),
218 None,
219 ]);
220 let args = vec![ColumnarValue::Array(Arc::new(input))];
221 let result = spark_ceil(&args).unwrap();
222 let result = match result {
223 ColumnarValue::Array(arr) => arr,
224 _ => panic!("Expected array"),
225 };
226 let result = result.as_primitive::<Int64Type>();
227 assert_eq!(
228 result,
229 &Int64Array::from(vec![
230 Some(126),
231 Some(16),
232 Some(1),
233 Some(0),
234 Some(-1),
235 Some(123),
236 None,
237 ])
238 );
239 }
240
241 #[test]
242 fn test_ceil_int64() {
243 let input = Int64Array::from(vec![Some(1), Some(-1), None]);
244 let args = vec![ColumnarValue::Array(Arc::new(input))];
245 let result = spark_ceil(&args).unwrap();
246 let result = match result {
247 ColumnarValue::Array(arr) => arr,
248 _ => panic!("Expected array"),
249 };
250 let result = result.as_primitive::<Int64Type>();
251 assert_eq!(result, &Int64Array::from(vec![Some(1), Some(-1), None]));
252 }
253
254 #[test]
255 fn test_ceil_decimal128() {
256 let return_type = DataType::Decimal128(9, 0);
258 let input = Decimal128Array::from(vec![Some(150), Some(-150), Some(100), None])
259 .with_data_type(DataType::Decimal128(10, 2));
260 let args = vec![ColumnarValue::Array(Arc::new(input))];
261 let result = spark_ceil(&args).unwrap();
262 let result = match result {
263 ColumnarValue::Array(arr) => arr,
264 _ => panic!("Expected array"),
265 };
266 let result = result.as_primitive::<Decimal128Type>();
267 let expected = Decimal128Array::from(vec![Some(2), Some(-1), Some(1), None])
268 .with_data_type(return_type);
269 assert_eq!(result, &expected);
270 }
271
272 #[test]
273 fn test_ceil_float64_scalar() {
274 let input = ScalarValue::Float64(Some(-1.1));
275 let args = vec![ColumnarValue::Scalar(input)];
276 let result = match spark_ceil(&args).unwrap() {
277 ColumnarValue::Scalar(v) => v,
278 _ => panic!("Expected scalar"),
279 };
280 assert_eq!(result, ScalarValue::Int64(Some(-1)));
281 }
282
283 #[test]
284 fn test_ceil_float32_scalar() {
285 let input = ScalarValue::Float32(Some(125.2345f32));
286 let args = vec![ColumnarValue::Scalar(input)];
287 let result = match spark_ceil(&args).unwrap() {
288 ColumnarValue::Scalar(v) => v,
289 _ => panic!("Expected scalar"),
290 };
291 assert_eq!(result, ScalarValue::Int64(Some(126)));
292 }
293
294 #[test]
295 fn test_ceil_int64_scalar() {
296 let input = ScalarValue::Int64(Some(48));
297 let args = vec![ColumnarValue::Scalar(input)];
298 let result = match spark_ceil(&args).unwrap() {
299 ColumnarValue::Scalar(v) => v,
300 _ => panic!("Expected scalar"),
301 };
302 assert_eq!(result, ScalarValue::Int64(Some(48)));
303 }
304}