Skip to main content

datafusion_functions/math/
abs.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! math expressions
19
20use std::sync::Arc;
21
22use arrow::array::{
23    ArrayRef, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array,
24    Float16Array, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array,
25    Int64Array,
26};
27use arrow::datatypes::DataType;
28use arrow::error::ArrowError;
29use datafusion_common::{Result, not_impl_err, utils::take_function_args};
30use datafusion_expr::interval_arithmetic::Interval;
31use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
32use datafusion_expr::{
33    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34    Volatility,
35};
36use datafusion_macros::user_doc;
37use num_traits::sign::Signed;
38
39type MathArrayFunction = fn(&ArrayRef) -> Result<ArrayRef>;
40
41#[macro_export]
42macro_rules! make_abs_function {
43    ($ARRAY_TYPE:ident) => {{
44        |input: &ArrayRef| {
45            let array = downcast_named_arg!(&input, "abs arg", $ARRAY_TYPE);
46            let res: $ARRAY_TYPE = array.unary(|x| x.abs());
47            Ok(Arc::new(res) as ArrayRef)
48        }
49    }};
50}
51
52#[macro_export]
53macro_rules! make_try_abs_function {
54    ($ARRAY_TYPE:ident) => {{
55        |input: &ArrayRef| {
56            let array = downcast_named_arg!(&input, "abs arg", $ARRAY_TYPE);
57            let res: $ARRAY_TYPE = array.try_unary(|x| {
58                x.checked_abs().ok_or_else(|| {
59                    ArrowError::ComputeError(format!(
60                        "{} overflow on abs({})",
61                        stringify!($ARRAY_TYPE),
62                        x
63                    ))
64                })
65            })
66            .and_then(|v| Ok(v.with_data_type(input.data_type().clone())))?; // maintain decimal's precision and scale
67            Ok(Arc::new(res) as ArrayRef)
68        }
69    }};
70}
71
72#[macro_export]
73macro_rules! make_wrapping_abs_function {
74    ($ARRAY_TYPE:ident) => {{
75        |input: &ArrayRef| {
76            let array = downcast_named_arg!(&input, "abs arg", $ARRAY_TYPE);
77            let res: $ARRAY_TYPE = array
78                .unary(|x| x.wrapping_abs())
79                .with_data_type(input.data_type().clone());
80            Ok(Arc::new(res) as ArrayRef)
81        }
82    }};
83}
84
85/// Abs SQL function
86/// Return different implementations based on input datatype to reduce branches during execution
87fn create_abs_function(input_data_type: &DataType) -> Result<MathArrayFunction> {
88    match input_data_type {
89        DataType::Float16 => Ok(make_abs_function!(Float16Array)),
90        DataType::Float32 => Ok(make_abs_function!(Float32Array)),
91        DataType::Float64 => Ok(make_abs_function!(Float64Array)),
92
93        // Types that may overflow, such as abs(-128_i8).
94        DataType::Int8 => Ok(make_try_abs_function!(Int8Array)),
95        DataType::Int16 => Ok(make_try_abs_function!(Int16Array)),
96        DataType::Int32 => Ok(make_try_abs_function!(Int32Array)),
97        DataType::Int64 => Ok(make_try_abs_function!(Int64Array)),
98
99        // Types of results are the same as the input.
100        DataType::Null
101        | DataType::UInt8
102        | DataType::UInt16
103        | DataType::UInt32
104        | DataType::UInt64 => Ok(|input: &ArrayRef| Ok(Arc::clone(input))),
105
106        // Decimal types
107        DataType::Decimal32(_, _) => Ok(make_wrapping_abs_function!(Decimal32Array)),
108        DataType::Decimal64(_, _) => Ok(make_wrapping_abs_function!(Decimal64Array)),
109        DataType::Decimal128(_, _) => Ok(make_wrapping_abs_function!(Decimal128Array)),
110        DataType::Decimal256(_, _) => Ok(make_wrapping_abs_function!(Decimal256Array)),
111
112        other => not_impl_err!("Unsupported data type {other:?} for function abs"),
113    }
114}
115#[user_doc(
116    doc_section(label = "Math Functions"),
117    description = "Returns the absolute value of a number.",
118    syntax_example = "abs(numeric_expression)",
119    sql_example = r#"```sql
120> SELECT abs(-5);
121+----------+
122| abs(-5)  |
123+----------+
124| 5        |
125+----------+
126```"#,
127    standard_argument(name = "numeric_expression", prefix = "Numeric")
128)]
129#[derive(Debug, PartialEq, Eq, Hash)]
130pub struct AbsFunc {
131    signature: Signature,
132}
133
134impl Default for AbsFunc {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140impl AbsFunc {
141    pub fn new() -> Self {
142        Self {
143            signature: Signature::numeric(1, Volatility::Immutable),
144        }
145    }
146}
147
148impl ScalarUDFImpl for AbsFunc {
149    fn name(&self) -> &str {
150        "abs"
151    }
152
153    fn signature(&self) -> &Signature {
154        &self.signature
155    }
156
157    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
158        Ok(arg_types[0].clone())
159    }
160
161    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
162        let args = ColumnarValue::values_to_arrays(&args.args)?;
163        let [input] = take_function_args(self.name(), args)?;
164
165        let input_data_type = input.data_type();
166        let abs_fun = create_abs_function(input_data_type)?;
167
168        abs_fun(&input).map(ColumnarValue::Array)
169    }
170
171    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
172        // Non-decreasing for x ≥ 0 and symmetrically non-increasing for x ≤ 0.
173        let arg = &input[0];
174        let range = &arg.range;
175        let zero_point = Interval::make_zero(&range.lower().data_type())?;
176
177        if range.gt_eq(&zero_point)? == Interval::TRUE {
178            Ok(arg.sort_properties)
179        } else if range.lt_eq(&zero_point)? == Interval::TRUE {
180            Ok(-arg.sort_properties)
181        } else {
182            Ok(SortProperties::Unordered)
183        }
184    }
185
186    fn documentation(&self) -> Option<&Documentation> {
187        self.doc()
188    }
189}