datafusion_functions_nested/
array_has.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_has, array_has_all and array_has_any functions.
19
20use arrow::array::{
21    Array, ArrayRef, BooleanArray, Datum, GenericListArray, OffsetSizeTrait, Scalar,
22};
23use arrow::buffer::BooleanBuffer;
24use arrow::datatypes::DataType;
25use arrow::row::{RowConverter, Rows, SortField};
26use datafusion_common::cast::as_generic_list_array;
27use datafusion_common::utils::string_utils::string_array_to_vec;
28use datafusion_common::utils::take_function_args;
29use datafusion_common::{exec_err, Result, ScalarValue};
30use datafusion_expr::{
31    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
32};
33use datafusion_macros::user_doc;
34use datafusion_physical_expr_common::datum::compare_with_eq;
35use itertools::Itertools;
36
37use crate::utils::make_scalar_function;
38
39use std::any::Any;
40use std::sync::Arc;
41
42// Create static instances of ScalarUDFs for each function
43make_udf_expr_and_func!(ArrayHas,
44    array_has,
45    haystack_array element, // arg names
46    "returns true, if the element appears in the first array, otherwise false.", // doc
47    array_has_udf // internal function name
48);
49make_udf_expr_and_func!(ArrayHasAll,
50    array_has_all,
51    haystack_array needle_array, // arg names
52    "returns true if each element of the second array appears in the first array; otherwise, it returns false.", // doc
53    array_has_all_udf // internal function name
54);
55make_udf_expr_and_func!(ArrayHasAny,
56    array_has_any,
57    haystack_array needle_array, // arg names
58    "returns true if at least one element of the second array appears in the first array; otherwise, it returns false.", // doc
59    array_has_any_udf // internal function name
60);
61
62#[user_doc(
63    doc_section(label = "Array Functions"),
64    description = "Returns true if the array contains the element.",
65    syntax_example = "array_has(array, element)",
66    sql_example = r#"```sql
67> select array_has([1, 2, 3], 2);
68+-----------------------------+
69| array_has(List([1,2,3]), 2) |
70+-----------------------------+
71| true                        |
72+-----------------------------+
73```"#,
74    argument(
75        name = "array",
76        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
77    ),
78    argument(
79        name = "element",
80        description = "Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators."
81    )
82)]
83#[derive(Debug)]
84pub struct ArrayHas {
85    signature: Signature,
86    aliases: Vec<String>,
87}
88
89impl Default for ArrayHas {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl ArrayHas {
96    pub fn new() -> Self {
97        Self {
98            signature: Signature::array_and_element(Volatility::Immutable),
99            aliases: vec![
100                String::from("list_has"),
101                String::from("array_contains"),
102                String::from("list_contains"),
103            ],
104        }
105    }
106}
107
108impl ScalarUDFImpl for ArrayHas {
109    fn as_any(&self) -> &dyn Any {
110        self
111    }
112    fn name(&self) -> &str {
113        "array_has"
114    }
115
116    fn signature(&self) -> &Signature {
117        &self.signature
118    }
119
120    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
121        Ok(DataType::Boolean)
122    }
123
124    fn invoke_with_args(
125        &self,
126        args: datafusion_expr::ScalarFunctionArgs,
127    ) -> Result<ColumnarValue> {
128        let [first_arg, second_arg] = take_function_args(self.name(), &args.args)?;
129        match &second_arg {
130            ColumnarValue::Array(array_needle) => {
131                // the needle is already an array, convert the haystack to an array of the same length
132                let haystack = first_arg.to_array(array_needle.len())?;
133                let array = array_has_inner_for_array(&haystack, array_needle)?;
134                Ok(ColumnarValue::Array(array))
135            }
136            ColumnarValue::Scalar(scalar_needle) => {
137                // Always return null if the second argument is null
138                // i.e. array_has(array, null) -> null
139                if scalar_needle.is_null() {
140                    return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
141                }
142
143                // since the needle is a scalar, convert it to an array of size 1
144                let haystack = first_arg.to_array(1)?;
145                let needle = scalar_needle.to_array_of_size(1)?;
146                let needle = Scalar::new(needle);
147                let array = array_has_inner_for_scalar(&haystack, &needle)?;
148                if let ColumnarValue::Scalar(_) = &first_arg {
149                    // If both inputs are scalar, keeps output as scalar
150                    let scalar_value = ScalarValue::try_from_array(&array, 0)?;
151                    Ok(ColumnarValue::Scalar(scalar_value))
152                } else {
153                    Ok(ColumnarValue::Array(array))
154                }
155            }
156        }
157    }
158
159    fn aliases(&self) -> &[String] {
160        &self.aliases
161    }
162
163    fn documentation(&self) -> Option<&Documentation> {
164        self.doc()
165    }
166}
167
168fn array_has_inner_for_scalar(
169    haystack: &ArrayRef,
170    needle: &dyn Datum,
171) -> Result<ArrayRef> {
172    match haystack.data_type() {
173        DataType::List(_) => array_has_dispatch_for_scalar::<i32>(haystack, needle),
174        DataType::LargeList(_) => array_has_dispatch_for_scalar::<i64>(haystack, needle),
175        _ => exec_err!(
176            "array_has does not support type '{:?}'.",
177            haystack.data_type()
178        ),
179    }
180}
181
182fn array_has_inner_for_array(haystack: &ArrayRef, needle: &ArrayRef) -> Result<ArrayRef> {
183    match haystack.data_type() {
184        DataType::List(_) => array_has_dispatch_for_array::<i32>(haystack, needle),
185        DataType::LargeList(_) => array_has_dispatch_for_array::<i64>(haystack, needle),
186        _ => exec_err!(
187            "array_has does not support type '{:?}'.",
188            haystack.data_type()
189        ),
190    }
191}
192
193fn array_has_dispatch_for_array<O: OffsetSizeTrait>(
194    haystack: &ArrayRef,
195    needle: &ArrayRef,
196) -> Result<ArrayRef> {
197    let haystack = as_generic_list_array::<O>(haystack)?;
198    let mut boolean_builder = BooleanArray::builder(haystack.len());
199
200    for (i, arr) in haystack.iter().enumerate() {
201        if arr.is_none() || needle.is_null(i) {
202            boolean_builder.append_null();
203            continue;
204        }
205        let arr = arr.unwrap();
206        let is_nested = arr.data_type().is_nested();
207        let needle_row = Scalar::new(needle.slice(i, 1));
208        let eq_array = compare_with_eq(&arr, &needle_row, is_nested)?;
209        boolean_builder.append_value(eq_array.true_count() > 0);
210    }
211
212    Ok(Arc::new(boolean_builder.finish()))
213}
214
215fn array_has_dispatch_for_scalar<O: OffsetSizeTrait>(
216    haystack: &ArrayRef,
217    needle: &dyn Datum,
218) -> Result<ArrayRef> {
219    let haystack = as_generic_list_array::<O>(haystack)?;
220    let values = haystack.values();
221    let is_nested = values.data_type().is_nested();
222    let offsets = haystack.value_offsets();
223    // If first argument is empty list (second argument is non-null), return false
224    // i.e. array_has([], non-null element) -> false
225    if values.len() == 0 {
226        return Ok(Arc::new(BooleanArray::new(
227            BooleanBuffer::new_unset(haystack.len()),
228            None,
229        )));
230    }
231    let eq_array = compare_with_eq(values, needle, is_nested)?;
232    let mut final_contained = vec![None; haystack.len()];
233    for (i, offset) in offsets.windows(2).enumerate() {
234        let start = offset[0].to_usize().unwrap();
235        let end = offset[1].to_usize().unwrap();
236        let length = end - start;
237        // For non-nested list, length is 0 for null
238        if length == 0 {
239            continue;
240        }
241        let sliced_array = eq_array.slice(start, length);
242        final_contained[i] = Some(sliced_array.true_count() > 0);
243    }
244
245    Ok(Arc::new(BooleanArray::from(final_contained)))
246}
247
248fn array_has_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
249    match args[0].data_type() {
250        DataType::List(_) => {
251            array_has_all_and_any_dispatch::<i32>(&args[0], &args[1], ComparisonType::All)
252        }
253        DataType::LargeList(_) => {
254            array_has_all_and_any_dispatch::<i64>(&args[0], &args[1], ComparisonType::All)
255        }
256        _ => exec_err!(
257            "array_has does not support type '{:?}'.",
258            args[0].data_type()
259        ),
260    }
261}
262
263fn array_has_any_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
264    match args[0].data_type() {
265        DataType::List(_) => {
266            array_has_all_and_any_dispatch::<i32>(&args[0], &args[1], ComparisonType::Any)
267        }
268        DataType::LargeList(_) => {
269            array_has_all_and_any_dispatch::<i64>(&args[0], &args[1], ComparisonType::Any)
270        }
271        _ => exec_err!(
272            "array_has does not support type '{:?}'.",
273            args[0].data_type()
274        ),
275    }
276}
277
278#[user_doc(
279    doc_section(label = "Array Functions"),
280    description = "Returns true if all elements of sub-array exist in array.",
281    syntax_example = "array_has_all(array, sub-array)",
282    sql_example = r#"```sql
283> select array_has_all([1, 2, 3, 4], [2, 3]);
284+--------------------------------------------+
285| array_has_all(List([1,2,3,4]), List([2,3])) |
286+--------------------------------------------+
287| true                                       |
288+--------------------------------------------+
289```"#,
290    argument(
291        name = "array",
292        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
293    ),
294    argument(
295        name = "sub-array",
296        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
297    )
298)]
299#[derive(Debug)]
300pub struct ArrayHasAll {
301    signature: Signature,
302    aliases: Vec<String>,
303}
304
305impl Default for ArrayHasAll {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311impl ArrayHasAll {
312    pub fn new() -> Self {
313        Self {
314            signature: Signature::any(2, Volatility::Immutable),
315            aliases: vec![String::from("list_has_all")],
316        }
317    }
318}
319
320impl ScalarUDFImpl for ArrayHasAll {
321    fn as_any(&self) -> &dyn Any {
322        self
323    }
324    fn name(&self) -> &str {
325        "array_has_all"
326    }
327
328    fn signature(&self) -> &Signature {
329        &self.signature
330    }
331
332    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
333        Ok(DataType::Boolean)
334    }
335
336    fn invoke_with_args(
337        &self,
338        args: datafusion_expr::ScalarFunctionArgs,
339    ) -> Result<ColumnarValue> {
340        make_scalar_function(array_has_all_inner)(&args.args)
341    }
342
343    fn aliases(&self) -> &[String] {
344        &self.aliases
345    }
346
347    fn documentation(&self) -> Option<&Documentation> {
348        self.doc()
349    }
350}
351
352#[user_doc(
353    doc_section(label = "Array Functions"),
354    description = "Returns true if any elements exist in both arrays.",
355    syntax_example = "array_has_any(array, sub-array)",
356    sql_example = r#"```sql
357> select array_has_any([1, 2, 3], [3, 4]);
358+------------------------------------------+
359| array_has_any(List([1,2,3]), List([3,4])) |
360+------------------------------------------+
361| true                                     |
362+------------------------------------------+
363```"#,
364    argument(
365        name = "array",
366        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
367    ),
368    argument(
369        name = "sub-array",
370        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
371    )
372)]
373#[derive(Debug)]
374pub struct ArrayHasAny {
375    signature: Signature,
376    aliases: Vec<String>,
377}
378
379impl Default for ArrayHasAny {
380    fn default() -> Self {
381        Self::new()
382    }
383}
384
385impl ArrayHasAny {
386    pub fn new() -> Self {
387        Self {
388            signature: Signature::any(2, Volatility::Immutable),
389            aliases: vec![String::from("list_has_any"), String::from("arrays_overlap")],
390        }
391    }
392}
393
394impl ScalarUDFImpl for ArrayHasAny {
395    fn as_any(&self) -> &dyn Any {
396        self
397    }
398    fn name(&self) -> &str {
399        "array_has_any"
400    }
401
402    fn signature(&self) -> &Signature {
403        &self.signature
404    }
405
406    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
407        Ok(DataType::Boolean)
408    }
409
410    fn invoke_with_args(
411        &self,
412        args: datafusion_expr::ScalarFunctionArgs,
413    ) -> Result<ColumnarValue> {
414        make_scalar_function(array_has_any_inner)(&args.args)
415    }
416
417    fn aliases(&self) -> &[String] {
418        &self.aliases
419    }
420
421    fn documentation(&self) -> Option<&Documentation> {
422        self.doc()
423    }
424}
425
426/// Represents the type of comparison for array_has.
427#[derive(Debug, PartialEq, Clone, Copy)]
428enum ComparisonType {
429    // array_has_all
430    All,
431    // array_has_any
432    Any,
433}
434
435fn array_has_all_and_any_dispatch<O: OffsetSizeTrait>(
436    haystack: &ArrayRef,
437    needle: &ArrayRef,
438    comparison_type: ComparisonType,
439) -> Result<ArrayRef> {
440    let haystack = as_generic_list_array::<O>(haystack)?;
441    let needle = as_generic_list_array::<O>(needle)?;
442    match needle.data_type() {
443        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => {
444            array_has_all_and_any_string_internal::<O>(haystack, needle, comparison_type)
445        }
446        _ => general_array_has_for_all_and_any::<O>(haystack, needle, comparison_type),
447    }
448}
449
450// String comparison for array_has_all and array_has_any
451fn array_has_all_and_any_string_internal<O: OffsetSizeTrait>(
452    array: &GenericListArray<O>,
453    needle: &GenericListArray<O>,
454    comparison_type: ComparisonType,
455) -> Result<ArrayRef> {
456    let mut boolean_builder = BooleanArray::builder(array.len());
457    for (arr, sub_arr) in array.iter().zip(needle.iter()) {
458        match (arr, sub_arr) {
459            (Some(arr), Some(sub_arr)) => {
460                let haystack_array = string_array_to_vec(&arr);
461                let needle_array = string_array_to_vec(&sub_arr);
462                boolean_builder.append_value(array_has_string_kernel(
463                    haystack_array,
464                    needle_array,
465                    comparison_type,
466                ));
467            }
468            (_, _) => {
469                boolean_builder.append_null();
470            }
471        }
472    }
473
474    Ok(Arc::new(boolean_builder.finish()))
475}
476
477fn array_has_string_kernel(
478    haystack: Vec<Option<&str>>,
479    needle: Vec<Option<&str>>,
480    comparison_type: ComparisonType,
481) -> bool {
482    match comparison_type {
483        ComparisonType::All => needle
484            .iter()
485            .dedup()
486            .all(|x| haystack.iter().dedup().any(|y| y == x)),
487        ComparisonType::Any => needle
488            .iter()
489            .dedup()
490            .any(|x| haystack.iter().dedup().any(|y| y == x)),
491    }
492}
493
494// General row comparison for array_has_all and array_has_any
495fn general_array_has_for_all_and_any<O: OffsetSizeTrait>(
496    haystack: &GenericListArray<O>,
497    needle: &GenericListArray<O>,
498    comparison_type: ComparisonType,
499) -> Result<ArrayRef> {
500    let mut boolean_builder = BooleanArray::builder(haystack.len());
501    let converter = RowConverter::new(vec![SortField::new(haystack.value_type())])?;
502
503    for (arr, sub_arr) in haystack.iter().zip(needle.iter()) {
504        if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
505            let arr_values = converter.convert_columns(&[arr])?;
506            let sub_arr_values = converter.convert_columns(&[sub_arr])?;
507            boolean_builder.append_value(general_array_has_all_and_any_kernel(
508                arr_values,
509                sub_arr_values,
510                comparison_type,
511            ));
512        } else {
513            boolean_builder.append_null();
514        }
515    }
516
517    Ok(Arc::new(boolean_builder.finish()))
518}
519
520fn general_array_has_all_and_any_kernel(
521    haystack_rows: Rows,
522    needle_rows: Rows,
523    comparison_type: ComparisonType,
524) -> bool {
525    match comparison_type {
526        ComparisonType::All => needle_rows.iter().all(|needle_row| {
527            haystack_rows
528                .iter()
529                .any(|haystack_row| haystack_row == needle_row)
530        }),
531        ComparisonType::Any => needle_rows.iter().any(|needle_row| {
532            haystack_rows
533                .iter()
534                .any(|haystack_row| haystack_row == needle_row)
535        }),
536    }
537}