Skip to main content

datafusion_functions/math/
iszero.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{ArrowNativeTypeOp, AsArray, BooleanArray};
21use arrow::datatypes::DataType::{
22    Boolean, Decimal32, Decimal64, Decimal128, Decimal256, Float16, Float32, Float64,
23    Int8, Int16, Int32, Int64, Null, UInt8, UInt16, UInt32, UInt64,
24};
25use arrow::datatypes::{
26    DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type,
27    Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type,
28    UInt16Type, UInt32Type, UInt64Type,
29};
30
31use datafusion_common::utils::take_function_args;
32use datafusion_common::{Result, ScalarValue, internal_err};
33use datafusion_expr::{Coercion, TypeSignatureClass};
34use datafusion_expr::{
35    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
36    Volatility,
37};
38use datafusion_macros::user_doc;
39
40#[user_doc(
41    doc_section(label = "Math Functions"),
42    description = "Returns true if a given number is +0.0 or -0.0 otherwise returns false.",
43    syntax_example = "iszero(numeric_expression)",
44    sql_example = r#"```sql
45> SELECT iszero(0);
46+------------+
47| iszero(0)  |
48+------------+
49| true       |
50+------------+
51```"#,
52    standard_argument(name = "numeric_expression", prefix = "Numeric")
53)]
54#[derive(Debug, PartialEq, Eq, Hash)]
55pub struct IsZeroFunc {
56    signature: Signature,
57}
58
59impl Default for IsZeroFunc {
60    fn default() -> Self {
61        IsZeroFunc::new()
62    }
63}
64
65impl IsZeroFunc {
66    pub fn new() -> Self {
67        // Accept any numeric type (ints, uints, floats, decimals) without implicit casts.
68        let numeric = Coercion::new_exact(TypeSignatureClass::Numeric);
69        Self {
70            signature: Signature::coercible(vec![numeric], Volatility::Immutable),
71        }
72    }
73}
74
75impl ScalarUDFImpl for IsZeroFunc {
76    fn name(&self) -> &str {
77        "iszero"
78    }
79
80    fn signature(&self) -> &Signature {
81        &self.signature
82    }
83
84    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
85        Ok(Boolean)
86    }
87
88    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
89        let [arg] = take_function_args(self.name(), args.args)?;
90
91        match arg {
92            ColumnarValue::Scalar(scalar) => {
93                if scalar.is_null() {
94                    return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
95                }
96
97                match scalar {
98                    ScalarValue::Float64(Some(v)) => {
99                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0.0))))
100                    }
101                    ScalarValue::Float32(Some(v)) => {
102                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0.0))))
103                    }
104                    ScalarValue::Float16(Some(v)) => Ok(ColumnarValue::Scalar(
105                        ScalarValue::Boolean(Some(v.is_zero())),
106                    )),
107
108                    ScalarValue::Int8(Some(v)) => {
109                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
110                    }
111                    ScalarValue::Int16(Some(v)) => {
112                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
113                    }
114                    ScalarValue::Int32(Some(v)) => {
115                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
116                    }
117                    ScalarValue::Int64(Some(v)) => {
118                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
119                    }
120                    ScalarValue::UInt8(Some(v)) => {
121                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
122                    }
123                    ScalarValue::UInt16(Some(v)) => {
124                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
125                    }
126                    ScalarValue::UInt32(Some(v)) => {
127                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
128                    }
129                    ScalarValue::UInt64(Some(v)) => {
130                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
131                    }
132
133                    ScalarValue::Decimal32(Some(v), ..) => {
134                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
135                    }
136                    ScalarValue::Decimal64(Some(v), ..) => {
137                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
138                    }
139                    ScalarValue::Decimal128(Some(v), ..) => {
140                        Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
141                    }
142                    ScalarValue::Decimal256(Some(v), ..) => Ok(ColumnarValue::Scalar(
143                        ScalarValue::Boolean(Some(v.is_zero())),
144                    )),
145
146                    _ => {
147                        internal_err!(
148                            "Unexpected scalar type for iszero: {:?}",
149                            scalar.data_type()
150                        )
151                    }
152                }
153            }
154            ColumnarValue::Array(array) => match array.data_type() {
155                Null => Ok(ColumnarValue::Array(Arc::new(BooleanArray::new_null(
156                    array.len(),
157                )))),
158
159                Float64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
160                    array.as_primitive::<Float64Type>(),
161                    |x| x == 0.0,
162                )))),
163                Float32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
164                    array.as_primitive::<Float32Type>(),
165                    |x| x == 0.0,
166                )))),
167                Float16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
168                    array.as_primitive::<Float16Type>(),
169                    |x| x.is_zero(),
170                )))),
171
172                Int8 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
173                    array.as_primitive::<Int8Type>(),
174                    |x| x == 0,
175                )))),
176                Int16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
177                    array.as_primitive::<Int16Type>(),
178                    |x| x == 0,
179                )))),
180                Int32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
181                    array.as_primitive::<Int32Type>(),
182                    |x| x == 0,
183                )))),
184                Int64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
185                    array.as_primitive::<Int64Type>(),
186                    |x| x == 0,
187                )))),
188                UInt8 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
189                    array.as_primitive::<UInt8Type>(),
190                    |x| x == 0,
191                )))),
192                UInt16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
193                    array.as_primitive::<UInt16Type>(),
194                    |x| x == 0,
195                )))),
196                UInt32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
197                    array.as_primitive::<UInt32Type>(),
198                    |x| x == 0,
199                )))),
200                UInt64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
201                    array.as_primitive::<UInt64Type>(),
202                    |x| x == 0,
203                )))),
204
205                Decimal32(_, _) => {
206                    Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
207                        array.as_primitive::<Decimal32Type>(),
208                        |x| x == 0,
209                    ))))
210                }
211                Decimal64(_, _) => {
212                    Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
213                        array.as_primitive::<Decimal64Type>(),
214                        |x| x == 0,
215                    ))))
216                }
217                Decimal128(_, _) => {
218                    Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
219                        array.as_primitive::<Decimal128Type>(),
220                        |x| x == 0,
221                    ))))
222                }
223                Decimal256(_, _) => {
224                    Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
225                        array.as_primitive::<Decimal256Type>(),
226                        |x| x.is_zero(),
227                    ))))
228                }
229
230                other => {
231                    internal_err!("Unexpected data type {other:?} for function iszero")
232                }
233            },
234        }
235    }
236
237    fn documentation(&self) -> Option<&Documentation> {
238        self.doc()
239    }
240}