Skip to main content

datafusion_spark/function/math/
abs.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::*;
19use arrow::datatypes::{DataType, Field, FieldRef};
20use arrow::error::ArrowError;
21use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err};
22use datafusion_expr::{
23    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
24    Volatility,
25};
26use datafusion_functions::{
27    downcast_named_arg, make_abs_function, make_try_abs_function,
28    make_wrapping_abs_function,
29};
30use std::sync::Arc;
31
32/// Spark-compatible `abs` expression
33/// <https://spark.apache.org/docs/latest/api/sql/index.html#abs>
34///
35/// Returns the absolute value of input
36/// Returns NULL if input is NULL, returns NaN if input is NaN.
37///
38/// Differences with DataFusion abs:
39///  - Spark's ANSI-compliant dialect, when off (i.e. `spark.sql.ansi.enabled=false`), taking absolute value on the minimal value of a signed integer returns the value as is. DataFusion's abs throws "DataFusion error: Arrow error: Compute error" on arithmetic overflow
40///
41/// TODOs:
42///  - Spark's abs also supports ANSI interval types: YearMonthIntervalType and DayTimeIntervalType. DataFusion's abs doesn't.
43///
44#[derive(Debug, PartialEq, Eq, Hash)]
45pub struct SparkAbs {
46    signature: Signature,
47}
48
49impl Default for SparkAbs {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl SparkAbs {
56    pub fn new() -> Self {
57        Self {
58            signature: Signature::numeric(1, Volatility::Immutable),
59        }
60    }
61}
62
63impl ScalarUDFImpl for SparkAbs {
64    fn name(&self) -> &str {
65        "abs"
66    }
67
68    fn signature(&self) -> &Signature {
69        &self.signature
70    }
71
72    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
73        internal_err!(
74            "SparkAbs: return_type() is not used; return_field_from_args() is implemented"
75        )
76    }
77
78    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
79        let input_field = &args.arg_fields[0];
80        let out_dt = input_field.data_type().clone();
81        let out_nullable = input_field.is_nullable();
82
83        Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable)))
84    }
85
86    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
87        spark_abs(&args.args, args.config_options.execution.enable_ansi_mode)
88    }
89}
90
91macro_rules! scalar_compute_op {
92    ($ENABLE_ANSI_MODE:expr, $INPUT:ident, $SCALAR_TYPE:ident) => {{
93        let result = if $ENABLE_ANSI_MODE {
94            $INPUT.checked_abs().ok_or_else(|| {
95                ArrowError::ComputeError(format!(
96                    "{} overflow on abs({:?})",
97                    stringify!($SCALAR_TYPE),
98                    $INPUT
99                ))
100            })?
101        } else {
102            $INPUT.wrapping_abs()
103        };
104        Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(Some(
105            result,
106        ))))
107    }};
108    ($ENABLE_ANSI_MODE:expr, $INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{
109        let result = if $ENABLE_ANSI_MODE {
110            $INPUT.checked_abs().ok_or_else(|| {
111                ArrowError::ComputeError(format!(
112                    "{} overflow on abs({:?})",
113                    stringify!($SCALAR_TYPE),
114                    $INPUT
115                ))
116            })?
117        } else {
118            $INPUT.wrapping_abs()
119        };
120        Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(
121            Some(result),
122            $PRECISION,
123            $SCALE,
124        )))
125    }};
126}
127
128pub fn spark_abs(
129    args: &[ColumnarValue],
130    enable_ansi_mode: bool,
131) -> Result<ColumnarValue, DataFusionError> {
132    if args.len() != 1 {
133        return internal_err!("abs takes exactly 1 argument, but got: {}", args.len());
134    }
135
136    match &args[0] {
137        ColumnarValue::Array(array) => match array.data_type() {
138            DataType::Null
139            | DataType::UInt8
140            | DataType::UInt16
141            | DataType::UInt32
142            | DataType::UInt64 => Ok(args[0].clone()),
143            DataType::Int8 => {
144                let abs_fun = if enable_ansi_mode {
145                    make_try_abs_function!(Int8Array)
146                } else {
147                    make_wrapping_abs_function!(Int8Array)
148                };
149                abs_fun(array).map(ColumnarValue::Array)
150            }
151            DataType::Int16 => {
152                let abs_fun = if enable_ansi_mode {
153                    make_try_abs_function!(Int16Array)
154                } else {
155                    make_wrapping_abs_function!(Int16Array)
156                };
157                abs_fun(array).map(ColumnarValue::Array)
158            }
159            DataType::Int32 => {
160                let abs_fun = if enable_ansi_mode {
161                    make_try_abs_function!(Int32Array)
162                } else {
163                    make_wrapping_abs_function!(Int32Array)
164                };
165                abs_fun(array).map(ColumnarValue::Array)
166            }
167            DataType::Int64 => {
168                let abs_fun = if enable_ansi_mode {
169                    make_try_abs_function!(Int64Array)
170                } else {
171                    make_wrapping_abs_function!(Int64Array)
172                };
173                abs_fun(array).map(ColumnarValue::Array)
174            }
175            DataType::Float32 => {
176                let abs_fun = make_abs_function!(Float32Array);
177                abs_fun(array).map(ColumnarValue::Array)
178            }
179            DataType::Float64 => {
180                let abs_fun = make_abs_function!(Float64Array);
181                abs_fun(array).map(ColumnarValue::Array)
182            }
183            DataType::Decimal128(_, _) => {
184                let abs_fun = if enable_ansi_mode {
185                    make_try_abs_function!(Decimal128Array)
186                } else {
187                    make_wrapping_abs_function!(Decimal128Array)
188                };
189                abs_fun(array).map(ColumnarValue::Array)
190            }
191            DataType::Decimal256(_, _) => {
192                let abs_fun = if enable_ansi_mode {
193                    make_try_abs_function!(Decimal256Array)
194                } else {
195                    make_wrapping_abs_function!(Decimal256Array)
196                };
197                abs_fun(array).map(ColumnarValue::Array)
198            }
199            dt => internal_err!("Not supported datatype for Spark ABS: {dt}"),
200        },
201        ColumnarValue::Scalar(sv) => match sv {
202            ScalarValue::Null
203            | ScalarValue::UInt8(_)
204            | ScalarValue::UInt16(_)
205            | ScalarValue::UInt32(_)
206            | ScalarValue::UInt64(_) => Ok(args[0].clone()),
207            sv if sv.is_null() => Ok(args[0].clone()),
208            ScalarValue::Int8(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int8),
209            ScalarValue::Int16(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int16),
210            ScalarValue::Int32(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int32),
211            ScalarValue::Int64(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int64),
212            ScalarValue::Float32(Some(v)) => {
213                Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(v.abs()))))
214            }
215            ScalarValue::Float64(Some(v)) => {
216                Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(v.abs()))))
217            }
218            ScalarValue::Decimal128(Some(v), precision, scale) => {
219                scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal128)
220            }
221            ScalarValue::Decimal256(Some(v), precision, scale) => {
222                scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal256)
223            }
224            dt => internal_err!("Not supported datatype for Spark ABS: {dt}"),
225        },
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use arrow::datatypes::i256;
233
234    macro_rules! eval_array_legacy_mode {
235        ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{
236            let input = $INPUT;
237            let args = ColumnarValue::Array(Arc::new(input));
238            let expected = $OUTPUT;
239            match spark_abs(&[args], false) {
240                Ok(ColumnarValue::Array(result)) => {
241                    let actual = datafusion_common::cast::$FUNC(&result).unwrap();
242                    assert_eq!(actual, &expected);
243                }
244                _ => unreachable!(),
245            }
246        }};
247    }
248
249    #[test]
250    fn test_abs_array_legacy_mode() {
251        eval_array_legacy_mode!(
252            Int8Array::from(vec![Some(-1), Some(i8::MIN), Some(i8::MAX), None]),
253            Int8Array::from(vec![Some(1), Some(i8::MIN), Some(i8::MAX), None]),
254            as_int8_array
255        );
256
257        eval_array_legacy_mode!(
258            Int16Array::from(vec![Some(-1), Some(i16::MIN), Some(i16::MAX), None]),
259            Int16Array::from(vec![Some(1), Some(i16::MIN), Some(i16::MAX), None]),
260            as_int16_array
261        );
262
263        eval_array_legacy_mode!(
264            Int32Array::from(vec![Some(-1), Some(i32::MIN), Some(i32::MAX), None]),
265            Int32Array::from(vec![Some(1), Some(i32::MIN), Some(i32::MAX), None]),
266            as_int32_array
267        );
268
269        eval_array_legacy_mode!(
270            Int64Array::from(vec![Some(-1), Some(i64::MIN), Some(i64::MAX), None]),
271            Int64Array::from(vec![Some(1), Some(i64::MIN), Some(i64::MAX), None]),
272            as_int64_array
273        );
274
275        eval_array_legacy_mode!(
276            Float32Array::from(vec![
277                Some(-1f32),
278                Some(f32::MIN),
279                Some(f32::MAX),
280                None,
281                Some(f32::NAN),
282                Some(f32::INFINITY),
283                Some(f32::NEG_INFINITY),
284                Some(0.0),
285                Some(-0.0),
286            ]),
287            Float32Array::from(vec![
288                Some(1f32),
289                Some(f32::MAX),
290                Some(f32::MAX),
291                None,
292                Some(f32::NAN),
293                Some(f32::INFINITY),
294                Some(f32::INFINITY),
295                Some(0.0),
296                Some(0.0),
297            ]),
298            as_float32_array
299        );
300
301        eval_array_legacy_mode!(
302            Float64Array::from(vec![
303                Some(-1f64),
304                Some(f64::MIN),
305                Some(f64::MAX),
306                None,
307                Some(f64::NAN),
308                Some(f64::INFINITY),
309                Some(f64::NEG_INFINITY),
310                Some(0.0),
311                Some(-0.0),
312            ]),
313            Float64Array::from(vec![
314                Some(1f64),
315                Some(f64::MAX),
316                Some(f64::MAX),
317                None,
318                Some(f64::NAN),
319                Some(f64::INFINITY),
320                Some(f64::INFINITY),
321                Some(0.0),
322                Some(0.0),
323            ]),
324            as_float64_array
325        );
326
327        eval_array_legacy_mode!(
328            Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MIN + 1), None])
329                .with_precision_and_scale(38, 37)
330                .unwrap(),
331            Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MAX), None])
332                .with_precision_and_scale(38, 37)
333                .unwrap(),
334            as_decimal128_array
335        );
336
337        eval_array_legacy_mode!(
338            Decimal256Array::from(vec![
339                Some(i256::MIN),
340                Some(i256::MINUS_ONE),
341                Some(i256::MIN + i256::from(1)),
342                None
343            ])
344            .with_precision_and_scale(5, 2)
345            .unwrap(),
346            Decimal256Array::from(vec![
347                Some(i256::MIN),
348                Some(i256::ONE),
349                Some(i256::MAX),
350                None
351            ])
352            .with_precision_and_scale(5, 2)
353            .unwrap(),
354            as_decimal256_array
355        );
356    }
357
358    macro_rules! eval_array_ansi_mode {
359        ($INPUT:expr) => {{
360            let input = $INPUT;
361            let args = ColumnarValue::Array(Arc::new(input));
362            match spark_abs(&[args], true) {
363                Err(e) => {
364                    assert!(
365                        e.to_string().contains("overflow on abs"),
366                        "Error message did not match. Actual message: {e}"
367                    );
368                }
369                _ => unreachable!(),
370            }
371        }};
372        ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{
373            let input = $INPUT;
374            let args = ColumnarValue::Array(Arc::new(input));
375            let expected = $OUTPUT;
376            match spark_abs(&[args], true) {
377                Ok(ColumnarValue::Array(result)) => {
378                    let actual = datafusion_common::cast::$FUNC(&result).unwrap();
379                    assert_eq!(actual, &expected);
380                }
381                _ => unreachable!(),
382            }
383        }};
384    }
385    #[test]
386    fn test_abs_array_ansi_mode() {
387        eval_array_ansi_mode!(
388            UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]),
389            UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]),
390            as_uint64_array
391        );
392
393        eval_array_ansi_mode!(Int8Array::from(vec![
394            Some(-1),
395            Some(i8::MIN),
396            Some(i8::MAX),
397            None
398        ]));
399        eval_array_ansi_mode!(Int16Array::from(vec![
400            Some(-1),
401            Some(i16::MIN),
402            Some(i16::MAX),
403            None
404        ]));
405        eval_array_ansi_mode!(Int32Array::from(vec![
406            Some(-1),
407            Some(i32::MIN),
408            Some(i32::MAX),
409            None
410        ]));
411        eval_array_ansi_mode!(Int64Array::from(vec![
412            Some(-1),
413            Some(i64::MIN),
414            Some(i64::MAX),
415            None
416        ]));
417        eval_array_ansi_mode!(
418            Float32Array::from(vec![
419                Some(-1f32),
420                Some(f32::MIN),
421                Some(f32::MAX),
422                None,
423                Some(f32::NAN),
424                Some(f32::INFINITY),
425                Some(f32::NEG_INFINITY),
426                Some(0.0),
427                Some(-0.0),
428            ]),
429            Float32Array::from(vec![
430                Some(1f32),
431                Some(f32::MAX),
432                Some(f32::MAX),
433                None,
434                Some(f32::NAN),
435                Some(f32::INFINITY),
436                Some(f32::INFINITY),
437                Some(0.0),
438                Some(0.0),
439            ]),
440            as_float32_array
441        );
442
443        eval_array_ansi_mode!(
444            Float64Array::from(vec![
445                Some(-1f64),
446                Some(f64::MIN),
447                Some(f64::MAX),
448                None,
449                Some(f64::NAN),
450                Some(f64::INFINITY),
451                Some(f64::NEG_INFINITY),
452                Some(0.0),
453                Some(-0.0),
454            ]),
455            Float64Array::from(vec![
456                Some(1f64),
457                Some(f64::MAX),
458                Some(f64::MAX),
459                None,
460                Some(f64::NAN),
461                Some(f64::INFINITY),
462                Some(f64::INFINITY),
463                Some(0.0),
464                Some(0.0),
465            ]),
466            as_float64_array
467        );
468
469        // decimal: no arithmetic overflow
470        eval_array_ansi_mode!(
471            Decimal128Array::from(vec![Some(-1), Some(-2), Some(i128::MIN + 1)])
472                .with_precision_and_scale(38, 37)
473                .unwrap(),
474            Decimal128Array::from(vec![Some(1), Some(2), Some(i128::MAX)])
475                .with_precision_and_scale(38, 37)
476                .unwrap(),
477            as_decimal128_array
478        );
479
480        eval_array_ansi_mode!(
481            Decimal256Array::from(vec![
482                Some(i256::MINUS_ONE),
483                Some(i256::from(-2)),
484                Some(i256::MIN + i256::from(1))
485            ])
486            .with_precision_and_scale(18, 7)
487            .unwrap(),
488            Decimal256Array::from(vec![
489                Some(i256::ONE),
490                Some(i256::from(2)),
491                Some(i256::MAX)
492            ])
493            .with_precision_and_scale(18, 7)
494            .unwrap(),
495            as_decimal256_array
496        );
497
498        // decimal: arithmetic overflow
499        eval_array_ansi_mode!(
500            Decimal128Array::from(vec![Some(i128::MIN), None])
501                .with_precision_and_scale(38, 37)
502                .unwrap()
503        );
504        eval_array_ansi_mode!(
505            Decimal256Array::from(vec![Some(i256::MIN), None])
506                .with_precision_and_scale(5, 2)
507                .unwrap()
508        );
509    }
510
511    #[test]
512    fn test_abs_nullability() {
513        let abs = SparkAbs::new();
514
515        // --- non-nullable Int32 input ---
516        let non_nullable_i32 = Arc::new(Field::new("c", DataType::Int32, false));
517        let out_non_null = abs
518            .return_field_from_args(ReturnFieldArgs {
519                arg_fields: &[Arc::clone(&non_nullable_i32)],
520                scalar_arguments: &[None],
521            })
522            .unwrap();
523
524        // result should be non-nullable and the same DataType as input
525        assert!(!out_non_null.is_nullable());
526        assert_eq!(out_non_null.data_type(), &DataType::Int32);
527
528        // --- nullable Int32 input ---
529        let nullable_i32 = Arc::new(Field::new("c", DataType::Int32, true));
530        let out_nullable = abs
531            .return_field_from_args(ReturnFieldArgs {
532                arg_fields: &[Arc::clone(&nullable_i32)],
533                scalar_arguments: &[None],
534            })
535            .unwrap();
536
537        // result should be nullable and the same DataType as input
538        assert!(out_nullable.is_nullable());
539        assert_eq!(out_nullable.data_type(), &DataType::Int32);
540
541        // --- non-nullable Float64 input ---
542        let non_nullable_f64 = Arc::new(Field::new("c", DataType::Float64, false));
543        let out_f64 = abs
544            .return_field_from_args(ReturnFieldArgs {
545                arg_fields: &[Arc::clone(&non_nullable_f64)],
546                scalar_arguments: &[None],
547            })
548            .unwrap();
549
550        assert!(!out_f64.is_nullable());
551        assert_eq!(out_f64.data_type(), &DataType::Float64);
552
553        // --- nullable Float64 input ---
554        let nullable_f64 = Arc::new(Field::new("c", DataType::Float64, true));
555        let out_f64_null = abs
556            .return_field_from_args(ReturnFieldArgs {
557                arg_fields: &[Arc::clone(&nullable_f64)],
558                scalar_arguments: &[None],
559            })
560            .unwrap();
561
562        assert!(out_f64_null.is_nullable());
563        assert_eq!(out_f64_null.data_type(), &DataType::Float64);
564    }
565}