Skip to main content

datafusion_spark/function/math/
abs.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::*;
19use arrow::datatypes::{DataType, Field, FieldRef};
20use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err};
21use datafusion_expr::{
22    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
23    Volatility,
24};
25use datafusion_functions::{
26    downcast_named_arg, make_abs_function, make_wrapping_abs_function,
27};
28use std::any::Any;
29use std::sync::Arc;
30
31/// Spark-compatible `abs` expression
32/// <https://spark.apache.org/docs/latest/api/sql/index.html#abs>
33///
34/// Returns the absolute value of input
35/// Returns NULL if input is NULL, returns NaN if input is NaN.
36///
37/// TODOs:
38///  - Spark's ANSI-compliant dialect, when off (i.e. `spark.sql.ansi.enabled=false`), taking absolute value on the minimal value of a signed integer returns the value as is. DataFusion's abs throws "DataFusion error: Arrow error: Compute error" on arithmetic overflow
39///  - Spark's abs also supports ANSI interval types: YearMonthIntervalType and DayTimeIntervalType. DataFusion's abs doesn't.
40///
41#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkAbs {
43    signature: Signature,
44}
45
46impl Default for SparkAbs {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl SparkAbs {
53    pub fn new() -> Self {
54        Self {
55            signature: Signature::numeric(1, Volatility::Immutable),
56        }
57    }
58}
59
60impl ScalarUDFImpl for SparkAbs {
61    fn as_any(&self) -> &dyn Any {
62        self
63    }
64
65    fn name(&self) -> &str {
66        "abs"
67    }
68
69    fn signature(&self) -> &Signature {
70        &self.signature
71    }
72
73    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
74        internal_err!(
75            "SparkAbs: return_type() is not used; return_field_from_args() is implemented"
76        )
77    }
78
79    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
80        let input_field = &args.arg_fields[0];
81        let out_dt = input_field.data_type().clone();
82        let out_nullable = input_field.is_nullable();
83
84        Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable)))
85    }
86
87    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
88        spark_abs(&args.args)
89    }
90}
91
92macro_rules! scalar_compute_op {
93    ($INPUT:ident, $SCALAR_TYPE:ident) => {{
94        let result = $INPUT.wrapping_abs();
95        Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(Some(
96            result,
97        ))))
98    }};
99    ($INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{
100        let result = $INPUT.wrapping_abs();
101        Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(
102            Some(result),
103            $PRECISION,
104            $SCALE,
105        )))
106    }};
107}
108
109pub fn spark_abs(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
110    if args.len() != 1 {
111        return internal_err!("abs takes exactly 1 argument, but got: {}", args.len());
112    }
113
114    match &args[0] {
115        ColumnarValue::Array(array) => match array.data_type() {
116            DataType::Null
117            | DataType::UInt8
118            | DataType::UInt16
119            | DataType::UInt32
120            | DataType::UInt64 => Ok(args[0].clone()),
121            DataType::Int8 => {
122                let abs_fun = make_wrapping_abs_function!(Int8Array);
123                abs_fun(array).map(ColumnarValue::Array)
124            }
125            DataType::Int16 => {
126                let abs_fun = make_wrapping_abs_function!(Int16Array);
127                abs_fun(array).map(ColumnarValue::Array)
128            }
129            DataType::Int32 => {
130                let abs_fun = make_wrapping_abs_function!(Int32Array);
131                abs_fun(array).map(ColumnarValue::Array)
132            }
133            DataType::Int64 => {
134                let abs_fun = make_wrapping_abs_function!(Int64Array);
135                abs_fun(array).map(ColumnarValue::Array)
136            }
137            DataType::Float32 => {
138                let abs_fun = make_abs_function!(Float32Array);
139                abs_fun(array).map(ColumnarValue::Array)
140            }
141            DataType::Float64 => {
142                let abs_fun = make_abs_function!(Float64Array);
143                abs_fun(array).map(ColumnarValue::Array)
144            }
145            DataType::Decimal128(_, _) => {
146                let abs_fun = make_wrapping_abs_function!(Decimal128Array);
147                abs_fun(array).map(ColumnarValue::Array)
148            }
149            DataType::Decimal256(_, _) => {
150                let abs_fun = make_wrapping_abs_function!(Decimal256Array);
151                abs_fun(array).map(ColumnarValue::Array)
152            }
153            dt => internal_err!("Not supported datatype for Spark ABS: {dt}"),
154        },
155        ColumnarValue::Scalar(sv) => match sv {
156            ScalarValue::Null
157            | ScalarValue::UInt8(_)
158            | ScalarValue::UInt16(_)
159            | ScalarValue::UInt32(_)
160            | ScalarValue::UInt64(_) => Ok(args[0].clone()),
161            sv if sv.is_null() => Ok(args[0].clone()),
162            ScalarValue::Int8(Some(v)) => scalar_compute_op!(v, Int8),
163            ScalarValue::Int16(Some(v)) => scalar_compute_op!(v, Int16),
164            ScalarValue::Int32(Some(v)) => scalar_compute_op!(v, Int32),
165            ScalarValue::Int64(Some(v)) => scalar_compute_op!(v, Int64),
166            ScalarValue::Float32(Some(v)) => {
167                Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(v.abs()))))
168            }
169            ScalarValue::Float64(Some(v)) => {
170                Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(v.abs()))))
171            }
172            ScalarValue::Decimal128(Some(v), precision, scale) => {
173                scalar_compute_op!(v, *precision, *scale, Decimal128)
174            }
175            ScalarValue::Decimal256(Some(v), precision, scale) => {
176                scalar_compute_op!(v, *precision, *scale, Decimal256)
177            }
178            dt => internal_err!("Not supported datatype for Spark ABS: {dt}"),
179        },
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use arrow::datatypes::i256;
187
188    macro_rules! eval_legacy_mode {
189        ($TYPE:ident, $VAL:expr) => {{
190            let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL)));
191            match spark_abs(&[args]) {
192                Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => {
193                    assert_eq!(result, $VAL);
194                }
195                _ => unreachable!(),
196            }
197        }};
198        ($TYPE:ident, $VAL:expr, $RESULT:expr) => {{
199            let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL)));
200            match spark_abs(&[args]) {
201                Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => {
202                    assert_eq!(result, $RESULT);
203                }
204                _ => unreachable!(),
205            }
206        }};
207        ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr) => {{
208            let args =
209                ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE));
210            match spark_abs(&[args]) {
211                Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(
212                    Some(result),
213                    precision,
214                    scale,
215                ))) => {
216                    assert_eq!(result, $VAL);
217                    assert_eq!(precision, $PRECISION);
218                    assert_eq!(scale, $SCALE);
219                }
220                _ => unreachable!(),
221            }
222        }};
223        ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr, $RESULT:expr) => {{
224            let args =
225                ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE));
226            match spark_abs(&[args]) {
227                Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(
228                    Some(result),
229                    precision,
230                    scale,
231                ))) => {
232                    assert_eq!(result, $RESULT);
233                    assert_eq!(precision, $PRECISION);
234                    assert_eq!(scale, $SCALE);
235                }
236                _ => unreachable!(),
237            }
238        }};
239    }
240
241    #[test]
242    fn test_abs_scalar_legacy_mode() {
243        // NumericType MIN
244        eval_legacy_mode!(UInt8, u8::MIN);
245        eval_legacy_mode!(UInt16, u16::MIN);
246        eval_legacy_mode!(UInt32, u32::MIN);
247        eval_legacy_mode!(UInt64, u64::MIN);
248        eval_legacy_mode!(Int8, i8::MIN);
249        eval_legacy_mode!(Int16, i16::MIN);
250        eval_legacy_mode!(Int32, i32::MIN);
251        eval_legacy_mode!(Int64, i64::MIN);
252        eval_legacy_mode!(Float32, f32::MIN, f32::MAX);
253        eval_legacy_mode!(Float64, f64::MIN, f64::MAX);
254        eval_legacy_mode!(Decimal128, i128::MIN, 18, 10);
255        eval_legacy_mode!(Decimal256, i256::MIN, 10, 2);
256
257        // NumericType not MIN
258        eval_legacy_mode!(Int8, -1i8, 1i8);
259        eval_legacy_mode!(Int16, -1i16, 1i16);
260        eval_legacy_mode!(Int32, -1i32, 1i32);
261        eval_legacy_mode!(Int64, -1i64, 1i64);
262        eval_legacy_mode!(Decimal128, -1i128, 18, 10, 1i128);
263        eval_legacy_mode!(Decimal256, i256::from(-1i8), 10, 2, i256::from(1i8));
264
265        // Float32, Float64
266        eval_legacy_mode!(Float32, f32::NEG_INFINITY, f32::INFINITY);
267        eval_legacy_mode!(Float32, f32::INFINITY, f32::INFINITY);
268        eval_legacy_mode!(Float32, 0.0f32, 0.0f32);
269        eval_legacy_mode!(Float32, -0.0f32, 0.0f32);
270        eval_legacy_mode!(Float64, f64::NEG_INFINITY, f64::INFINITY);
271        eval_legacy_mode!(Float64, f64::INFINITY, f64::INFINITY);
272        eval_legacy_mode!(Float64, 0.0f64, 0.0f64);
273        eval_legacy_mode!(Float64, -0.0f64, 0.0f64);
274    }
275
276    macro_rules! eval_array_legacy_mode {
277        ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{
278            let input = $INPUT;
279            let args = ColumnarValue::Array(Arc::new(input));
280            let expected = $OUTPUT;
281            match spark_abs(&[args]) {
282                Ok(ColumnarValue::Array(result)) => {
283                    let actual = datafusion_common::cast::$FUNC(&result).unwrap();
284                    assert_eq!(actual, &expected);
285                }
286                _ => unreachable!(),
287            }
288        }};
289    }
290
291    #[test]
292    fn test_abs_array_legacy_mode() {
293        eval_array_legacy_mode!(
294            Int8Array::from(vec![Some(-1), Some(i8::MIN), Some(i8::MAX), None]),
295            Int8Array::from(vec![Some(1), Some(i8::MIN), Some(i8::MAX), None]),
296            as_int8_array
297        );
298
299        eval_array_legacy_mode!(
300            Int16Array::from(vec![Some(-1), Some(i16::MIN), Some(i16::MAX), None]),
301            Int16Array::from(vec![Some(1), Some(i16::MIN), Some(i16::MAX), None]),
302            as_int16_array
303        );
304
305        eval_array_legacy_mode!(
306            Int32Array::from(vec![Some(-1), Some(i32::MIN), Some(i32::MAX), None]),
307            Int32Array::from(vec![Some(1), Some(i32::MIN), Some(i32::MAX), None]),
308            as_int32_array
309        );
310
311        eval_array_legacy_mode!(
312            Int64Array::from(vec![Some(-1), Some(i64::MIN), Some(i64::MAX), None]),
313            Int64Array::from(vec![Some(1), Some(i64::MIN), Some(i64::MAX), None]),
314            as_int64_array
315        );
316
317        eval_array_legacy_mode!(
318            Float32Array::from(vec![
319                Some(-1f32),
320                Some(f32::MIN),
321                Some(f32::MAX),
322                None,
323                Some(f32::NAN),
324                Some(f32::INFINITY),
325                Some(f32::NEG_INFINITY),
326                Some(0.0),
327                Some(-0.0),
328            ]),
329            Float32Array::from(vec![
330                Some(1f32),
331                Some(f32::MAX),
332                Some(f32::MAX),
333                None,
334                Some(f32::NAN),
335                Some(f32::INFINITY),
336                Some(f32::INFINITY),
337                Some(0.0),
338                Some(0.0),
339            ]),
340            as_float32_array
341        );
342
343        eval_array_legacy_mode!(
344            Float64Array::from(vec![
345                Some(-1f64),
346                Some(f64::MIN),
347                Some(f64::MAX),
348                None,
349                Some(f64::NAN),
350                Some(f64::INFINITY),
351                Some(f64::NEG_INFINITY),
352                Some(0.0),
353                Some(-0.0),
354            ]),
355            Float64Array::from(vec![
356                Some(1f64),
357                Some(f64::MAX),
358                Some(f64::MAX),
359                None,
360                Some(f64::NAN),
361                Some(f64::INFINITY),
362                Some(f64::INFINITY),
363                Some(0.0),
364                Some(0.0),
365            ]),
366            as_float64_array
367        );
368
369        eval_array_legacy_mode!(
370            Decimal128Array::from(vec![Some(i128::MIN), None])
371                .with_precision_and_scale(38, 37)
372                .unwrap(),
373            Decimal128Array::from(vec![Some(i128::MIN), None])
374                .with_precision_and_scale(38, 37)
375                .unwrap(),
376            as_decimal128_array
377        );
378
379        eval_array_legacy_mode!(
380            Decimal256Array::from(vec![Some(i256::MIN), None])
381                .with_precision_and_scale(5, 2)
382                .unwrap(),
383            Decimal256Array::from(vec![Some(i256::MIN), None])
384                .with_precision_and_scale(5, 2)
385                .unwrap(),
386            as_decimal256_array
387        );
388    }
389
390    #[test]
391    fn test_abs_nullability() {
392        use arrow::datatypes::{DataType, Field};
393        use datafusion_expr::ReturnFieldArgs;
394        use std::sync::Arc;
395
396        let abs = SparkAbs::new();
397
398        // --- non-nullable Int32 input ---
399        let non_nullable_i32 = Arc::new(Field::new("c", DataType::Int32, false));
400        let out_non_null = abs
401            .return_field_from_args(ReturnFieldArgs {
402                arg_fields: &[Arc::clone(&non_nullable_i32)],
403                scalar_arguments: &[None],
404            })
405            .unwrap();
406
407        // result should be non-nullable and the same DataType as input
408        assert!(!out_non_null.is_nullable());
409        assert_eq!(out_non_null.data_type(), &DataType::Int32);
410
411        // --- nullable Int32 input ---
412        let nullable_i32 = Arc::new(Field::new("c", DataType::Int32, true));
413        let out_nullable = abs
414            .return_field_from_args(ReturnFieldArgs {
415                arg_fields: &[Arc::clone(&nullable_i32)],
416                scalar_arguments: &[None],
417            })
418            .unwrap();
419
420        // result should be nullable and the same DataType as input
421        assert!(out_nullable.is_nullable());
422        assert_eq!(out_nullable.data_type(), &DataType::Int32);
423
424        // --- non-nullable Float64 input ---
425        let non_nullable_f64 = Arc::new(Field::new("c", DataType::Float64, false));
426        let out_f64 = abs
427            .return_field_from_args(ReturnFieldArgs {
428                arg_fields: &[Arc::clone(&non_nullable_f64)],
429                scalar_arguments: &[None],
430            })
431            .unwrap();
432
433        assert!(!out_f64.is_nullable());
434        assert_eq!(out_f64.data_type(), &DataType::Float64);
435
436        // --- nullable Float64 input ---
437        let nullable_f64 = Arc::new(Field::new("c", DataType::Float64, true));
438        let out_f64_null = abs
439            .return_field_from_args(ReturnFieldArgs {
440                arg_fields: &[Arc::clone(&nullable_f64)],
441                scalar_arguments: &[None],
442            })
443            .unwrap();
444
445        assert!(out_f64_null.is_nullable());
446        assert_eq!(out_f64_null.data_type(), &DataType::Float64);
447    }
448}