datafusion_functions/math/
signum.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::AsArray;
22use arrow::datatypes::DataType::{Float32, Float64};
23use arrow::datatypes::{DataType, Float32Type, Float64Type};
24
25use datafusion_common::utils::take_function_args;
26use datafusion_common::{Result, ScalarValue, internal_err};
27use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
28use datafusion_expr::{
29 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30 Volatility,
31};
32use datafusion_macros::user_doc;
33
34#[user_doc(
35 doc_section(label = "Math Functions"),
36 description = r#"Returns the sign of a number.
37Negative numbers return `-1`.
38Zero and positive numbers return `1`."#,
39 syntax_example = "signum(numeric_expression)",
40 standard_argument(name = "numeric_expression", prefix = "Numeric"),
41 sql_example = r#"```sql
42> SELECT signum(-42);
43+-------------+
44| signum(-42) |
45+-------------+
46| -1 |
47+-------------+
48```"#
49)]
50#[derive(Debug, PartialEq, Eq, Hash)]
51pub struct SignumFunc {
52 signature: Signature,
53}
54
55impl Default for SignumFunc {
56 fn default() -> Self {
57 SignumFunc::new()
58 }
59}
60
61impl SignumFunc {
62 pub fn new() -> Self {
63 use DataType::*;
64 Self {
65 signature: Signature::uniform(
66 1,
67 vec![Float64, Float32],
68 Volatility::Immutable,
69 ),
70 }
71 }
72}
73
74impl ScalarUDFImpl for SignumFunc {
75 fn as_any(&self) -> &dyn Any {
76 self
77 }
78
79 fn name(&self) -> &str {
80 "signum"
81 }
82
83 fn signature(&self) -> &Signature {
84 &self.signature
85 }
86
87 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
88 match &arg_types[0] {
89 Float32 => Ok(Float32),
90 _ => Ok(Float64),
91 }
92 }
93
94 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
95 Ok(input[0].sort_properties)
97 }
98
99 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
100 let return_type = args.return_type().clone();
101 let [arg] = take_function_args(self.name(), args.args)?;
102
103 match arg {
104 ColumnarValue::Scalar(scalar) => {
105 if scalar.is_null() {
106 return ColumnarValue::Scalar(ScalarValue::Null)
107 .cast_to(&return_type, None);
108 }
109
110 match scalar {
111 ScalarValue::Float64(Some(v)) => {
112 let result = if v == 0.0 { 0.0 } else { v.signum() };
113 Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result))))
114 }
115 ScalarValue::Float32(Some(v)) => {
116 let result = if v == 0.0 { 0.0 } else { v.signum() };
117 Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(result))))
118 }
119 _ => {
120 internal_err!(
121 "Unexpected scalar type for signum: {:?}",
122 scalar.data_type()
123 )
124 }
125 }
126 }
127 ColumnarValue::Array(array) => match array.data_type() {
128 Float64 => Ok(ColumnarValue::Array(Arc::new(
129 array.as_primitive::<Float64Type>().unary::<_, Float64Type>(
130 |x: f64| {
131 if x == 0.0 { 0.0 } else { x.signum() }
132 },
133 ),
134 ))),
135 Float32 => Ok(ColumnarValue::Array(Arc::new(
136 array.as_primitive::<Float32Type>().unary::<_, Float32Type>(
137 |x: f32| {
138 if x == 0.0 { 0.0 } else { x.signum() }
139 },
140 ),
141 ))),
142 other => {
143 internal_err!("Unsupported data type {other:?} for function signum")
144 }
145 },
146 }
147 }
148
149 fn documentation(&self) -> Option<&Documentation> {
150 self.doc()
151 }
152}
153
154#[cfg(test)]
155mod test {
156 use std::sync::Arc;
157
158 use arrow::array::{ArrayRef, Float32Array, Float64Array};
159 use arrow::datatypes::{DataType, Field};
160 use datafusion_common::cast::{as_float32_array, as_float64_array};
161 use datafusion_common::config::ConfigOptions;
162 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
163
164 use crate::math::signum::SignumFunc;
165
166 #[test]
167 fn test_signum_f32() {
168 let array = Arc::new(Float32Array::from(vec![
169 -1.0,
170 -0.0,
171 0.0,
172 1.0,
173 -0.01,
174 0.01,
175 f32::NAN,
176 f32::INFINITY,
177 f32::NEG_INFINITY,
178 ]));
179 let arg_fields = vec![Field::new("a", DataType::Float32, false).into()];
180 let args = ScalarFunctionArgs {
181 args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
182 arg_fields,
183 number_rows: array.len(),
184 return_field: Field::new("f", DataType::Float32, true).into(),
185 config_options: Arc::new(ConfigOptions::default()),
186 };
187 let result = SignumFunc::new()
188 .invoke_with_args(args)
189 .expect("failed to initialize function signum");
190
191 match result {
192 ColumnarValue::Array(arr) => {
193 let floats = as_float32_array(&arr)
194 .expect("failed to convert result to a Float32Array");
195
196 assert_eq!(floats.len(), 9);
197 assert_eq!(floats.value(0), -1.0);
198 assert_eq!(floats.value(1), 0.0);
199 assert_eq!(floats.value(2), 0.0);
200 assert_eq!(floats.value(3), 1.0);
201 assert_eq!(floats.value(4), -1.0);
202 assert_eq!(floats.value(5), 1.0);
203 assert!(floats.value(6).is_nan());
204 assert_eq!(floats.value(7), 1.0);
205 assert_eq!(floats.value(8), -1.0);
206 }
207 ColumnarValue::Scalar(_) => {
208 panic!("Expected an array value")
209 }
210 }
211 }
212
213 #[test]
214 fn test_signum_f64() {
215 let array = Arc::new(Float64Array::from(vec![
216 -1.0,
217 -0.0,
218 0.0,
219 1.0,
220 -0.01,
221 0.01,
222 f64::NAN,
223 f64::INFINITY,
224 f64::NEG_INFINITY,
225 ]));
226 let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
227 let args = ScalarFunctionArgs {
228 args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
229 arg_fields,
230 number_rows: array.len(),
231 return_field: Field::new("f", DataType::Float64, true).into(),
232 config_options: Arc::new(ConfigOptions::default()),
233 };
234 let result = SignumFunc::new()
235 .invoke_with_args(args)
236 .expect("failed to initialize function signum");
237
238 match result {
239 ColumnarValue::Array(arr) => {
240 let floats = as_float64_array(&arr)
241 .expect("failed to convert result to a Float32Array");
242
243 assert_eq!(floats.len(), 9);
244 assert_eq!(floats.value(0), -1.0);
245 assert_eq!(floats.value(1), 0.0);
246 assert_eq!(floats.value(2), 0.0);
247 assert_eq!(floats.value(3), 1.0);
248 assert_eq!(floats.value(4), -1.0);
249 assert_eq!(floats.value(5), 1.0);
250 assert!(floats.value(6).is_nan());
251 assert_eq!(floats.value(7), 1.0);
252 assert_eq!(floats.value(8), -1.0);
253 }
254 ColumnarValue::Scalar(_) => {
255 panic!("Expected an array value")
256 }
257 }
258 }
259}