Skip to main content

datafusion_functions/core/
nullif.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::DataType;
19use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs};
20
21use arrow::compute::kernels::nullif::nullif;
22use datafusion_common::{Result, ScalarValue, utils::take_function_args};
23use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
24use datafusion_macros::user_doc;
25use datafusion_physical_expr_common::datum::compare_with_eq;
26
27#[user_doc(
28    doc_section(label = "Conditional Functions"),
29    description = "Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_.
30This can be used to perform the inverse operation of [`coalesce`](#coalesce).",
31    syntax_example = "nullif(expression1, expression2)",
32    sql_example = r#"```sql
33> select nullif('datafusion', 'data');
34+-----------------------------------------+
35| nullif(Utf8("datafusion"),Utf8("data")) |
36+-----------------------------------------+
37| datafusion                              |
38+-----------------------------------------+
39> select nullif('datafusion', 'datafusion');
40+-----------------------------------------------+
41| nullif(Utf8("datafusion"),Utf8("datafusion")) |
42+-----------------------------------------------+
43|                                               |
44+-----------------------------------------------+
45```"#,
46    argument(
47        name = "expression1",
48        description = "Expression to compare and return if equal to expression2. Can be a constant, column, or function, and any combination of operators."
49    ),
50    argument(
51        name = "expression2",
52        description = "Expression to compare to expression1. Can be a constant, column, or function, and any combination of operators."
53    )
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct NullIfFunc {
57    signature: Signature,
58}
59
60impl Default for NullIfFunc {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl NullIfFunc {
67    pub fn new() -> Self {
68        Self {
69            // Documentation mentioned in Postgres,
70            // The result has the same type as the first argument — but there is a subtlety.
71            // What is actually returned is the first argument of the implied = operator,
72            // and in some cases that will have been promoted to match the second argument's type.
73            // For example, NULLIF(1, 2.2) yields numeric, because there is no integer = numeric operator, only numeric = numeric
74            //
75            // We don't strictly follow Postgres or DuckDB for **simplicity**.
76            // In this function, we will coerce arguments to the same data type for comparison need. Unlike DuckDB
77            // we don't return the **original** first argument type but return the final coerced type.
78            //
79            // In Postgres, nullif('2', 2) returns Null but nullif('2::varchar', 2) returns error.
80            // While in DuckDB both query returns Null. We follow DuckDB in this case since I think they are equivalent thing and should
81            // have the same result as well.
82            signature: Signature::comparable(2, Volatility::Immutable),
83        }
84    }
85}
86
87impl ScalarUDFImpl for NullIfFunc {
88    fn name(&self) -> &str {
89        "nullif"
90    }
91
92    fn signature(&self) -> &Signature {
93        &self.signature
94    }
95
96    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
97        Ok(arg_types[0].to_owned())
98    }
99
100    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
101        nullif_func(&args.args)
102    }
103
104    fn documentation(&self) -> Option<&Documentation> {
105        self.doc()
106    }
107}
108
109/// Implements NULLIF(expr1, expr2)
110/// Args: 0 - left expr is any array
111///       1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed.
112fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
113    let [lhs, rhs] = take_function_args("nullif", args)?;
114    let is_nested = lhs.data_type().is_nested();
115
116    match (lhs, rhs) {
117        (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
118            let rhs = rhs.to_scalar()?;
119            let eq_array = compare_with_eq(lhs, &rhs, is_nested)?;
120            let array = nullif(lhs, &eq_array)?;
121
122            Ok(ColumnarValue::Array(array))
123        }
124        (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => {
125            let eq_array = compare_with_eq(lhs, rhs, is_nested)?;
126            let array = nullif(lhs, &eq_array)?;
127            Ok(ColumnarValue::Array(array))
128        }
129        (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => {
130            let lhs_s = lhs.to_scalar()?;
131            let lhs_a = lhs.to_array_of_size(rhs.len())?;
132            let eq_array = compare_with_eq(&lhs_s, rhs, is_nested)?;
133            let array = nullif(
134                // nullif in arrow-select does not support Datum, so we need to convert to array
135                lhs_a.as_ref(),
136                &eq_array,
137            )?;
138            Ok(ColumnarValue::Array(array))
139        }
140        (ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => {
141            let val: ScalarValue = match lhs.eq(rhs) {
142                true => lhs.data_type().try_into()?,
143                false => lhs.clone(),
144            };
145
146            Ok(ColumnarValue::Scalar(val))
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use std::sync::Arc;
154
155    use arrow::{
156        array::*,
157        buffer::NullBuffer,
158        datatypes::{Field, Fields, Int64Type},
159    };
160    use datafusion_common::DataFusionError;
161
162    use super::*;
163
164    #[test]
165    fn nullif_int32() -> Result<()> {
166        let a = Int32Array::from(vec![
167            Some(1),
168            Some(2),
169            None,
170            None,
171            Some(3),
172            None,
173            None,
174            Some(4),
175            Some(5),
176        ]);
177        let a = ColumnarValue::Array(Arc::new(a));
178
179        let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
180
181        let result = nullif_func(&[a, lit_array])?;
182        let result = result.into_array(0).expect("Failed to convert to array");
183
184        let expected = Arc::new(Int32Array::from(vec![
185            Some(1),
186            None,
187            None,
188            None,
189            Some(3),
190            None,
191            None,
192            Some(4),
193            Some(5),
194        ])) as ArrayRef;
195        assert_eq!(expected.as_ref(), result.as_ref());
196        Ok(())
197    }
198
199    #[test]
200    // Ensure that arrays with no nulls can also invoke NULLIF() correctly
201    fn nullif_int32_non_nulls() -> Result<()> {
202        let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]);
203        let a = ColumnarValue::Array(Arc::new(a));
204
205        let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32)));
206
207        let result = nullif_func(&[a, lit_array])?;
208        let result = result.into_array(0).expect("Failed to convert to array");
209
210        let expected = Arc::new(Int32Array::from(vec![
211            None,
212            Some(3),
213            Some(10),
214            Some(7),
215            Some(8),
216            None,
217            Some(2),
218            Some(4),
219            Some(5),
220        ])) as ArrayRef;
221        assert_eq!(expected.as_ref(), result.as_ref());
222        Ok(())
223    }
224
225    #[test]
226    fn nullif_boolean() -> Result<()> {
227        let a = BooleanArray::from(vec![Some(true), Some(false), None]);
228        let a = ColumnarValue::Array(Arc::new(a));
229
230        let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false)));
231
232        let result = nullif_func(&[a, lit_array])?;
233        let result = result.into_array(0).expect("Failed to convert to array");
234
235        let expected =
236            Arc::new(BooleanArray::from(vec![Some(true), None, None])) as ArrayRef;
237
238        assert_eq!(expected.as_ref(), result.as_ref());
239        Ok(())
240    }
241
242    #[test]
243    fn nullif_string() -> Result<()> {
244        let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]);
245        let a = ColumnarValue::Array(Arc::new(a));
246
247        let lit_array = ColumnarValue::Scalar(ScalarValue::from("bar"));
248
249        let result = nullif_func(&[a, lit_array])?;
250        let result = result.into_array(0).expect("Failed to convert to array");
251
252        let expected = Arc::new(StringArray::from(vec![
253            Some("foo"),
254            None,
255            None,
256            Some("baz"),
257        ])) as ArrayRef;
258
259        assert_eq!(expected.as_ref(), result.as_ref());
260        Ok(())
261    }
262
263    #[test]
264    fn nullif_struct() -> Result<()> {
265        let fields = Fields::from(vec![
266            Field::new("a", DataType::Int64, true),
267            Field::new("b", DataType::Utf8, true),
268        ]);
269
270        let lhs_a = Arc::new(Int64Array::from(vec![Some(1), Some(2), None]));
271        let lhs_b = Arc::new(StringArray::from(vec![Some("1"), Some("2"), None]));
272        let lhs_nulls = Some(NullBuffer::from(vec![true, true, false]));
273        let lhs = ColumnarValue::Array(Arc::new(StructArray::new(
274            fields.clone(),
275            vec![lhs_a, lhs_b],
276            lhs_nulls,
277        )));
278
279        let rhs_a = Arc::new(Int64Array::from(vec![Some(1), Some(9), None]));
280        let rhs_b = Arc::new(StringArray::from(vec![Some("1"), Some("2"), None]));
281        let rhs_nulls = Some(NullBuffer::from(vec![true, true, false]));
282        let rhs = ColumnarValue::Array(Arc::new(StructArray::new(
283            fields.clone(),
284            vec![rhs_a, rhs_b],
285            rhs_nulls,
286        )));
287
288        let result = nullif_func(&[lhs, rhs])?;
289        let result = result.into_array(0).expect("Failed to convert to array");
290
291        let expected_arrays = vec![
292            Arc::new(Int64Array::from(vec![None, Some(2), None])) as ArrayRef,
293            Arc::new(StringArray::from(vec![None, Some("2"), None])) as ArrayRef,
294        ];
295        let expected_nulls = NullBuffer::from(vec![false, true, false]);
296
297        let expected = Arc::new(StructArray::try_new(
298            fields,
299            expected_arrays,
300            Some(expected_nulls),
301        )?) as ArrayRef;
302
303        assert_eq!(expected.as_ref(), result.as_ref());
304
305        Ok(())
306    }
307
308    #[test]
309    fn nullif_list() -> Result<()> {
310        let lhs = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
311            Some(vec![Some(1), Some(2)]),
312            Some(vec![Some(3)]),
313            Some(vec![]),
314            Some(vec![Some(5), Some(6), Some(7)]),
315            None,
316        ]));
317        let lhs = ColumnarValue::Array(lhs);
318
319        let rhs = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
320            Some(vec![Some(1), Some(2)]),
321        ]));
322        let rhs = ColumnarValue::Scalar(ScalarValue::List(rhs));
323
324        let result = nullif_func(&[lhs, rhs])?;
325        let result = result.into_array(0).expect("Failed to convert to array");
326
327        let expected = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
328            None,
329            Some(vec![Some(3)]),
330            Some(vec![]),
331            Some(vec![Some(5), Some(6), Some(7)]),
332            None,
333        ])) as ArrayRef;
334
335        assert_eq!(expected.as_ref(), result.as_ref());
336
337        Ok(())
338    }
339
340    #[test]
341    fn nullif_compare_nested_to_unnested() -> Result<()> {
342        let lhs = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
343            Some(vec![Some(1), Some(2)]),
344            Some(vec![Some(3)]),
345            Some(vec![]),
346            Some(vec![Some(5), Some(6), Some(7)]),
347            None,
348        ]));
349        let lhs = ColumnarValue::Array(lhs);
350
351        let rhs = Arc::new(Int64Array::from(vec![Some(1), Some(3), None, None, None]));
352        let rhs = ColumnarValue::Array(rhs);
353
354        let result = nullif_func(&[lhs, rhs]);
355
356        assert!(matches!(result, Err(DataFusionError::ArrowError(_, _))));
357
358        Ok(())
359    }
360
361    #[test]
362    fn nullif_literal_first() -> Result<()> {
363        let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]);
364        let a = ColumnarValue::Array(Arc::new(a));
365
366        let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
367
368        let result = nullif_func(&[lit_array, a])?;
369        let result = result.into_array(0).expect("Failed to convert to array");
370
371        let expected = Arc::new(Int32Array::from(vec![
372            Some(2),
373            None,
374            Some(2),
375            Some(2),
376            Some(2),
377            Some(2),
378        ])) as ArrayRef;
379        assert_eq!(expected.as_ref(), result.as_ref());
380        Ok(())
381    }
382
383    #[test]
384    fn nullif_scalar() -> Result<()> {
385        let a_eq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
386        let b_eq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
387
388        let result_eq = nullif_func(&[a_eq, b_eq])?;
389        let result_eq = result_eq.into_array(1).expect("Failed to convert to array");
390
391        let expected_eq = Arc::new(Int32Array::from(vec![None])) as ArrayRef;
392
393        assert_eq!(expected_eq.as_ref(), result_eq.as_ref());
394
395        let a_neq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
396        let b_neq = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32)));
397
398        let result_neq = nullif_func(&[a_neq, b_neq])?;
399        let result_neq = result_neq
400            .into_array(1)
401            .expect("Failed to convert to array");
402
403        let expected_neq = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef;
404        assert_eq!(expected_neq.as_ref(), result_neq.as_ref());
405
406        Ok(())
407    }
408}