1use arrow::array::*;
19use arrow::datatypes::{DataType, Field, FieldRef};
20use arrow::error::ArrowError;
21use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err};
22use datafusion_expr::{
23 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
24 Volatility,
25};
26use datafusion_functions::{
27 downcast_named_arg, make_abs_function, make_try_abs_function,
28 make_wrapping_abs_function,
29};
30use std::sync::Arc;
31
32#[derive(Debug, PartialEq, Eq, Hash)]
45pub struct SparkAbs {
46 signature: Signature,
47}
48
49impl Default for SparkAbs {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl SparkAbs {
56 pub fn new() -> Self {
57 Self {
58 signature: Signature::numeric(1, Volatility::Immutable),
59 }
60 }
61}
62
63impl ScalarUDFImpl for SparkAbs {
64 fn name(&self) -> &str {
65 "abs"
66 }
67
68 fn signature(&self) -> &Signature {
69 &self.signature
70 }
71
72 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
73 internal_err!(
74 "SparkAbs: return_type() is not used; return_field_from_args() is implemented"
75 )
76 }
77
78 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
79 let input_field = &args.arg_fields[0];
80 let out_dt = input_field.data_type().clone();
81 let out_nullable = input_field.is_nullable();
82
83 Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable)))
84 }
85
86 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
87 spark_abs(&args.args, args.config_options.execution.enable_ansi_mode)
88 }
89}
90
91macro_rules! scalar_compute_op {
92 ($ENABLE_ANSI_MODE:expr, $INPUT:ident, $SCALAR_TYPE:ident) => {{
93 let result = if $ENABLE_ANSI_MODE {
94 $INPUT.checked_abs().ok_or_else(|| {
95 ArrowError::ComputeError(format!(
96 "{} overflow on abs({:?})",
97 stringify!($SCALAR_TYPE),
98 $INPUT
99 ))
100 })?
101 } else {
102 $INPUT.wrapping_abs()
103 };
104 Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(Some(
105 result,
106 ))))
107 }};
108 ($ENABLE_ANSI_MODE:expr, $INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{
109 let result = if $ENABLE_ANSI_MODE {
110 $INPUT.checked_abs().ok_or_else(|| {
111 ArrowError::ComputeError(format!(
112 "{} overflow on abs({:?})",
113 stringify!($SCALAR_TYPE),
114 $INPUT
115 ))
116 })?
117 } else {
118 $INPUT.wrapping_abs()
119 };
120 Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(
121 Some(result),
122 $PRECISION,
123 $SCALE,
124 )))
125 }};
126}
127
128pub fn spark_abs(
129 args: &[ColumnarValue],
130 enable_ansi_mode: bool,
131) -> Result<ColumnarValue, DataFusionError> {
132 if args.len() != 1 {
133 return internal_err!("abs takes exactly 1 argument, but got: {}", args.len());
134 }
135
136 match &args[0] {
137 ColumnarValue::Array(array) => match array.data_type() {
138 DataType::Null
139 | DataType::UInt8
140 | DataType::UInt16
141 | DataType::UInt32
142 | DataType::UInt64 => Ok(args[0].clone()),
143 DataType::Int8 => {
144 let abs_fun = if enable_ansi_mode {
145 make_try_abs_function!(Int8Array)
146 } else {
147 make_wrapping_abs_function!(Int8Array)
148 };
149 abs_fun(array).map(ColumnarValue::Array)
150 }
151 DataType::Int16 => {
152 let abs_fun = if enable_ansi_mode {
153 make_try_abs_function!(Int16Array)
154 } else {
155 make_wrapping_abs_function!(Int16Array)
156 };
157 abs_fun(array).map(ColumnarValue::Array)
158 }
159 DataType::Int32 => {
160 let abs_fun = if enable_ansi_mode {
161 make_try_abs_function!(Int32Array)
162 } else {
163 make_wrapping_abs_function!(Int32Array)
164 };
165 abs_fun(array).map(ColumnarValue::Array)
166 }
167 DataType::Int64 => {
168 let abs_fun = if enable_ansi_mode {
169 make_try_abs_function!(Int64Array)
170 } else {
171 make_wrapping_abs_function!(Int64Array)
172 };
173 abs_fun(array).map(ColumnarValue::Array)
174 }
175 DataType::Float32 => {
176 let abs_fun = make_abs_function!(Float32Array);
177 abs_fun(array).map(ColumnarValue::Array)
178 }
179 DataType::Float64 => {
180 let abs_fun = make_abs_function!(Float64Array);
181 abs_fun(array).map(ColumnarValue::Array)
182 }
183 DataType::Decimal128(_, _) => {
184 let abs_fun = if enable_ansi_mode {
185 make_try_abs_function!(Decimal128Array)
186 } else {
187 make_wrapping_abs_function!(Decimal128Array)
188 };
189 abs_fun(array).map(ColumnarValue::Array)
190 }
191 DataType::Decimal256(_, _) => {
192 let abs_fun = if enable_ansi_mode {
193 make_try_abs_function!(Decimal256Array)
194 } else {
195 make_wrapping_abs_function!(Decimal256Array)
196 };
197 abs_fun(array).map(ColumnarValue::Array)
198 }
199 dt => internal_err!("Not supported datatype for Spark ABS: {dt}"),
200 },
201 ColumnarValue::Scalar(sv) => match sv {
202 ScalarValue::Null
203 | ScalarValue::UInt8(_)
204 | ScalarValue::UInt16(_)
205 | ScalarValue::UInt32(_)
206 | ScalarValue::UInt64(_) => Ok(args[0].clone()),
207 sv if sv.is_null() => Ok(args[0].clone()),
208 ScalarValue::Int8(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int8),
209 ScalarValue::Int16(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int16),
210 ScalarValue::Int32(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int32),
211 ScalarValue::Int64(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int64),
212 ScalarValue::Float32(Some(v)) => {
213 Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(v.abs()))))
214 }
215 ScalarValue::Float64(Some(v)) => {
216 Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(v.abs()))))
217 }
218 ScalarValue::Decimal128(Some(v), precision, scale) => {
219 scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal128)
220 }
221 ScalarValue::Decimal256(Some(v), precision, scale) => {
222 scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal256)
223 }
224 dt => internal_err!("Not supported datatype for Spark ABS: {dt}"),
225 },
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use arrow::datatypes::i256;
233
234 macro_rules! eval_array_legacy_mode {
235 ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{
236 let input = $INPUT;
237 let args = ColumnarValue::Array(Arc::new(input));
238 let expected = $OUTPUT;
239 match spark_abs(&[args], false) {
240 Ok(ColumnarValue::Array(result)) => {
241 let actual = datafusion_common::cast::$FUNC(&result).unwrap();
242 assert_eq!(actual, &expected);
243 }
244 _ => unreachable!(),
245 }
246 }};
247 }
248
249 #[test]
250 fn test_abs_array_legacy_mode() {
251 eval_array_legacy_mode!(
252 Int8Array::from(vec![Some(-1), Some(i8::MIN), Some(i8::MAX), None]),
253 Int8Array::from(vec![Some(1), Some(i8::MIN), Some(i8::MAX), None]),
254 as_int8_array
255 );
256
257 eval_array_legacy_mode!(
258 Int16Array::from(vec![Some(-1), Some(i16::MIN), Some(i16::MAX), None]),
259 Int16Array::from(vec![Some(1), Some(i16::MIN), Some(i16::MAX), None]),
260 as_int16_array
261 );
262
263 eval_array_legacy_mode!(
264 Int32Array::from(vec![Some(-1), Some(i32::MIN), Some(i32::MAX), None]),
265 Int32Array::from(vec![Some(1), Some(i32::MIN), Some(i32::MAX), None]),
266 as_int32_array
267 );
268
269 eval_array_legacy_mode!(
270 Int64Array::from(vec![Some(-1), Some(i64::MIN), Some(i64::MAX), None]),
271 Int64Array::from(vec![Some(1), Some(i64::MIN), Some(i64::MAX), None]),
272 as_int64_array
273 );
274
275 eval_array_legacy_mode!(
276 Float32Array::from(vec![
277 Some(-1f32),
278 Some(f32::MIN),
279 Some(f32::MAX),
280 None,
281 Some(f32::NAN),
282 Some(f32::INFINITY),
283 Some(f32::NEG_INFINITY),
284 Some(0.0),
285 Some(-0.0),
286 ]),
287 Float32Array::from(vec![
288 Some(1f32),
289 Some(f32::MAX),
290 Some(f32::MAX),
291 None,
292 Some(f32::NAN),
293 Some(f32::INFINITY),
294 Some(f32::INFINITY),
295 Some(0.0),
296 Some(0.0),
297 ]),
298 as_float32_array
299 );
300
301 eval_array_legacy_mode!(
302 Float64Array::from(vec![
303 Some(-1f64),
304 Some(f64::MIN),
305 Some(f64::MAX),
306 None,
307 Some(f64::NAN),
308 Some(f64::INFINITY),
309 Some(f64::NEG_INFINITY),
310 Some(0.0),
311 Some(-0.0),
312 ]),
313 Float64Array::from(vec![
314 Some(1f64),
315 Some(f64::MAX),
316 Some(f64::MAX),
317 None,
318 Some(f64::NAN),
319 Some(f64::INFINITY),
320 Some(f64::INFINITY),
321 Some(0.0),
322 Some(0.0),
323 ]),
324 as_float64_array
325 );
326
327 eval_array_legacy_mode!(
328 Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MIN + 1), None])
329 .with_precision_and_scale(38, 37)
330 .unwrap(),
331 Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MAX), None])
332 .with_precision_and_scale(38, 37)
333 .unwrap(),
334 as_decimal128_array
335 );
336
337 eval_array_legacy_mode!(
338 Decimal256Array::from(vec![
339 Some(i256::MIN),
340 Some(i256::MINUS_ONE),
341 Some(i256::MIN + i256::from(1)),
342 None
343 ])
344 .with_precision_and_scale(5, 2)
345 .unwrap(),
346 Decimal256Array::from(vec![
347 Some(i256::MIN),
348 Some(i256::ONE),
349 Some(i256::MAX),
350 None
351 ])
352 .with_precision_and_scale(5, 2)
353 .unwrap(),
354 as_decimal256_array
355 );
356 }
357
358 macro_rules! eval_array_ansi_mode {
359 ($INPUT:expr) => {{
360 let input = $INPUT;
361 let args = ColumnarValue::Array(Arc::new(input));
362 match spark_abs(&[args], true) {
363 Err(e) => {
364 assert!(
365 e.to_string().contains("overflow on abs"),
366 "Error message did not match. Actual message: {e}"
367 );
368 }
369 _ => unreachable!(),
370 }
371 }};
372 ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{
373 let input = $INPUT;
374 let args = ColumnarValue::Array(Arc::new(input));
375 let expected = $OUTPUT;
376 match spark_abs(&[args], true) {
377 Ok(ColumnarValue::Array(result)) => {
378 let actual = datafusion_common::cast::$FUNC(&result).unwrap();
379 assert_eq!(actual, &expected);
380 }
381 _ => unreachable!(),
382 }
383 }};
384 }
385 #[test]
386 fn test_abs_array_ansi_mode() {
387 eval_array_ansi_mode!(
388 UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]),
389 UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]),
390 as_uint64_array
391 );
392
393 eval_array_ansi_mode!(Int8Array::from(vec![
394 Some(-1),
395 Some(i8::MIN),
396 Some(i8::MAX),
397 None
398 ]));
399 eval_array_ansi_mode!(Int16Array::from(vec![
400 Some(-1),
401 Some(i16::MIN),
402 Some(i16::MAX),
403 None
404 ]));
405 eval_array_ansi_mode!(Int32Array::from(vec![
406 Some(-1),
407 Some(i32::MIN),
408 Some(i32::MAX),
409 None
410 ]));
411 eval_array_ansi_mode!(Int64Array::from(vec![
412 Some(-1),
413 Some(i64::MIN),
414 Some(i64::MAX),
415 None
416 ]));
417 eval_array_ansi_mode!(
418 Float32Array::from(vec![
419 Some(-1f32),
420 Some(f32::MIN),
421 Some(f32::MAX),
422 None,
423 Some(f32::NAN),
424 Some(f32::INFINITY),
425 Some(f32::NEG_INFINITY),
426 Some(0.0),
427 Some(-0.0),
428 ]),
429 Float32Array::from(vec![
430 Some(1f32),
431 Some(f32::MAX),
432 Some(f32::MAX),
433 None,
434 Some(f32::NAN),
435 Some(f32::INFINITY),
436 Some(f32::INFINITY),
437 Some(0.0),
438 Some(0.0),
439 ]),
440 as_float32_array
441 );
442
443 eval_array_ansi_mode!(
444 Float64Array::from(vec![
445 Some(-1f64),
446 Some(f64::MIN),
447 Some(f64::MAX),
448 None,
449 Some(f64::NAN),
450 Some(f64::INFINITY),
451 Some(f64::NEG_INFINITY),
452 Some(0.0),
453 Some(-0.0),
454 ]),
455 Float64Array::from(vec![
456 Some(1f64),
457 Some(f64::MAX),
458 Some(f64::MAX),
459 None,
460 Some(f64::NAN),
461 Some(f64::INFINITY),
462 Some(f64::INFINITY),
463 Some(0.0),
464 Some(0.0),
465 ]),
466 as_float64_array
467 );
468
469 eval_array_ansi_mode!(
471 Decimal128Array::from(vec![Some(-1), Some(-2), Some(i128::MIN + 1)])
472 .with_precision_and_scale(38, 37)
473 .unwrap(),
474 Decimal128Array::from(vec![Some(1), Some(2), Some(i128::MAX)])
475 .with_precision_and_scale(38, 37)
476 .unwrap(),
477 as_decimal128_array
478 );
479
480 eval_array_ansi_mode!(
481 Decimal256Array::from(vec![
482 Some(i256::MINUS_ONE),
483 Some(i256::from(-2)),
484 Some(i256::MIN + i256::from(1))
485 ])
486 .with_precision_and_scale(18, 7)
487 .unwrap(),
488 Decimal256Array::from(vec![
489 Some(i256::ONE),
490 Some(i256::from(2)),
491 Some(i256::MAX)
492 ])
493 .with_precision_and_scale(18, 7)
494 .unwrap(),
495 as_decimal256_array
496 );
497
498 eval_array_ansi_mode!(
500 Decimal128Array::from(vec![Some(i128::MIN), None])
501 .with_precision_and_scale(38, 37)
502 .unwrap()
503 );
504 eval_array_ansi_mode!(
505 Decimal256Array::from(vec![Some(i256::MIN), None])
506 .with_precision_and_scale(5, 2)
507 .unwrap()
508 );
509 }
510
511 #[test]
512 fn test_abs_nullability() {
513 let abs = SparkAbs::new();
514
515 let non_nullable_i32 = Arc::new(Field::new("c", DataType::Int32, false));
517 let out_non_null = abs
518 .return_field_from_args(ReturnFieldArgs {
519 arg_fields: &[Arc::clone(&non_nullable_i32)],
520 scalar_arguments: &[None],
521 })
522 .unwrap();
523
524 assert!(!out_non_null.is_nullable());
526 assert_eq!(out_non_null.data_type(), &DataType::Int32);
527
528 let nullable_i32 = Arc::new(Field::new("c", DataType::Int32, true));
530 let out_nullable = abs
531 .return_field_from_args(ReturnFieldArgs {
532 arg_fields: &[Arc::clone(&nullable_i32)],
533 scalar_arguments: &[None],
534 })
535 .unwrap();
536
537 assert!(out_nullable.is_nullable());
539 assert_eq!(out_nullable.data_type(), &DataType::Int32);
540
541 let non_nullable_f64 = Arc::new(Field::new("c", DataType::Float64, false));
543 let out_f64 = abs
544 .return_field_from_args(ReturnFieldArgs {
545 arg_fields: &[Arc::clone(&non_nullable_f64)],
546 scalar_arguments: &[None],
547 })
548 .unwrap();
549
550 assert!(!out_f64.is_nullable());
551 assert_eq!(out_f64.data_type(), &DataType::Float64);
552
553 let nullable_f64 = Arc::new(Field::new("c", DataType::Float64, true));
555 let out_f64_null = abs
556 .return_field_from_args(ReturnFieldArgs {
557 arg_fields: &[Arc::clone(&nullable_f64)],
558 scalar_arguments: &[None],
559 })
560 .unwrap();
561
562 assert!(out_f64_null.is_nullable());
563 assert_eq!(out_f64_null.data_type(), &DataType::Float64);
564 }
565}