datafusion_functions/math/
signum.rs1use std::sync::Arc;
19
20use arrow::array::AsArray;
21use arrow::datatypes::DataType::{Float32, Float64};
22use arrow::datatypes::{DataType, Float32Type, Float64Type};
23
24use datafusion_common::utils::take_function_args;
25use datafusion_common::{Result, ScalarValue, internal_err};
26use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
27use datafusion_expr::{
28 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
29 Volatility,
30};
31use datafusion_macros::user_doc;
32
33#[user_doc(
34 doc_section(label = "Math Functions"),
35 description = r#"Returns the sign of a number.
36Negative numbers return `-1`.
37Zero and positive numbers return `1`."#,
38 syntax_example = "signum(numeric_expression)",
39 standard_argument(name = "numeric_expression", prefix = "Numeric"),
40 sql_example = r#"```sql
41> SELECT signum(-42);
42+-------------+
43| signum(-42) |
44+-------------+
45| -1 |
46+-------------+
47```"#
48)]
49#[derive(Debug, PartialEq, Eq, Hash)]
50pub struct SignumFunc {
51 signature: Signature,
52}
53
54impl Default for SignumFunc {
55 fn default() -> Self {
56 SignumFunc::new()
57 }
58}
59
60impl SignumFunc {
61 pub fn new() -> Self {
62 use DataType::*;
63 Self {
64 signature: Signature::uniform(
65 1,
66 vec![Float64, Float32],
67 Volatility::Immutable,
68 ),
69 }
70 }
71}
72
73impl ScalarUDFImpl for SignumFunc {
74 fn name(&self) -> &str {
75 "signum"
76 }
77
78 fn signature(&self) -> &Signature {
79 &self.signature
80 }
81
82 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
83 match &arg_types[0] {
84 Float32 => Ok(Float32),
85 _ => Ok(Float64),
86 }
87 }
88
89 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
90 Ok(input[0].sort_properties)
92 }
93
94 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
95 let return_type = args.return_type().clone();
96 let [arg] = take_function_args(self.name(), args.args)?;
97
98 match arg {
99 ColumnarValue::Scalar(scalar) => {
100 if scalar.is_null() {
101 return ColumnarValue::Scalar(ScalarValue::Null)
102 .cast_to(&return_type, None);
103 }
104
105 match scalar {
106 ScalarValue::Float64(Some(v)) => {
107 let result = if v == 0.0 { 0.0 } else { v.signum() };
108 Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result))))
109 }
110 ScalarValue::Float32(Some(v)) => {
111 let result = if v == 0.0 { 0.0 } else { v.signum() };
112 Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(result))))
113 }
114 _ => {
115 internal_err!(
116 "Unexpected scalar type for signum: {:?}",
117 scalar.data_type()
118 )
119 }
120 }
121 }
122 ColumnarValue::Array(array) => match array.data_type() {
123 Float64 => Ok(ColumnarValue::Array(Arc::new(
124 array.as_primitive::<Float64Type>().unary::<_, Float64Type>(
125 |x: f64| {
126 if x == 0.0 { 0.0 } else { x.signum() }
127 },
128 ),
129 ))),
130 Float32 => Ok(ColumnarValue::Array(Arc::new(
131 array.as_primitive::<Float32Type>().unary::<_, Float32Type>(
132 |x: f32| {
133 if x == 0.0 { 0.0 } else { x.signum() }
134 },
135 ),
136 ))),
137 other => {
138 internal_err!("Unsupported data type {other:?} for function signum")
139 }
140 },
141 }
142 }
143
144 fn documentation(&self) -> Option<&Documentation> {
145 self.doc()
146 }
147}
148
149#[cfg(test)]
150mod test {
151 use std::sync::Arc;
152
153 use arrow::array::{ArrayRef, Float32Array, Float64Array};
154 use arrow::datatypes::{DataType, Field};
155 use datafusion_common::cast::{as_float32_array, as_float64_array};
156 use datafusion_common::config::ConfigOptions;
157 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
158
159 use crate::math::signum::SignumFunc;
160
161 #[test]
162 fn test_signum_f32() {
163 let array = Arc::new(Float32Array::from(vec![
164 -1.0,
165 -0.0,
166 0.0,
167 1.0,
168 -0.01,
169 0.01,
170 f32::NAN,
171 f32::INFINITY,
172 f32::NEG_INFINITY,
173 ]));
174 let arg_fields = vec![Field::new("a", DataType::Float32, false).into()];
175 let args = ScalarFunctionArgs {
176 args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
177 arg_fields,
178 number_rows: array.len(),
179 return_field: Field::new("f", DataType::Float32, true).into(),
180 config_options: Arc::new(ConfigOptions::default()),
181 };
182 let result = SignumFunc::new()
183 .invoke_with_args(args)
184 .expect("failed to initialize function signum");
185
186 match result {
187 ColumnarValue::Array(arr) => {
188 let floats = as_float32_array(&arr)
189 .expect("failed to convert result to a Float32Array");
190
191 assert_eq!(floats.len(), 9);
192 assert_eq!(floats.value(0), -1.0);
193 assert_eq!(floats.value(1), 0.0);
194 assert_eq!(floats.value(2), 0.0);
195 assert_eq!(floats.value(3), 1.0);
196 assert_eq!(floats.value(4), -1.0);
197 assert_eq!(floats.value(5), 1.0);
198 assert!(floats.value(6).is_nan());
199 assert_eq!(floats.value(7), 1.0);
200 assert_eq!(floats.value(8), -1.0);
201 }
202 ColumnarValue::Scalar(_) => {
203 panic!("Expected an array value")
204 }
205 }
206 }
207
208 #[test]
209 fn test_signum_f64() {
210 let array = Arc::new(Float64Array::from(vec![
211 -1.0,
212 -0.0,
213 0.0,
214 1.0,
215 -0.01,
216 0.01,
217 f64::NAN,
218 f64::INFINITY,
219 f64::NEG_INFINITY,
220 ]));
221 let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
222 let args = ScalarFunctionArgs {
223 args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
224 arg_fields,
225 number_rows: array.len(),
226 return_field: Field::new("f", DataType::Float64, true).into(),
227 config_options: Arc::new(ConfigOptions::default()),
228 };
229 let result = SignumFunc::new()
230 .invoke_with_args(args)
231 .expect("failed to initialize function signum");
232
233 match result {
234 ColumnarValue::Array(arr) => {
235 let floats = as_float64_array(&arr)
236 .expect("failed to convert result to a Float32Array");
237
238 assert_eq!(floats.len(), 9);
239 assert_eq!(floats.value(0), -1.0);
240 assert_eq!(floats.value(1), 0.0);
241 assert_eq!(floats.value(2), 0.0);
242 assert_eq!(floats.value(3), 1.0);
243 assert_eq!(floats.value(4), -1.0);
244 assert_eq!(floats.value(5), 1.0);
245 assert!(floats.value(6).is_nan());
246 assert_eq!(floats.value(7), 1.0);
247 assert_eq!(floats.value(8), -1.0);
248 }
249 ColumnarValue::Scalar(_) => {
250 panic!("Expected an array value")
251 }
252 }
253 }
254}