Skip to main content

datafusion_functions/math/
signum.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::AsArray;
21use arrow::datatypes::DataType::{Float32, Float64};
22use arrow::datatypes::{DataType, Float32Type, Float64Type};
23
24use datafusion_common::utils::take_function_args;
25use datafusion_common::{Result, ScalarValue, internal_err};
26use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
27use datafusion_expr::{
28    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
29    Volatility,
30};
31use datafusion_macros::user_doc;
32
33#[user_doc(
34    doc_section(label = "Math Functions"),
35    description = r#"Returns the sign of a number.
36Negative numbers return `-1`.
37Zero and positive numbers return `1`."#,
38    syntax_example = "signum(numeric_expression)",
39    standard_argument(name = "numeric_expression", prefix = "Numeric"),
40    sql_example = r#"```sql
41> SELECT signum(-42);
42+-------------+
43| signum(-42) |
44+-------------+
45| -1          |
46+-------------+
47```"#
48)]
49#[derive(Debug, PartialEq, Eq, Hash)]
50pub struct SignumFunc {
51    signature: Signature,
52}
53
54impl Default for SignumFunc {
55    fn default() -> Self {
56        SignumFunc::new()
57    }
58}
59
60impl SignumFunc {
61    pub fn new() -> Self {
62        use DataType::*;
63        Self {
64            signature: Signature::uniform(
65                1,
66                vec![Float64, Float32],
67                Volatility::Immutable,
68            ),
69        }
70    }
71}
72
73impl ScalarUDFImpl for SignumFunc {
74    fn name(&self) -> &str {
75        "signum"
76    }
77
78    fn signature(&self) -> &Signature {
79        &self.signature
80    }
81
82    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
83        match &arg_types[0] {
84            Float32 => Ok(Float32),
85            _ => Ok(Float64),
86        }
87    }
88
89    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
90        // Non-decreasing for all real numbers x.
91        Ok(input[0].sort_properties)
92    }
93
94    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
95        let return_type = args.return_type().clone();
96        let [arg] = take_function_args(self.name(), args.args)?;
97
98        match arg {
99            ColumnarValue::Scalar(scalar) => {
100                if scalar.is_null() {
101                    return ColumnarValue::Scalar(ScalarValue::Null)
102                        .cast_to(&return_type, None);
103                }
104
105                match scalar {
106                    ScalarValue::Float64(Some(v)) => {
107                        let result = if v == 0.0 { 0.0 } else { v.signum() };
108                        Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result))))
109                    }
110                    ScalarValue::Float32(Some(v)) => {
111                        let result = if v == 0.0 { 0.0 } else { v.signum() };
112                        Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(result))))
113                    }
114                    _ => {
115                        internal_err!(
116                            "Unexpected scalar type for signum: {:?}",
117                            scalar.data_type()
118                        )
119                    }
120                }
121            }
122            ColumnarValue::Array(array) => match array.data_type() {
123                Float64 => Ok(ColumnarValue::Array(Arc::new(
124                    array.as_primitive::<Float64Type>().unary::<_, Float64Type>(
125                        |x: f64| {
126                            if x == 0.0 { 0.0 } else { x.signum() }
127                        },
128                    ),
129                ))),
130                Float32 => Ok(ColumnarValue::Array(Arc::new(
131                    array.as_primitive::<Float32Type>().unary::<_, Float32Type>(
132                        |x: f32| {
133                            if x == 0.0 { 0.0 } else { x.signum() }
134                        },
135                    ),
136                ))),
137                other => {
138                    internal_err!("Unsupported data type {other:?} for function signum")
139                }
140            },
141        }
142    }
143
144    fn documentation(&self) -> Option<&Documentation> {
145        self.doc()
146    }
147}
148
149#[cfg(test)]
150mod test {
151    use std::sync::Arc;
152
153    use arrow::array::{ArrayRef, Float32Array, Float64Array};
154    use arrow::datatypes::{DataType, Field};
155    use datafusion_common::cast::{as_float32_array, as_float64_array};
156    use datafusion_common::config::ConfigOptions;
157    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
158
159    use crate::math::signum::SignumFunc;
160
161    #[test]
162    fn test_signum_f32() {
163        let array = Arc::new(Float32Array::from(vec![
164            -1.0,
165            -0.0,
166            0.0,
167            1.0,
168            -0.01,
169            0.01,
170            f32::NAN,
171            f32::INFINITY,
172            f32::NEG_INFINITY,
173        ]));
174        let arg_fields = vec![Field::new("a", DataType::Float32, false).into()];
175        let args = ScalarFunctionArgs {
176            args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
177            arg_fields,
178            number_rows: array.len(),
179            return_field: Field::new("f", DataType::Float32, true).into(),
180            config_options: Arc::new(ConfigOptions::default()),
181        };
182        let result = SignumFunc::new()
183            .invoke_with_args(args)
184            .expect("failed to initialize function signum");
185
186        match result {
187            ColumnarValue::Array(arr) => {
188                let floats = as_float32_array(&arr)
189                    .expect("failed to convert result to a Float32Array");
190
191                assert_eq!(floats.len(), 9);
192                assert_eq!(floats.value(0), -1.0);
193                assert_eq!(floats.value(1), 0.0);
194                assert_eq!(floats.value(2), 0.0);
195                assert_eq!(floats.value(3), 1.0);
196                assert_eq!(floats.value(4), -1.0);
197                assert_eq!(floats.value(5), 1.0);
198                assert!(floats.value(6).is_nan());
199                assert_eq!(floats.value(7), 1.0);
200                assert_eq!(floats.value(8), -1.0);
201            }
202            ColumnarValue::Scalar(_) => {
203                panic!("Expected an array value")
204            }
205        }
206    }
207
208    #[test]
209    fn test_signum_f64() {
210        let array = Arc::new(Float64Array::from(vec![
211            -1.0,
212            -0.0,
213            0.0,
214            1.0,
215            -0.01,
216            0.01,
217            f64::NAN,
218            f64::INFINITY,
219            f64::NEG_INFINITY,
220        ]));
221        let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
222        let args = ScalarFunctionArgs {
223            args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
224            arg_fields,
225            number_rows: array.len(),
226            return_field: Field::new("f", DataType::Float64, true).into(),
227            config_options: Arc::new(ConfigOptions::default()),
228        };
229        let result = SignumFunc::new()
230            .invoke_with_args(args)
231            .expect("failed to initialize function signum");
232
233        match result {
234            ColumnarValue::Array(arr) => {
235                let floats = as_float64_array(&arr)
236                    .expect("failed to convert result to a Float32Array");
237
238                assert_eq!(floats.len(), 9);
239                assert_eq!(floats.value(0), -1.0);
240                assert_eq!(floats.value(1), 0.0);
241                assert_eq!(floats.value(2), 0.0);
242                assert_eq!(floats.value(3), 1.0);
243                assert_eq!(floats.value(4), -1.0);
244                assert_eq!(floats.value(5), 1.0);
245                assert!(floats.value(6).is_nan());
246                assert_eq!(floats.value(7), 1.0);
247                assert_eq!(floats.value(8), -1.0);
248            }
249            ColumnarValue::Scalar(_) => {
250                panic!("Expected an array value")
251            }
252        }
253    }
254}