1use arrow::array::*;
19use arrow::datatypes::{DataType, Field, FieldRef};
20use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err};
21use datafusion_expr::{
22 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
23 Volatility,
24};
25use datafusion_functions::{
26 downcast_named_arg, make_abs_function, make_wrapping_abs_function,
27};
28use std::any::Any;
29use std::sync::Arc;
30
31#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkAbs {
43 signature: Signature,
44}
45
46impl Default for SparkAbs {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl SparkAbs {
53 pub fn new() -> Self {
54 Self {
55 signature: Signature::numeric(1, Volatility::Immutable),
56 }
57 }
58}
59
60impl ScalarUDFImpl for SparkAbs {
61 fn as_any(&self) -> &dyn Any {
62 self
63 }
64
65 fn name(&self) -> &str {
66 "abs"
67 }
68
69 fn signature(&self) -> &Signature {
70 &self.signature
71 }
72
73 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
74 internal_err!(
75 "SparkAbs: return_type() is not used; return_field_from_args() is implemented"
76 )
77 }
78
79 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
80 let input_field = &args.arg_fields[0];
81 let out_dt = input_field.data_type().clone();
82 let out_nullable = input_field.is_nullable();
83
84 Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable)))
85 }
86
87 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
88 spark_abs(&args.args)
89 }
90}
91
92macro_rules! scalar_compute_op {
93 ($INPUT:ident, $SCALAR_TYPE:ident) => {{
94 let result = $INPUT.wrapping_abs();
95 Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(Some(
96 result,
97 ))))
98 }};
99 ($INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{
100 let result = $INPUT.wrapping_abs();
101 Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(
102 Some(result),
103 $PRECISION,
104 $SCALE,
105 )))
106 }};
107}
108
109pub fn spark_abs(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
110 if args.len() != 1 {
111 return internal_err!("abs takes exactly 1 argument, but got: {}", args.len());
112 }
113
114 match &args[0] {
115 ColumnarValue::Array(array) => match array.data_type() {
116 DataType::Null
117 | DataType::UInt8
118 | DataType::UInt16
119 | DataType::UInt32
120 | DataType::UInt64 => Ok(args[0].clone()),
121 DataType::Int8 => {
122 let abs_fun = make_wrapping_abs_function!(Int8Array);
123 abs_fun(array).map(ColumnarValue::Array)
124 }
125 DataType::Int16 => {
126 let abs_fun = make_wrapping_abs_function!(Int16Array);
127 abs_fun(array).map(ColumnarValue::Array)
128 }
129 DataType::Int32 => {
130 let abs_fun = make_wrapping_abs_function!(Int32Array);
131 abs_fun(array).map(ColumnarValue::Array)
132 }
133 DataType::Int64 => {
134 let abs_fun = make_wrapping_abs_function!(Int64Array);
135 abs_fun(array).map(ColumnarValue::Array)
136 }
137 DataType::Float32 => {
138 let abs_fun = make_abs_function!(Float32Array);
139 abs_fun(array).map(ColumnarValue::Array)
140 }
141 DataType::Float64 => {
142 let abs_fun = make_abs_function!(Float64Array);
143 abs_fun(array).map(ColumnarValue::Array)
144 }
145 DataType::Decimal128(_, _) => {
146 let abs_fun = make_wrapping_abs_function!(Decimal128Array);
147 abs_fun(array).map(ColumnarValue::Array)
148 }
149 DataType::Decimal256(_, _) => {
150 let abs_fun = make_wrapping_abs_function!(Decimal256Array);
151 abs_fun(array).map(ColumnarValue::Array)
152 }
153 dt => internal_err!("Not supported datatype for Spark ABS: {dt}"),
154 },
155 ColumnarValue::Scalar(sv) => match sv {
156 ScalarValue::Null
157 | ScalarValue::UInt8(_)
158 | ScalarValue::UInt16(_)
159 | ScalarValue::UInt32(_)
160 | ScalarValue::UInt64(_) => Ok(args[0].clone()),
161 sv if sv.is_null() => Ok(args[0].clone()),
162 ScalarValue::Int8(Some(v)) => scalar_compute_op!(v, Int8),
163 ScalarValue::Int16(Some(v)) => scalar_compute_op!(v, Int16),
164 ScalarValue::Int32(Some(v)) => scalar_compute_op!(v, Int32),
165 ScalarValue::Int64(Some(v)) => scalar_compute_op!(v, Int64),
166 ScalarValue::Float32(Some(v)) => {
167 Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(v.abs()))))
168 }
169 ScalarValue::Float64(Some(v)) => {
170 Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(v.abs()))))
171 }
172 ScalarValue::Decimal128(Some(v), precision, scale) => {
173 scalar_compute_op!(v, *precision, *scale, Decimal128)
174 }
175 ScalarValue::Decimal256(Some(v), precision, scale) => {
176 scalar_compute_op!(v, *precision, *scale, Decimal256)
177 }
178 dt => internal_err!("Not supported datatype for Spark ABS: {dt}"),
179 },
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use arrow::datatypes::i256;
187
188 macro_rules! eval_legacy_mode {
189 ($TYPE:ident, $VAL:expr) => {{
190 let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL)));
191 match spark_abs(&[args]) {
192 Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => {
193 assert_eq!(result, $VAL);
194 }
195 _ => unreachable!(),
196 }
197 }};
198 ($TYPE:ident, $VAL:expr, $RESULT:expr) => {{
199 let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL)));
200 match spark_abs(&[args]) {
201 Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => {
202 assert_eq!(result, $RESULT);
203 }
204 _ => unreachable!(),
205 }
206 }};
207 ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr) => {{
208 let args =
209 ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE));
210 match spark_abs(&[args]) {
211 Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(
212 Some(result),
213 precision,
214 scale,
215 ))) => {
216 assert_eq!(result, $VAL);
217 assert_eq!(precision, $PRECISION);
218 assert_eq!(scale, $SCALE);
219 }
220 _ => unreachable!(),
221 }
222 }};
223 ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr, $RESULT:expr) => {{
224 let args =
225 ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE));
226 match spark_abs(&[args]) {
227 Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(
228 Some(result),
229 precision,
230 scale,
231 ))) => {
232 assert_eq!(result, $RESULT);
233 assert_eq!(precision, $PRECISION);
234 assert_eq!(scale, $SCALE);
235 }
236 _ => unreachable!(),
237 }
238 }};
239 }
240
241 #[test]
242 fn test_abs_scalar_legacy_mode() {
243 eval_legacy_mode!(UInt8, u8::MIN);
245 eval_legacy_mode!(UInt16, u16::MIN);
246 eval_legacy_mode!(UInt32, u32::MIN);
247 eval_legacy_mode!(UInt64, u64::MIN);
248 eval_legacy_mode!(Int8, i8::MIN);
249 eval_legacy_mode!(Int16, i16::MIN);
250 eval_legacy_mode!(Int32, i32::MIN);
251 eval_legacy_mode!(Int64, i64::MIN);
252 eval_legacy_mode!(Float32, f32::MIN, f32::MAX);
253 eval_legacy_mode!(Float64, f64::MIN, f64::MAX);
254 eval_legacy_mode!(Decimal128, i128::MIN, 18, 10);
255 eval_legacy_mode!(Decimal256, i256::MIN, 10, 2);
256
257 eval_legacy_mode!(Int8, -1i8, 1i8);
259 eval_legacy_mode!(Int16, -1i16, 1i16);
260 eval_legacy_mode!(Int32, -1i32, 1i32);
261 eval_legacy_mode!(Int64, -1i64, 1i64);
262 eval_legacy_mode!(Decimal128, -1i128, 18, 10, 1i128);
263 eval_legacy_mode!(Decimal256, i256::from(-1i8), 10, 2, i256::from(1i8));
264
265 eval_legacy_mode!(Float32, f32::NEG_INFINITY, f32::INFINITY);
267 eval_legacy_mode!(Float32, f32::INFINITY, f32::INFINITY);
268 eval_legacy_mode!(Float32, 0.0f32, 0.0f32);
269 eval_legacy_mode!(Float32, -0.0f32, 0.0f32);
270 eval_legacy_mode!(Float64, f64::NEG_INFINITY, f64::INFINITY);
271 eval_legacy_mode!(Float64, f64::INFINITY, f64::INFINITY);
272 eval_legacy_mode!(Float64, 0.0f64, 0.0f64);
273 eval_legacy_mode!(Float64, -0.0f64, 0.0f64);
274 }
275
276 macro_rules! eval_array_legacy_mode {
277 ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{
278 let input = $INPUT;
279 let args = ColumnarValue::Array(Arc::new(input));
280 let expected = $OUTPUT;
281 match spark_abs(&[args]) {
282 Ok(ColumnarValue::Array(result)) => {
283 let actual = datafusion_common::cast::$FUNC(&result).unwrap();
284 assert_eq!(actual, &expected);
285 }
286 _ => unreachable!(),
287 }
288 }};
289 }
290
291 #[test]
292 fn test_abs_array_legacy_mode() {
293 eval_array_legacy_mode!(
294 Int8Array::from(vec![Some(-1), Some(i8::MIN), Some(i8::MAX), None]),
295 Int8Array::from(vec![Some(1), Some(i8::MIN), Some(i8::MAX), None]),
296 as_int8_array
297 );
298
299 eval_array_legacy_mode!(
300 Int16Array::from(vec![Some(-1), Some(i16::MIN), Some(i16::MAX), None]),
301 Int16Array::from(vec![Some(1), Some(i16::MIN), Some(i16::MAX), None]),
302 as_int16_array
303 );
304
305 eval_array_legacy_mode!(
306 Int32Array::from(vec![Some(-1), Some(i32::MIN), Some(i32::MAX), None]),
307 Int32Array::from(vec![Some(1), Some(i32::MIN), Some(i32::MAX), None]),
308 as_int32_array
309 );
310
311 eval_array_legacy_mode!(
312 Int64Array::from(vec![Some(-1), Some(i64::MIN), Some(i64::MAX), None]),
313 Int64Array::from(vec![Some(1), Some(i64::MIN), Some(i64::MAX), None]),
314 as_int64_array
315 );
316
317 eval_array_legacy_mode!(
318 Float32Array::from(vec![
319 Some(-1f32),
320 Some(f32::MIN),
321 Some(f32::MAX),
322 None,
323 Some(f32::NAN),
324 Some(f32::INFINITY),
325 Some(f32::NEG_INFINITY),
326 Some(0.0),
327 Some(-0.0),
328 ]),
329 Float32Array::from(vec![
330 Some(1f32),
331 Some(f32::MAX),
332 Some(f32::MAX),
333 None,
334 Some(f32::NAN),
335 Some(f32::INFINITY),
336 Some(f32::INFINITY),
337 Some(0.0),
338 Some(0.0),
339 ]),
340 as_float32_array
341 );
342
343 eval_array_legacy_mode!(
344 Float64Array::from(vec![
345 Some(-1f64),
346 Some(f64::MIN),
347 Some(f64::MAX),
348 None,
349 Some(f64::NAN),
350 Some(f64::INFINITY),
351 Some(f64::NEG_INFINITY),
352 Some(0.0),
353 Some(-0.0),
354 ]),
355 Float64Array::from(vec![
356 Some(1f64),
357 Some(f64::MAX),
358 Some(f64::MAX),
359 None,
360 Some(f64::NAN),
361 Some(f64::INFINITY),
362 Some(f64::INFINITY),
363 Some(0.0),
364 Some(0.0),
365 ]),
366 as_float64_array
367 );
368
369 eval_array_legacy_mode!(
370 Decimal128Array::from(vec![Some(i128::MIN), None])
371 .with_precision_and_scale(38, 37)
372 .unwrap(),
373 Decimal128Array::from(vec![Some(i128::MIN), None])
374 .with_precision_and_scale(38, 37)
375 .unwrap(),
376 as_decimal128_array
377 );
378
379 eval_array_legacy_mode!(
380 Decimal256Array::from(vec![Some(i256::MIN), None])
381 .with_precision_and_scale(5, 2)
382 .unwrap(),
383 Decimal256Array::from(vec![Some(i256::MIN), None])
384 .with_precision_and_scale(5, 2)
385 .unwrap(),
386 as_decimal256_array
387 );
388 }
389
390 #[test]
391 fn test_abs_nullability() {
392 use arrow::datatypes::{DataType, Field};
393 use datafusion_expr::ReturnFieldArgs;
394 use std::sync::Arc;
395
396 let abs = SparkAbs::new();
397
398 let non_nullable_i32 = Arc::new(Field::new("c", DataType::Int32, false));
400 let out_non_null = abs
401 .return_field_from_args(ReturnFieldArgs {
402 arg_fields: &[Arc::clone(&non_nullable_i32)],
403 scalar_arguments: &[None],
404 })
405 .unwrap();
406
407 assert!(!out_non_null.is_nullable());
409 assert_eq!(out_non_null.data_type(), &DataType::Int32);
410
411 let nullable_i32 = Arc::new(Field::new("c", DataType::Int32, true));
413 let out_nullable = abs
414 .return_field_from_args(ReturnFieldArgs {
415 arg_fields: &[Arc::clone(&nullable_i32)],
416 scalar_arguments: &[None],
417 })
418 .unwrap();
419
420 assert!(out_nullable.is_nullable());
422 assert_eq!(out_nullable.data_type(), &DataType::Int32);
423
424 let non_nullable_f64 = Arc::new(Field::new("c", DataType::Float64, false));
426 let out_f64 = abs
427 .return_field_from_args(ReturnFieldArgs {
428 arg_fields: &[Arc::clone(&non_nullable_f64)],
429 scalar_arguments: &[None],
430 })
431 .unwrap();
432
433 assert!(!out_f64.is_nullable());
434 assert_eq!(out_f64.data_type(), &DataType::Float64);
435
436 let nullable_f64 = Arc::new(Field::new("c", DataType::Float64, true));
438 let out_f64_null = abs
439 .return_field_from_args(ReturnFieldArgs {
440 arg_fields: &[Arc::clone(&nullable_f64)],
441 scalar_arguments: &[None],
442 })
443 .unwrap();
444
445 assert!(out_f64_null.is_nullable());
446 assert_eq!(out_f64_null.data_type(), &DataType::Float64);
447 }
448}