1use arrow::array::*;
19use arrow::datatypes::{DataType, Field, FieldRef};
20use arrow::error::ArrowError;
21use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err};
22use datafusion_expr::{
23 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
24 Volatility,
25};
26use datafusion_functions::{
27 downcast_named_arg, make_abs_function, make_try_abs_function,
28 make_wrapping_abs_function,
29};
30use std::any::Any;
31use std::sync::Arc;
32
33#[derive(Debug, PartialEq, Eq, Hash)]
46pub struct SparkAbs {
47 signature: Signature,
48}
49
50impl Default for SparkAbs {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl SparkAbs {
57 pub fn new() -> Self {
58 Self {
59 signature: Signature::numeric(1, Volatility::Immutable),
60 }
61 }
62}
63
64impl ScalarUDFImpl for SparkAbs {
65 fn as_any(&self) -> &dyn Any {
66 self
67 }
68
69 fn name(&self) -> &str {
70 "abs"
71 }
72
73 fn signature(&self) -> &Signature {
74 &self.signature
75 }
76
77 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
78 internal_err!(
79 "SparkAbs: return_type() is not used; return_field_from_args() is implemented"
80 )
81 }
82
83 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
84 let input_field = &args.arg_fields[0];
85 let out_dt = input_field.data_type().clone();
86 let out_nullable = input_field.is_nullable();
87
88 Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable)))
89 }
90
91 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
92 spark_abs(&args.args, args.config_options.execution.enable_ansi_mode)
93 }
94}
95
96macro_rules! scalar_compute_op {
97 ($ENABLE_ANSI_MODE:expr, $INPUT:ident, $SCALAR_TYPE:ident) => {{
98 let result = if $ENABLE_ANSI_MODE {
99 $INPUT.checked_abs().ok_or_else(|| {
100 ArrowError::ComputeError(format!(
101 "{} overflow on abs({:?})",
102 stringify!($SCALAR_TYPE),
103 $INPUT
104 ))
105 })?
106 } else {
107 $INPUT.wrapping_abs()
108 };
109 Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(Some(
110 result,
111 ))))
112 }};
113 ($ENABLE_ANSI_MODE:expr, $INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{
114 let result = if $ENABLE_ANSI_MODE {
115 $INPUT.checked_abs().ok_or_else(|| {
116 ArrowError::ComputeError(format!(
117 "{} overflow on abs({:?})",
118 stringify!($SCALAR_TYPE),
119 $INPUT
120 ))
121 })?
122 } else {
123 $INPUT.wrapping_abs()
124 };
125 Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(
126 Some(result),
127 $PRECISION,
128 $SCALE,
129 )))
130 }};
131}
132
133pub fn spark_abs(
134 args: &[ColumnarValue],
135 enable_ansi_mode: bool,
136) -> Result<ColumnarValue, DataFusionError> {
137 if args.len() != 1 {
138 return internal_err!("abs takes exactly 1 argument, but got: {}", args.len());
139 }
140
141 match &args[0] {
142 ColumnarValue::Array(array) => match array.data_type() {
143 DataType::Null
144 | DataType::UInt8
145 | DataType::UInt16
146 | DataType::UInt32
147 | DataType::UInt64 => Ok(args[0].clone()),
148 DataType::Int8 => {
149 let abs_fun = if enable_ansi_mode {
150 make_try_abs_function!(Int8Array)
151 } else {
152 make_wrapping_abs_function!(Int8Array)
153 };
154 abs_fun(array).map(ColumnarValue::Array)
155 }
156 DataType::Int16 => {
157 let abs_fun = if enable_ansi_mode {
158 make_try_abs_function!(Int16Array)
159 } else {
160 make_wrapping_abs_function!(Int16Array)
161 };
162 abs_fun(array).map(ColumnarValue::Array)
163 }
164 DataType::Int32 => {
165 let abs_fun = if enable_ansi_mode {
166 make_try_abs_function!(Int32Array)
167 } else {
168 make_wrapping_abs_function!(Int32Array)
169 };
170 abs_fun(array).map(ColumnarValue::Array)
171 }
172 DataType::Int64 => {
173 let abs_fun = if enable_ansi_mode {
174 make_try_abs_function!(Int64Array)
175 } else {
176 make_wrapping_abs_function!(Int64Array)
177 };
178 abs_fun(array).map(ColumnarValue::Array)
179 }
180 DataType::Float32 => {
181 let abs_fun = make_abs_function!(Float32Array);
182 abs_fun(array).map(ColumnarValue::Array)
183 }
184 DataType::Float64 => {
185 let abs_fun = make_abs_function!(Float64Array);
186 abs_fun(array).map(ColumnarValue::Array)
187 }
188 DataType::Decimal128(_, _) => {
189 let abs_fun = if enable_ansi_mode {
190 make_try_abs_function!(Decimal128Array)
191 } else {
192 make_wrapping_abs_function!(Decimal128Array)
193 };
194 abs_fun(array).map(ColumnarValue::Array)
195 }
196 DataType::Decimal256(_, _) => {
197 let abs_fun = if enable_ansi_mode {
198 make_try_abs_function!(Decimal256Array)
199 } else {
200 make_wrapping_abs_function!(Decimal256Array)
201 };
202 abs_fun(array).map(ColumnarValue::Array)
203 }
204 dt => internal_err!("Not supported datatype for Spark ABS: {dt}"),
205 },
206 ColumnarValue::Scalar(sv) => match sv {
207 ScalarValue::Null
208 | ScalarValue::UInt8(_)
209 | ScalarValue::UInt16(_)
210 | ScalarValue::UInt32(_)
211 | ScalarValue::UInt64(_) => Ok(args[0].clone()),
212 sv if sv.is_null() => Ok(args[0].clone()),
213 ScalarValue::Int8(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int8),
214 ScalarValue::Int16(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int16),
215 ScalarValue::Int32(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int32),
216 ScalarValue::Int64(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int64),
217 ScalarValue::Float32(Some(v)) => {
218 Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(v.abs()))))
219 }
220 ScalarValue::Float64(Some(v)) => {
221 Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(v.abs()))))
222 }
223 ScalarValue::Decimal128(Some(v), precision, scale) => {
224 scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal128)
225 }
226 ScalarValue::Decimal256(Some(v), precision, scale) => {
227 scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal256)
228 }
229 dt => internal_err!("Not supported datatype for Spark ABS: {dt}"),
230 },
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use arrow::datatypes::i256;
238
239 macro_rules! eval_array_legacy_mode {
240 ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{
241 let input = $INPUT;
242 let args = ColumnarValue::Array(Arc::new(input));
243 let expected = $OUTPUT;
244 match spark_abs(&[args], false) {
245 Ok(ColumnarValue::Array(result)) => {
246 let actual = datafusion_common::cast::$FUNC(&result).unwrap();
247 assert_eq!(actual, &expected);
248 }
249 _ => unreachable!(),
250 }
251 }};
252 }
253
254 #[test]
255 fn test_abs_array_legacy_mode() {
256 eval_array_legacy_mode!(
257 Int8Array::from(vec![Some(-1), Some(i8::MIN), Some(i8::MAX), None]),
258 Int8Array::from(vec![Some(1), Some(i8::MIN), Some(i8::MAX), None]),
259 as_int8_array
260 );
261
262 eval_array_legacy_mode!(
263 Int16Array::from(vec![Some(-1), Some(i16::MIN), Some(i16::MAX), None]),
264 Int16Array::from(vec![Some(1), Some(i16::MIN), Some(i16::MAX), None]),
265 as_int16_array
266 );
267
268 eval_array_legacy_mode!(
269 Int32Array::from(vec![Some(-1), Some(i32::MIN), Some(i32::MAX), None]),
270 Int32Array::from(vec![Some(1), Some(i32::MIN), Some(i32::MAX), None]),
271 as_int32_array
272 );
273
274 eval_array_legacy_mode!(
275 Int64Array::from(vec![Some(-1), Some(i64::MIN), Some(i64::MAX), None]),
276 Int64Array::from(vec![Some(1), Some(i64::MIN), Some(i64::MAX), None]),
277 as_int64_array
278 );
279
280 eval_array_legacy_mode!(
281 Float32Array::from(vec![
282 Some(-1f32),
283 Some(f32::MIN),
284 Some(f32::MAX),
285 None,
286 Some(f32::NAN),
287 Some(f32::INFINITY),
288 Some(f32::NEG_INFINITY),
289 Some(0.0),
290 Some(-0.0),
291 ]),
292 Float32Array::from(vec![
293 Some(1f32),
294 Some(f32::MAX),
295 Some(f32::MAX),
296 None,
297 Some(f32::NAN),
298 Some(f32::INFINITY),
299 Some(f32::INFINITY),
300 Some(0.0),
301 Some(0.0),
302 ]),
303 as_float32_array
304 );
305
306 eval_array_legacy_mode!(
307 Float64Array::from(vec![
308 Some(-1f64),
309 Some(f64::MIN),
310 Some(f64::MAX),
311 None,
312 Some(f64::NAN),
313 Some(f64::INFINITY),
314 Some(f64::NEG_INFINITY),
315 Some(0.0),
316 Some(-0.0),
317 ]),
318 Float64Array::from(vec![
319 Some(1f64),
320 Some(f64::MAX),
321 Some(f64::MAX),
322 None,
323 Some(f64::NAN),
324 Some(f64::INFINITY),
325 Some(f64::INFINITY),
326 Some(0.0),
327 Some(0.0),
328 ]),
329 as_float64_array
330 );
331
332 eval_array_legacy_mode!(
333 Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MIN + 1), None])
334 .with_precision_and_scale(38, 37)
335 .unwrap(),
336 Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MAX), None])
337 .with_precision_and_scale(38, 37)
338 .unwrap(),
339 as_decimal128_array
340 );
341
342 eval_array_legacy_mode!(
343 Decimal256Array::from(vec![
344 Some(i256::MIN),
345 Some(i256::MINUS_ONE),
346 Some(i256::MIN + i256::from(1)),
347 None
348 ])
349 .with_precision_and_scale(5, 2)
350 .unwrap(),
351 Decimal256Array::from(vec![
352 Some(i256::MIN),
353 Some(i256::ONE),
354 Some(i256::MAX),
355 None
356 ])
357 .with_precision_and_scale(5, 2)
358 .unwrap(),
359 as_decimal256_array
360 );
361 }
362
363 macro_rules! eval_array_ansi_mode {
364 ($INPUT:expr) => {{
365 let input = $INPUT;
366 let args = ColumnarValue::Array(Arc::new(input));
367 match spark_abs(&[args], true) {
368 Err(e) => {
369 assert!(
370 e.to_string().contains("overflow on abs"),
371 "Error message did not match. Actual message: {e}"
372 );
373 }
374 _ => unreachable!(),
375 }
376 }};
377 ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{
378 let input = $INPUT;
379 let args = ColumnarValue::Array(Arc::new(input));
380 let expected = $OUTPUT;
381 match spark_abs(&[args], true) {
382 Ok(ColumnarValue::Array(result)) => {
383 let actual = datafusion_common::cast::$FUNC(&result).unwrap();
384 assert_eq!(actual, &expected);
385 }
386 _ => unreachable!(),
387 }
388 }};
389 }
390 #[test]
391 fn test_abs_array_ansi_mode() {
392 eval_array_ansi_mode!(
393 UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]),
394 UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]),
395 as_uint64_array
396 );
397
398 eval_array_ansi_mode!(Int8Array::from(vec![
399 Some(-1),
400 Some(i8::MIN),
401 Some(i8::MAX),
402 None
403 ]));
404 eval_array_ansi_mode!(Int16Array::from(vec![
405 Some(-1),
406 Some(i16::MIN),
407 Some(i16::MAX),
408 None
409 ]));
410 eval_array_ansi_mode!(Int32Array::from(vec![
411 Some(-1),
412 Some(i32::MIN),
413 Some(i32::MAX),
414 None
415 ]));
416 eval_array_ansi_mode!(Int64Array::from(vec![
417 Some(-1),
418 Some(i64::MIN),
419 Some(i64::MAX),
420 None
421 ]));
422 eval_array_ansi_mode!(
423 Float32Array::from(vec![
424 Some(-1f32),
425 Some(f32::MIN),
426 Some(f32::MAX),
427 None,
428 Some(f32::NAN),
429 Some(f32::INFINITY),
430 Some(f32::NEG_INFINITY),
431 Some(0.0),
432 Some(-0.0),
433 ]),
434 Float32Array::from(vec![
435 Some(1f32),
436 Some(f32::MAX),
437 Some(f32::MAX),
438 None,
439 Some(f32::NAN),
440 Some(f32::INFINITY),
441 Some(f32::INFINITY),
442 Some(0.0),
443 Some(0.0),
444 ]),
445 as_float32_array
446 );
447
448 eval_array_ansi_mode!(
449 Float64Array::from(vec![
450 Some(-1f64),
451 Some(f64::MIN),
452 Some(f64::MAX),
453 None,
454 Some(f64::NAN),
455 Some(f64::INFINITY),
456 Some(f64::NEG_INFINITY),
457 Some(0.0),
458 Some(-0.0),
459 ]),
460 Float64Array::from(vec![
461 Some(1f64),
462 Some(f64::MAX),
463 Some(f64::MAX),
464 None,
465 Some(f64::NAN),
466 Some(f64::INFINITY),
467 Some(f64::INFINITY),
468 Some(0.0),
469 Some(0.0),
470 ]),
471 as_float64_array
472 );
473
474 eval_array_ansi_mode!(
476 Decimal128Array::from(vec![Some(-1), Some(-2), Some(i128::MIN + 1)])
477 .with_precision_and_scale(38, 37)
478 .unwrap(),
479 Decimal128Array::from(vec![Some(1), Some(2), Some(i128::MAX)])
480 .with_precision_and_scale(38, 37)
481 .unwrap(),
482 as_decimal128_array
483 );
484
485 eval_array_ansi_mode!(
486 Decimal256Array::from(vec![
487 Some(i256::MINUS_ONE),
488 Some(i256::from(-2)),
489 Some(i256::MIN + i256::from(1))
490 ])
491 .with_precision_and_scale(18, 7)
492 .unwrap(),
493 Decimal256Array::from(vec![
494 Some(i256::ONE),
495 Some(i256::from(2)),
496 Some(i256::MAX)
497 ])
498 .with_precision_and_scale(18, 7)
499 .unwrap(),
500 as_decimal256_array
501 );
502
503 eval_array_ansi_mode!(
505 Decimal128Array::from(vec![Some(i128::MIN), None])
506 .with_precision_and_scale(38, 37)
507 .unwrap()
508 );
509 eval_array_ansi_mode!(
510 Decimal256Array::from(vec![Some(i256::MIN), None])
511 .with_precision_and_scale(5, 2)
512 .unwrap()
513 );
514 }
515
516 #[test]
517 fn test_abs_nullability() {
518 use arrow::datatypes::{DataType, Field};
519 use datafusion_expr::ReturnFieldArgs;
520 use std::sync::Arc;
521
522 let abs = SparkAbs::new();
523
524 let non_nullable_i32 = Arc::new(Field::new("c", DataType::Int32, false));
526 let out_non_null = abs
527 .return_field_from_args(ReturnFieldArgs {
528 arg_fields: &[Arc::clone(&non_nullable_i32)],
529 scalar_arguments: &[None],
530 })
531 .unwrap();
532
533 assert!(!out_non_null.is_nullable());
535 assert_eq!(out_non_null.data_type(), &DataType::Int32);
536
537 let nullable_i32 = Arc::new(Field::new("c", DataType::Int32, true));
539 let out_nullable = abs
540 .return_field_from_args(ReturnFieldArgs {
541 arg_fields: &[Arc::clone(&nullable_i32)],
542 scalar_arguments: &[None],
543 })
544 .unwrap();
545
546 assert!(out_nullable.is_nullable());
548 assert_eq!(out_nullable.data_type(), &DataType::Int32);
549
550 let non_nullable_f64 = Arc::new(Field::new("c", DataType::Float64, false));
552 let out_f64 = abs
553 .return_field_from_args(ReturnFieldArgs {
554 arg_fields: &[Arc::clone(&non_nullable_f64)],
555 scalar_arguments: &[None],
556 })
557 .unwrap();
558
559 assert!(!out_f64.is_nullable());
560 assert_eq!(out_f64.data_type(), &DataType::Float64);
561
562 let nullable_f64 = Arc::new(Field::new("c", DataType::Float64, true));
564 let out_f64_null = abs
565 .return_field_from_args(ReturnFieldArgs {
566 arg_fields: &[Arc::clone(&nullable_f64)],
567 scalar_arguments: &[None],
568 })
569 .unwrap();
570
571 assert!(out_f64_null.is_nullable());
572 assert_eq!(out_f64_null.data_type(), &DataType::Float64);
573 }
574}