Skip to main content

datafusion_functions/math/
iszero.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrowNativeTypeOp, AsArray, BooleanArray};
22use arrow::datatypes::DataType::{
23    Boolean, Decimal32, Decimal64, Decimal128, Decimal256, Float16, Float32, Float64,
24    Int8, Int16, Int32, Int64, Null, UInt8, UInt16, UInt32, UInt64,
25};
26use arrow::datatypes::{
27    DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type,
28    Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type,
29    UInt16Type, UInt32Type, UInt64Type,
30};
31
32use datafusion_common::utils::take_function_args;
33use datafusion_common::{Result, ScalarValue, internal_err};
34use datafusion_expr::{Coercion, TypeSignatureClass};
35use datafusion_expr::{
36    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
37    Volatility,
38};
39use datafusion_macros::user_doc;
40
41#[user_doc(
42    doc_section(label = "Math Functions"),
43    description = "Returns true if a given number is +0.0 or -0.0 otherwise returns false.",
44    syntax_example = "iszero(numeric_expression)",
45    sql_example = r#"```sql
46> SELECT iszero(0);
47+------------+
48| iszero(0)  |
49+------------+
50| true       |
51+------------+
52```"#,
53    standard_argument(name = "numeric_expression", prefix = "Numeric")
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct IsZeroFunc {
57    signature: Signature,
58}
59
60impl Default for IsZeroFunc {
61    fn default() -> Self {
62        IsZeroFunc::new()
63    }
64}
65
66impl IsZeroFunc {
67    pub fn new() -> Self {
68        // Accept any numeric type (ints, uints, floats, decimals) without implicit casts.
69        let numeric = Coercion::new_exact(TypeSignatureClass::Numeric);
70        Self {
71            signature: Signature::coercible(vec![numeric], Volatility::Immutable),
72        }
73    }
74}
75
76impl ScalarUDFImpl for IsZeroFunc {
77    fn as_any(&self) -> &dyn Any {
78        self
79    }
80
81    fn name(&self) -> &str {
82        "iszero"
83    }
84
85    fn signature(&self) -> &Signature {
86        &self.signature
87    }
88
89    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
90        Ok(Boolean)
91    }
92
93    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
94        let [arg] = take_function_args(self.name(), args.args)?;
95
96        match arg {
97            ColumnarValue::Scalar(scalar) => {
98                if scalar.is_null() {
99                    return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
100                }
101
102                match scalar {
103                    ScalarValue::Float64(Some(v)) => {
104                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0.0))))
105                    }
106                    ScalarValue::Float32(Some(v)) => {
107                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0.0))))
108                    }
109                    ScalarValue::Float16(Some(v)) => Ok(ColumnarValue::Scalar(
110                        ScalarValue::Boolean(Some(v.is_zero())),
111                    )),
112
113                    ScalarValue::Int8(Some(v)) => {
114                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
115                    }
116                    ScalarValue::Int16(Some(v)) => {
117                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
118                    }
119                    ScalarValue::Int32(Some(v)) => {
120                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
121                    }
122                    ScalarValue::Int64(Some(v)) => {
123                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
124                    }
125                    ScalarValue::UInt8(Some(v)) => {
126                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
127                    }
128                    ScalarValue::UInt16(Some(v)) => {
129                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
130                    }
131                    ScalarValue::UInt32(Some(v)) => {
132                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
133                    }
134                    ScalarValue::UInt64(Some(v)) => {
135                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
136                    }
137
138                    ScalarValue::Decimal32(Some(v), ..) => {
139                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
140                    }
141                    ScalarValue::Decimal64(Some(v), ..) => {
142                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
143                    }
144                    ScalarValue::Decimal128(Some(v), ..) => {
145                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
146                    }
147                    ScalarValue::Decimal256(Some(v), ..) => Ok(ColumnarValue::Scalar(
148                        ScalarValue::Boolean(Some(v.is_zero())),
149                    )),
150
151                    _ => {
152                        internal_err!(
153                            "Unexpected scalar type for iszero: {:?}",
154                            scalar.data_type()
155                        )
156                    }
157                }
158            }
159            ColumnarValue::Array(array) => match array.data_type() {
160                Null => Ok(ColumnarValue::Array(Arc::new(BooleanArray::new_null(
161                    array.len(),
162                )))),
163
164                Float64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
165                    array.as_primitive::<Float64Type>(),
166                    |x| x == 0.0,
167                )))),
168                Float32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
169                    array.as_primitive::<Float32Type>(),
170                    |x| x == 0.0,
171                )))),
172                Float16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
173                    array.as_primitive::<Float16Type>(),
174                    |x| x.is_zero(),
175                )))),
176
177                Int8 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
178                    array.as_primitive::<Int8Type>(),
179                    |x| x == 0,
180                )))),
181                Int16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
182                    array.as_primitive::<Int16Type>(),
183                    |x| x == 0,
184                )))),
185                Int32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
186                    array.as_primitive::<Int32Type>(),
187                    |x| x == 0,
188                )))),
189                Int64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
190                    array.as_primitive::<Int64Type>(),
191                    |x| x == 0,
192                )))),
193                UInt8 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
194                    array.as_primitive::<UInt8Type>(),
195                    |x| x == 0,
196                )))),
197                UInt16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
198                    array.as_primitive::<UInt16Type>(),
199                    |x| x == 0,
200                )))),
201                UInt32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
202                    array.as_primitive::<UInt32Type>(),
203                    |x| x == 0,
204                )))),
205                UInt64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
206                    array.as_primitive::<UInt64Type>(),
207                    |x| x == 0,
208                )))),
209
210                Decimal32(_, _) => {
211                    Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
212                        array.as_primitive::<Decimal32Type>(),
213                        |x| x == 0,
214                    ))))
215                }
216                Decimal64(_, _) => {
217                    Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
218                        array.as_primitive::<Decimal64Type>(),
219                        |x| x == 0,
220                    ))))
221                }
222                Decimal128(_, _) => {
223                    Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
224                        array.as_primitive::<Decimal128Type>(),
225                        |x| x == 0,
226                    ))))
227                }
228                Decimal256(_, _) => {
229                    Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
230                        array.as_primitive::<Decimal256Type>(),
231                        |x| x.is_zero(),
232                    ))))
233                }
234
235                other => {
236                    internal_err!("Unexpected data type {other:?} for function iszero")
237                }
238            },
239        }
240    }
241
242    fn documentation(&self) -> Option<&Documentation> {
243        self.doc()
244    }
245}