Skip to main content

datafusion_spark/function/math/
abs.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::*;
19use arrow::datatypes::{DataType, Field, FieldRef};
20use arrow::error::ArrowError;
21use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err};
22use datafusion_expr::{
23    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
24    Volatility,
25};
26use datafusion_functions::{
27    downcast_named_arg, make_abs_function, make_try_abs_function,
28    make_wrapping_abs_function,
29};
30use std::any::Any;
31use std::sync::Arc;
32
33/// Spark-compatible `abs` expression
34/// <https://spark.apache.org/docs/latest/api/sql/index.html#abs>
35///
36/// Returns the absolute value of input
37/// Returns NULL if input is NULL, returns NaN if input is NaN.
38///
39/// Differences with DataFusion abs:
40///  - Spark's ANSI-compliant dialect, when off (i.e. `spark.sql.ansi.enabled=false`), taking absolute value on the minimal value of a signed integer returns the value as is. DataFusion's abs throws "DataFusion error: Arrow error: Compute error" on arithmetic overflow
41///
42/// TODOs:
43///  - Spark's abs also supports ANSI interval types: YearMonthIntervalType and DayTimeIntervalType. DataFusion's abs doesn't.
44///
45#[derive(Debug, PartialEq, Eq, Hash)]
46pub struct SparkAbs {
47    signature: Signature,
48}
49
50impl Default for SparkAbs {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl SparkAbs {
57    pub fn new() -> Self {
58        Self {
59            signature: Signature::numeric(1, Volatility::Immutable),
60        }
61    }
62}
63
64impl ScalarUDFImpl for SparkAbs {
65    fn as_any(&self) -> &dyn Any {
66        self
67    }
68
69    fn name(&self) -> &str {
70        "abs"
71    }
72
73    fn signature(&self) -> &Signature {
74        &self.signature
75    }
76
77    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
78        internal_err!(
79            "SparkAbs: return_type() is not used; return_field_from_args() is implemented"
80        )
81    }
82
83    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
84        let input_field = &args.arg_fields[0];
85        let out_dt = input_field.data_type().clone();
86        let out_nullable = input_field.is_nullable();
87
88        Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable)))
89    }
90
91    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
92        spark_abs(&args.args, args.config_options.execution.enable_ansi_mode)
93    }
94}
95
96macro_rules! scalar_compute_op {
97    ($ENABLE_ANSI_MODE:expr, $INPUT:ident, $SCALAR_TYPE:ident) => {{
98        let result = if $ENABLE_ANSI_MODE {
99            $INPUT.checked_abs().ok_or_else(|| {
100                ArrowError::ComputeError(format!(
101                    "{} overflow on abs({:?})",
102                    stringify!($SCALAR_TYPE),
103                    $INPUT
104                ))
105            })?
106        } else {
107            $INPUT.wrapping_abs()
108        };
109        Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(Some(
110            result,
111        ))))
112    }};
113    ($ENABLE_ANSI_MODE:expr, $INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{
114        let result = if $ENABLE_ANSI_MODE {
115            $INPUT.checked_abs().ok_or_else(|| {
116                ArrowError::ComputeError(format!(
117                    "{} overflow on abs({:?})",
118                    stringify!($SCALAR_TYPE),
119                    $INPUT
120                ))
121            })?
122        } else {
123            $INPUT.wrapping_abs()
124        };
125        Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(
126            Some(result),
127            $PRECISION,
128            $SCALE,
129        )))
130    }};
131}
132
133pub fn spark_abs(
134    args: &[ColumnarValue],
135    enable_ansi_mode: bool,
136) -> Result<ColumnarValue, DataFusionError> {
137    if args.len() != 1 {
138        return internal_err!("abs takes exactly 1 argument, but got: {}", args.len());
139    }
140
141    match &args[0] {
142        ColumnarValue::Array(array) => match array.data_type() {
143            DataType::Null
144            | DataType::UInt8
145            | DataType::UInt16
146            | DataType::UInt32
147            | DataType::UInt64 => Ok(args[0].clone()),
148            DataType::Int8 => {
149                let abs_fun = if enable_ansi_mode {
150                    make_try_abs_function!(Int8Array)
151                } else {
152                    make_wrapping_abs_function!(Int8Array)
153                };
154                abs_fun(array).map(ColumnarValue::Array)
155            }
156            DataType::Int16 => {
157                let abs_fun = if enable_ansi_mode {
158                    make_try_abs_function!(Int16Array)
159                } else {
160                    make_wrapping_abs_function!(Int16Array)
161                };
162                abs_fun(array).map(ColumnarValue::Array)
163            }
164            DataType::Int32 => {
165                let abs_fun = if enable_ansi_mode {
166                    make_try_abs_function!(Int32Array)
167                } else {
168                    make_wrapping_abs_function!(Int32Array)
169                };
170                abs_fun(array).map(ColumnarValue::Array)
171            }
172            DataType::Int64 => {
173                let abs_fun = if enable_ansi_mode {
174                    make_try_abs_function!(Int64Array)
175                } else {
176                    make_wrapping_abs_function!(Int64Array)
177                };
178                abs_fun(array).map(ColumnarValue::Array)
179            }
180            DataType::Float32 => {
181                let abs_fun = make_abs_function!(Float32Array);
182                abs_fun(array).map(ColumnarValue::Array)
183            }
184            DataType::Float64 => {
185                let abs_fun = make_abs_function!(Float64Array);
186                abs_fun(array).map(ColumnarValue::Array)
187            }
188            DataType::Decimal128(_, _) => {
189                let abs_fun = if enable_ansi_mode {
190                    make_try_abs_function!(Decimal128Array)
191                } else {
192                    make_wrapping_abs_function!(Decimal128Array)
193                };
194                abs_fun(array).map(ColumnarValue::Array)
195            }
196            DataType::Decimal256(_, _) => {
197                let abs_fun = if enable_ansi_mode {
198                    make_try_abs_function!(Decimal256Array)
199                } else {
200                    make_wrapping_abs_function!(Decimal256Array)
201                };
202                abs_fun(array).map(ColumnarValue::Array)
203            }
204            dt => internal_err!("Not supported datatype for Spark ABS: {dt}"),
205        },
206        ColumnarValue::Scalar(sv) => match sv {
207            ScalarValue::Null
208            | ScalarValue::UInt8(_)
209            | ScalarValue::UInt16(_)
210            | ScalarValue::UInt32(_)
211            | ScalarValue::UInt64(_) => Ok(args[0].clone()),
212            sv if sv.is_null() => Ok(args[0].clone()),
213            ScalarValue::Int8(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int8),
214            ScalarValue::Int16(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int16),
215            ScalarValue::Int32(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int32),
216            ScalarValue::Int64(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int64),
217            ScalarValue::Float32(Some(v)) => {
218                Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(v.abs()))))
219            }
220            ScalarValue::Float64(Some(v)) => {
221                Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(v.abs()))))
222            }
223            ScalarValue::Decimal128(Some(v), precision, scale) => {
224                scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal128)
225            }
226            ScalarValue::Decimal256(Some(v), precision, scale) => {
227                scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal256)
228            }
229            dt => internal_err!("Not supported datatype for Spark ABS: {dt}"),
230        },
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use arrow::datatypes::i256;
238
239    macro_rules! eval_array_legacy_mode {
240        ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{
241            let input = $INPUT;
242            let args = ColumnarValue::Array(Arc::new(input));
243            let expected = $OUTPUT;
244            match spark_abs(&[args], false) {
245                Ok(ColumnarValue::Array(result)) => {
246                    let actual = datafusion_common::cast::$FUNC(&result).unwrap();
247                    assert_eq!(actual, &expected);
248                }
249                _ => unreachable!(),
250            }
251        }};
252    }
253
254    #[test]
255    fn test_abs_array_legacy_mode() {
256        eval_array_legacy_mode!(
257            Int8Array::from(vec![Some(-1), Some(i8::MIN), Some(i8::MAX), None]),
258            Int8Array::from(vec![Some(1), Some(i8::MIN), Some(i8::MAX), None]),
259            as_int8_array
260        );
261
262        eval_array_legacy_mode!(
263            Int16Array::from(vec![Some(-1), Some(i16::MIN), Some(i16::MAX), None]),
264            Int16Array::from(vec![Some(1), Some(i16::MIN), Some(i16::MAX), None]),
265            as_int16_array
266        );
267
268        eval_array_legacy_mode!(
269            Int32Array::from(vec![Some(-1), Some(i32::MIN), Some(i32::MAX), None]),
270            Int32Array::from(vec![Some(1), Some(i32::MIN), Some(i32::MAX), None]),
271            as_int32_array
272        );
273
274        eval_array_legacy_mode!(
275            Int64Array::from(vec![Some(-1), Some(i64::MIN), Some(i64::MAX), None]),
276            Int64Array::from(vec![Some(1), Some(i64::MIN), Some(i64::MAX), None]),
277            as_int64_array
278        );
279
280        eval_array_legacy_mode!(
281            Float32Array::from(vec![
282                Some(-1f32),
283                Some(f32::MIN),
284                Some(f32::MAX),
285                None,
286                Some(f32::NAN),
287                Some(f32::INFINITY),
288                Some(f32::NEG_INFINITY),
289                Some(0.0),
290                Some(-0.0),
291            ]),
292            Float32Array::from(vec![
293                Some(1f32),
294                Some(f32::MAX),
295                Some(f32::MAX),
296                None,
297                Some(f32::NAN),
298                Some(f32::INFINITY),
299                Some(f32::INFINITY),
300                Some(0.0),
301                Some(0.0),
302            ]),
303            as_float32_array
304        );
305
306        eval_array_legacy_mode!(
307            Float64Array::from(vec![
308                Some(-1f64),
309                Some(f64::MIN),
310                Some(f64::MAX),
311                None,
312                Some(f64::NAN),
313                Some(f64::INFINITY),
314                Some(f64::NEG_INFINITY),
315                Some(0.0),
316                Some(-0.0),
317            ]),
318            Float64Array::from(vec![
319                Some(1f64),
320                Some(f64::MAX),
321                Some(f64::MAX),
322                None,
323                Some(f64::NAN),
324                Some(f64::INFINITY),
325                Some(f64::INFINITY),
326                Some(0.0),
327                Some(0.0),
328            ]),
329            as_float64_array
330        );
331
332        eval_array_legacy_mode!(
333            Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MIN + 1), None])
334                .with_precision_and_scale(38, 37)
335                .unwrap(),
336            Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MAX), None])
337                .with_precision_and_scale(38, 37)
338                .unwrap(),
339            as_decimal128_array
340        );
341
342        eval_array_legacy_mode!(
343            Decimal256Array::from(vec![
344                Some(i256::MIN),
345                Some(i256::MINUS_ONE),
346                Some(i256::MIN + i256::from(1)),
347                None
348            ])
349            .with_precision_and_scale(5, 2)
350            .unwrap(),
351            Decimal256Array::from(vec![
352                Some(i256::MIN),
353                Some(i256::ONE),
354                Some(i256::MAX),
355                None
356            ])
357            .with_precision_and_scale(5, 2)
358            .unwrap(),
359            as_decimal256_array
360        );
361    }
362
363    macro_rules! eval_array_ansi_mode {
364        ($INPUT:expr) => {{
365            let input = $INPUT;
366            let args = ColumnarValue::Array(Arc::new(input));
367            match spark_abs(&[args], true) {
368                Err(e) => {
369                    assert!(
370                        e.to_string().contains("overflow on abs"),
371                        "Error message did not match. Actual message: {e}"
372                    );
373                }
374                _ => unreachable!(),
375            }
376        }};
377        ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{
378            let input = $INPUT;
379            let args = ColumnarValue::Array(Arc::new(input));
380            let expected = $OUTPUT;
381            match spark_abs(&[args], true) {
382                Ok(ColumnarValue::Array(result)) => {
383                    let actual = datafusion_common::cast::$FUNC(&result).unwrap();
384                    assert_eq!(actual, &expected);
385                }
386                _ => unreachable!(),
387            }
388        }};
389    }
390    #[test]
391    fn test_abs_array_ansi_mode() {
392        eval_array_ansi_mode!(
393            UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]),
394            UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]),
395            as_uint64_array
396        );
397
398        eval_array_ansi_mode!(Int8Array::from(vec![
399            Some(-1),
400            Some(i8::MIN),
401            Some(i8::MAX),
402            None
403        ]));
404        eval_array_ansi_mode!(Int16Array::from(vec![
405            Some(-1),
406            Some(i16::MIN),
407            Some(i16::MAX),
408            None
409        ]));
410        eval_array_ansi_mode!(Int32Array::from(vec![
411            Some(-1),
412            Some(i32::MIN),
413            Some(i32::MAX),
414            None
415        ]));
416        eval_array_ansi_mode!(Int64Array::from(vec![
417            Some(-1),
418            Some(i64::MIN),
419            Some(i64::MAX),
420            None
421        ]));
422        eval_array_ansi_mode!(
423            Float32Array::from(vec![
424                Some(-1f32),
425                Some(f32::MIN),
426                Some(f32::MAX),
427                None,
428                Some(f32::NAN),
429                Some(f32::INFINITY),
430                Some(f32::NEG_INFINITY),
431                Some(0.0),
432                Some(-0.0),
433            ]),
434            Float32Array::from(vec![
435                Some(1f32),
436                Some(f32::MAX),
437                Some(f32::MAX),
438                None,
439                Some(f32::NAN),
440                Some(f32::INFINITY),
441                Some(f32::INFINITY),
442                Some(0.0),
443                Some(0.0),
444            ]),
445            as_float32_array
446        );
447
448        eval_array_ansi_mode!(
449            Float64Array::from(vec![
450                Some(-1f64),
451                Some(f64::MIN),
452                Some(f64::MAX),
453                None,
454                Some(f64::NAN),
455                Some(f64::INFINITY),
456                Some(f64::NEG_INFINITY),
457                Some(0.0),
458                Some(-0.0),
459            ]),
460            Float64Array::from(vec![
461                Some(1f64),
462                Some(f64::MAX),
463                Some(f64::MAX),
464                None,
465                Some(f64::NAN),
466                Some(f64::INFINITY),
467                Some(f64::INFINITY),
468                Some(0.0),
469                Some(0.0),
470            ]),
471            as_float64_array
472        );
473
474        // decimal: no arithmetic overflow
475        eval_array_ansi_mode!(
476            Decimal128Array::from(vec![Some(-1), Some(-2), Some(i128::MIN + 1)])
477                .with_precision_and_scale(38, 37)
478                .unwrap(),
479            Decimal128Array::from(vec![Some(1), Some(2), Some(i128::MAX)])
480                .with_precision_and_scale(38, 37)
481                .unwrap(),
482            as_decimal128_array
483        );
484
485        eval_array_ansi_mode!(
486            Decimal256Array::from(vec![
487                Some(i256::MINUS_ONE),
488                Some(i256::from(-2)),
489                Some(i256::MIN + i256::from(1))
490            ])
491            .with_precision_and_scale(18, 7)
492            .unwrap(),
493            Decimal256Array::from(vec![
494                Some(i256::ONE),
495                Some(i256::from(2)),
496                Some(i256::MAX)
497            ])
498            .with_precision_and_scale(18, 7)
499            .unwrap(),
500            as_decimal256_array
501        );
502
503        // decimal: arithmetic overflow
504        eval_array_ansi_mode!(
505            Decimal128Array::from(vec![Some(i128::MIN), None])
506                .with_precision_and_scale(38, 37)
507                .unwrap()
508        );
509        eval_array_ansi_mode!(
510            Decimal256Array::from(vec![Some(i256::MIN), None])
511                .with_precision_and_scale(5, 2)
512                .unwrap()
513        );
514    }
515
516    #[test]
517    fn test_abs_nullability() {
518        use arrow::datatypes::{DataType, Field};
519        use datafusion_expr::ReturnFieldArgs;
520        use std::sync::Arc;
521
522        let abs = SparkAbs::new();
523
524        // --- non-nullable Int32 input ---
525        let non_nullable_i32 = Arc::new(Field::new("c", DataType::Int32, false));
526        let out_non_null = abs
527            .return_field_from_args(ReturnFieldArgs {
528                arg_fields: &[Arc::clone(&non_nullable_i32)],
529                scalar_arguments: &[None],
530            })
531            .unwrap();
532
533        // result should be non-nullable and the same DataType as input
534        assert!(!out_non_null.is_nullable());
535        assert_eq!(out_non_null.data_type(), &DataType::Int32);
536
537        // --- nullable Int32 input ---
538        let nullable_i32 = Arc::new(Field::new("c", DataType::Int32, true));
539        let out_nullable = abs
540            .return_field_from_args(ReturnFieldArgs {
541                arg_fields: &[Arc::clone(&nullable_i32)],
542                scalar_arguments: &[None],
543            })
544            .unwrap();
545
546        // result should be nullable and the same DataType as input
547        assert!(out_nullable.is_nullable());
548        assert_eq!(out_nullable.data_type(), &DataType::Int32);
549
550        // --- non-nullable Float64 input ---
551        let non_nullable_f64 = Arc::new(Field::new("c", DataType::Float64, false));
552        let out_f64 = abs
553            .return_field_from_args(ReturnFieldArgs {
554                arg_fields: &[Arc::clone(&non_nullable_f64)],
555                scalar_arguments: &[None],
556            })
557            .unwrap();
558
559        assert!(!out_f64.is_nullable());
560        assert_eq!(out_f64.data_type(), &DataType::Float64);
561
562        // --- nullable Float64 input ---
563        let nullable_f64 = Arc::new(Field::new("c", DataType::Float64, true));
564        let out_f64_null = abs
565            .return_field_from_args(ReturnFieldArgs {
566                arg_fields: &[Arc::clone(&nullable_f64)],
567                scalar_arguments: &[None],
568            })
569            .unwrap();
570
571        assert!(out_f64_null.is_nullable());
572        assert_eq!(out_f64_null.data_type(), &DataType::Float64);
573    }
574}