datafusion_functions_nested/
array_has.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_has, array_has_all and array_has_any functions.
19
20use arrow::array::{
21    Array, ArrayRef, BooleanArray, Datum, GenericListArray, OffsetSizeTrait, Scalar,
22};
23use arrow::buffer::BooleanBuffer;
24use arrow::datatypes::DataType;
25use arrow::row::{RowConverter, Rows, SortField};
26use datafusion_common::cast::as_generic_list_array;
27use datafusion_common::utils::string_utils::string_array_to_vec;
28use datafusion_common::utils::take_function_args;
29use datafusion_common::{exec_err, Result, ScalarValue};
30use datafusion_expr::expr::{InList, ScalarFunction};
31use datafusion_expr::simplify::ExprSimplifyResult;
32use datafusion_expr::{
33    ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility,
34};
35use datafusion_macros::user_doc;
36use datafusion_physical_expr_common::datum::compare_with_eq;
37use itertools::Itertools;
38
39use crate::make_array::make_array_udf;
40use crate::utils::make_scalar_function;
41
42use std::any::Any;
43use std::sync::Arc;
44
45// Create static instances of ScalarUDFs for each function
46make_udf_expr_and_func!(ArrayHas,
47    array_has,
48    haystack_array element, // arg names
49    "returns true, if the element appears in the first array, otherwise false.", // doc
50    array_has_udf // internal function name
51);
52make_udf_expr_and_func!(ArrayHasAll,
53    array_has_all,
54    haystack_array needle_array, // arg names
55    "returns true if each element of the second array appears in the first array; otherwise, it returns false.", // doc
56    array_has_all_udf // internal function name
57);
58make_udf_expr_and_func!(ArrayHasAny,
59    array_has_any,
60    haystack_array needle_array, // arg names
61    "returns true if at least one element of the second array appears in the first array; otherwise, it returns false.", // doc
62    array_has_any_udf // internal function name
63);
64
65#[user_doc(
66    doc_section(label = "Array Functions"),
67    description = "Returns true if the array contains the element.",
68    syntax_example = "array_has(array, element)",
69    sql_example = r#"```sql
70> select array_has([1, 2, 3], 2);
71+-----------------------------+
72| array_has(List([1,2,3]), 2) |
73+-----------------------------+
74| true                        |
75+-----------------------------+
76```"#,
77    argument(
78        name = "array",
79        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
80    ),
81    argument(
82        name = "element",
83        description = "Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators."
84    )
85)]
86#[derive(Debug)]
87pub struct ArrayHas {
88    signature: Signature,
89    aliases: Vec<String>,
90}
91
92impl Default for ArrayHas {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98impl ArrayHas {
99    pub fn new() -> Self {
100        Self {
101            signature: Signature::array_and_element(Volatility::Immutable),
102            aliases: vec![
103                String::from("list_has"),
104                String::from("array_contains"),
105                String::from("list_contains"),
106            ],
107        }
108    }
109}
110
111impl ScalarUDFImpl for ArrayHas {
112    fn as_any(&self) -> &dyn Any {
113        self
114    }
115    fn name(&self) -> &str {
116        "array_has"
117    }
118
119    fn signature(&self) -> &Signature {
120        &self.signature
121    }
122
123    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
124        Ok(DataType::Boolean)
125    }
126
127    fn simplify(
128        &self,
129        mut args: Vec<Expr>,
130        _info: &dyn datafusion_expr::simplify::SimplifyInfo,
131    ) -> Result<ExprSimplifyResult> {
132        let [haystack, needle] = take_function_args(self.name(), &mut args)?;
133
134        // if the haystack is a constant list, we can use an inlist expression which is more
135        // efficient because the haystack is not varying per-row
136        if let Expr::Literal(ScalarValue::List(array)) = haystack {
137            // TODO: support LargeList
138            // (not supported by `convert_array_to_scalar_vec`)
139            // (FixedSizeList not supported either, but seems to have worked fine when attempting to
140            // build a reproducer)
141
142            assert_eq!(array.len(), 1); // guarantee of ScalarValue
143            if let Ok(scalar_values) =
144                ScalarValue::convert_array_to_scalar_vec(array.as_ref())
145            {
146                assert_eq!(scalar_values.len(), 1);
147                let list = scalar_values
148                    .into_iter()
149                    .flatten()
150                    .map(Expr::Literal)
151                    .collect();
152
153                return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList {
154                    expr: Box::new(std::mem::take(needle)),
155                    list,
156                    negated: false,
157                })));
158            }
159        } else if let Expr::ScalarFunction(ScalarFunction { func, args }) = haystack {
160            // make_array has a static set of arguments, so we can pull the arguments out from it
161            if func == &make_array_udf() {
162                return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList {
163                    expr: Box::new(std::mem::take(needle)),
164                    list: std::mem::take(args),
165                    negated: false,
166                })));
167            }
168        }
169
170        Ok(ExprSimplifyResult::Original(args))
171    }
172
173    fn invoke_with_args(
174        &self,
175        args: datafusion_expr::ScalarFunctionArgs,
176    ) -> Result<ColumnarValue> {
177        let [first_arg, second_arg] = take_function_args(self.name(), &args.args)?;
178        match &second_arg {
179            ColumnarValue::Array(array_needle) => {
180                // the needle is already an array, convert the haystack to an array of the same length
181                let haystack = first_arg.to_array(array_needle.len())?;
182                let array = array_has_inner_for_array(&haystack, array_needle)?;
183                Ok(ColumnarValue::Array(array))
184            }
185            ColumnarValue::Scalar(scalar_needle) => {
186                // Always return null if the second argument is null
187                // i.e. array_has(array, null) -> null
188                if scalar_needle.is_null() {
189                    return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
190                }
191
192                // since the needle is a scalar, convert it to an array of size 1
193                let haystack = first_arg.to_array(1)?;
194                let needle = scalar_needle.to_array_of_size(1)?;
195                let needle = Scalar::new(needle);
196                let array = array_has_inner_for_scalar(&haystack, &needle)?;
197                if let ColumnarValue::Scalar(_) = &first_arg {
198                    // If both inputs are scalar, keeps output as scalar
199                    let scalar_value = ScalarValue::try_from_array(&array, 0)?;
200                    Ok(ColumnarValue::Scalar(scalar_value))
201                } else {
202                    Ok(ColumnarValue::Array(array))
203                }
204            }
205        }
206    }
207
208    fn aliases(&self) -> &[String] {
209        &self.aliases
210    }
211
212    fn documentation(&self) -> Option<&Documentation> {
213        self.doc()
214    }
215}
216
217fn array_has_inner_for_scalar(
218    haystack: &ArrayRef,
219    needle: &dyn Datum,
220) -> Result<ArrayRef> {
221    match haystack.data_type() {
222        DataType::List(_) => array_has_dispatch_for_scalar::<i32>(haystack, needle),
223        DataType::LargeList(_) => array_has_dispatch_for_scalar::<i64>(haystack, needle),
224        _ => exec_err!(
225            "array_has does not support type '{:?}'.",
226            haystack.data_type()
227        ),
228    }
229}
230
231fn array_has_inner_for_array(haystack: &ArrayRef, needle: &ArrayRef) -> Result<ArrayRef> {
232    match haystack.data_type() {
233        DataType::List(_) => array_has_dispatch_for_array::<i32>(haystack, needle),
234        DataType::LargeList(_) => array_has_dispatch_for_array::<i64>(haystack, needle),
235        _ => exec_err!(
236            "array_has does not support type '{:?}'.",
237            haystack.data_type()
238        ),
239    }
240}
241
242fn array_has_dispatch_for_array<O: OffsetSizeTrait>(
243    haystack: &ArrayRef,
244    needle: &ArrayRef,
245) -> Result<ArrayRef> {
246    let haystack = as_generic_list_array::<O>(haystack)?;
247    let mut boolean_builder = BooleanArray::builder(haystack.len());
248
249    for (i, arr) in haystack.iter().enumerate() {
250        if arr.is_none() || needle.is_null(i) {
251            boolean_builder.append_null();
252            continue;
253        }
254        let arr = arr.unwrap();
255        let is_nested = arr.data_type().is_nested();
256        let needle_row = Scalar::new(needle.slice(i, 1));
257        let eq_array = compare_with_eq(&arr, &needle_row, is_nested)?;
258        boolean_builder.append_value(eq_array.true_count() > 0);
259    }
260
261    Ok(Arc::new(boolean_builder.finish()))
262}
263
264fn array_has_dispatch_for_scalar<O: OffsetSizeTrait>(
265    haystack: &ArrayRef,
266    needle: &dyn Datum,
267) -> Result<ArrayRef> {
268    let haystack = as_generic_list_array::<O>(haystack)?;
269    let values = haystack.values();
270    let is_nested = values.data_type().is_nested();
271    let offsets = haystack.value_offsets();
272    // If first argument is empty list (second argument is non-null), return false
273    // i.e. array_has([], non-null element) -> false
274    if values.is_empty() {
275        return Ok(Arc::new(BooleanArray::new(
276            BooleanBuffer::new_unset(haystack.len()),
277            None,
278        )));
279    }
280    let eq_array = compare_with_eq(values, needle, is_nested)?;
281    let mut final_contained = vec![None; haystack.len()];
282    for (i, offset) in offsets.windows(2).enumerate() {
283        let start = offset[0].to_usize().unwrap();
284        let end = offset[1].to_usize().unwrap();
285        let length = end - start;
286        // For non-nested list, length is 0 for null
287        if length == 0 {
288            continue;
289        }
290        let sliced_array = eq_array.slice(start, length);
291        final_contained[i] = Some(sliced_array.true_count() > 0);
292    }
293
294    Ok(Arc::new(BooleanArray::from(final_contained)))
295}
296
297fn array_has_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
298    match args[0].data_type() {
299        DataType::List(_) => {
300            array_has_all_and_any_dispatch::<i32>(&args[0], &args[1], ComparisonType::All)
301        }
302        DataType::LargeList(_) => {
303            array_has_all_and_any_dispatch::<i64>(&args[0], &args[1], ComparisonType::All)
304        }
305        _ => exec_err!(
306            "array_has does not support type '{:?}'.",
307            args[0].data_type()
308        ),
309    }
310}
311
312fn array_has_any_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
313    match args[0].data_type() {
314        DataType::List(_) => {
315            array_has_all_and_any_dispatch::<i32>(&args[0], &args[1], ComparisonType::Any)
316        }
317        DataType::LargeList(_) => {
318            array_has_all_and_any_dispatch::<i64>(&args[0], &args[1], ComparisonType::Any)
319        }
320        _ => exec_err!(
321            "array_has does not support type '{:?}'.",
322            args[0].data_type()
323        ),
324    }
325}
326
327#[user_doc(
328    doc_section(label = "Array Functions"),
329    description = "Returns true if all elements of sub-array exist in array.",
330    syntax_example = "array_has_all(array, sub-array)",
331    sql_example = r#"```sql
332> select array_has_all([1, 2, 3, 4], [2, 3]);
333+--------------------------------------------+
334| array_has_all(List([1,2,3,4]), List([2,3])) |
335+--------------------------------------------+
336| true                                       |
337+--------------------------------------------+
338```"#,
339    argument(
340        name = "array",
341        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
342    ),
343    argument(
344        name = "sub-array",
345        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
346    )
347)]
348#[derive(Debug)]
349pub struct ArrayHasAll {
350    signature: Signature,
351    aliases: Vec<String>,
352}
353
354impl Default for ArrayHasAll {
355    fn default() -> Self {
356        Self::new()
357    }
358}
359
360impl ArrayHasAll {
361    pub fn new() -> Self {
362        Self {
363            signature: Signature::any(2, Volatility::Immutable),
364            aliases: vec![String::from("list_has_all")],
365        }
366    }
367}
368
369impl ScalarUDFImpl for ArrayHasAll {
370    fn as_any(&self) -> &dyn Any {
371        self
372    }
373    fn name(&self) -> &str {
374        "array_has_all"
375    }
376
377    fn signature(&self) -> &Signature {
378        &self.signature
379    }
380
381    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
382        Ok(DataType::Boolean)
383    }
384
385    fn invoke_with_args(
386        &self,
387        args: datafusion_expr::ScalarFunctionArgs,
388    ) -> Result<ColumnarValue> {
389        make_scalar_function(array_has_all_inner)(&args.args)
390    }
391
392    fn aliases(&self) -> &[String] {
393        &self.aliases
394    }
395
396    fn documentation(&self) -> Option<&Documentation> {
397        self.doc()
398    }
399}
400
401#[user_doc(
402    doc_section(label = "Array Functions"),
403    description = "Returns true if any elements exist in both arrays.",
404    syntax_example = "array_has_any(array, sub-array)",
405    sql_example = r#"```sql
406> select array_has_any([1, 2, 3], [3, 4]);
407+------------------------------------------+
408| array_has_any(List([1,2,3]), List([3,4])) |
409+------------------------------------------+
410| true                                     |
411+------------------------------------------+
412```"#,
413    argument(
414        name = "array",
415        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
416    ),
417    argument(
418        name = "sub-array",
419        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
420    )
421)]
422#[derive(Debug)]
423pub struct ArrayHasAny {
424    signature: Signature,
425    aliases: Vec<String>,
426}
427
428impl Default for ArrayHasAny {
429    fn default() -> Self {
430        Self::new()
431    }
432}
433
434impl ArrayHasAny {
435    pub fn new() -> Self {
436        Self {
437            signature: Signature::any(2, Volatility::Immutable),
438            aliases: vec![String::from("list_has_any"), String::from("arrays_overlap")],
439        }
440    }
441}
442
443impl ScalarUDFImpl for ArrayHasAny {
444    fn as_any(&self) -> &dyn Any {
445        self
446    }
447    fn name(&self) -> &str {
448        "array_has_any"
449    }
450
451    fn signature(&self) -> &Signature {
452        &self.signature
453    }
454
455    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
456        Ok(DataType::Boolean)
457    }
458
459    fn invoke_with_args(
460        &self,
461        args: datafusion_expr::ScalarFunctionArgs,
462    ) -> Result<ColumnarValue> {
463        make_scalar_function(array_has_any_inner)(&args.args)
464    }
465
466    fn aliases(&self) -> &[String] {
467        &self.aliases
468    }
469
470    fn documentation(&self) -> Option<&Documentation> {
471        self.doc()
472    }
473}
474
475/// Represents the type of comparison for array_has.
476#[derive(Debug, PartialEq, Clone, Copy)]
477enum ComparisonType {
478    // array_has_all
479    All,
480    // array_has_any
481    Any,
482}
483
484fn array_has_all_and_any_dispatch<O: OffsetSizeTrait>(
485    haystack: &ArrayRef,
486    needle: &ArrayRef,
487    comparison_type: ComparisonType,
488) -> Result<ArrayRef> {
489    let haystack = as_generic_list_array::<O>(haystack)?;
490    let needle = as_generic_list_array::<O>(needle)?;
491    if needle.values().is_empty() {
492        let buffer = match comparison_type {
493            ComparisonType::All => BooleanBuffer::new_set(haystack.len()),
494            ComparisonType::Any => BooleanBuffer::new_unset(haystack.len()),
495        };
496        return Ok(Arc::new(BooleanArray::from(buffer)));
497    }
498    match needle.data_type() {
499        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => {
500            array_has_all_and_any_string_internal::<O>(haystack, needle, comparison_type)
501        }
502        _ => general_array_has_for_all_and_any::<O>(haystack, needle, comparison_type),
503    }
504}
505
506// String comparison for array_has_all and array_has_any
507fn array_has_all_and_any_string_internal<O: OffsetSizeTrait>(
508    array: &GenericListArray<O>,
509    needle: &GenericListArray<O>,
510    comparison_type: ComparisonType,
511) -> Result<ArrayRef> {
512    let mut boolean_builder = BooleanArray::builder(array.len());
513    for (arr, sub_arr) in array.iter().zip(needle.iter()) {
514        match (arr, sub_arr) {
515            (Some(arr), Some(sub_arr)) => {
516                let haystack_array = string_array_to_vec(&arr);
517                let needle_array = string_array_to_vec(&sub_arr);
518                boolean_builder.append_value(array_has_string_kernel(
519                    haystack_array,
520                    needle_array,
521                    comparison_type,
522                ));
523            }
524            (_, _) => {
525                boolean_builder.append_null();
526            }
527        }
528    }
529
530    Ok(Arc::new(boolean_builder.finish()))
531}
532
533fn array_has_string_kernel(
534    haystack: Vec<Option<&str>>,
535    needle: Vec<Option<&str>>,
536    comparison_type: ComparisonType,
537) -> bool {
538    match comparison_type {
539        ComparisonType::All => needle
540            .iter()
541            .dedup()
542            .all(|x| haystack.iter().dedup().any(|y| y == x)),
543        ComparisonType::Any => needle
544            .iter()
545            .dedup()
546            .any(|x| haystack.iter().dedup().any(|y| y == x)),
547    }
548}
549
550// General row comparison for array_has_all and array_has_any
551fn general_array_has_for_all_and_any<O: OffsetSizeTrait>(
552    haystack: &GenericListArray<O>,
553    needle: &GenericListArray<O>,
554    comparison_type: ComparisonType,
555) -> Result<ArrayRef> {
556    let mut boolean_builder = BooleanArray::builder(haystack.len());
557    let converter = RowConverter::new(vec![SortField::new(haystack.value_type())])?;
558
559    for (arr, sub_arr) in haystack.iter().zip(needle.iter()) {
560        if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
561            let arr_values = converter.convert_columns(&[arr])?;
562            let sub_arr_values = converter.convert_columns(&[sub_arr])?;
563            boolean_builder.append_value(general_array_has_all_and_any_kernel(
564                arr_values,
565                sub_arr_values,
566                comparison_type,
567            ));
568        } else {
569            boolean_builder.append_null();
570        }
571    }
572
573    Ok(Arc::new(boolean_builder.finish()))
574}
575
576fn general_array_has_all_and_any_kernel(
577    haystack_rows: Rows,
578    needle_rows: Rows,
579    comparison_type: ComparisonType,
580) -> bool {
581    match comparison_type {
582        ComparisonType::All => needle_rows.iter().all(|needle_row| {
583            haystack_rows
584                .iter()
585                .any(|haystack_row| haystack_row == needle_row)
586        }),
587        ComparisonType::Any => needle_rows.iter().any(|needle_row| {
588            haystack_rows
589                .iter()
590                .any(|haystack_row| haystack_row == needle_row)
591        }),
592    }
593}
594
595#[cfg(test)]
596mod tests {
597    use arrow::array::create_array;
598    use datafusion_common::utils::SingleRowListArrayBuilder;
599    use datafusion_expr::{
600        col, execution_props::ExecutionProps, lit, simplify::ExprSimplifyResult, Expr,
601        ScalarUDFImpl,
602    };
603
604    use crate::expr_fn::make_array;
605
606    use super::ArrayHas;
607
608    #[test]
609    fn test_simplify_array_has_to_in_list() {
610        let haystack = lit(SingleRowListArrayBuilder::new(create_array!(
611            Int32,
612            [1, 2, 3]
613        ))
614        .build_list_scalar());
615        let needle = col("c");
616
617        let props = ExecutionProps::new();
618        let context = datafusion_expr::simplify::SimplifyContext::new(&props);
619
620        let Ok(ExprSimplifyResult::Simplified(Expr::InList(in_list))) =
621            ArrayHas::new().simplify(vec![haystack, needle.clone()], &context)
622        else {
623            panic!("Expected simplified expression");
624        };
625
626        assert_eq!(
627            in_list,
628            datafusion_expr::expr::InList {
629                expr: Box::new(needle),
630                list: vec![lit(1), lit(2), lit(3)],
631                negated: false,
632            }
633        );
634    }
635
636    #[test]
637    fn test_simplify_array_has_with_make_array_to_in_list() {
638        let haystack = make_array(vec![lit(1), lit(2), lit(3)]);
639        let needle = col("c");
640
641        let props = ExecutionProps::new();
642        let context = datafusion_expr::simplify::SimplifyContext::new(&props);
643
644        let Ok(ExprSimplifyResult::Simplified(Expr::InList(in_list))) =
645            ArrayHas::new().simplify(vec![haystack, needle.clone()], &context)
646        else {
647            panic!("Expected simplified expression");
648        };
649
650        assert_eq!(
651            in_list,
652            datafusion_expr::expr::InList {
653                expr: Box::new(needle),
654                list: vec![lit(1), lit(2), lit(3)],
655                negated: false,
656            }
657        );
658    }
659
660    #[test]
661    fn test_array_has_complex_list_not_simplified() {
662        let haystack = col("c1");
663        let needle = col("c2");
664
665        let props = ExecutionProps::new();
666        let context = datafusion_expr::simplify::SimplifyContext::new(&props);
667
668        let Ok(ExprSimplifyResult::Original(args)) =
669            ArrayHas::new().simplify(vec![haystack, needle.clone()], &context)
670        else {
671            panic!("Expected simplified expression");
672        };
673
674        assert_eq!(args, vec![col("c1"), col("c2")],);
675    }
676}